From 20c9487e92bae3b1acbced3f4cf11ec1b0304ff7 Mon Sep 17 00:00:00 2001 From: "m.dvizov" Date: Thu, 1 Sep 2022 11:20:17 +0300 Subject: [PATCH] Add dsterf calling, minor fixes --- CMakeLists.txt | 1 + library/src/auxiliary/rocauxiliary_sterf.cpp | 26 ++++++++++++++++++++ library/src/auxiliary/rocauxiliary_sterf.hpp | 26 ++++++++++++++------ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 718897566..6eed65857 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(ROCSOLVER_EMBED_FMT "Hide libfmt symbols" OFF) option(OPTIMAL "Build specialized kernels for small matrix sizes" ON) option(HYBRID_CPU "Build hybrid schema with CPU using" ON) +option(LAPACK_FUNCTIONS "Build hybrid with lapack routine functions" ON) option(EXPERIMENTAL "Experimental parallelization" OFF) option(ROCSOLVER_FIND_PACKAGE_LAPACK_CONFIG "Skip module mode search for LAPACK" ON) diff --git a/library/src/auxiliary/rocauxiliary_sterf.cpp b/library/src/auxiliary/rocauxiliary_sterf.cpp index ca2e95953..1f5fa41fd 100644 --- a/library/src/auxiliary/rocauxiliary_sterf.cpp +++ b/library/src/auxiliary/rocauxiliary_sterf.cpp @@ -4,6 +4,32 @@ #include "rocauxiliary_sterf.hpp" +#ifdef LAPACK_FUNCTIONS + +#ifdef __cplusplus +extern "C" { +#endif + void dsterf(int* n, double* D, double* E, int* info); + void ssterf(int* n, float* D, float* E, int* info); +#ifdef __cplusplus +} +#endif + +template <> +void lapack_sterf(rocblas_int n, double* D, double* E, int &info) +{ + dsterf(&n, D, E, &info); +} + +template <> +void lapack_sterf(rocblas_int n, float* D, float* E, int &info) +{ + ssterf(&n, D, E, &info); +} + +#endif + + template rocblas_status rocsolver_sterf_impl(rocblas_handle handle, const rocblas_int n, T* D, T* E, rocblas_int* info) diff --git a/library/src/auxiliary/rocauxiliary_sterf.hpp b/library/src/auxiliary/rocauxiliary_sterf.hpp index c70661345..2abfb7759 100644 --- a/library/src/auxiliary/rocauxiliary_sterf.hpp +++ b/library/src/auxiliary/rocauxiliary_sterf.hpp @@ -22,6 +22,12 @@ BATCH).) ***************************************************************************/ +#ifdef LAPACK_FUNCTIONS +/** direct call sterf from LAPACK **/ +template +void lapack_sterf(rocblas_int n, T* D, T* E, int &info); +#endif + /** STERF_SQ_E squares the elements of E **/ template __device__ void sterf_sq_e(const rocblas_int start, const rocblas_int end, T* E) @@ -684,13 +690,13 @@ ROCSOLVER_KERNEL void sterf_parallelize(T* D, { rocblas_int m = 0; rocblas_int count = 0, l = -1, lend = -1; - rocblas_int l0, lend0; + rocblas_int l_orig, lend_orig; T p, anorm; const rocblas_int tid = hipThreadIdx_x; - l0 = l = split_ranges[2 * tid]; - lend0 = lend = split_ranges[2 * tid + 1]; + l_orig = l = split_ranges[2 * tid]; + lend_orig = lend = split_ranges[2 * tid + 1]; if(l == -1 || lend == -1) return; @@ -709,7 +715,7 @@ ROCSOLVER_KERNEL void sterf_parallelize(T* D, if(abs(D[lend]) < abs(D[l])) { lend = l; - l = lend0; + l = lend_orig; } rocblas_int iters = 0; @@ -875,11 +881,11 @@ ROCSOLVER_KERNEL void sterf_parallelize(T* D, } if(anorm > ssfmax) - scale_tridiag(l, lend, D, E, ssfmax / anorm); + scale_tridiag(l_orig, lend_orig, D, E, ssfmax / anorm); if(anorm < ssfmin) - scale_tridiag(l, lend, D, E, ssfmin / anorm); + scale_tridiag(l_orig, lend_orig, D, E, ssfmin / anorm); - for(int i = l; i <= lend; i++) + for(int i = l_orig; i <= lend_orig; i++) if(E[i] != 0) info[0]++; } @@ -1010,7 +1016,7 @@ rocblas_status rocsolver_sterf_template(rocblas_handle handle, T* h_E = new T[n]; rocblas_int h_info = 0; - hipDeviceSynchronize(); + hipStreamSynchronize(stream); T* shD = D + i * strideD + shiftD; T* shE = E + i * strideE + shiftE; @@ -1019,7 +1025,11 @@ rocblas_status rocsolver_sterf_template(rocblas_handle handle, hipMemcpy(h_D, shD, sizeof(T) * n, hipMemcpyDeviceToHost); hipMemcpy(h_E, shE, sizeof(T) * n, hipMemcpyDeviceToHost); +#ifdef LAPACK_FUNCTIONS + lapack_sterf(n, h_D, h_E, h_info); +#else sterf_cpu(n, h_D, h_E, h_info, 30 * n, eps, ssfmin, ssfmax); +#endif hipMemcpy(shD, h_D, sizeof(T) * n, hipMemcpyHostToDevice); hipMemcpy(shE, h_E, sizeof(T) * n, hipMemcpyHostToDevice);