Skip to content

Commit ca406d7

Browse files
committed
push_attr fixes
- don't raise exception if mods is used together with common, nonunique, or unique: there is nothing in the logic preventing it - don't raise if columns is used together with common, nonunique, or unique, warn instead - fix ordering of pushed column - minor code cleanup
1 parent 8422877 commit ca406d7

File tree

3 files changed

+115
-107
lines changed

3 files changed

+115
-107
lines changed

src/mudata/_core/mudata.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1920,32 +1920,33 @@ def _pull_attr(
19201920
raise ValueError("All mods should be present in mdata.mod")
19211921
elif len(mods) == self.n_mod:
19221922
mods = None
1923-
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
1924-
assert v is None, f"Cannot use mods with {k}."
19251923

19261924
if only_drop:
19271925
drop = True
19281926

19291927
cols = _classify_attr_columns(
1930-
np.concatenate(
1931-
[
1932-
[f"{m}:{val}" for val in getattr(mod, attr).columns.values]
1933-
for m, mod in self.mod.items()
1934-
]
1935-
),
1936-
self.mod.keys(),
1928+
{modname: getattr(mod, attr).columns for modname, mod in self.mod.items()}
19371929
)
19381930

19391931
if columns is not None:
19401932
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
1941-
assert v is None, f"Cannot use {k} with columns."
1933+
if v is not None:
1934+
warnings.warn(
1935+
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
1936+
RuntimeWarning,
1937+
stacklevel=2,
1938+
)
19421939

19431940
# - modname1:column -> [modname1:column]
19441941
# - column -> [modname1:column, modname2:column, ...]
1945-
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]
1946-
1947-
if mods is not None:
1948-
cols = [col for col in cols if col["prefix"] in mods]
1942+
cols = {
1943+
prefix: [
1944+
col
1945+
for col in modcols
1946+
if col["name"] in columns or col["derived_name"] in columns
1947+
]
1948+
for prefix, modcols in cols.items()
1949+
}
19491950

19501951
# TODO: Counter for columns in order to track their usage
19511952
# and error out if some columns were not used
@@ -1959,10 +1960,17 @@ def _pull_attr(
19591960
unique = True
19601961

19611962
selector = {"common": common, "nonunique": nonunique, "unique": unique}
1963+
cols = {
1964+
prefix: [col for col in modcols if selector[col["class"]]]
1965+
for prefix, modcols in cols.items()
1966+
}
19621967

1963-
cols = [col for col in cols if selector[col["class"]]]
1968+
if mods is not None:
1969+
cols = {prefix: cols[prefix] for prefix in mods}
19641970

1965-
derived_name_count = Counter([col["derived_name"] for col in cols])
1971+
derived_name_count = Counter(
1972+
[col["derived_name"] for modcols in cols.values() for col in modcols]
1973+
)
19661974

19671975
# - axis == self.axis
19681976
# e.g. combine var from multiple modalities (with unique vars)
@@ -1995,44 +2003,36 @@ def _pull_attr(
19952003
n_attr = self.n_vars if attr == "var" else self.n_obs
19962004

19972005
dfs: list[pd.DataFrame] = []
1998-
for m, mod in self.mod.items():
1999-
if mods is not None and m not in mods:
2000-
continue
2006+
for m, modcols in cols.items():
2007+
mod = self.mod[m]
20012008
mod_map = attrmap[m].ravel()
2002-
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
2003-
mask = mod_map != 0
2004-
2005-
mod_df = getattr(mod, attr)
2006-
mod_columns = [
2007-
col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m
2008-
]
2009-
mod_df = mod_df[mod_df.columns.intersection(mod_columns)]
2009+
mask = mod_map > 0
20102010

2011+
mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]]
20112012
if drop:
20122013
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)
20132014

