Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gene_symbols argument in mu.pl.embedding #124

Open
racng opened this issue Aug 30, 2023 · 1 comment
Open

Enable gene_symbols argument in mu.pl.embedding #124

racng opened this issue Aug 30, 2023 · 1 comment
Labels
enhancement New feature or request

Comments

@racng
Copy link

racng commented Aug 30, 2023

Is your feature request related to a problem? Please describe.
sc.pl.embedding takes an arugument gene_symbols that specifies which column in adata.var to look for the color keys. This argument does not work in mu.pl.embedding. Based on the source code, a new adata is made with the color values stored in adata.obs and so it has no way to access adata.var.

Describe the solution you'd like
Redesign the way mu.pl.embedding generates the intermediate adata passed onto sc.pl.embedding. Perhaps use the sc.get.obs_df() function that also takes layers and gene_symbols as arguments?

Describe alternatives you've considered
User could write their own function to consolidate the basis and color keys into one adata and use sc.pl.embedding.

@racng racng added the enhancement New feature or request label Aug 30, 2023
@racng
Copy link
Author

racng commented Aug 31, 2023

Here is a proposed solution that I tested:

def get_uns_colors(data: Union[AnnData, MuData], key: str):
    uns_key = key + '_colors'
    if uns_key in data.uns:
        return data.uns[uns_key]

def to_dtype_list(x, dtype, n, none=True):
    if not isinstance(dtype, Iterable) or isinstance(dtype, str):
        dtypes = [dtype]  
    if none:
        dtypes.append(type(None))
    if any([isinstance(x, t) for t in dtypes]):
        # Return as list of repeated value
        return [x] * n
    elif isinstance(x, Iterable):
        # Check types
        assert(all([any([isinstance(y, t) for t in dtypes]) for y in x]))
        # Check length
        assert(len(x) == n)
        # Return list unchanged
        return x


# Rewrite muon.pl.embedding to use gene_symbols
def embedding(
    data: MuData,
    basis: str,
    color: Optional[Union[str, Sequence[str]]] = None,
    layer: Optional[Union[str, Sequence[str]]] = None,
    gene_symbols: Optional[Union[str, Sequence[str]]] = None,
    use_raw: Optional[Union[bool, Sequence[bool]]] = False,
    **kwargs
):
    if isinstance(data, AnnData):
        return sc.pl.embedding(
            data, basis=basis, color=color, use_raw=use_raw, layer=layer, 
            gene_symbols=gene_symbols, **kwargs
        )
    if basis not in data.obsm:
        if "X_" + basis in data.obsm:
            basis = 'X_' + basis
            
    #  Determine basis
    if basis in data.obsm:
        adata = data
        basis_mod = basis
    else:
        try:
            mod, basis_mod = basis.split(":")
        except ValueError:
            raise ValueError(f"Basis {basis} is not present in the MuData object (.obsm)")
        
        if mod not in data.mod:
            raise ValueError(
                f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
            )

        adata = data.mod[mod]
        if basis_mod not in adata.obsm:
            if "X_" + basis_mod in adata.obsm:
                basis_mod = "X_" + basis_mod
            elif len(adata.obsm) > 0:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with embeddings {', '.join(adata.obsm)}"
                )
            else:
                raise ValueError(
                    f"Basis {basis_mod} is not present in the modality {mod} with no embeddings"
                )
    
    # Subset joint obs to embedding observations
    obs = data.obs.loc[adata.obs.index.values].copy()

    if color is None:
        ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp)
        return sc.pl.embedding(ad, basis=basis_mod, **kwargs)

    # Some `color` has been provided
    if isinstance(color, str):
        keys = color = [color]
    elif isinstance(color, Iterable):
        keys = color
    else:
        raise TypeError("Expected color to be a string or an iterable.")
    
    # Convert keyword args to lists
    n = len(keys)
    ls = to_dtype_list(layer, str, n, none=True)
    gs = to_dtype_list(gene_symbols, str, n, none=True)
    rs = to_dtype_list(use_raw, bool, n, none=True)

    # Parse features
    mod2keys = {m: defaultdict(list) for m in data.mod.keys()}
    joint_keys = []
    uns = dict()
    for key, layer, gene_symbols, use_raw in zip(keys, ls, gs, rs):
        if key is None:
            joint_keys.append(key)
            continue

        # Key in joint obs
        if key in obs:
            joint_keys.append(key)

            # Look for color palette 
            palette = get_uns_colors(data, key)
            if palette is not None:
                uns[key + '_colors'] = palette
            continue
        
        # Key in modality
        try:
            mod, key_mod = key.split(":")

        except ValueError:
            raise ValueError(f"Key {key} is not present in the MuData object (.obs)")
        
        try:
            mod2keys[mod][(layer, gene_symbols, use_raw)].append(key_mod)
        except ValueError:
            raise ValueError(
                f"Modality {mod} is not present in the MuData object with modalities {', '.join(data.mod)}"
            )

        # Look for color palette 
        palette = get_uns_colors(data.mod[mod], key_mod)
        if palette is not None:
            uns[f"{mod}:{key_mod}_colors"] = palette
    
    # Add features for each modality to obs
    mod_keys = []
    for m in mod2keys:
        # Loop through unique combinations of args
        for args in mod2keys[m]:
            # Get features as dataframe
            layer, gene_symbols, use_raw = args
            df = sc.get.obs_df(data.mod[m], keys=mod2keys[m][args], layer=layer, 
                gene_symbols=gene_symbols, use_raw=use_raw)
            labels = []
            if use_raw: 
                labels.append('use_raw')
            if layer is not None:
                labels.append(layer)
            cond = '_'.join(labels)
            not_obs = [x not in data.mod[m].obs for x in mod2keys[m][args]]
            df.columns = f"{m}:" + df.columns 
            cols = df.columns.values
            cols[not_obs] += '\n' + cond
            df.columns = cols
            mod_keys += cols.tolist()
            # Merge with joint obs
            obs = obs.merge(df, left_index=True, right_index=True, how='left')

    # Plot
    ad = AnnData(obs=obs, obsm=adata.obsm, uns=uns)
    retval = sc.pl.embedding(ad, basis=basis_mod, color=joint_keys+mod_keys, **kwargs)
    
    # Update color palettes for joint keys
    for key in joint_keys:
        try:
            data.uns[f"{key}_colors"] = ad.uns[f"{key}_colors"]
        except KeyError:
            pass

    # Update color palettes for modality keys
    for m in mod2keys:
        for cond in mod2keys[m]:
            for key in mod2keys[m][cond]:
                try:
                    data.mod[m].uns[f"{key}_colors"] = ad.uns[f"{m}:{key}_colors"]
                except KeyError:
                    pass
    return retval

Example usuage:

sw.pl.embedding(mdata, 'rna:umap', [
	'prot:CD4', 'prot:CD4', 'rna:CD4', 'rna:sample'], 
	gene_symbols=['symbols', 'symbols', None, None], 
	layer=['raw', 'cellbender', None, None])

3077a023-96fe-4102-a263-d1489b13f9d1

Checking color palette updated:

mdata.uns
# Output:
# {'rna:sample_colors': ['#1f77b4', '#ff7f0e']}

There are some non-ideal behaviors that could be fixed if needed but it doesn't affect the functionality:

  • The order of color keys plotted are reordered and grouped by modalites
  • The color palette for a modality-specific categorial variable mod:key is added to mdata.uns['mod:key_colors'] instead of mdata[mod].uns['key_colors']. This is because mod:key could be found in mdata.obs and so it was treated as a joint obs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant