Skip to content

Commit

Permalink
BUG: sparse: Fix 1D specialty hstack codes (scipy#21400)
Browse files Browse the repository at this point in the history
* add more hstack/vstack tests to cover coo and csr special cases

* change shape to _shape_as_2d where needed in specialty block() functions.
  • Loading branch information
dschult authored Aug 15, 2024
1 parent 4710732 commit ed184ff
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
16 changes: 8 additions & 8 deletions scipy/sparse/_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,23 +602,23 @@ def _compressed_sparse_stack(blocks, axis, return_spmatrix):
"""
other_axis = 1 if axis == 0 else 0
data = np.concatenate([b.data for b in blocks])
constant_dim = blocks[0].shape[other_axis]
constant_dim = blocks[0]._shape_as_2d[other_axis]
idx_dtype = get_index_dtype(arrays=[b.indptr for b in blocks],
maxval=max(data.size, constant_dim))
indices = np.empty(data.size, dtype=idx_dtype)
indptr = np.empty(sum(b.shape[axis] for b in blocks) + 1, dtype=idx_dtype)
indptr = np.empty(sum(b._shape_as_2d[axis] for b in blocks) + 1, dtype=idx_dtype)
last_indptr = idx_dtype(0)
sum_dim = 0
sum_indices = 0
for b in blocks:
if b.shape[other_axis] != constant_dim:
if b._shape_as_2d[other_axis] != constant_dim:
raise ValueError(f'incompatible dimensions for axis {other_axis}')
indices[sum_indices:sum_indices+b.indices.size] = b.indices
sum_indices += b.indices.size
idxs = slice(sum_dim, sum_dim + b.shape[axis])
idxs = slice(sum_dim, sum_dim + b._shape_as_2d[axis])
indptr[idxs] = b.indptr[:-1]
indptr[idxs] += last_indptr
sum_dim += b.shape[axis]
sum_dim += b._shape_as_2d[axis]
last_indptr += b.indptr[-1]
indptr[-1] = last_indptr
# TODO remove this if-structure when sparse matrices removed
Expand Down Expand Up @@ -652,7 +652,7 @@ def _stack_along_minor_axis(blocks, axis):

# check for incompatible dimensions
other_axis = 1 if axis == 0 else 0
other_axis_dims = {b.shape[other_axis] for b in blocks}
other_axis_dims = {b._shape_as_2d[other_axis] for b in blocks}
if len(other_axis_dims) > 1:
raise ValueError(f'Mismatching dimensions along axis {other_axis}: '
f'{other_axis_dims}')
Expand All @@ -668,10 +668,10 @@ def _stack_along_minor_axis(blocks, axis):
# - The max value in indptr is the number of non-zero entries. This is
# exceedingly unlikely to require int64, but is checked out of an
# abundance of caution.
sum_dim = sum(b.shape[axis] for b in blocks)
sum_dim = sum(b._shape_as_2d[axis] for b in blocks)
nnz = sum(len(b.indices) for b in blocks)
idx_dtype = get_index_dtype(maxval=max(sum_dim - 1, nnz))
stack_dim_cat = np.array([b.shape[axis] for b in blocks], dtype=idx_dtype)
stack_dim_cat = np.array([b._shape_as_2d[axis] for b in blocks], dtype=idx_dtype)
if data_cat.size > 0:
indptr_cat = np.concatenate(indptr_list).astype(idx_dtype)
indices_cat = (np.concatenate([b.indices for b in blocks])
Expand Down
12 changes: 11 additions & 1 deletion scipy/sparse/tests/test_construct.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,21 @@ def test_vstack_matrix_or_array(self):
def test_vstack_1d_with_2d(self):
# fixes gh-21064
arr = csr_array([[1, 0, 0], [0, 1, 0]])
arr1d = arr[0]
arr1d = csr_array([1, 0, 0])
arr1dcoo = coo_array([1, 0, 0])
assert construct.vstack([arr, np.array([0, 0, 0])]).shape == (3, 3)
assert construct.hstack([arr1d, np.array([[0]])]).shape == (1, 4)
assert construct.hstack([arr1d, arr1d]).shape == (1, 6)
assert construct.vstack([arr1d, arr1d]).shape == (2, 3)

# check csr specialty stacking code like _stack_along_minor_axis
assert construct.hstack([arr, arr]).shape == (2, 6)
assert construct.hstack([arr1d, arr1d]).shape == (1, 6)

assert construct.hstack([arr1d, arr1dcoo]).shape == (1, 6)
assert construct.vstack([arr, arr1dcoo]).shape == (3, 3)
assert construct.vstack([arr1d, arr1dcoo]).shape == (2, 3)

with pytest.raises(ValueError, match="incompatible row dimensions"):
construct.hstack([arr, np.array([0, 0])])
with pytest.raises(ValueError, match="incompatible column dimensions"):
Expand Down

0 comments on commit ed184ff

Please sign in to comment.