diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index cbfaacf79..e700b48b8 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -1071,10 +1071,7 @@ def _default_wrapper_classes(cls) -> Iterator[type[SchemaBase]]: @classmethod def from_dict( - cls: type[TSchemaBase], - dct: dict[str, Any], - validate: bool = True, - _wrapper_classes: Iterable[type[SchemaBase]] | None = None, + cls: type[TSchemaBase], dct: dict[str, Any], validate: bool = True, **kwds: Any ) -> TSchemaBase: """Construct class from a dictionary representation @@ -1101,9 +1098,9 @@ def from_dict( """ if validate: cls.validate(dct) - if _wrapper_classes is None: - _wrapper_classes = cls._default_wrapper_classes() - converter = _FromDict(_wrapper_classes) + converter = _FromDict( + kwds.pop("_wrapper_classes", cls._default_wrapper_classes()) + ) return converter.from_dict(dct, cls) @classmethod diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 23138613f..5d23fe4d9 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -6,10 +6,9 @@ import json import jsonschema import itertools -from typing import Union, cast, Any, Iterable, Literal, IO, TYPE_CHECKING +from typing import Union, Any, Iterable, Literal, IO, TYPE_CHECKING from typing_extensions import TypeAlias import typing - from .schema import core, channels, mixins, Undefined, SCHEMA_URL from altair.utils import Optional @@ -26,9 +25,11 @@ from ...utils.data import DataType, is_data_type as _is_data_type from ...utils.deprecation import AltairDeprecationWarning + if TYPE_CHECKING: - from ...utils.core import DataFrameLike import sys + from ...utils.core import DataFrameLike + from pathlib import Path if sys.version_info >= (3, 13): @@ -74,7 +75,6 @@ AnyMark, Step, RepeatRef, - NonNormalizedSpec, UrlData, SequenceGenerator, GraticuleGenerator, @@ -91,6 +91,10 @@ SelectionResolution_T, SingleDefUnitChannel_T, StackOffset_T, + ProjectionType_T, + AggregateOp_T, + MultiTimeUnit_T, + SingleTimeUnit_T, ) ChartDataType: TypeAlias = Optional[Union[DataType, core.Data, str, core.Generator]] @@ -120,12 +124,12 @@ def _dataset_name(values: dict | list | InlineDataset) -> str: return "data-" + hsh -def _consolidate_data(data, context): +def _consolidate_data(data: Any, context: Any) -> Any: """If data is specified inline, then move it to context['datasets'] This function will modify context in-place, and return a new version of data """ - values = Undefined + values: Any = Undefined kwds = {} if isinstance(data, core.InlineData): @@ -271,11 +275,8 @@ def to_dict(self) -> dict[str, str | dict]: if self.param_type == "variable": return {"expr": self.name} elif self.param_type == "selection": - return { - "param": ( - self.name.to_dict() if hasattr(self.name, "to_dict") else self.name - ) - } + nm: Any = self.name + return {"param": nm.to_dict() if hasattr(nm, "to_dict") else nm} else: msg = f"Unrecognized parameter type: {self.param_type}" raise ValueError(msg) @@ -368,9 +369,11 @@ def _from_expr(self, expr) -> SelectionExpression: def check_fields_and_encodings(parameter: Parameter, field_name: str) -> bool: + if isinstance(parameter.param, (core.UndefinedType, core.VariableParameter)): + return False for prop in ["fields", "encodings"]: try: - if field_name in getattr(parameter.param.select, prop): # type: ignore[union-attr] + if field_name in getattr(parameter.param.select, prop): return True except (AttributeError, TypeError): pass @@ -936,10 +939,23 @@ def condition( # Top-level objects +def _top_schema_base(obj: Any, /): # -> + """Enforces an intersection type w/ `SchemaBase` & `TopLevelMixin` objects. + + Use for instance methods. + """ + if isinstance(obj, core.SchemaBase) and isinstance(obj, TopLevelMixin): + return obj + else: + msg = f"{type(obj).__name__!r} does not derive from {type(core.SchemaBase).__name__!r}" + raise TypeError(msg) + + class TopLevelMixin(mixins.ConfigMethodMixin): """Mixin for top-level chart objects such as Chart, LayeredChart, etc.""" _class_is_valid_at_instantiation: bool = False + data: Any def to_dict( self, @@ -1003,9 +1019,7 @@ def to_dict( context.setdefault("datasets", {}) is_top_level = context.get("top_level", True) - # TopLevelMixin instance does not necessarily have copy defined but due to how - # Altair is set up this should hold. Too complex to type hint right now - copy = self.copy(deep=False) # type: ignore[attr-defined] + copy = _top_schema_base(self).copy(deep=False) original_data = getattr(copy, "data", Undefined) copy.data = _prepare_data(original_data, context) @@ -1018,7 +1032,7 @@ def to_dict( # TopLevelMixin instance does not necessarily have to_dict defined # but due to how Altair is set up this should hold. # Too complex to type hint right now - vegalite_spec = super(TopLevelMixin, copy).to_dict( # type: ignore[misc] + vegalite_spec: Any = super(TopLevelMixin, copy).to_dict( # type: ignore[misc] validate=validate, ignore=ignore, context=dict(context, pre_transform=False) ) @@ -1029,11 +1043,14 @@ def to_dict( vegalite_spec["$schema"] = SCHEMA_URL # apply theme from theme registry - the_theme = themes.get() - # Use assert to tell type checkers that it is not None. Holds true - # as there is always a default theme set when importing Altair - assert the_theme is not None - vegalite_spec = utils.update_nested(the_theme(), vegalite_spec, copy=True) + if theme := themes.get(): + vegalite_spec = utils.update_nested(theme(), vegalite_spec, copy=True) + else: + msg = ( + f"Expected a theme to be set but got {None!r}.\n" + f"Call `themes.enable('default')` to reset the `ThemeRegistry`." + ) + raise TypeError(msg) # update datasets if context["datasets"]: @@ -1271,7 +1288,7 @@ def save( from ...utils.save import save - kwds = dict( + kwds: dict[str, Any] = dict( chart=self, fp=fp, format=format, @@ -1315,11 +1332,14 @@ def __and__(self, other) -> VConcatChart: # Too difficult to type check this return vconcat(self, other) - def __or__(self, other) -> HConcatChart: + def __or__(self, other) -> HConcatChart | ConcatChart: if not isinstance(other, TopLevelMixin): msg = "Only Chart objects can be concatenated." raise ValueError(msg) - return hconcat(self, other) + elif isinstance(self, ConcatChart): + return concat(self, other) + else: + return hconcat(self, other) def repeat( self, @@ -1385,8 +1405,7 @@ def properties(self, **kwargs) -> Self: Argument names and types are the same as class initialization. """ - # ignore type as copy comes from another class for subclasses of TopLevelMixin - copy = self.copy(deep=False) # type: ignore[attr-defined] + copy = _top_schema_base(self).copy(deep=False) for key, val in kwargs.items(): if key == "selection" and isinstance(val, Parameter): # TODO: Can this be removed @@ -1395,15 +1414,15 @@ def properties(self, **kwargs) -> Self: else: # Don't validate data, because it hasn't been processed. if key != "data": - # ignore type as validate_property comes from SchemaBase, - # not from TopLevelMixin - self.validate_property(key, val) # type: ignore[attr-defined] + _top_schema_base(self).validate_property(key, val) setattr(copy, key, val) - return copy + return typing.cast("Self", copy) def project( self, - type: Optional[str | ProjectionType | ExprRef | Parameter] = Undefined, + type: Optional[ + ProjectionType_T | ProjectionType | ExprRef | Parameter + ] = Undefined, center: Optional[list[float] | Vector2number | ExprRef | Parameter] = Undefined, clipAngle: Optional[float | ExprRef | Parameter] = Undefined, clipExtent: Optional[ @@ -1546,20 +1565,18 @@ def project( spacing=spacing, tilt=tilt, translate=translate, - # Ignore as we type here `type` as a str but in core.Projection - # it's a Literal with all options - type=type, # type: ignore[arg-type] + type=type, **kwds, ) return self.properties(projection=projection) def _add_transform(self, *transforms: Transform) -> Self: """Copy the chart and add specified transforms to chart.transform""" - copy = self.copy(deep=["transform"]) # type: ignore[attr-defined] + copy = _top_schema_base(self).copy(deep=["transform"]) if copy.transform is Undefined: copy.transform = [] copy.transform.extend(transforms) - return copy + return typing.cast("Self", copy) def transform_aggregate( self, @@ -1755,21 +1772,20 @@ def transform_calculate( -------- alt.CalculateTransform : underlying transform object """ + calc_as: Optional[str | FieldName | Expr | Expression] if as_ is Undefined: - # Ignoring assignment error as passing 'as' as a keyword argument is - # an edge case and it's not worth changing the type annotation - # in this function to account for it as it could be confusing to - # users. - as_ = kwargs.pop("as", Undefined) # type: ignore[assignment] + calc_as = kwargs.pop("as", Undefined) elif "as" in kwargs: msg = "transform_calculate: both 'as_' and 'as' passed as arguments." raise ValueError(msg) - if as_ is not Undefined or calculate is not Undefined: - dct = {"as": as_, "calculate": calculate} - self = self._add_transform(core.CalculateTransform(**dct)) # type: ignore[arg-type] - for as_, calculate in kwargs.items(): - dct = {"as": as_, "calculate": calculate} - self = self._add_transform(core.CalculateTransform(**dct)) # type: ignore[arg-type] + else: + calc_as = as_ + if calc_as is not Undefined or calculate is not Undefined: + dct: dict[str, Any] = {"as": calc_as, "calculate": calculate} + self = self._add_transform(core.CalculateTransform(**dct)) + for a, calculate in kwargs.items(): + dct = {"as": a, "calculate": calculate} + self = self._add_transform(core.CalculateTransform(**dct)) return self def transform_density( @@ -2018,13 +2034,13 @@ def transform_filter( returns chart to allow for chaining """ if isinstance(filter, Parameter): - new_filter: dict[str, bool | str] = {"param": filter.name} + new_filter: dict[str, Any] = {"param": filter.name} if "empty" in kwargs: new_filter["empty"] = kwargs.pop("empty") elif isinstance(filter.empty, bool): new_filter["empty"] = filter.empty - filter = new_filter # type: ignore[assignment] - return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) # type: ignore[arg-type] + filter = new_filter + return self._add_transform(core.FilterTransform(filter=filter, **kwargs)) def transform_flatten( self, @@ -2186,7 +2202,7 @@ def transform_pivot( value: str | FieldName, groupby: Optional[list[str | FieldName]] = Undefined, limit: Optional[int] = Undefined, - op: Optional[str | AggregateOp] = Undefined, + op: Optional[AggregateOp_T | AggregateOp] = Undefined, ) -> Self: """Add a :class:`PivotTransform` to the chart. @@ -2206,7 +2222,7 @@ def transform_pivot( The default ( ``0`` ) applies no limit. The pivoted ``pivot`` names are sorted in ascending order prior to enforcing the limit. **Default value:** ``0`` - op : string + op : Literal['argmax', 'argmin', 'average', 'count', 'distinct', 'max', 'mean', 'median', 'min', 'missing', 'product', 'q1', 'q3', 'ci0', 'ci1', 'stderr', 'stdev', 'stdevp', 'sum', 'valid', 'values', 'variance', 'variancep', 'exponential', 'exponentialb'] The aggregation operation to apply to grouped ``value`` field values. **Default value:** ``sum`` @@ -2222,13 +2238,7 @@ def transform_pivot( """ return self._add_transform( core.PivotTransform( - # Ignore as we type here `op` as a str but in core.PivotTransform - # it's a Literal with all options - pivot=pivot, - value=value, - groupby=groupby, - limit=limit, - op=op, # type: ignore[arg-type] + pivot=pivot, value=value, groupby=groupby, limit=limit, op=op ) ) @@ -2411,7 +2421,7 @@ def transform_timeunit( self, as_: Optional[str | FieldName] = Undefined, field: Optional[str | FieldName] = Undefined, - timeUnit: Optional[str | TimeUnit] = Undefined, + timeUnit: Optional[MultiTimeUnit_T | SingleTimeUnit_T | TimeUnit] = Undefined, **kwargs: str, ) -> Self: """ @@ -2473,8 +2483,8 @@ def transform_timeunit( msg = "transform_timeunit: both 'as_' and 'as' passed as arguments." raise ValueError(msg) if as_ is not Undefined: - dct = {"as": as_, "timeUnit": timeUnit, "field": field} - self = self._add_transform(core.TimeUnitTransform(**dct)) # type: ignore[arg-type] + dct: dict[str, Any] = {"as": as_, "timeUnit": timeUnit, "field": field} + self = self._add_transform(core.TimeUnitTransform(**dct)) for as_, shorthand in kwargs.items(): dct = utils.parse_shorthand( shorthand, @@ -2487,7 +2497,7 @@ def transform_timeunit( if "timeUnit" not in dct: msg = f"'{shorthand}' must include a valid timeUnit" raise ValueError(msg) - self = self._add_transform(core.TimeUnitTransform(**dct)) # type: ignore[arg-type] + self = self._add_transform(core.TimeUnitTransform(**dct)) return self def transform_window( @@ -2566,11 +2576,10 @@ def transform_window( }) """ + w = window if isinstance(window, list) else [] if kwargs: - if window is Undefined: - window = [] for as_, shorthand in kwargs.items(): - kwds = {"as": as_} + kwds: dict[str, Any] = {"as": as_} kwds.update( utils.parse_shorthand( shorthand, @@ -2580,13 +2589,11 @@ def transform_window( parse_types=False, ) ) - assert isinstance(window, list) - # Ignore as core.WindowFieldDef has a Literal type hint with all options - window.append(core.WindowFieldDef(**kwds)) # type: ignore[arg-type] + w.append(core.WindowFieldDef(**kwds)) return self._add_transform( core.WindowTransform( - window=window, + window=w or Undefined, frame=frame, groupby=groupby, ignorePeers=ignorePeers, @@ -2606,7 +2613,8 @@ def _repr_mimebundle_(self, include=None, exclude=None): utils.display_traceback(in_ipython=True) return {} else: - return renderers.get()(dct) + if renderer := renderers.get(): + return renderer(dct) def display( self, @@ -2728,7 +2736,7 @@ def _set_resolve(self, **kwargs): if not hasattr(self, "resolve"): msg = f"{self.__class__} object has no attribute " "'resolve'" raise ValueError(msg) - copy = self.copy(deep=["resolve"]) + copy = _top_schema_base(self).copy(deep=["resolve"]) if copy.resolve is Undefined: copy.resolve = core.Resolve() for key, val in kwargs.items(): @@ -2736,19 +2744,27 @@ def _set_resolve(self, **kwargs): return copy @utils.use_signature(core.AxisResolveMap) - def resolve_axis(self, *args, **kwargs) -> Self: - return self._set_resolve(axis=core.AxisResolveMap(*args, **kwargs)) + def resolve_axis(self, *args, **kwargs): + return _top_schema_base(self)._set_resolve( + axis=core.AxisResolveMap(*args, **kwargs) + ) @utils.use_signature(core.LegendResolveMap) - def resolve_legend(self, *args, **kwargs) -> Self: - return self._set_resolve(legend=core.LegendResolveMap(*args, **kwargs)) + def resolve_legend(self, *args, **kwargs): + return _top_schema_base(self)._set_resolve( + legend=core.LegendResolveMap(*args, **kwargs) + ) @utils.use_signature(core.ScaleResolveMap) - def resolve_scale(self, *args, **kwargs) -> Self: - return self._set_resolve(scale=core.ScaleResolveMap(*args, **kwargs)) + def resolve_scale(self, *args, **kwargs): + return _top_schema_base(self)._set_resolve( + scale=core.ScaleResolveMap(*args, **kwargs) + ) class _EncodingMixin(channels._EncodingMixin): + data: Any + def facet( self, facet: Optional[str | Facet] = Undefined, @@ -2792,9 +2808,10 @@ def facet( if facet_specified and rowcol_specified: msg = "facet argument cannot be combined with row/column argument." raise ValueError(msg) + self = _top_schema_base(self) if data is Undefined: - if self.data is Undefined: # type: ignore[has-type] + if self.data is Undefined: msg = ( "Facet charts require data to be specified at the top level. " "If you are trying to facet layered or concatenated charts, " @@ -2802,17 +2819,16 @@ def facet( "or specify the data inside the facet method instead." ) raise ValueError(msg) - # ignore type as copy comes from another class - self = self.copy(deep=False) # type: ignore[attr-defined] - data, self.data = self.data, Undefined # type: ignore[has-type] + self = _top_schema_base(self).copy(deep=False) + data, self.data = self.data, Undefined if facet_specified: - if isinstance(facet, str): - facet = channels.Facet(facet) + f = channels.Facet(facet) if isinstance(facet, str) else facet else: - facet = FacetMapping(row=row, column=column) + r: Any = row + f = FacetMapping(row=r, column=column) - return FacetChart(spec=self, facet=facet, data=data, columns=columns, **kwargs) + return FacetChart(spec=self, facet=f, data=data, columns=columns, **kwargs) class Chart( @@ -2903,7 +2919,9 @@ def _get_name(cls) -> str: return f"view_{cls._counter}" @classmethod - def from_dict(cls, dct: dict, validate: bool = True) -> SchemaBase: # type: ignore[override] # Not the same signature as SchemaBase.from_dict. Would ideally be aligned in the future + def from_dict( + cls: type[_TSchemaBase], dct: dict[str, Any], validate: bool = True, **kwds: Any + ) -> _TSchemaBase: """Construct class from a dictionary representation Parameters @@ -2923,19 +2941,16 @@ def from_dict(cls, dct: dict, validate: bool = True) -> SchemaBase: # type: ign jsonschema.ValidationError : if validate=True and dct does not conform to the schema """ - for class_ in TopLevelMixin.__subclasses__(): - if class_ is Chart: - class_ = cast(Any, super()) + _tp: Any + for tp in TopLevelMixin.__subclasses__(): + _tp = super() if tp is Chart else tp try: - # TopLevelMixin classes don't necessarily have from_dict defined - # but all classes which are used here have due to how Altair is - # designed. Too complex to type check right now. - return class_.from_dict(dct, validate=validate) # type: ignore[attr-defined] + return _tp.from_dict(dct, validate=validate) except jsonschema.ValidationError: pass # As a last resort, try using the Root vegalite object - return core.Root.from_dict(dct, validate) + return typing.cast(_TSchemaBase, core.Root.from_dict(dct, validate)) def to_dict( self, @@ -2981,7 +2996,7 @@ def to_dict( # No data specified here or in parent: inject empty data # for easier specification of datum encodings. copy = self.copy(deep=False) - copy.data = core.InlineData(values=[{}]) # type: ignore[assignment] + copy.data = core.InlineData(values=[{}]) return super(Chart, copy).to_dict( validate=validate, format=format, ignore=ignore, context=context ) @@ -3061,7 +3076,7 @@ def interactive( return self.add_params(selection_interval(bind="scales", encodings=encodings)) -def _check_if_valid_subspec(spec: dict | SchemaBase, classname: str) -> None: +def _check_if_valid_subspec(spec: Any, classname: str) -> None: """Check if the spec is a valid sub-spec. If it is not, then raise a ValueError @@ -3092,11 +3107,16 @@ def _get(spec, attr): else: return spec.get(attr, Undefined) + def _get_any(spec: dict | SchemaBase, *attrs: str) -> bool: + return any(_get(spec, attr) is not Undefined for attr in attrs) + + base_msg = "charts cannot be layered. Instead, layer the charts before" + encoding = _get(spec, "encoding") if encoding is not Undefined: for channel in ["row", "column", "facet"]: if _get(encoding, channel) is not Undefined: - msg = "Faceted charts cannot be layered. Instead, layer the charts before faceting." + msg = f"Faceted {base_msg} faceting." raise ValueError(msg) if isinstance(spec, (Chart, LayerChart)): return @@ -3104,23 +3124,15 @@ def _get(spec, attr): if not isinstance(spec, (core.SchemaBase, dict)): msg = "Only chart objects can be layered." raise ValueError(msg) - if _get(spec, "facet") is not Undefined: - msg = "Faceted charts cannot be layered. Instead, layer the charts before faceting." - raise ValueError(msg) if isinstance(spec, FacetChart) or _get(spec, "facet") is not Undefined: - msg = "Faceted charts cannot be layered. Instead, layer the charts before faceting." + msg = f"Faceted {base_msg} faceting." raise ValueError(msg) if isinstance(spec, RepeatChart) or _get(spec, "repeat") is not Undefined: - msg = "Repeat charts cannot be layered. Instead, layer the charts before repeating." - raise ValueError(msg) - if isinstance(spec, ConcatChart) or _get(spec, "concat") is not Undefined: - msg = "Concatenated charts cannot be layered. Instead, layer the charts before concatenating." - raise ValueError(msg) - if isinstance(spec, HConcatChart) or _get(spec, "hconcat") is not Undefined: - msg = "Concatenated charts cannot be layered. Instead, layer the charts before concatenating." + msg = f"Repeat {base_msg} repeating." raise ValueError(msg) - if isinstance(spec, VConcatChart) or _get(spec, "vconcat") is not Undefined: - msg = "Concatenated charts cannot be layered. Instead, layer the charts before concatenating." + _concat = ConcatChart, HConcatChart, VConcatChart + if isinstance(spec, _concat) or _get_any(spec, "concat", "hconcat", "vconcat"): + msg = f"Concatenated {base_msg} concatenating." raise ValueError(msg) @@ -3284,16 +3296,14 @@ def __init__(self, data=Undefined, concat=(), columns=Undefined, **kwargs): self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) - # Too difficult to fix override error - def __ior__(self, other: core.NonNormalizedSpec) -> Self: # type: ignore[override] + def __ior__(self, other) -> Self: _check_if_valid_subspec(other, "ConcatChart") self.concat.append(other) self.data, self.concat = _combine_subchart_data(self.data, self.concat) self.params, self.concat = _combine_subchart_params(self.params, self.concat) return self - # Too difficult to fix override error - def __or__(self, other: NonNormalizedSpec) -> Self: # type: ignore[override] + def __or__(self, other) -> Self: copy = self.copy(deep=["concat"]) copy |= other return copy @@ -3368,7 +3378,7 @@ def add_selection(self, *selections) -> Self: def concat(*charts, **kwargs) -> ConcatChart: """Concatenate charts horizontally""" - return ConcatChart(concat=charts, **kwargs) + return ConcatChart(concat=charts, **kwargs) # pyright: ignore class HConcatChart(TopLevelMixin, core.TopLevelHConcatSpec): @@ -3383,14 +3393,14 @@ def __init__(self, data=Undefined, hconcat=(), **kwargs): self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) - def __ior__(self, other: NonNormalizedSpec) -> Self: + def __ior__(self, other) -> Self: _check_if_valid_subspec(other, "HConcatChart") self.hconcat.append(other) self.data, self.hconcat = _combine_subchart_data(self.data, self.hconcat) self.params, self.hconcat = _combine_subchart_params(self.params, self.hconcat) return self - def __or__(self, other: NonNormalizedSpec) -> Self: + def __or__(self, other) -> Self: copy = self.copy(deep=["hconcat"]) copy |= other return copy @@ -3465,7 +3475,7 @@ def add_selection(self, *selections) -> Self: def hconcat(*charts, **kwargs) -> HConcatChart: """Concatenate charts horizontally""" - return HConcatChart(hconcat=charts, **kwargs) + return HConcatChart(hconcat=charts, **kwargs) # pyright: ignore class VConcatChart(TopLevelMixin, core.TopLevelVConcatSpec): @@ -3480,14 +3490,14 @@ def __init__(self, data=Undefined, vconcat=(), **kwargs): self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) - def __iand__(self, other: NonNormalizedSpec) -> Self: + def __iand__(self, other) -> Self: _check_if_valid_subspec(other, "VConcatChart") self.vconcat.append(other) self.data, self.vconcat = _combine_subchart_data(self.data, self.vconcat) self.params, self.vconcat = _combine_subchart_params(self.params, self.vconcat) return self - def __and__(self, other: NonNormalizedSpec) -> Self: + def __and__(self, other) -> Self: copy = self.copy(deep=["vconcat"]) copy &= other return copy @@ -3564,7 +3574,7 @@ def add_selection(self, *selections) -> Self: def vconcat(*charts, **kwargs) -> VConcatChart: """Concatenate charts vertically""" - return VConcatChart(vconcat=charts, **kwargs) + return VConcatChart(vconcat=charts, **kwargs) # pyright: ignore class LayerChart(TopLevelMixin, _EncodingMixin, core.TopLevelLayerSpec): @@ -3683,7 +3693,7 @@ def add_selection(self, *selections) -> Self: def layer(*charts, **kwargs) -> LayerChart: """layer multiple charts""" - return LayerChart(layer=charts, **kwargs) + return LayerChart(layer=charts, **kwargs) # pyright: ignore class FacetChart(TopLevelMixin, core.TopLevelFacetSpec): @@ -3832,7 +3842,7 @@ def _needs_name(subchart): # Convert SelectionParameters to TopLevelSelectionParameters with a views property. -def _prepare_to_lift(param): +def _prepare_to_lift(param: Any) -> Any: param = param.copy() if isinstance(param, core.VariableParameter): diff --git a/pyproject.toml b/pyproject.toml index e0b31c1ac..a4039604f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -339,6 +339,7 @@ addopts = ["--numprocesses=logical"] [tool.mypy] warn_unused_ignores = true +pretty = true [[tool.mypy.overrides]] module = [ diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 9c5081e2c..d62f8bb6d 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -1069,10 +1069,7 @@ def _default_wrapper_classes(cls) -> Iterator[type[SchemaBase]]: @classmethod def from_dict( - cls: type[TSchemaBase], - dct: dict[str, Any], - validate: bool = True, - _wrapper_classes: Iterable[type[SchemaBase]] | None = None, + cls: type[TSchemaBase], dct: dict[str, Any], validate: bool = True, **kwds: Any ) -> TSchemaBase: """Construct class from a dictionary representation @@ -1099,9 +1096,9 @@ def from_dict( """ if validate: cls.validate(dct) - if _wrapper_classes is None: - _wrapper_classes = cls._default_wrapper_classes() - converter = _FromDict(_wrapper_classes) + converter = _FromDict( + kwds.pop("_wrapper_classes", cls._default_wrapper_classes()) + ) return converter.from_dict(dct, cls) @classmethod