Skip to content

Commit

Permalink
ARROW-241 Allow list in Schema as an alias for pa.list_ (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
lazargugleta authored Jun 28, 2024
1 parent 1c5cdd2 commit ce4c1d1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
5 changes: 4 additions & 1 deletion bindings/python/pymongoarrow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def _normalize_typeid(typeid, field_name):
fields.append((sub_field_name, _normalize_typeid(sub_typeid, sub_field_name)))
return struct(fields)
if isinstance(typeid, list):
return list_(_normalize_typeid(type(typeid[0]), "0"))
if len(typeid) != 1:
msg = f"list field in schema must contain exactly one element, not {len(typeid)}"
raise ValueError(msg)
return list_(_normalize_typeid(typeid[0], "0"))
if _is_typeid_supported(typeid):
normalizer = _TYPE_NORMALIZER_FACTORY[typeid]
return normalizer(typeid)
Expand Down
20 changes: 20 additions & 0 deletions bindings/python/test/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datetime import datetime
from unittest import TestCase

import pytest
from bson import Binary, Code, Decimal128, Int64, ObjectId
from pyarrow import Table, field, float64, int64, list_, struct, timestamp
from pyarrow import schema as ArrowSchema
Expand Down Expand Up @@ -94,3 +95,22 @@ def test_list_of_list_projection(self):
}
)
self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})

def test_py_list_projection(self):
schema = Schema(
{"_id": ObjectId, "list": [(struct([field("a", int64()), field("b", float64())]))]}
)

self.assertEqual(schema._get_projection(), {"_id": True, "list": {"a": True, "b": True}})

def test_py_list_with_multiple_fields_raises(self):
with pytest.raises(
ValueError, match="list field in schema must contain exactly one element, not 2"
):
_ = Schema({"_id": ObjectId, "list": [([field("a", int64()), field("b", float64())])]})

def test_py_empty_list_raises(self):
with pytest.raises(
ValueError, match="list field in schema must contain exactly one element, not 0"
):
_ = Schema({"_id": ObjectId, "list": []})

0 comments on commit ce4c1d1

Please sign in to comment.