Skip to content

Commit

Permalink
Merge branch 'main' into reduce-api-type-ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Jul 15, 2024
2 parents 6adf564 + 2036b99 commit 1fb6f1f
Show file tree
Hide file tree
Showing 17 changed files with 365 additions and 233 deletions.
4 changes: 2 additions & 2 deletions altair/_magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import IPython
from IPython.core import magic_arguments
import pandas as pd
from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe

from altair.vegalite import v5 as vegalite_v5

Expand Down Expand Up @@ -39,7 +39,7 @@ def _prepare_data(data, data_transformers):
"""Convert input data to data for use within schema"""
if data is None or isinstance(data, dict):
return data
elif isinstance(data, pd.DataFrame):
elif _is_pandas_dataframe(data):
if func := data_transformers.get():
data = func(data)
return data
Expand Down
12 changes: 6 additions & 6 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .core import (
infer_vegalite_type,
infer_vegalite_type_for_pandas,
infer_encoding_types,
sanitize_dataframe,
sanitize_arrow_table,
sanitize_pandas_dataframe,
sanitize_narwhals_dataframe,
parse_shorthand,
use_signature,
update_nested,
Expand All @@ -23,10 +23,10 @@
"Undefined",
"display_traceback",
"infer_encoding_types",
"infer_vegalite_type",
"infer_vegalite_type_for_pandas",
"parse_shorthand",
"sanitize_arrow_table",
"sanitize_dataframe",
"sanitize_narwhals_dataframe",
"sanitize_pandas_dataframe",
"spec_to_html",
"update_nested",
"use_signature",
Expand Down
13 changes: 9 additions & 4 deletions altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
Callable,
)

import narwhals.stable.v1 as nw

from altair.utils._importers import import_vegafusion
from altair.utils.core import DataFrameLike
from altair.utils.data import (
DataType,
ToValuesReturnType,
MaxRowsError,
SupportsGeoInterface,
)
from altair.utils.core import DataFrameLike
from altair.vegalite.data import default_data_transformer


if TYPE_CHECKING:
import pandas as pd
from narwhals.typing import IntoDataFrame
from vegafusion.runtime import ChartState # type: ignore

# Temporary storage for dataframes that have been extracted
Expand Down Expand Up @@ -60,14 +61,18 @@ def vegafusion_data_transformer(

@overload
def vegafusion_data_transformer(
data: dict | pd.DataFrame | SupportsGeoInterface, max_rows: int = ...
data: dict | IntoDataFrame | SupportsGeoInterface, max_rows: int = ...
) -> _VegaFusionReturnType: ...


def vegafusion_data_transformer(
data: DataType | None = None, max_rows: int = 100000
) -> Callable[..., Any] | _VegaFusionReturnType:
"""VegaFusion Data Transformer"""
# Vegafusion does not support Narwhals, so if `data` is a Narwhals
# object, we make sure to extract the native object and let Vegafusion handle it.
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data = nw.to_native(data, strict=False)
if data is None:
return vegafusion_data_transformer
elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface):
Expand Down
180 changes: 90 additions & 90 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
from operator import itemgetter

import jsonschema
import pandas as pd
import numpy as np
from pandas.api.types import infer_dtype
import narwhals.stable.v1 as nw
from narwhals.dependencies import is_pandas_dataframe, get_polars
from narwhals.typing import IntoDataFrame

from altair.utils.schemapi import SchemaBase, Undefined
from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand All @@ -43,11 +42,14 @@
if TYPE_CHECKING:
from types import ModuleType
import typing as t
from pandas.core.interchange.dataframe_protocol import Column as PandasColumn
import pyarrow as pa
from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType
from altair.utils._dfi_types import DataFrame as DfiDataFrame
from narwhals.typing import IntoExpr
import pandas as pd

V = TypeVar("V")
P = ParamSpec("P")
TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame)


@runtime_checkable
Expand Down Expand Up @@ -198,10 +200,7 @@ def __dataframe__(
]


InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"]


