diff --git a/muon/_core/plot.py b/muon/_core/plot.py index da8cab3..7d72f3e 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -22,7 +22,7 @@ def scatter( data: Union[AnnData, MuData], x: Optional[str] = None, y: Optional[str] = None, - color: Optional[Union[str, Sequence[str]]] = None, + color: Optional[str] = None, use_raw: Optional[bool] = None, layers: Optional[Union[str, Sequence[str]]] = None, **kwargs, @@ -42,8 +42,8 @@ def scatter( x coordinate y : Optional[str] y coordinate - color : Optional[Union[str, Sequence[str]]], optional (default: None) - Keys for variables or annotations of observations (.obs columns), + color : Optional[str], optional (default: None) + Key for variables or annotations of observations (.obs columns), or a hex colour specification. use_raw : Optional[bool], optional (default: None) Use `.raw` attribute of the modality where a feature (from `color`) is derived from. @@ -72,10 +72,10 @@ def scatter( if isinstance(color, str): color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs = pd.DataFrame({color: color_obs}) - color = [color] else: - # scanpy#311 / scanpy#1497 has to be fixed for this to work - color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) + raise TypeError("Expected color to be a string.") + + color_obs.index = data.obs_names obs = pd.concat([obs, color_obs], axis=1, ignore_index=False) @@ -86,14 +86,14 @@ def scatter( # and are now stored in .obs retval = sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs) if color is not None: - for col in color: - try: - data.uns[f"{col}_colors"] = ad.uns[f"{col}_colors"] - except KeyError: - pass + try: + data.uns[f"{color}_colors"] = ad.uns[f"{color}_colors"] + except KeyError: + pass return retval + # # Embedding #