Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 116 additions & 89 deletions src/mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from .file_backing import MuDataFileManager
from .repr import MUDATA_CSS, block_matrix, details_block_table
from .utils import (
_classify_attr_columns,
_classify_prefixed_columns,
MetadataColumn,
_make_index_unique,
_maybe_coerce_to_bool,
_maybe_coerce_to_boolean,
Expand Down Expand Up @@ -1915,37 +1914,54 @@ def _pull_attr(
if mods is not None:
if isinstance(mods, str):
mods = [mods]
mods = list(dict.fromkeys(mods))
if not all(m in self.mod for m in mods):
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
assert v is None, f"Cannot use mods with {k}."

if only_drop:
drop = True

cols = _classify_attr_columns(
np.concatenate(
[
[f"{m}:{val}" for val in getattr(mod, attr).columns.values]
for m, mod in self.mod.items()
]
),
self.mod.keys(),
)
cols: dict[str, list[MetadataColumn]] = {}

# get all columns from all modalities and count how many times each column is present
derived_name_counts = Counter()
for prefix, mod in self.mod.items():
modcols = getattr(mod, attr).columns
ccols = []
for name in modcols:
ccols.append(
MetadataColumn(
allowed_prefixes=self.mod.keys(),
prefix=prefix,
name=name,
strip_prefix=False,
)
)
derived_name_counts[name] += 1
cols[prefix] = ccols

for prefix, modcols in cols.items():
for col in modcols:
count = derived_name_counts[col.derived_name]
col.count = count # this is important to classify columns

if columns is not None:
for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items():
assert v is None, f"Cannot use {k} with columns."

# - modname1:column -> [modname1:column]
# - column -> [modname1:column, modname2:column, ...]
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]
if v is not None:
warnings.warn(
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would something like this improve readability? (I am not sure we have a consistent policy for formatting in such cases.)

Both `columns=...` and `{k}=True` were given. <...>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If yes, this is also true for similar warnings in other parts of the PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would be a bit misleading here, since the warning will also be emitted if k=False or any other value which is not the None default. Perhaps something like

Both `columns=...` and `{k}={locals()[k]}` were given...

? But I'm not sure if that brings the message across that it should just be not passed at all (as in leave the None default).

RuntimeWarning,
stacklevel=2,
)

if mods is not None:
cols = [col for col in cols if col["prefix"] in mods]
# keep only requested columns
cols = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e., with cols as a class this could be

prefix_to_cols = cols.filter_by_name_or_derived_name(colums)

(I would also advocate changing the name from cols to prefix_to_cols to avoid confusion with columns)

prefix: [
col for col in modcols if col.name in columns or col.derived_name in columns
]
for prefix, modcols in cols.items()
}

# TODO: Counter for columns in order to track their usage
# and error out if some columns were not used
Expand All @@ -1958,28 +1974,37 @@ def _pull_attr(
if unique is None:
unique = True

# filter columns by class, keep only those that were requested
selector = {"common": common, "nonunique": nonunique, "unique": unique}
cols = {
prefix: [col for col in modcols if selector[col.klass]]
for prefix, modcols in cols.items()
}

cols = [col for col in cols if selector[col["class"]]]
# filter columns, keep only requested modalities
if mods is not None:
cols = {prefix: cols[prefix] for prefix in mods}

derived_name_count = Counter([col["derived_name"] for col in cols])
# count final filtered column names, required later to decide whether to prefix a column with its source modality
derived_name_count = Counter(
[col.derived_name for modcols in cols.values() for col in modcols]
)

# - axis == self.axis
# e.g. combine var from multiple modalities (with unique vars)
# e.g. combine obs from multiple modalities (with shared obs)
# - 1 - axis == self.axis
# . e.g. combine obs from multiple modalities (with shared obs)
axis = 0 if attr == "var" else 1
# e.g. combine var from multiple modalities (with unique vars)
axis = 0 if attr == "obs" else 1

if 1 - axis == self.axis or self.axis == -1:
if axis == self.axis or self.axis == -1:
if join_common or join_nonunique:
raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.")

if join_common is None:
join_common = False
if attr == "var":
join_common = self.axis == 0
elif attr == "obs":
if attr == "obs":
join_common = self.axis == 1
else:
join_common = self.axis == 0

if join_nonunique is None:
join_nonunique = False
Expand All @@ -1995,44 +2020,39 @@ def _pull_attr(
n_attr = self.n_vars if attr == "var" else self.n_obs

dfs: list[pd.DataFrame] = []
for m, mod in self.mod.items():
if mods is not None and m not in mods:
continue
for m, modcols in cols.items():
mod = self.mod[m]
mod_map = attrmap[m].ravel()
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
mask = mod_map != 0

mod_df = getattr(mod, attr)
mod_columns = [
col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m
]
mod_df = mod_df[mod_df.columns.intersection(mod_columns)]
mask = mod_map > 0

mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]]
if drop:
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)