def infer_vegalite_type(
def infer_vegalite_type_for_pandas(
data: object,
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]:
"""
Expand All @@ -212,6 +211,9 @@ def infer_vegalite_type(
----------
data: object
"""
# This is safe to import here, as this function is only called on pandas input.
from pandas.api.types import infer_dtype

typ = infer_dtype(data, skipna=False)

if typ in {
Expand Down Expand Up @@ -297,13 +299,16 @@ def sanitize_geo_interface(geo: t.MutableMapping[Any, Any]) -> dict[str, Any]:


def numpy_is_subtype(dtype: Any, subtype: Any) -> bool:
# This is only called on `numpy` inputs, so it's safe to import it here.
import numpy as np

try:
return np.issubdtype(dtype, subtype)
except (NotImplementedError, TypeError):
return False


def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
def sanitize_pandas_dataframe(df: pd.DataFrame) -> pd.DataFrame:
"""Sanitize a DataFrame to prepare it for serialization.
* Make a copy
Expand All @@ -320,6 +325,11 @@ def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame:
* convert dedicated string column to objects and replace NaN with None
* Raise a ValueError for TimeDelta dtypes
"""
# This is safe to import here, as this function is only called on pandas input.
# NumPy is a required dependency of pandas so is also safe to import.
import pandas as pd
import numpy as np

df = df.copy()

if isinstance(df.columns, pd.RangeIndex):
Expand Down Expand Up @@ -429,30 +439,54 @@ def to_list_if_array(val):
return df


def sanitize_arrow_table(pa_table: pa.Table) -> pa.Table:
"""Sanitize arrow table for JSON serialization"""
import pyarrow as pa
import pyarrow.compute as pc

arrays = []
schema = pa_table.schema
for name in schema.names:
array = pa_table[name]
dtype_name = str(schema.field(name).type)
if dtype_name.startswith(("timestamp", "date")):
arrays.append(pc.strftime(array))
elif dtype_name.startswith("duration"):
def sanitize_narwhals_dataframe(
data: nw.DataFrame[TIntoDataFrame],
) -> nw.DataFrame[TIntoDataFrame]:
"""Sanitize narwhals.DataFrame for JSON serialization"""
schema = data.schema
columns: list[IntoExpr] = []
# See https://github.com/vega/altair/issues/1027 for why this is necessary.
local_iso_fmt_string = "%Y-%m-%dT%H:%M:%S"
for name, dtype in schema.items():
if dtype == nw.Date and nw.get_native_namespace(data) is get_polars():
# Polars doesn't allow formatting `Date` with time directives.
# The date -> datetime cast is extremely fast compared with `to_string`
columns.append(
nw.col(name).cast(nw.Datetime).dt.to_string(local_iso_fmt_string)
)
elif dtype == nw.Date:
columns.append(nw.col(name).dt.to_string(local_iso_fmt_string))
elif dtype == nw.Datetime:
columns.append(nw.col(name).dt.to_string(f"{local_iso_fmt_string}%.f"))
elif dtype == nw.Duration:
msg = (
f'Field "{name}" has type "{dtype_name}" which is '
f'Field "{name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
""
)
raise ValueError(msg)
else:
arrays.append(array)
columns.append(name)
return data.select(columns)


return pa.Table.from_arrays(arrays, names=schema.names)
def to_eager_narwhals_dataframe(data: IntoDataFrame) -> nw.DataFrame[Any]:
"""Wrap `data` in `narwhals.DataFrame`.
If `data` is not supported by Narwhals, but it is convertible
to a PyArrow table, then first convert to a PyArrow Table,
and then wrap in `narwhals.DataFrame`.
"""
data_nw = nw.from_native(data, eager_or_interchange_only=True)
if nw.get_level(data_nw) == "interchange":
# If Narwhals' support for `data`'s class is only metadata-level, then we
# use the interchange protocol to convert to a PyArrow Table.
from altair.utils.data import arrow_table_from_dfi_dataframe

pa_table = arrow_table_from_dfi_dataframe(data) # type: ignore[arg-type]
data_nw = nw.from_native(pa_table, eager_only=True)
return data_nw


def parse_shorthand(
Expand Down Expand Up @@ -498,6 +532,7 @@ def parse_shorthand(
Examples
--------
>>> import pandas as pd
>>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'],
... 'bar': [1, 2, 3, 4]})
Expand Down Expand Up @@ -537,7 +572,7 @@ def parse_shorthand(
>>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'}
True
"""
from altair.utils._importers import pyarrow_available
from altair.utils.data import is_data_type

if not shorthand:
return {}
Expand Down Expand Up @@ -597,39 +632,22 @@ def parse_shorthand(
attrs["type"] = "temporal"

# if data is specified and type is not, infer type from data
if "type" not in attrs:
if pyarrow_available() and data is not None and isinstance(data, DataFrameLike):
dfi = data.__dataframe__()
if "field" in attrs:
unescaped_field = attrs["field"].replace("\\", "")
if unescaped_field in dfi.column_names():
column = dfi.get_column_by_name(unescaped_field)
try:
attrs["type"] = infer_vegalite_type_for_dfi_column(column)
except (NotImplementedError, AttributeError, ValueError):
# Fall back to pandas-based inference.
# Note: The AttributeError catch is a workaround for
# https://github.com/pandas-dev/pandas/issues/55332
if isinstance(data, pd.DataFrame):
attrs["type"] = infer_vegalite_type(data[unescaped_field])
else:
raise

if isinstance(attrs["type"], tuple):
attrs["sort"] = attrs["type"][1]
attrs["type"] = attrs["type"][0]
elif isinstance(data, pd.DataFrame):
# Fallback if pyarrow is not installed or if pandas is older than 1.5
#
# Remove escape sequences so that types can be inferred for columns with special characters
if "field" in attrs and attrs["field"].replace("\\", "") in data.columns:
attrs["type"] = infer_vegalite_type(
data[attrs["field"].replace("\\", "")]
)
# ordered categorical dataframe columns return the type and sort order as a tuple
if isinstance(attrs["type"], tuple):
attrs["sort"] = attrs["type"][1]
attrs["type"] = attrs["type"][0]
if "type" not in attrs and is_data_type(data):
unescaped_field = attrs["field"].replace("\\", "")
data_nw = nw.from_native(data, eager_or_interchange_only=True)
schema = data_nw.schema
if unescaped_field in schema:
column = data_nw[unescaped_field]
if schema[unescaped_field] in {
nw.Object,
nw.Unknown,
} and is_pandas_dataframe(nw.to_native(data_nw)):
attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column))
else:
attrs["type"] = infer_vegalite_type_for_narwhals(column)
if isinstance(attrs["type"], tuple):
attrs["sort"] = attrs["type"][1]
attrs["type"] = attrs["type"][0]

# If an unescaped colon is still present, it's often due to an incorrect data type specification
# but could also be due to using a column name with ":" in it.
Expand All @@ -650,41 +668,23 @@ def parse_shorthand(
return attrs


def infer_vegalite_type_for_dfi_column(
column: Column | PandasColumn,
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]:
from pyarrow.interchange.from_dataframe import column_to_array

try:
kind = column.dtype[0]
except NotImplementedError as e:
# Edge case hack:
# dtype access fails for pandas column with datetime64[ns, UTC] type,
# but all we need to know is that its temporal, so check the
# error message for the presence of datetime64.
#
# See https://github.com/pandas-dev/pandas/issues/54239
if "datetime64" in e.args[0] or "timestamp" in e.args[0]:
return "temporal"
raise e

def infer_vegalite_type_for_narwhals(
column: nw.Series,
) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list]:
dtype = column.dtype
if (
kind == DtypeKind.CATEGORICAL
and column.describe_categorical["is_ordered"]
and column.describe_categorical["categories"] is not None
nw.is_ordered_categorical(column)
and not (categories := column.cat.get_categories()).is_empty()
):
# Treat ordered categorical column as Vega-Lite ordinal
categories_column = column.describe_categorical["categories"]
categories_array = column_to_array(categories_column)
return "ordinal", categories_array.to_pylist()
if kind in {DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL}:
return "ordinal", categories.to_list()
if dtype in {nw.String, nw.Categorical, nw.Boolean}:
return "nominal"
elif kind in {DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT}:
elif dtype.is_numeric():
return "quantitative"
elif kind == DtypeKind.DATETIME:
elif dtype in {nw.Datetime, nw.Date}:
return "temporal"
else:
msg = f"Unexpected DtypeKind: {kind}"
msg = f"Unexpected DtypeKind: {dtype}"
raise ValueError(msg)


Expand Down
Loading

0 comments on commit 1fb6f1f

Please sign in to comment.