Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-238 Add Support for nested ObjectIDs in polars conversion #220

Merged
merged 6 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,36 +295,38 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs):
)


def _cast_away_extension_types_on_array(array: pa.Array) -> pa.Array:
"""Return an Array where ExtensionTypes have been cast to their base pyarrow types"""
if isinstance(array.type, pa.ExtensionType):
return array.cast(array.type.storage_type)
# elif pa.types.is_struct(field.type):
# ...
# elif pa.types.is_list(field.type):
# ...
return array


def _cast_away_extension_types_on_table(table: pa.Table) -> pa.Table:
"""Given arrow_table that may ExtensionTypes, cast these to the base pyarrow types"""
# Convert all fields in the Arrow table
converted_fields = [
_cast_away_extension_types_on_array(table.column(i)) for i in range(table.num_columns)
]
# Reconstruct the Arrow table
return pa.Table.from_arrays(converted_fields, names=table.column_names)


def _arrow_to_polars(arrow_table):
def _cast_away_extension_type(field: pa.field) -> pa.field:
if isinstance(field.type, pa.ExtensionType):
field_without_extension = pa.field(field.name, field.type.storage_type)
elif isinstance(field.type, pa.StructType):
field_without_extension = pa.field(
field.name,
pa.struct([_cast_away_extension_type(nested_field) for nested_field in field.type]),
)
elif isinstance(field.type, pa.ListType):
field_without_extension = pa.field(
field.name, pa.list_(_cast_away_extension_type(field.type.value_field))
)
else:
field_without_extension = field

return field_without_extension


def _arrow_to_polars(arrow_table: pa.Table):
"""Helper function that converts an Arrow Table to a Polars DataFrame.

Note: Polars lacks ExtensionTypes. We cast them to their base arrow classes.
"""
if pl is None:
msg = "polars is not installed. Try pip install polars."
raise ValueError(msg)
arrow_table_without_extensions = _cast_away_extension_types_on_table(arrow_table)

schema_without_extensions = pa.schema(
[_cast_away_extension_type(field) for field in arrow_table.schema]
)
arrow_table_without_extensions = arrow_table.cast(schema_without_extensions)

return pl.from_arrow(arrow_table_without_extensions)


Expand Down
8 changes: 5 additions & 3 deletions bindings/python/test/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def test_arrow_to_polars(self):
"str": [str(i) for i in range(2)],
"int": [i for i in range(2)],
"bool": [True, False],
"struct": [{"objId": bson.ObjectId().binary, "str1": str(i)} for i in range(2)],
"list": [[str(i), str(i + 1)] for i in range(2)],
"Binary": [b"1", b"23"],
"ObjectId": [bson.ObjectId().binary, bson.ObjectId().binary],
"Decimal128": [bson.Decimal128(str(i)).bid for i in range(2)],
Expand All @@ -241,9 +243,9 @@ def test_arrow_to_polars(self):
self.assertEqual(len(arrow_table_in), res.raw_result["insertedCount"])
df_out = find_polars_all(self.coll, query={}, schema=Schema(arrow_schema))

# Sanity check: compare with cast_away_extension_types_on_table
arrow_cast = api._cast_away_extension_types_on_table(arrow_table_in)
assert_frame_equal(df_out, pl.from_arrow(arrow_cast))
# Sanity check: compare with _arrow_to_polars
df_actual_output = api._arrow_to_polars(arrow_table_in)
assert_frame_equal(df_out, df_actual_output)

def test_exceptions_for_unsupported_polar_types(self):
"""Confirm exceptions thrown are expected.
Expand Down
Loading