Skip to content

Commit

Permalink
Fix refactchol for 6.0 (#639)
Browse files Browse the repository at this point in the history
* add HIP_CHECK

* remove beta in add_QAQ, add hipMemsetAsync

* remove beta from add_PAQ, use hipMemsetAsync

* modification for refactlu and minor bug fix

* Cleanup (#3)

* Changes to HIP_CHECK and added status to rocsolver_trsm

---------

Co-authored-by: Ed D'Azevedo <[email protected]>
Co-authored-by: Eduardo D'Azevedo <[email protected]>
  • Loading branch information
3 people committed Dec 6, 2023
1 parent 702cf34 commit 447a52f
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 290 deletions.
27 changes: 17 additions & 10 deletions library/src/include/rocblas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,24 @@ constexpr auto rocblas2string_status(rocblas_status status)
}
}

#define ROCBLAS_CHECK(fcn) \
{ \
rocblas_status _status = (fcn); \
if(_status != rocblas_status_success) \
return _status; \
#define HIP_CHECK(...) \
{ \
hipError_t _status = (__VA_ARGS__); \
if(_status != hipSuccess) \
return get_rocblas_status_for_hip_status(_status); \
}
#define THROW_IF_ROCBLAS_ERROR(fcn) \
{ \
rocblas_status _status = (fcn); \
if(_status != rocblas_status_success) \
throw _status; \

#define ROCBLAS_CHECK(...) \
{ \
rocblas_status _status = (__VA_ARGS__); \
if(_status != rocblas_status_success) \
return _status; \
}
#define THROW_IF_ROCBLAS_ERROR(...) \
{ \
rocblas_status _status = (__VA_ARGS__); \
if(_status != rocblas_status_success) \
throw _status; \
}

template <typename T>
Expand Down
168 changes: 84 additions & 84 deletions library/src/include/rocsolver_run_specialized_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,96 +55,96 @@ void rocsolver_trsm_mem(const rocblas_side side,
const rocblas_int incb = 1);

template <bool BATCHED, bool STRIDED, typename T, typename U>
void rocsolver_trsm_lower(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);
rocblas_status rocsolver_trsm_lower(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);

template <bool BATCHED, bool STRIDED, typename T, typename U>
void rocsolver_trsm_lower(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int inca,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int incb,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);
rocblas_status rocsolver_trsm_lower(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int inca,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int incb,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);

template <bool BATCHED, bool STRIDED, typename T, typename U>
void rocsolver_trsm_upper(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);
rocblas_status rocsolver_trsm_upper(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);

template <bool BATCHED, bool STRIDED, typename T, typename U>
void rocsolver_trsm_upper(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int inca,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int incb,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);
rocblas_status rocsolver_trsm_upper(rocblas_handle handle,
const rocblas_side side,
const rocblas_operation trans,
const rocblas_diagonal diag,
const rocblas_int m,
const rocblas_int n,
U A,
const rocblas_int shiftA,
const rocblas_int inca,
const rocblas_int lda,
const rocblas_stride strideA,
U B,
const rocblas_int shiftB,
const rocblas_int incb,
const rocblas_int ldb,
const rocblas_stride strideB,
const rocblas_int batch_count,
const bool optim_mem,
void* work1,
void* work2,
void* work3,
void* work4);

// gemm
template <bool BATCHED, bool STRIDED, typename T, typename U>
Expand Down
8 changes: 4 additions & 4 deletions library/src/include/rocsparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ constexpr auto rocsparse2rocblas_status(rocsparse_status status)
}
}

#define ROCSPARSE_CHECK(fcn) \
#define ROCSPARSE_CHECK(...) \
{ \
rocsparse_status _status = (fcn); \
rocsparse_status _status = (__VA_ARGS__); \
if(_status != rocsparse_status_success) \
return rocsparse2rocblas_status(_status); \
}
#define THROW_IF_ROCSPARSE_ERROR(fcn) \
#define THROW_IF_ROCSPARSE_ERROR(...) \
{ \
rocsparse_status _status = (fcn); \
rocsparse_status _status = (__VA_ARGS__); \
if(_status != rocsparse_status_success) \
throw rocsparse2rocblas_status(_status); \
}
Expand Down
13 changes: 4 additions & 9 deletions library/src/lapack/roclapack_syevdx_heevdx_inplace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* December 2016
* Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -290,14 +290,9 @@ rocblas_status rocsolver_syevdx_heevdx_inplace_template(rocblas_handle handle,
// copy nev from device to host
if(h_nev)
{
hipError_t status = hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count,
hipMemcpyDeviceToHost, stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);

status = hipStreamSynchronize(stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);
HIP_CHECK(hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count,
hipMemcpyDeviceToHost, stream));
HIP_CHECK(hipStreamSynchronize(stream));
}

return rocblas_status_success;
Expand Down
13 changes: 4 additions & 9 deletions library/src/lapack/roclapack_syevj_heevj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* and
* Hari & Kovac (2019). On the Convergence of Complex Jacobi Methods.
* Linear and Multilinear Algebra 69(3), p. 489-514.
* Copyright (c) 2021-2023 Advanced Micro Devices, Inc.
* Copyright (C) 2021-2023 Advanced Micro Devices, Inc.
* ***********************************************************************/