2014-
# Don't use modname: prefix if columns need to be joined
2015-
if join_common or join_nonunique or (not prefix_unique):
2016-
cols_special = [
2017-
col["derived_name"]
2018-
for col in cols
2019-
if (
2020-
(col["class"] == "common") & join_common
2021-
or (col["class"] == "nonunique") & join_nonunique
2022-
or (col["class"] == "unique") & (not prefix_unique)
2015+
mod_df.rename(
2016+
columns={
2017+
col["derived_name"]: col["name"]
2018+
for col in modcols
2019+
if not (
2020+
(
2021+
join_common
2022+
and col["class"] == "common"
2023+
or join_nonunique
2024+
and col["class"] == "nonunique"
2025+
or not prefix_unique
2026+
and col["class"] == "unique"
2027+
)
2028+
and derived_name_count[col["derived_name"]] == col["count"]
20232029
)
2024-
and col["prefix"] == m
2025-
and derived_name_count[col["derived_name"]] == col["count"]
2026-
]
2027-
mod_df.columns = [
2028-
col if col in cols_special else f"{m}:{col}" for col in mod_df.columns
2029-
]
2030-
else:
2031-
mod_df.columns = [f"{m}:{col}" for col in mod_df.columns]
2030+
},
2031+
inplace=True,
2032+
)
20322033

20332034
mod_df = (
20342035
_maybe_coerce_to_boolean(mod_df)
2035-
.set_index(np.arange(mod_n_attr))
20362036
.iloc[mod_map[mask] - 1]
20372037
.set_index(np.arange(n_attr)[mask])
20382038
.reindex(np.arange(n_attr))
@@ -2297,19 +2297,15 @@ def _push_attr(
22972297
if mods is not None and m not in mods:
22982298
continue
22992299

2300-
mod_map = attrmap[m]
2300+
mod_map = attrmap[m].ravel()
23012301
mask = mod_map != 0
23022302
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
23032303

23042304
mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
23052305
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
23062306
df.columns = [col["derived_name"] for col in mod_cols]
23072307

2308-
df = (
2309-
df.set_index(np.arange(mod_n_attr))
2310-
.iloc[mod_map[mask] - 1]
2311-
.set_index(np.arange(mod_n_attr))
2312-
)
2308+
df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr))
23132309

23142310
if not only_drop:
23152311
# TODO: _maybe_coerce_to_bool

src/mudata/_core/utils.py

Lines changed: 30 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import Counter
2-
from collections.abc import Sequence
2+
from collections.abc import Mapping, Sequence
33
from typing import TypeVar
44

55
import numpy as np
@@ -38,9 +38,7 @@ def _maybe_coerce_to_boolean(df: T) -> T:
3838
return df
3939

4040

41-
def _classify_attr_columns(
42-
names: Sequence[str], prefixes: Sequence[str]
43-
) -> Sequence[dict[str, str]]:
41+
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[dict[str, str]]]:
4442
"""
4543
Classify names into common, non-unique, and unique
4644
w.r.t. to the list of prefixes.
@@ -53,50 +51,35 @@ def _classify_attr_columns(
5351
and there is only one modality prefix
5452
for a column with a certain name.
5553
56-
E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified
57-
into [
58-
{"name": "global", "prefix": "", "derived_name": "global", "count": 1, "class": "common"},
59-
{"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "nonunique"},
60-
{"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "count": 2, "class": "nonunique"},
61-
{"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "unique"},
62-
]
54+
E.g. {"mod1": ["annotation", "unique"], "mod2": ["annotation"]} will be classified
55+
into {"mod1": [{"name": "mod1:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"},
56+
{"name": "mod1:unique", "derived_name": "unique", "count": 1, "class": "unique"}}],
57+
"mod2": [{"name": "mod2:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}],
58+
}
6359
"""
64-
n_mod = len(prefixes)
65-
res: list[dict[str, str]] = []
66-
67-
for name in names:
68-
name_common = {
69-
"name": name,
70-
"prefix": "",
71-
"derived_name": name,
72-
}
73-
name_split = name.split(":", 1)
74-
75-
if len(name_split) < 2:
76-
res.append(name_common)
77-
else:
78-
maybe_modname, derived_name = name_split
79-
80-
if maybe_modname in prefixes:
81-
name_prefixed = {
82-
"name": name,
83-
"prefix": maybe_modname,
84-
"derived_name": derived_name,
60+
n_mod = len(names)
61+
res: dict[str, list[dict[str, str]]] = {}
62+
63+
derived_name_counts = Counter()
64+
for prefix, names in names.items():
65+
cres = []
66+
for name in names:
67+
cres.append(
68+
{
69+
"name": f"{prefix}:{name}",
70+
"derived_name": name,
8571
}
86-
res.append(name_prefixed)
87-
else:
88-
res.append(name_common)
89-
90-
derived_name_counts = Counter(name_res["derived_name"] for name_res in res)
91-
for name_res in res:
92-
name_res["count"] = derived_name_counts[name_res["derived_name"]]
93-
94-
for name_res in res:
95-
name_res["class"] = (
96-
"common"
97-
if name_res["count"] == n_mod
98-
else "unique" if name_res["count"] == 1 else "nonunique"
99-
)
72+
)
73+
derived_name_counts[name] += 1
74+
res[prefix] = cres
75+
76+
for prefix, names in res.items():
77+
for name_res in names:
78+
count = derived_name_counts[name_res["derived_name"]]
79+
name_res["count"] = count
80+
name_res["class"] = (
81+
"common" if count == n_mod else "unique" if count == 1 else "nonunique"
82+
)
10083

10184
return res
10285

@@ -138,7 +121,7 @@ def _classify_prefixed_columns(
138121

139122

140123
def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame:
141-
df = df1.copy()
124+
df = df1.copy(deep=False)
142125
# This converts boolean to object dtype, unfortunately
143126
# df.update(df2)
144127
common_cols = df1.columns.intersection(df2.columns)

tests/test_pull_push.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66
from anndata import AnnData
77

8-
from mudata import MuData
8+
from mudata import MuData, set_options
99

1010

1111
@pytest.fixture()
@@ -21,7 +21,8 @@ def modalities(request, obs_n, var_unique):
2121
mods[m].var["mod"] = m
2222

2323
# common column
24-
mods[m].var["highly_variable"] = np.tile([False, True], mods[m].n_vars // 2)
24+
mods[m].var["highly_variable"] = np.random.choice([False, True], size=mods[m].n_vars)
25+
mods[m].obs["common_obs_col"] = np.random.randint(0, int(1e6), size=mods[m].n_obs)
2526

2627
if var_unique:
2728
mods[m].var_names = [f"mod{m}_var{j}" for j in range(mods[m].n_vars)]
@@ -88,7 +89,6 @@ def test_pull_var(self, modalities):
8889
"""
8990
mdata = MuData(modalities)
9091
mdata.update()
91-
9292
mdata.pull_var()
9393

9494
assert "mod" in mdata.var.columns
@@ -165,6 +165,15 @@ def test_pull_obs_simple(self, modalities):
165165
for m in mdata.mod.keys():
166166
assert f"{m}:mod" in mdata.obs.columns
167167

168+
assert f"{m}:common_obs_col" in mdata.obs.columns
169+
170+
modmap = mdata.obsmap[m].ravel()
171+
mask = modmap > 0
172+
assert (
173+
mdata.obs[f"{m}:common_obs_col"][mask].to_numpy()
174+
== mdata.mod[m].obs["common_obs_col"].to_numpy()[modmap[mask] - 1]
175+
).all()
176+
168177
# join_common shouldn't work
169178
with pytest.raises(ValueError, match="shared obs_names"):
170179
mdata.pull_obs(join_common=True)
@@ -182,14 +191,24 @@ def test_push_var_simple(self, modalities):
182191
mdata = MuData(modalities)
183192
mdata.update()
184193

185-
mdata.var["pushed"] = True
186-
mdata.var["mod2:mod2_pushed"] = True
194+
mdata.var["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var)
195+
mdata.var["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var)
187196
mdata.push_var()
188197

189198
# pushing should work
190-
for mod in mdata.mod.values():
199+
for modname, mod in mdata.mod.items():
191200
assert "pushed" in mod.var.columns
201+
202+
map = mdata.varmap[modname].ravel()
203+
mask = map > 0
204+
assert (mdata.var["pushed"][mask] == mod.var["pushed"][map[mask] - 1]).all()
205+
192206
assert "mod2_pushed" in mdata["mod2"].var.columns
207+
map = mdata.varmap["mod2"].ravel()
208+
mask = map > 0
209+
assert (
210+
mdata.var["mod2:mod2_pushed"][mask] == mdata["mod2"].var["mod2_pushed"][map[mask] - 1]
211+
).all()
193212

194213
@pytest.mark.parametrize("var_unique", [True, False])
195214
@pytest.mark.parametrize("obs_n", ["joint", "disjoint"])
@@ -200,14 +219,24 @@ def test_push_obs_simple(self, modalities):
200219
mdata = MuData(modalities)
201220
mdata.update()
202221

203-
mdata.obs["pushed"] = True
204-
mdata.obs["mod2:mod2_pushed"] = True
222+
mdata.obs["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs)
223+
mdata.obs["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs)
205224
mdata.push_obs()
206225

207226
# pushing should work
208-
for mod in mdata.mod.values():
227+
for modname, mod in mdata.mod.items():
209228
assert "pushed" in mod.obs.columns
229+
230+
map = mdata.obsmap[modname].ravel()
231+
mask = map > 0
232+
assert (mdata.obs["pushed"][mask] == mod.obs["pushed"][map[mask] - 1]).all()
233+
210234
assert "mod2_pushed" in mdata["mod2"].obs.columns
235+
map = mdata.obsmap["mod2"].ravel()
236+
mask = map > 0
237+
assert (
238+
mdata.obs["mod2:mod2_pushed"][mask] == mdata["mod2"].obs["mod2_pushed"][map[mask] - 1]
239+
).all()
211240

212241

213242
@pytest.mark.usefixtures("filepath_h5mu")

0 commit comments

Comments
 (0)