Skip to content

Commit

Permalink
FIX-#7383: Avoid broadcast issue in partition manager with custom NPa…
Browse files Browse the repository at this point in the history
…rtitions (#7399)

Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Sep 20, 2024
1 parent 6cf3ca2 commit 7867400
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
14 changes: 9 additions & 5 deletions modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,18 +915,22 @@ def map_partitions_joined_by_column(
# step cannot be less than 1
step = max(partitions.shape[0] // column_splits, 1)
preprocessed_map_func = cls.preprocess_func(map_func)
kw = {
"num_splits": step,
}
result = np.empty(partitions.shape, dtype=object)
for i in range(
0,
partitions.shape[0],
step,
):
joined_column_partitions = cls.column_partitions(partitions[i : i + step])
partitions_subset = partitions[i : i + step]
# This is necessary when ``partitions.shape[0]`` is not divisible
# by `column_splits` without a remainder.
actual_step = len(partitions_subset)
kw = {
"num_splits": actual_step,
}
joined_column_partitions = cls.column_partitions(partitions_subset)
for j in range(partitions.shape[1]):
result[i : i + step, j] = joined_column_partitions[j].apply(
result[i : i + actual_step, j] = joined_column_partitions[j].apply(
preprocessed_map_func,
*map_func_args if map_func_args is not None else (),
**kw,
Expand Down
5 changes: 3 additions & 2 deletions modin/tests/core/storage_formats/pandas/test_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2677,8 +2677,9 @@ def test_dynamic_partitioning(partitioning_scheme, expected_map_approach):
expected_method.assert_called()


def test_map_partitions_joined_by_column():
with context(NPartitions=CpuCount.get() * 2):
@pytest.mark.parametrize("npartitions", [7, CpuCount.get() * 2])
def test_map_partitions_joined_by_column(npartitions):
with context(NPartitions=npartitions):
ncols = MinColumnPartitionSize.get()
nrows = MinRowPartitionSize.get() * CpuCount.get() * 2
data = {f"col{i}": np.ones(nrows) for i in range(ncols)}
Expand Down

0 comments on commit 7867400

Please sign in to comment.