#pragma once
Expand Down Expand Up @@ -1493,14 +1493,9 @@ rocblas_status rocsolver_syevj_heevj_template(rocblas_handle handle,
while(h_sweeps < max_sweeps)
{
// if all instances in the batch have finished, exit the loop
hipError_t status = hipMemcpyAsync(&h_completed, completed, sizeof(rocblas_int),
hipMemcpyDeviceToHost, stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);

status = hipStreamSynchronize(stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);
HIP_CHECK(hipMemcpyAsync(&h_completed, completed, sizeof(rocblas_int),
hipMemcpyDeviceToHost, stream));
HIP_CHECK(hipStreamSynchronize(stream));

if(h_completed == batch_count)
break;
Expand Down
13 changes: 4 additions & 9 deletions library/src/lapack/roclapack_sygvdx_hegvdx_inplace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* December 2016
* Copyright (C) 2021-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -320,14 +320,9 @@ rocblas_status rocsolver_sygvdx_hegvdx_inplace_template(rocblas_handle handle,
// copy nev from device to host
if(h_nev)
{
hipError_t status = hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count,
hipMemcpyDeviceToHost, stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);

status = hipStreamSynchronize(stream);
if(status != hipSuccess)
return get_rocblas_status_for_hip_status(status);
HIP_CHECK(hipMemcpyAsync(h_nev, d_nev, sizeof(rocblas_int) * batch_count,
hipMemcpyDeviceToHost, stream));
HIP_CHECK(hipStreamSynchronize(stream));
}

rocblas_set_pointer_mode(handle, old_mode);
Expand Down
4 changes: 2 additions & 2 deletions library/src/refact/rocrefact_csrrf_refactchol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ rocblas_status rocsolver_csrrf_refactchol_impl(rocblas_handle handle,
work = mem[0];

// execution
return rocsolver_csrrf_refactchol_template<T>(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT,
indT, valT, pivQ, rfinfo, work);
return rocsolver_csrrf_refactchol_template<T, U>(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT,
indT, valT, pivQ, rfinfo, work);
#else
return rocblas_status_not_implemented;
#endif
Expand Down
20 changes: 4 additions & 16 deletions library/src/refact/rocrefact_csrrf_refactchol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ ROCSOLVER_KERNEL void rf_add_QAQ_kernel(const rocblas_int n,
rocblas_int* Ap,
rocblas_int* Ai,
T* Ax,
const T beta,
rocblas_int* LUp,
rocblas_int* LUi,
T* LUx)
Expand All @@ -76,18 +75,6 @@ ROCSOLVER_KERNEL void rf_add_QAQ_kernel(const rocblas_int n,

T aij;

// ----------------
// scale B by beta
// ----------------
for(i = istart + tiy; i < iend; i += hipBlockDim_y)
{
// only access lower triangle of B
if(irow < LUi[i])
break;
LUx[i] *= beta;
}
__syncthreads();

// ------------------------------
// scale A by alpha and add to B
// ------------------------------
Expand Down Expand Up @@ -216,16 +203,17 @@ rocblas_status rocsolver_csrrf_refactchol_template(rocblas_handle handle,
ROCSOLVER_LAUNCH_KERNEL(rf_ipvec_kernel<T>, dim3(nblocks), dim3(BS2), 0, stream, n, pivQ,
(rocblas_int*)work);

// set T to zero
HIP_CHECK(hipMemsetAsync((void*)valT, 0, sizeof(T) * nnzT, stream));

// --------------------------------------------------------------
// copy Q'*A*Q into T
//
// Note: assume A and B are symmetric and ONLY the LOWER triangular parts of A and T are touched
// --------------------------------------------------------------
T const alpha = static_cast<T>(1);
T const beta = static_cast<T>(0);
ROCSOLVER_LAUNCH_KERNEL(rf_add_QAQ_kernel<T>, dim3(nblocks, 1), dim3(BS2, BS2), 0, stream, n,
pivQ, (rocblas_int*)work, alpha, ptrA, indA, valA, beta, ptrT, indT,
valT);
pivQ, (rocblas_int*)work, alpha, ptrA, indA, valA, ptrT, indT, valT);

// perform incomplete factorization of T
ROCSPARSE_CHECK(rocsparseCall_csric0(rfinfo->sphandle, n, nnzT, rfinfo->descrT, valT, ptrT,
Expand Down
4 changes: 2 additions & 2 deletions library/src/refact/rocrefact_csrrf_refactlu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ rocblas_status rocsolver_csrrf_refactlu_impl(rocblas_handle handle,
work = mem[0];

// execution
return rocsolver_csrrf_refactlu_template<T>(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT, indT,
valT, pivP, pivQ, rfinfo, work);
return rocsolver_csrrf_refactlu_template<T, U>(handle, n, nnzA, ptrA, indA, valA, nnzT, ptrT,
indT, valT, pivP, pivQ, rfinfo, work);
#else
return rocblas_status_not_implemented;
#endif
Expand Down
Loading

0 comments on commit 447a52f

Please sign in to comment.