Skip to content

Commit 1072cf6

Browse files
committed
push/pull: replace dict holding column information with custom class
1 parent 1204186 commit 1072cf6

File tree

2 files changed

+80
-81
lines changed

2 files changed

+80
-81
lines changed

src/mudata/_core/mudata.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from .file_backing import MuDataFileManager
2828
from .repr import MUDATA_CSS, block_matrix, details_block_table
2929
from .utils import (
30+
MetadataColumn,
3031
_classify_attr_columns,
31-
_classify_prefixed_columns,
3232
_make_index_unique,
3333
_maybe_coerce_to_bool,
3434
_maybe_coerce_to_boolean,
@@ -1940,9 +1940,7 @@ def _pull_attr(
19401940
# - column -> [modname1:column, modname2:column, ...]
19411941
cols = {
19421942
prefix: [
1943-
col
1944-
for col in modcols
1945-
if col["name"] in columns or col["derived_name"] in columns
1943+
col for col in modcols if col.name in columns or col.derived_name in columns
19461944
]
19471945
for prefix, modcols in cols.items()
19481946
}
@@ -1960,15 +1958,15 @@ def _pull_attr(
19601958

19611959
selector = {"common": common, "nonunique": nonunique, "unique": unique}
19621960
cols = {
1963-
prefix: [col for col in modcols if selector[col["class"]]]
1961+
prefix: [col for col in modcols if selector[col.klass]]
19641962
for prefix, modcols in cols.items()
19651963
}
19661964

19671965
if mods is not None:
19681966
cols = {prefix: cols[prefix] for prefix in mods}
19691967

19701968
derived_name_count = Counter(
1971-
[col["derived_name"] for modcols in cols.values() for col in modcols]
1969+
[col.derived_name for modcols in cols.values() for col in modcols]
19721970
)
19731971

19741972
# - axis == self.axis
@@ -2007,24 +2005,24 @@ def _pull_attr(
20072005
mod_map = attrmap[m].ravel()
20082006
mask = mod_map > 0
20092007

2010-
mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]]
2008+
mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]]
20112009
if drop:
20122010
getattr(mod, attr).drop(columns=mod_df.columns, inplace=True)
20132011

20142012
mod_df.rename(
20152013
columns={
2016-
col["derived_name"]: col["name"]
2014+
col.derived_name: col.name
20172015
for col in modcols
20182016
if not (
20192017
(
20202018
join_common
2021-
and col["class"] == "common"
2019+
and col.klass == "common"
20222020
or join_nonunique
2023-
and col["class"] == "nonunique"
2021+
and col.klass == "nonunique"
20242022
or not prefix_unique
2025-
and col["class"] == "unique"
2023+
and col.klass == "unique"
20262024
)
2027-
and derived_name_count[col["derived_name"]] == col["count"]
2025+
and derived_name_count[col.derived_name] == col.count
20282026
)
20292027
},
20302028
inplace=True,
@@ -2245,7 +2243,10 @@ def _push_attr(
22452243
if only_drop:
22462244
drop = True
22472245

2248-
cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys())
2246+
cols = [
2247+
MetadataColumn(allowed_prefixes=self.mod.keys(), name=name)
2248+
for name in getattr(self, attr).columns
2249+
]
22492250

22502251
if columns is not None:
22512252
for k, v in {"common": common, "prefixed": prefixed}.items():
@@ -2262,23 +2263,23 @@ def _push_attr(
22622263
cols = [
22632264
col
22642265
for col in cols
2265-
if (col["name"] in columns or col["derived_name"] in columns)
2266-
and (col["prefix"] == "" or mods is not None and col["prefix"] in mods)
2266+
if (col.name in columns or col.derived_name in columns)
2267+
and (col.prefix is None or mods is not None and col.prefix in mods)
22672268
]
22682269
else:
22692270
if common is None:
22702271
common = True
22712272
if prefixed is None:
22722273
prefixed = True
22732274

2274-
selector = {"common": common, "prefixed": prefixed}
2275+
selector = {"common": common, "unknown": prefixed}
22752276

2276-
cols = [col for col in cols if selector[col["class"]]]
2277+
cols = [col for col in cols if selector[col.klass]]
22772278

22782279
if len(cols) == 0:
22792280
return
22802281

2281-
derived_name_count = Counter([col["derived_name"] for col in cols])
2282+
derived_name_count = Counter([col.derived_name for col in cols])
22822283
for c, count in derived_name_count.items():
22832284
# if count > 1, there are both colname and modname:colname present
22842285
if count > 1 and c in getattr(self, attr).columns:
@@ -2300,9 +2301,9 @@ def _push_attr(
23002301
mask = mod_map != 0
23012302
mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs
23022303

2303-
mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"]
2304-
df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]]
2305-
df.columns = [col["derived_name"] for col in mod_cols]
2304+
mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"]
2305+
df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]]
2306+
df.columns = [col.derived_name for col in mod_cols]
23062307

23072308
df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr))
23082309

