diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 76721dd..c896592 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -27,8 +27,7 @@ from .file_backing import MuDataFileManager from .repr import MUDATA_CSS, block_matrix, details_block_table from .utils import ( - _classify_attr_columns, - _classify_prefixed_columns, + MetadataColumn, _make_index_unique, _maybe_coerce_to_bool, _maybe_coerce_to_boolean, @@ -1915,37 +1914,54 @@ def _pull_attr( if mods is not None: if isinstance(mods, str): mods = [mods] - mods = list(dict.fromkeys(mods)) if not all(m in self.mod for m in mods): raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None - for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): - assert v is None, f"Cannot use mods with {k}." if only_drop: drop = True - cols = _classify_attr_columns( - np.concatenate( - [ - [f"{m}:{val}" for val in getattr(mod, attr).columns.values] - for m, mod in self.mod.items() - ] - ), - self.mod.keys(), - ) + cols: dict[str, list[MetadataColumn]] = {} + + # get all columns from all modalities and count how many times each column is present + derived_name_counts = Counter() + for prefix, mod in self.mod.items(): + modcols = getattr(mod, attr).columns + ccols = [] + for name in modcols: + ccols.append( + MetadataColumn( + allowed_prefixes=self.mod.keys(), + prefix=prefix, + name=name, + strip_prefix=False, + ) + ) + derived_name_counts[name] += 1 + cols[prefix] = ccols + + for prefix, modcols in cols.items(): + for col in modcols: + count = derived_name_counts[col.derived_name] + col.count = count # this is important to classify columns if columns is not None: for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): - assert v is None, f"Cannot use {k} with columns." - - # - modname1:column -> [modname1:column] - # - column -> [modname1:column, modname2:column, ...] - cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] + if v is not None: + warnings.warn( + f"Both columns and {k} given. Columns take precedence, {k} will be ignored", + RuntimeWarning, + stacklevel=2, + ) - if mods is not None: - cols = [col for col in cols if col["prefix"] in mods] + # keep only requested columns + cols = { + prefix: [ + col for col in modcols if col.name in columns or col.derived_name in columns + ] + for prefix, modcols in cols.items() + } # TODO: Counter for columns in order to track their usage # and error out if some columns were not used @@ -1958,28 +1974,37 @@ def _pull_attr( if unique is None: unique = True + # filter columns by class, keep only those that were requested selector = {"common": common, "nonunique": nonunique, "unique": unique} + cols = { + prefix: [col for col in modcols if selector[col.klass]] + for prefix, modcols in cols.items() + } - cols = [col for col in cols if selector[col["class"]]] + # filter columns, keep only requested modalities + if mods is not None: + cols = {prefix: cols[prefix] for prefix in mods} - derived_name_count = Counter([col["derived_name"] for col in cols]) + # count final filtered column names, required later to decide whether to prefix a column with its source modality + derived_name_count = Counter( + [col.derived_name for modcols in cols.values() for col in modcols] + ) # - axis == self.axis - # e.g. combine var from multiple modalities (with unique vars) + # e.g. combine obs from multiple modalities (with shared obs) # - 1 - axis == self.axis - # . e.g. combine obs from multiple modalities (with shared obs) - axis = 0 if attr == "var" else 1 + # e.g. combine var from multiple modalities (with unique vars) + axis = 0 if attr == "obs" else 1 - if 1 - axis == self.axis or self.axis == -1: + if axis == self.axis or self.axis == -1: if join_common or join_nonunique: raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.") if join_common is None: - join_common = False - if attr == "var": - join_common = self.axis == 0 - elif attr == "obs": + if attr == "obs": join_common = self.axis == 1 + else: + join_common = self.axis == 0 if join_nonunique is None: join_nonunique = False @@ -1995,44 +2020,39 @@ def _pull_attr( n_attr = self.n_vars if attr == "var" else self.n_obs dfs: list[pd.DataFrame] = [] - for m, mod in self.mod.items(): - if mods is not None and m not in mods: - continue + for m, modcols in cols.items(): + mod = self.mod[m] mod_map = attrmap[m].ravel() - mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs - mask = mod_map != 0 - - mod_df = getattr(mod, attr) - mod_columns = [ - col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m - ] - mod_df = mod_df[mod_df.columns.intersection(mod_columns)] + mask = mod_map > 0 + mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]] if drop: getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) - # Don't use modname: prefix if columns need to be joined - if join_common or join_nonunique or (not prefix_unique): - cols_special = [ - col["derived_name"] - for col in cols - if ( - (col["class"] == "common") & join_common - or (col["class"] == "nonunique") & join_nonunique - or (col["class"] == "unique") & (not prefix_unique) + # prepend modality prefix to column names if requested via arguments and there are no skipped modalities with + # the same column name (prefixing those columns may cause problems with future pulls or pushes) + mod_df.rename( + columns={ + col.derived_name: col.name + for col in modcols + if not ( + ( + join_common + and col.klass == "common" + or join_nonunique + and col.klass == "nonunique" + or not prefix_unique + and col.klass == "unique" + ) + and derived_name_count[col.derived_name] == col.count ) - and col["prefix"] == m - and derived_name_count[col["derived_name"]] == col["count"] - ] - mod_df.columns = [ - col if col in cols_special else f"{m}:{col}" for col in mod_df.columns - ] - else: - mod_df.columns = [f"{m}:{col}" for col in mod_df.columns] + }, + inplace=True, + ) + # reorder modality DF to conform to global order mod_df = ( _maybe_coerce_to_boolean(mod_df) - .set_index(np.arange(mod_n_attr)) .iloc[mod_map[mask] - 1] .set_index(np.arange(n_attr)[mask]) .reindex(np.arange(n_attr)) @@ -2242,39 +2262,46 @@ def _push_attr( raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None - for k, v in {"common": common, "prefixed": prefixed}.items(): - assert v is None, f"Cannot use mods with {k}." if only_drop: drop = True - cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys()) + # get all global columns + cols = [ + MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) + for name in getattr(self, attr).columns + ] if columns is not None: for k, v in {"common": common, "prefixed": prefixed}.items(): - assert v is None, f"Cannot use columns with {k}." - - # - modname1:column -> [modname1:column] - # - column -> [modname1:column, modname2:column, ...] - cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] + if v: + warnings.warn( + f"Both columns and {k} given. Columns take precedence, {k} will be ignored", + RuntimeWarning, + stacklevel=2, + ) - # preemptively drop columns from other modalities - if mods is not None: - cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""] + # keep only requested columns + cols = [ + col + for col in cols + if (col.name in columns or col.derived_name in columns) + and (col.prefix is None or mods is not None and col.prefix in mods) + ] else: if common is None: common = True if prefixed is None: prefixed = True - selector = {"common": common, "prefixed": prefixed} - - cols = [col for col in cols if selector[col["class"]]] + # filter columns by class, keep only those that were requested + selector = {"common": common, "unknown": prefixed} + cols = [col for col in cols if selector[col.klass]] if len(cols) == 0: return - derived_name_count = Counter([col["derived_name"] for col in cols]) + derived_name_count = Counter([col.derived_name for col in cols]) for c, count in derived_name_count.items(): # if count > 1, there are both colname and modname:colname present if count > 1 and c in getattr(self, attr).columns: @@ -2286,25 +2313,25 @@ def _push_attr( ) attrmap = getattr(self, f"{attr}map") - _n_attr = self.n_vars if attr == "var" else self.n_obs - for m, mod in self.mod.items(): if mods is not None and m not in mods: continue - mod_map = attrmap[m] - mask = mod_map != 0 - mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs + mod_map = attrmap[m].ravel() + mask = mod_map > 0 + mod_n_attr = mod.n_obs if attr == "obs" else mod.n_vars - mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"] - df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]] - df.columns = [col["derived_name"] for col in mod_cols] + # get all common and modality-specific columns for the current modality + mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] + df = getattr(self, attr)[mask][[col.name for col in mod_cols]] - df = ( - df.set_index(np.arange(mod_n_attr)) - .iloc[mod_map[mask] - 1] - .set_index(np.arange(mod_n_attr)) - ) + # strip modality prefix where necessary + df.columns = [col.derived_name for col in mod_cols] + + # reorder global DF to conform to modality order + idx = np.empty(mod_n_attr, dtype=mod_map.dtype) + idx[mod_map[mask] - 1] = np.arange(mod_n_attr) + df = df.iloc[idx].set_index(np.arange(mod_n_attr, dtype=mod_map.dtype)) if not only_drop: # TODO: _maybe_coerce_to_bool @@ -2317,7 +2344,7 @@ def _push_attr( if drop: for col in cols: - getattr(self, attr).drop(col["name"], axis=1, inplace=True) + getattr(self, attr).drop(col.name, axis=1, inplace=True) def push_obs( self, diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 712dc36..520d182 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -1,6 +1,6 @@ from collections import Counter from collections.abc import Sequence -from typing import TypeVar +from typing import Literal, TypeVar import numpy as np import pandas as pd @@ -38,120 +38,61 @@ def _maybe_coerce_to_boolean(df: T) -> T: return df -def _classify_attr_columns( - names: Sequence[str], prefixes: Sequence[str] -) -> Sequence[dict[str, str]]: - """ - Classify names into common, non-unique, and unique - w.r.t. to the list of prefixes. - - - Common columns do not have modality prefixes. - - Non-unqiue columns have a modality prefix, - and there are multiple columns that differ - only by their modality prefix. - - Unique columns are prefixed by modality names, - and there is only one modality prefix - for a column with a certain name. - - E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified - into [ - {"name": "global", "prefix": "", "derived_name": "global", "count": 1, "class": "common"}, - {"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "nonunique"}, - {"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "count": 2, "class": "nonunique"}, - {"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "unique"}, - ] - """ - n_mod = len(prefixes) - res: list[dict[str, str]] = [] - - for name in names: - name_common = { - "name": name, - "prefix": "", - "derived_name": name, - } - name_split = name.split(":", 1) - - if len(name_split) < 2: - res.append(name_common) +class MetadataColumn: + __slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes", "_strip_prefix") + + def __init__( + self, + *, + allowed_prefixes: Sequence[str], + prefix: str | None = None, + name: str | None = None, + count: int = 0, + strip_prefix: bool = True, + ): + self._strip_prefix = strip_prefix + self._allowed_prefixes = allowed_prefixes + self.prefix = prefix + if prefix is None and strip_prefix: + self.name = name else: - maybe_modname, derived_name = name_split - - if maybe_modname in prefixes: - name_prefixed = { - "name": name, - "prefix": maybe_modname, - "derived_name": derived_name, - } - res.append(name_prefixed) - else: - res.append(name_common) - - derived_name_counts = Counter(name_res["derived_name"] for name_res in res) - for name_res in res: - name_res["count"] = derived_name_counts[name_res["derived_name"]] - - for name_res in res: - name_res["class"] = ( - "common" - if name_res["count"] == n_mod - else "unique" if name_res["count"] == 1 else "nonunique" - ) - - return res - - -def _classify_prefixed_columns( - names: Sequence[str], prefixes: Sequence[str] -) -> Sequence[dict[str, str]]: - """ - Classify names into common and prefixed - w.r.t. to the list of prefixes. - - - Common columns do not have modality prefixes. - - Prefixed columns are prefixed by modality names. - - E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified - into [ - {"name": "global", "prefix": "", "derived_name": "global", "class": "common"}, - {"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"}, - {"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "class": "prefixed"}, - {"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"}, - ] - """ - res: list[dict[str, str]] = [] - - for name in names: - name_common = { - "name": name, - "prefix": "", - "derived_name": name, - } - name_split = name.split(":", 1) - - if len(name_split) < 2: - res.append(name_common) + self.prefix = prefix + self.derived_name = name + self.count = count + + @property + def name(self) -> str: + if self.prefix is not None: + return f"{self.prefix}:{self.derived_name}" else: - maybe_modname, derived_name = name_split - - if maybe_modname in prefixes: - name_prefixed = { - "name": name, - "prefix": maybe_modname, - "derived_name": derived_name, - } - res.append(name_prefixed) - else: - res.append(name_common) - - for name_res in res: - name_res["class"] = "common" if name_res["prefix"] == "" else "prefixed" - - return res + return self.derived_name + + @name.setter + def name(self, new_name): + if ( + not self._strip_prefix + or len(name_split := new_name.split(":", 1)) < 2 + or name_split[0] not in self._allowed_prefixes + ): + self.prefix = None + self.derived_name = new_name + else: + self.prefix, self.derived_name = name_split + + @property + def klass(self) -> Literal["common", "unique", "nonunique", "unknown"]: + if self.prefix is None or self.count == len(self._allowed_prefixes): + return "common" + elif self.count == 1: + return "unique" + elif self.count > 0: + return "nonunique" + else: + return "unknown" def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: - df = df1.copy() + df = df1.copy(deep=False) # This converts boolean to object dtype, unfortunately # df.update(df2) common_cols = df1.columns.intersection(df2.columns) diff --git a/tests/test_pull_push.py b/tests/test_pull_push.py index ddb0520..7251966 100644 --- a/tests/test_pull_push.py +++ b/tests/test_pull_push.py @@ -5,7 +5,7 @@ import pytest from anndata import AnnData -from mudata import MuData +from mudata import MuData, set_options @pytest.fixture() @@ -21,7 +21,8 @@ def modalities(request, obs_n, var_unique): mods[m].var["mod"] = m # common column - mods[m].var["highly_variable"] = np.tile([False, True], mods[m].n_vars // 2) + mods[m].var["highly_variable"] = np.random.choice([False, True], size=mods[m].n_vars) + mods[m].obs["common_obs_col"] = np.random.randint(0, int(1e6), size=mods[m].n_obs) if var_unique: mods[m].var_names = [f"mod{m}_var{j}" for j in range(mods[m].n_vars)] @@ -88,7 +89,6 @@ def test_pull_var(self, modalities): """ mdata = MuData(modalities) mdata.update() - mdata.pull_var() assert "mod" in mdata.var.columns @@ -165,6 +165,15 @@ def test_pull_obs_simple(self, modalities): for m in mdata.mod.keys(): assert f"{m}:mod" in mdata.obs.columns + assert f"{m}:common_obs_col" in mdata.obs.columns + + modmap = mdata.obsmap[m].ravel() + mask = modmap > 0 + assert ( + mdata.obs[f"{m}:common_obs_col"][mask].to_numpy() + == mdata.mod[m].obs["common_obs_col"].to_numpy()[modmap[mask] - 1] + ).all() + # join_common shouldn't work with pytest.raises(ValueError, match="shared obs_names"): mdata.pull_obs(join_common=True) @@ -182,14 +191,24 @@ def test_push_var_simple(self, modalities): mdata = MuData(modalities) mdata.update() - mdata.var["pushed"] = True - mdata.var["mod2:mod2_pushed"] = True + mdata.var["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var) + mdata.var["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var) mdata.push_var() # pushing should work - for mod in mdata.mod.values(): + for modname, mod in mdata.mod.items(): assert "pushed" in mod.var.columns + + map = mdata.varmap[modname].ravel() + mask = map > 0 + assert (mdata.var["pushed"][mask] == mod.var["pushed"][map[mask] - 1]).all() + assert "mod2_pushed" in mdata["mod2"].var.columns + map = mdata.varmap["mod2"].ravel() + mask = map > 0 + assert ( + mdata.var["mod2:mod2_pushed"][mask] == mdata["mod2"].var["mod2_pushed"][map[mask] - 1] + ).all() @pytest.mark.parametrize("var_unique", [True, False]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) @@ -200,14 +219,24 @@ def test_push_obs_simple(self, modalities): mdata = MuData(modalities) mdata.update() - mdata.obs["pushed"] = True - mdata.obs["mod2:mod2_pushed"] = True + mdata.obs["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs) + mdata.obs["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs) mdata.push_obs() # pushing should work - for mod in mdata.mod.values(): + for modname, mod in mdata.mod.items(): assert "pushed" in mod.obs.columns + + map = mdata.obsmap[modname].ravel() + mask = map > 0 + assert (mdata.obs["pushed"][mask] == mod.obs["pushed"][map[mask] - 1]).all() + assert "mod2_pushed" in mdata["mod2"].obs.columns + map = mdata.obsmap["mod2"].ravel() + mask = map > 0 + assert ( + mdata.obs["mod2:mod2_pushed"][mask] == mdata["mod2"].obs["mod2_pushed"][map[mask] - 1] + ).all() @pytest.mark.usefixtures("filepath_h5mu")