-
Notifications
You must be signed in to change notification settings - Fork 21
fixes for push_attr/pull_attr #105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
3a4b2e1
pull_attr fixes
ilia-kats 8647ddd
push_attr fixes
ilia-kats 2b26ebf
push/pull: replace dict holding column information with custom class
ilia-kats 21b2e83
get rid of _classify_attr_columns, sprinkle more comments throughout
ilia-kats 8b8d1c1
Apply suggestion from @ilan-gold
ilia-kats f4d5eca
improve push performance scaling
ilia-kats 0d646ed
fixup! improve push performance scaling
ilia-kats 717c362
fixup! improve push performance scaling
ilia-kats File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,7 @@ | |
| from .file_backing import MuDataFileManager | ||
| from .repr import MUDATA_CSS, block_matrix, details_block_table | ||
| from .utils import ( | ||
| _classify_attr_columns, | ||
| _classify_prefixed_columns, | ||
| MetadataColumn, | ||
| _make_index_unique, | ||
| _maybe_coerce_to_bool, | ||
| _maybe_coerce_to_boolean, | ||
|
|
@@ -1915,37 +1914,54 @@ def _pull_attr( | |
| if mods is not None: | ||
| if isinstance(mods, str): | ||
| mods = [mods] | ||
| mods = list(dict.fromkeys(mods)) | ||
| if not all(m in self.mod for m in mods): | ||
| raise ValueError("All mods should be present in mdata.mod") | ||
| elif len(mods) == self.n_mod: | ||
| mods = None | ||
| for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): | ||
| assert v is None, f"Cannot use mods with {k}." | ||
|
|
||
| if only_drop: | ||
| drop = True | ||
|
|
||
| cols = _classify_attr_columns( | ||
| np.concatenate( | ||
| [ | ||
| [f"{m}:{val}" for val in getattr(mod, attr).columns.values] | ||
| for m, mod in self.mod.items() | ||
| ] | ||
| ), | ||
| self.mod.keys(), | ||
| ) | ||
| cols: dict[str, list[MetadataColumn]] = {} | ||
|
|
||
| # get all columns from all modalities and count how many times each column is present | ||
| derived_name_counts = Counter() | ||
| for prefix, mod in self.mod.items(): | ||
| modcols = getattr(mod, attr).columns | ||
| ccols = [] | ||
| for name in modcols: | ||
| ccols.append( | ||
| MetadataColumn( | ||
| allowed_prefixes=self.mod.keys(), | ||
| prefix=prefix, | ||
| name=name, | ||
| strip_prefix=False, | ||
| ) | ||
| ) | ||
| derived_name_counts[name] += 1 | ||
| cols[prefix] = ccols | ||
|
|
||
| for prefix, modcols in cols.items(): | ||
| for col in modcols: | ||
| count = derived_name_counts[col.derived_name] | ||
| col.count = count # this is important to classify columns | ||
|
|
||
| if columns is not None: | ||
| for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): | ||
| assert v is None, f"Cannot use {k} with columns." | ||
|
|
||
| # - modname1:column -> [modname1:column] | ||
| # - column -> [modname1:column, modname2:column, ...] | ||
| cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] | ||
| if v is not None: | ||
| warnings.warn( | ||
| f"Both columns and {k} given. Columns take precedence, {k} will be ignored", | ||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| if mods is not None: | ||
| cols = [col for col in cols if col["prefix"] in mods] | ||
| # keep only requested columns | ||
| cols = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i.e., with prefix_to_cols = cols.filter_by_name_or_derived_name(colums)(I would also advocate changing the name from |
||
| prefix: [ | ||
| col for col in modcols if col.name in columns or col.derived_name in columns | ||
| ] | ||
| for prefix, modcols in cols.items() | ||
| } | ||
|
|
||
| # TODO: Counter for columns in order to track their usage | ||
| # and error out if some columns were not used | ||
|
|
@@ -1958,28 +1974,37 @@ def _pull_attr( | |
| if unique is None: | ||
| unique = True | ||
|
|
||
| # filter columns by class, keep only those that were requested | ||
| selector = {"common": common, "nonunique": nonunique, "unique": unique} | ||
| cols = { | ||
| prefix: [col for col in modcols if selector[col.klass]] | ||
| for prefix, modcols in cols.items() | ||
| } | ||
|
|
||
| cols = [col for col in cols if selector[col["class"]]] | ||
| # filter columns, keep only requested modalities | ||
| if mods is not None: | ||
| cols = {prefix: cols[prefix] for prefix in mods} | ||
|
|
||
| derived_name_count = Counter([col["derived_name"] for col in cols]) | ||
| # count final filtered column names, required later to decide whether to prefix a column with its source modality | ||
| derived_name_count = Counter( | ||
| [col.derived_name for modcols in cols.values() for col in modcols] | ||
| ) | ||
|
|
||
| # - axis == self.axis | ||
| # e.g. combine var from multiple modalities (with unique vars) | ||
| # e.g. combine obs from multiple modalities (with shared obs) | ||
| # - 1 - axis == self.axis | ||
| # . e.g. combine obs from multiple modalities (with shared obs) | ||
| axis = 0 if attr == "var" else 1 | ||
| # e.g. combine var from multiple modalities (with unique vars) | ||
| axis = 0 if attr == "obs" else 1 | ||
|
|
||
| if 1 - axis == self.axis or self.axis == -1: | ||
| if axis == self.axis or self.axis == -1: | ||
| if join_common or join_nonunique: | ||
| raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.") | ||
|
|
||
| if join_common is None: | ||
| join_common = False | ||
| if attr == "var": | ||
| join_common = self.axis == 0 | ||
| elif attr == "obs": | ||
| if attr == "obs": | ||
| join_common = self.axis == 1 | ||
| else: | ||
| join_common = self.axis == 0 | ||
|
|
||
| if join_nonunique is None: | ||
| join_nonunique = False | ||
|
|
@@ -1995,44 +2020,39 @@ def _pull_attr( | |
| n_attr = self.n_vars if attr == "var" else self.n_obs | ||
|
|
||
| dfs: list[pd.DataFrame] = [] | ||
| for m, mod in self.mod.items(): | ||
| if mods is not None and m not in mods: | ||
| continue | ||
| for m, modcols in cols.items(): | ||
| mod = self.mod[m] | ||
| mod_map = attrmap[m].ravel() | ||
| mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs | ||
| mask = mod_map != 0 | ||
|
|
||
| mod_df = getattr(mod, attr) | ||
| mod_columns = [ | ||
| col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m | ||
| ] | ||
| mod_df = mod_df[mod_df.columns.intersection(mod_columns)] | ||
| mask = mod_map > 0 | ||
|
|
||
| mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]] | ||
| if drop: | ||
| getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) | ||
|
|
||
| # Don't use modname: prefix if columns need to be joined | ||
| if join_common or join_nonunique or (not prefix_unique): | ||
| cols_special = [ | ||
| col["derived_name"] | ||
| for col in cols | ||
| if ( | ||
| (col["class"] == "common") & join_common | ||
| or (col["class"] == "nonunique") & join_nonunique | ||
| or (col["class"] == "unique") & (not prefix_unique) | ||
| # prepend modality prefix to column names if requested via arguments and there are no skipped modalities with | ||
| # the same column name (prefixing those columns may cause problems with future pulls or pushes) | ||
| mod_df.rename( | ||
| columns={ | ||
| col.derived_name: col.name | ||
| for col in modcols | ||
| if not ( | ||
| ( | ||
| join_common | ||
| and col.klass == "common" | ||
| or join_nonunique | ||
| and col.klass == "nonunique" | ||
| or not prefix_unique | ||
| and col.klass == "unique" | ||
| ) | ||
| and derived_name_count[col.derived_name] == col.count | ||
ilia-kats marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) | ||
| and col["prefix"] == m | ||
| and derived_name_count[col["derived_name"]] == col["count"] | ||
| ] | ||
| mod_df.columns = [ | ||
| col if col in cols_special else f"{m}:{col}" for col in mod_df.columns | ||
| ] | ||
| else: | ||
| mod_df.columns = [f"{m}:{col}" for col in mod_df.columns] | ||
| }, | ||
| inplace=True, | ||
| ) | ||
|
|
||
| # reorder modality DF to conform to global order | ||
| mod_df = ( | ||
| _maybe_coerce_to_boolean(mod_df) | ||
| .set_index(np.arange(mod_n_attr)) | ||
| .iloc[mod_map[mask] - 1] | ||
| .set_index(np.arange(n_attr)[mask]) | ||
| .reindex(np.arange(n_attr)) | ||
|
|
@@ -2242,39 +2262,46 @@ def _push_attr( | |
| raise ValueError("All mods should be present in mdata.mod") | ||
| elif len(mods) == self.n_mod: | ||
| mods = None | ||
| for k, v in {"common": common, "prefixed": prefixed}.items(): | ||
| assert v is None, f"Cannot use mods with {k}." | ||
|
|
||
| if only_drop: | ||
| drop = True | ||
|
|
||
| cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys()) | ||
| # get all global columns | ||
| cols = [ | ||
| MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) | ||
| for name in getattr(self, attr).columns | ||
| ] | ||
|
|
||
| if columns is not None: | ||
| for k, v in {"common": common, "prefixed": prefixed}.items(): | ||
| assert v is None, f"Cannot use columns with {k}." | ||
|
|
||
| # - modname1:column -> [modname1:column] | ||
| # - column -> [modname1:column, modname2:column, ...] | ||
| cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] | ||
| if v: | ||
| warnings.warn( | ||
| f"Both columns and {k} given. Columns take precedence, {k} will be ignored", | ||
| RuntimeWarning, | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| # preemptively drop columns from other modalities | ||
| if mods is not None: | ||
| cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""] | ||
| # keep only requested columns | ||
| cols = [ | ||
| col | ||
| for col in cols | ||
| if (col.name in columns or col.derived_name in columns) | ||
| and (col.prefix is None or mods is not None and col.prefix in mods) | ||
| ] | ||
| else: | ||
| if common is None: | ||
| common = True | ||
| if prefixed is None: | ||
| prefixed = True | ||
|
|
||
| selector = {"common": common, "prefixed": prefixed} | ||
|
|
||
| cols = [col for col in cols if selector[col["class"]]] | ||
| # filter columns by class, keep only those that were requested | ||
| selector = {"common": common, "unknown": prefixed} | ||
| cols = [col for col in cols if selector[col.klass]] | ||
|
|
||
| if len(cols) == 0: | ||
| return | ||
|
|
||
| derived_name_count = Counter([col["derived_name"] for col in cols]) | ||
| derived_name_count = Counter([col.derived_name for col in cols]) | ||
| for c, count in derived_name_count.items(): | ||
| # if count > 1, there are both colname and modname:colname present | ||
| if count > 1 and c in getattr(self, attr).columns: | ||
|
|
@@ -2286,25 +2313,25 @@ def _push_attr( | |
| ) | ||
|
|
||
| attrmap = getattr(self, f"{attr}map") | ||
| _n_attr = self.n_vars if attr == "var" else self.n_obs | ||
|
|
||
| for m, mod in self.mod.items(): | ||
| if mods is not None and m not in mods: | ||
| continue | ||
|
|
||
| mod_map = attrmap[m] | ||
| mask = mod_map != 0 | ||
| mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs | ||
| mod_map = attrmap[m].ravel() | ||
| mask = mod_map > 0 | ||
| mod_n_attr = mod.n_obs if attr == "obs" else mod.n_vars | ||
|
|
||
| mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"] | ||
| df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]] | ||
| df.columns = [col["derived_name"] for col in mod_cols] | ||
| # get all common and modality-specific columns for the current modality | ||
| mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] | ||
| df = getattr(self, attr)[mask][[col.name for col in mod_cols]] | ||
|
|
||
| df = ( | ||
| df.set_index(np.arange(mod_n_attr)) | ||
| .iloc[mod_map[mask] - 1] | ||
| .set_index(np.arange(mod_n_attr)) | ||
| ) | ||
| # strip modality prefix where necessary | ||
| df.columns = [col.derived_name for col in mod_cols] | ||
|
|
||
| # reorder global DF to conform to modality order | ||
| idx = np.empty(mod_n_attr, dtype=mod_map.dtype) | ||
| idx[mod_map[mask] - 1] = np.arange(mod_n_attr) | ||
| df = df.iloc[idx].set_index(np.arange(mod_n_attr, dtype=mod_map.dtype)) | ||
|
|
||
| if not only_drop: | ||
| # TODO: _maybe_coerce_to_bool | ||
|
|
@@ -2317,7 +2344,7 @@ def _push_attr( | |
|
|
||
| if drop: | ||
| for col in cols: | ||
| getattr(self, attr).drop(col["name"], axis=1, inplace=True) | ||
| getattr(self, attr).drop(col.name, axis=1, inplace=True) | ||
|
|
||
| def push_obs( | ||
| self, | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would something like this improve readability? (I am not sure we have a consistent policy for formatting in such cases.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If yes, this is also true for similar warnings in other parts of the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that would be a bit misleading here, since the warning will also be emitted if
k=Falseor any other value which is not theNonedefault. Perhaps something like? But I'm not sure if that brings the message across that it should just be not passed at all (as in leave the None default).