Skip to content

Commit aad3562

Browse files
committed
push_attr fixes
- don't raise exception if mods is used together with common, nonunique, or unique: there is nothing in the logic preventing it - don't raise if columns is used together with common, nonunique, or unique, warn instead - fix ordering of pushed column - minor code cleanup
1 parent 8422877 commit aad3562

File tree

3 files changed

+109
-99
lines changed

3 files changed

+109
-99
lines changed

src/mudata/_core/mudata.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,32 +1920,33 @@ def _pull_attr(
19201920
raise ValueError("All mods should be present in mdata.mod")
19211921
elif len(mods) == self.n_mod:
19221922
mods = None
1923-
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
1924-
assert v is None, f"Cannot use mods with {k}."
19251923

19261924
if only_drop:
19271925
drop = True
19281926

19291927
cols = _classify_attr_columns(
1930-
np.concatenate(
1931-
[
1932-
[f"{m}:{val}" for val in getattr(mod, attr).columns.values]
1933-
for m, mod in self.mod.items()
1934-
]
1935-
),
1936-
self.mod.keys(),
1928+
{modname: getattr(mod, attr).columns for modname, mod in self.mod.items()}
19371929
)
19381930

19391931
if columns is not None:
19401932
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
1941-
assert v is None, f"Cannot use {k} with columns."
1933+
if v is not None:
1934+
warnings.warn(
1935+
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
1936+
RuntimeWarning,
1937+
stacklevel=2,
1938+
)
19421939

19431940
# - modname1:column -> [modname1:column]
19441941
# - column -> [modname1:column, modname2:column, ...]
1945-
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]
1946-
1947-
if mods is not None:
1948-
cols = [col for col in cols if col["prefix"] in mods]
1942+
cols = {
1943+
prefix: [
1944+
col
1945+
for col in modcols
1946+
if col["name"] in columns or col["derived_name"] in columns
1947+
]
1948+
for prefix, modcols in cols.items()
1949+
}
19491950

19501951
# TODO: Counter for columns in order to track their usage
19511952
# and error out if some columns were not used
@@ -1959,10 +1960,17 @@ def _pull_attr(
19591960
unique = True
19601961

19611962
selector = {"common": common, "nonunique": nonunique, "unique": unique}
1963+
cols = {
1964+
prefix: [col for col in modcols if selector[col["class"]]]
1965+
for prefix, modcols in cols.items()
1966+
}
19621967

1963-
cols = [col for col in cols if selector[col["class"]]]
1968+
if mods is not None:
1969+
cols = {prefix: cols[prefix] for prefix in mods}
19641970

1965-
derived_name_count = Counter([col["derived_name"] for col in cols])
1971+
derived_name_count = Counter(
1972+
[col["derived_name"] for modcols in cols.values() for col in modcols]
1973+
)
19661974

19671975
# - axis == self.axis
19681976
# e.g. combine var from multiple modalities (with unique vars)
@@ -1995,44 +2003,36 @@ def _pull_attr(
19952003
n_attr = self.n_vars if attr == "var" else self.n_obs
19962004

19972005
dfs: list[pd.DataFrame] = []
1998-
for m, mod in self.mod.items():
1999-
if mods is not None and m not in mods:
2000-
continue
2006+
for m, modcols in cols.items():
2007+
mod = self.mod[m]
20012008
mod_map = attrmap[m].ravel()
2002-
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
2003-
mask = mod_map != 0
2004-
2005-
mod_df = getattr(mod, attr)
2006-
mod_columns = [
2007-
col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m
2008-
]
2009-
mod_df = mod_df[mod_df.columns.intersection(mod_columns)]
2009+
mask = mod_map > 0
20102010

2011+
mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]]
20112012
if drop:
20122013
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)
20132014