# Don't use modname: prefix if columns need to be joined
if join_common or join_nonunique or (not prefix_unique):
cols_special = [
col["derived_name"]
for col in cols
if (
(col["class"] == "common") & join_common
or (col["class"] == "nonunique") & join_nonunique
or (col["class"] == "unique") & (not prefix_unique)
# prepend modality prefix to column names if requested via arguments and there are no skipped modalities with
# the same column name (prefixing those columns may cause problems with future pulls or pushes)
mod_df.rename(
columns={
col.derived_name: col.name
for col in modcols
if not (
(
join_common
and col.klass == "common"
or join_nonunique
and col.klass == "nonunique"
or not prefix_unique
and col.klass == "unique"
)
and derived_name_count[col.derived_name] == col.count
)
and col["prefix"] == m
and derived_name_count[col["derived_name"]] == col["count"]
]
mod_df.columns = [
col if col in cols_special else f"{m}:{col}" for col in mod_df.columns
]
else:
mod_df.columns = [f"{m}:{col}" for col in mod_df.columns]
},
inplace=True,
)

# reorder modality DF to conform to global order
mod_df = (
_maybe_coerce_to_boolean(mod_df)
.set_index(np.arange(mod_n_attr))
.iloc[mod_map[mask] - 1]
.set_index(np.arange(n_attr)[mask])
.reindex(np.arange(n_attr))
Expand Down Expand Up @@ -2242,39 +2262,46 @@ def _push_attr(
raise ValueError("All mods should be present in mdata.mod")
elif len(mods) == self.n_mod:
mods = None
for k, v in {"common": common, "prefixed": prefixed}.items():
assert v is None, f"Cannot use mods with {k}."

if only_drop:
drop = True

cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys())
# get all global columns
cols = [
MetadataColumn(allowed_prefixes=self.mod.keys(), name=name)
for name in getattr(self, attr).columns
]

if columns is not None:
for k, v in {"common": common, "prefixed": prefixed}.items():
assert v is None, f"Cannot use columns with {k}."

# - modname1:column -> [modname1:column]
# - column -> [modname1:column, modname2:column, ...]
cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns]
if v:
warnings.warn(
f"Both columns and {k} given. Columns take precedence, {k} will be ignored",
RuntimeWarning,
stacklevel=2,
)

# preemptively drop columns from other modalities
if mods is not None:
cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""]
# keep only requested columns
cols = [
col
for col in cols
if (col.name in columns or col.derived_name in columns)
and (col.prefix is None or mods is not None and col.prefix in mods)
]
else:
if common is None:
common = True
if prefixed is None:
prefixed = True

selector = {"common": common, "prefixed": prefixed}

cols = [col for col in cols if selector[col["class"]]]
# filter columns by class, keep only those that were requested
selector = {"common": common, "unknown": prefixed}
cols = [col for col in cols if selector[col.klass]]

if len(cols) == 0:
return

derived_name_count = Counter([col["derived_name"] for col in cols])
derived_name_count = Counter([col.derived_name for col in cols])
for c, count in derived_name_count.items():
# if count > 1, there are both colname and modname:colname present
if count > 1 and c in getattr(self, attr).columns:
Expand All @@ -2286,25 +2313,25 @@ def _push_attr(
)

attrmap = getattr(self, f"{attr}map")
_n_attr = self.n_vars if attr == "var" else self.n_obs

for m, mod in self.mod.items():
if mods is not None and m not in mods:
continue

mod_map = attrmap[m]
mask = mod_map != 0
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
mod_map = attrmap[m].ravel()
mask = mod_map > 0
mod_n_attr = mod.n_obs if attr == "obs" else mod.n_vars

mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
df.columns = [col["derived_name"] for col in mod_cols]
# get all common and modality-specific columns for the current modality
mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"]
df = getattr(self, attr)[mask][[col.name for col in mod_cols]]

df = (
df.set_index(np.arange(mod_n_attr))
.iloc[mod_map[mask] - 1]
.set_index(np.arange(mod_n_attr))
)
# strip modality prefix where necessary
df.columns = [col.derived_name for col in mod_cols]

# reorder global DF to conform to modality order
idx = np.empty(mod_n_attr, dtype=mod_map.dtype)
idx[mod_map[mask] - 1] = np.arange(mod_n_attr)
df = df.iloc[idx].set_index(np.arange(mod_n_attr, dtype=mod_map.dtype))

if not only_drop:
# TODO: _maybe_coerce_to_bool
Expand All @@ -2317,7 +2344,7 @@ def _push_attr(

if drop:
for col in cols:
getattr(self, attr).drop(col["name"], axis=1, inplace=True)
getattr(self, attr).drop(col.name, axis=1, inplace=True)

def push_obs(
self,
Expand Down
Loading