@@ -2317,7 +2318,7 @@ def _push_attr(
23172318

23182319
if drop:
23192320
for col in cols:
2320-
getattr(self, attr).drop(col["name"], axis=1, inplace=True)
2321+
getattr(self, attr).drop(col.name, axis=1, inplace=True)
23212322

23222323
def push_obs(
23232324
self,

src/mudata/_core/utils.py

Lines changed: 57 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import Counter
22
from collections.abc import Mapping, Sequence
3-
from typing import TypeVar
3+
from typing import Literal, TypeVar
44

55
import numpy as np
66
import pandas as pd
@@ -38,7 +38,56 @@ def _maybe_coerce_to_boolean(df: T) -> T:
3838
return df
3939

4040

41-
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[dict[str, str]]]:
41+
class MetadataColumn:
42+
__slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes")
43+
44+
def __init__(
45+
self,
46+
*,
47+
allowed_prefixes: Sequence[str],
48+
prefix: str | None = None,
49+
name: str | None = None,
50+
count: int = 0,
51+
):
52+
self._allowed_prefixes = allowed_prefixes
53+
if prefix is None:
54+
self.name = name
55+
else:
56+
self.prefix = prefix
57+
self.derived_name = name
58+
self.count = count
59+
60+
@property
61+
def name(self) -> str:
62+
if self.prefix is not None:
63+
return f"{self.prefix}:{self.derived_name}"
64+
else:
65+
return self.derived_name
66+
67+
@name.setter
68+
def name(self, new_name):
69+
if (
70+
len(name_split := new_name.split(":", 1)) < 2
71+
or name_split[0] not in self._allowed_prefixes
72+
):
73+
self.prefix = None
74+
self.derived_name = new_name
75+
else:
76+
self.prefix, self.derived_name = name_split
77+
78+
@property
79+
def klass(self) -> Literal["common", "unique", "nonunique", "unknown"]:
80+
if self.prefix is None or self.count == len(self._allowed_prefixes):
81+
return "common"
82+
elif self.count == 1:
83+
return "unique"
84+
elif self.count > 0:
85+
return "nonunique"
86+
else:
87+
return "unknown"
88+
89+
90+
def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[MetadataColumn]]:
4291
"""
4392
Classify names into common, non-unique, and unique
4493
w.r.t. to the list of prefixes.
@@ -50,72 +99,21 @@ def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list
5099
- Unique columns are prefixed by modality names,
51100
and there is only one modality prefix
52101
for a column with a certain name.
53-
54-
E.g. {"mod1": ["annotation", "unique"], "mod2": ["annotation"]} will be classified
55-
into {"mod1": [{"name": "mod1:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"},
56-
{"name": "mod1:unique", "derived_name": "unique", "count": 1, "class": "unique"}}],
57-
"mod2": [{"name": "mod2:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}],
58-
}
59102
"""
60-
n_mod = len(names)
61-
res: dict[str, list[dict[str, str]]] = {}
103+
res: dict[str, list[MetadataColumn]] = {}
62104

63105
derived_name_counts = Counter()
64-
for prefix, names in names.items():
106+
for prefix, pnames in names.items():
65107
cres = []
66-
for name in names:
67-
cres.append(
68-
{
69-
"name": f"{prefix}:{name}",
70-
"derived_name": name,
71-
}
72-
)
108+
for name in pnames:
109+
cres.append(MetadataColumn(allowed_prefixes=names.keys(), prefix=prefix, name=name))
73110
derived_name_counts[name] += 1
74111
res[prefix] = cres
75112

76113
for prefix, names in res.items():
77114
for name_res in names:
78-
count = derived_name_counts[name_res["derived_name"]]
79-
name_res["count"] = count
80-
name_res["class"] = (
81-
"common" if count == n_mod else "unique" if count == 1 else "nonunique"
82-
)
83-
84-
return res
85-
86-
87-
def _classify_prefixed_columns(
88-
names: Sequence[str], prefixes: Sequence[str]
89-
) -> Sequence[dict[str, str]]:
90-
"""
91-
Classify names into common and prefixed
92-
w.r.t. to the list of prefixes.
93-
94-
- Common columns do not have modality prefixes.
95-
- Prefixed columns are prefixed by modality names.
96-
97-
E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified
98-
into [
99-
{"name": "global", "prefix": "", "derived_name": "global", "class": "common"},
100-
{"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"},
101-
{"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "class": "prefixed"},
102-
{"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"},
103-
]
104-
"""
105-
res: list[dict[str, str]] = []
106-
107-
for name in names:
108-
if len(name_split := name.split(":", 1)) < 2 or name_split[0] not in prefixes:
109-
res.append({"name": name, "prefix": "", "derived_name": name, "class": "common"})
110-
else:
111-
res.append(
112-
{
113-
"name": name,
114-
"prefix": name_split[0],
115-
"derived_name": name_split[1],
116-
"class": "prefixed",
117-
}
118-
)
115+
count = derived_name_counts[name_res.derived_name]
116+
name_res.count = count
119117

120118
return res
121119

0 commit comments

Comments
 (0)