2014-
# Don't use modname: prefix if columns need to be joined
2015-
if join_common or join_nonunique or (not prefix_unique):
2016-
cols_special = [
2017-
col["derived_name"]
2018-
for col in cols
2019-
if (
2020-
(col["class"] == "common") & join_common
2021-
or (col["class"] == "nonunique") & join_nonunique
2022-
or (col["class"] == "unique") & (not prefix_unique)
2015+
mod_df.rename(
2016+
columns={
2017+
col["derived_name"]: col["name"]
2018+
for col in modcols
2019+
if not (
2020+
(
2021+
join_common
2022+
and col["class"] == "common"
2023+
or join_nonunique
2024+
and col["class"] == "nonunique"
2025+
or not prefix_unique
2026+
and col["class"] == "unique"
2027+
)
2028+
and derived_name_count[col["derived_name"]] == col["count"]
20232029
)
2024-
and col["prefix"] == m
2025-
and derived_name_count[col["derived_name"]] == col["count"]
2026-
]
2027-
mod_df.columns = [
2028-
col if col in cols_special else f"{m}:{col}" for col in mod_df.columns
2029-
]
2030-
else:
2031-
mod_df.columns = [f"{m}:{col}" for col in mod_df.columns]
2030+
},
2031+
inplace=True,
2032+
)
20322033

20332034
mod_df = (
20342035
_maybe_coerce_to_boolean(mod_df)
2035-
.set_index(np.arange(mod_n_attr))
20362036
.iloc[mod_map[mask] - 1]
20372037
.set_index(np.arange(n_attr)[mask])
20382038
.reindex(np.arange(n_attr))
@@ -2297,19 +2297,15 @@ def _push_attr(
22972297
if mods is not None and m not in mods:
22982298
continue
22992299

2300-
mod_map = attrmap[m]
2300+
mod_map = attrmap[m].ravel()
23012301
mask = mod_map != 0
23022302
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
23032303

23042304
mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
23052305
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
23062306
df.columns = [col["derived_name"] for col in mod_cols]
23072307

2308-
df = (
2309-
df.set_index(np.arange(mod_n_attr))
2310-
.iloc[mod_map[mask] - 1]
2311-
.set_index(np.arange(mod_n_attr))
2312-
)
2308+
df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr))
23132309

23142310
if not only_drop:
23152311
# TODO: _maybe_coerce_to_bool

src/mudata/_core/utils.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import Counter
2-
from collections.abc import Sequence
2+
from collections.abc import Mapping, Sequence
33
from typing import TypeVar
44

55
import numpy as np
@@ -38,9 +38,7 @@ def _maybe_coerce_to_boolean(df: T) -> T:
3838
return df
3939

4040

41-
def _classify_attr_columns(
42-
names: Sequence[str], prefixes: Sequence[str]
43-
) -> Sequence[dict[str, str]]:
41+
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, dict[str, str]]:
4442
"""
4543
Classify names into common, non-unique, and unique
4644
w.r.t. to the list of prefixes.
@@ -61,42 +59,29 @@ def _classify_attr_columns(
6159
{"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "unique"},
6260
]
6361
"""
64-
n_mod = len(prefixes)
65-
res: list[dict[str, str]] = []
66-
67-
for name in names:
68-
name_common = {
69-
"name": name,
70-
"prefix": "",
71-
"derived_name": name,
72-
}
73-
name_split = name.split(":", 1)
74-
75-
if len(name_split) < 2:
76-
res.append(name_common)
77-
else:
78-
maybe_modname, derived_name = name_split
79-
80-
if maybe_modname in prefixes:
81-
name_prefixed = {
82-
"name": name,
83-
"prefix": maybe_modname,
84-
"derived_name": derived_name,
62+
n_mod = len(names)
63+
res: dict[str, dict[str, str]] = {}
64+
65+
derived_name_counts = Counter()
66+
for prefix, names in names.items():
67+
cres = []
68+
for name in names:
69+
cres.append(
70+
{
71+
"name": f"{prefix}:{name}",
72+
"derived_name": name,
8573
}
86-
res.append(name_prefixed)
87-
else:
88-
res.append(name_common)
89-
90-
derived_name_counts = Counter(name_res["derived_name"] for name_res in res)
91-
for name_res in res:
92-
name_res["count"] = derived_name_counts[name_res["derived_name"]]
93-
94-
for name_res in res:
95-
name_res["class"] = (
96-
"common"
97-
if name_res["count"] == n_mod
98-
else "unique" if name_res["count"] == 1 else "nonunique"
99-
)
74+
)
75+
derived_name_counts[name] += 1
76+
res[prefix] = cres
77+
78+
for prefix, names in res.items():
79+
for name_res in names:
80+
count = derived_name_counts[name_res["derived_name"]]
81+
name_res["count"] = count
82+
name_res["class"] = (
83+
"common" if count == n_mod else "unique" if count == 1 else "nonunique"
84+
)
10085

10186
return res
10287

tests/test_pull_push.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from anndata import AnnData
77

8-
from mudata import MuData
8+
from mudata import MuData, set_options
99

1010

1111
@pytest.fixture()
@@ -21,7 +21,8 @@ def modalities(request, obs_n, var_unique):
2121
mods[m].var["mod"] = m
2222

2323
# common column
24-
mods[m].var["highly_variable"] = np.tile([False, True], mods[m].n_vars // 2)
24+
mods[m].var["highly_variable"] = np.random.choice([False, True], size=mods[m].n_vars)
25+
mods[m].obs["common_obs_col"] = np.random.randint(0, int(1e6), size=mods[m].n_obs)
2526

2627
if var_unique:
2728
mods[m].var_names = [f"mod{m}_var{j}" for j in range(mods[m].n_vars)]
@@ -88,7 +89,6 @@ def test_pull_var(self, modalities):
8889
"""
8990
mdata = MuData(modalities)
9091
mdata.update()
91-
9292
mdata.pull_var()
9393

9494
assert "mod" in mdata.var.columns
@@ -165,6 +165,15 @@ def test_pull_obs_simple(self, modalities):
165165
for m in mdata.mod.keys():
166166
assert f"{m}:mod" in mdata.obs.columns
167167

168+
assert f"{m}:common_obs_col" in mdata.obs.columns
169+
170+
modmap = mdata.obsmap[m].ravel()
171+
mask = modmap > 0
172+
assert (
173+
mdata.obs[f"{m}:common_obs_col"][mask].to_numpy()
174+
== mdata.mod[m].obs["common_obs_col"].to_numpy()[modmap[mask] - 1]
175+
).all()
176+
168177
# join_common shouldn't work
169178
with pytest.raises(ValueError, match="shared obs_names"):
170179
mdata.pull_obs(join_common=True)
@@ -182,14 +191,24 @@ def test_push_var_simple(self, modalities):
182191
mdata = MuData(modalities)
183192
mdata.update()
184193

185-
mdata.var["pushed"] = True
186-
mdata.var["mod2:mod2_pushed"] = True
194+
mdata.var["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var)
195+
mdata.var["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var)
187196
mdata.push_var()
188197

189198
# pushing should work
190-
for mod in mdata.mod.values():
199+
for modname, mod in mdata.mod.items():
191200
assert "pushed" in mod.var.columns
201+
202+
map = mdata.varmap[modname].ravel()
203+
mask = map > 0
204+
assert (mdata.var["pushed"][mask] == mod.var["pushed"][map[mask] - 1]).all()
205+
192206
assert "mod2_pushed" in mdata["mod2"].var.columns
207+
map = mdata.varmap["mod2"].ravel()
208+
mask = map > 0
209+
assert (
210+
mdata.var["mod2:mod2_pushed"][mask] == mdata["mod2"].var["mod2_pushed"][map[mask] - 1]
211+
).all()
193212

194213
@pytest.mark.parametrize("var_unique", [True, False])
195214
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
@@ -200,14 +219,24 @@ def test_push_obs_simple(self, modalities):
200219
mdata = MuData(modalities)
201220
mdata.update()
202221

203-
mdata.obs["pushed"] = True
204-
mdata.obs["mod2:mod2_pushed"] = True
222+
mdata.obs["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs)
223+
mdata.obs["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs)
205224
mdata.push_obs()
206225

207226
# pushing should work
208-
for mod in mdata.mod.values():
227+
for modname, mod in mdata.mod.items():
209228
assert "pushed" in mod.obs.columns
229+
230+
map = mdata.obsmap[modname].ravel()
231+
mask = map > 0
232+
assert (mdata.obs["pushed"][mask] == mod.obs["pushed"][map[mask] - 1]).all()
233+
210234
assert "mod2_pushed" in mdata["mod2"].obs.columns
235+
map = mdata.obsmap["mod2"].ravel()
236+
mask = map > 0
237+
assert (
238+
mdata.obs["mod2:mod2_pushed"][mask] == mdata["mod2"].obs["mod2_pushed"][map[mask] - 1]
239+
).all()
211240

212241

213242
@pytest.mark.usefixtures("filepath_h5mu")

0 commit comments

Comments
 (0)