From def0e2ac334b2885b77ad10c3abf9904ae03b1b7 Mon Sep 17 00:00:00 2001 From: "m.dvizov" Date: Mon, 22 Aug 2022 17:11:17 +0300 Subject: [PATCH] Move sterf to CPU; add experimental parallelism for sterf --- CMakeLists.txt | 2 + library/src/CMakeLists.txt | 9 + library/src/auxiliary/rocauxiliary_sterf.cpp | 21 + library/src/auxiliary/rocauxiliary_sterf.hpp | 703 +++++++++++++++++++ library/src/include/lib_host_helpers.hpp | 76 ++ library/src/lapack/roclapack_syevd_heevd.hpp | 9 + 6 files changed, 820 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2fad03547..718897566 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,8 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(ROCSOLVER_EMBED_FMT "Hide libfmt symbols" OFF) option(OPTIMAL "Build specialized kernels for small matrix sizes" ON) +option(HYBRID_CPU "Build hybrid schema with CPU using" ON) +option(EXPERIMENTAL "Experimental parallelization" OFF) option(ROCSOLVER_FIND_PACKAGE_LAPACK_CONFIG "Skip module mode search for LAPACK" ON) # Add our CMake helper files to the lookup path diff --git a/library/src/CMakeLists.txt b/library/src/CMakeLists.txt index 0e48d6744..623dbb21f 100755 --- a/library/src/CMakeLists.txt +++ b/library/src/CMakeLists.txt @@ -290,6 +290,15 @@ if(OPTIMAL) target_compile_definitions(rocsolver PRIVATE OPTIMAL) endif() +if(HYBRID_CPU) + target_compile_definitions(rocsolver PRIVATE HYBRID_CPU) +endif() + +if(EXPERIMENTAL) + target_compile_definitions(rocsolver PRIVATE EXPERIMENTAL) +endif() + + target_compile_definitions(rocsolver PRIVATE ROCM_USE_FLOAT16 ROCBLAS_INTERNAL_API diff --git a/library/src/auxiliary/rocauxiliary_sterf.cpp b/library/src/auxiliary/rocauxiliary_sterf.cpp index 06fa89873..ca2e95953 100644 --- a/library/src/auxiliary/rocauxiliary_sterf.cpp +++ b/library/src/auxiliary/rocauxiliary_sterf.cpp @@ -27,6 +27,26 @@ rocblas_status rocblas_stride strideE = 0; rocblas_int batch_count = 1; +#ifdef EXPERIMENTAL + // additional memory for internal kernels (parallel sterf) + size_t size_ranges; + rocsolver_sterf_parallel_getMemorySize(n, &size_ranges); + + if(rocblas_is_device_memory_size_query(handle)) + return rocblas_set_optimal_device_memory_size(handle, size_ranges); + + // memory workspace allocation + void* ranges; + rocblas_device_malloc mem_range(handle, size_ranges); + if(!mem_range) + return rocblas_status_memory_error; + + ranges = mem_range[0]; + + // execution + return rocsolver_sterf_template(handle, n, D, shiftD, strideD, E, shiftE, strideE, info, + batch_count, (rocblas_int*)ranges); +#else // memory workspace sizes: // size for lasrt stack size_t size_stack; @@ -46,6 +66,7 @@ rocblas_status // execution return rocsolver_sterf_template(handle, n, D, shiftD, strideD, E, shiftE, strideE, info, batch_count, (rocblas_int*)stack); +#endif } /* diff --git a/library/src/auxiliary/rocauxiliary_sterf.hpp b/library/src/auxiliary/rocauxiliary_sterf.hpp index 16a6203e9..c70661345 100644 --- a/library/src/auxiliary/rocauxiliary_sterf.hpp +++ b/library/src/auxiliary/rocauxiliary_sterf.hpp @@ -13,6 +13,9 @@ #include "rocblas.hpp" #include "rocsolver/rocsolver.h" +#include +#include + /**************************************************************************** (TODO:THIS IS BASIC IMPLEMENTATION. THE ONLY PARALLELISM INTRODUCED HERE IS FOR THE BATCHED VERSIONS (A DIFFERENT THREAD WORKS ON EACH INSTANCE OF THE @@ -27,6 +30,268 @@ __device__ void sterf_sq_e(const rocblas_int start, const rocblas_int end, T* E) E[i] = E[i] * E[i]; } +template +void host_sterf_sq_e(const rocblas_int start, const rocblas_int end, T* E) +{ + for(int i = start; i < end; i++) + E[i] = E[i] * E[i]; +} + +/** STERF_CPU implements the main loop of the sterf algorithm + to compute the eigenvalues of a symmetric tridiagonal matrix given by D + and E on CPUi, non batched version**/ +template +void sterf_cpu(const rocblas_int n, + T* D, + T* E, + rocblas_int& info, + const rocblas_int max_iters, + const T eps, + const T ssfmin, + const T ssfmax) +{ + rocblas_int m, l, lsv, lend, lendsv; + rocblas_int l1 = 0; + rocblas_int iters = 0; + T anorm, p; + + while(l1 < n && iters < max_iters) + { + if(l1 > 0) + E[l1 - 1] = 0; + + for(m = l1; m < n - 1; m++) + { + if(abs(E[m]) <= sqrt(abs(D[m])) * sqrt(abs(D[m + 1])) * eps) + { + E[m] = 0; + break; + } + } + + lsv = l = l1; + lendsv = lend = m; + + l1 = m + 1; + + if(lend == l) + continue; + + // Scale submatrix + anorm = host_find_max_tridiag(l, lend, D, E); + + if(anorm == 0) + continue; + else if(anorm > ssfmax) + host_scale_tridiag(l, lend, D, E, anorm / ssfmax); + else if(anorm < ssfmin) + host_scale_tridiag(l, lend, D, E, anorm / ssfmin); + host_sterf_sq_e(l, lend, E); + + // Choose iteration type (QL or QR) + if(abs(D[lend]) < abs(D[l])) + { + lend = lsv; + l = lendsv; + } + + if(lend >= l) + { + // QL iteration + while(l <= lend && iters < max_iters) + { + // Find small subdiagonal element + for(m = l; m <= lend - 1; m++) + { + if(abs(E[m]) <= eps * eps * abs(D[m] * D[m + 1])) + { + break; + } + } + + if(m < lend) + E[m] = 0; + p = D[l]; + if(m == l) + { + D[l] = p; + l++; + } + else if(m == l + 1) + { + T rte, rt1, rt2; + rte = sqrt(E[l]); + host_lae2(D[l], rte, D[l + 1], rt1, rt2); + D[l] = rt1; + D[l + 1] = rt2; + E[l] = 0; + l = l + 2; + } + else + { + if(iters == max_iters) + break; + iters++; + + T sigma, gamma, r, rte, c, s; + + // Form shift + rte = sqrt(E[l]); + sigma = (D[l + 1] - p) / (2 * rte); + if(sigma >= 0) + r = abs(sqrt(1 + sigma * sigma)); + else + r = -abs(sqrt(1 + sigma * sigma)); + sigma = p - (rte / (sigma + r)); + + c = 1; + s = 0; + gamma = D[m] - sigma; + p = gamma * gamma; + + for(int i = m - 1; i >= l; i--) + { + T bb = E[i]; + r = p + bb; + if(i != m - 1) + E[i + 1] = s * r; + + T oldc = c; + c = p / r; + s = bb / r; + T oldgam = gamma; + gamma = c * (D[i] - sigma) - s * oldgam; + D[i + 1] = oldgam + (D[i] - gamma); + if(c != 0) + p = (gamma * gamma) / c; + else + p = oldc * bb; + } + + E[l] = s * p; + D[l] = sigma + gamma; + } + } + } + + else + { + // QR iteration + while(l >= lend && iters < max_iters) + { + // Find small subdiagonal element + for(m = l; m >= lend + 1; m--) + { + if(abs(E[m - 1]) <= eps * eps * abs(D[m] * D[m - 1])) + { + break; + } + } + + if(m > lend) + E[m - 1] = 0; + p = D[l]; + if(m == l) + { + D[l] = p; + l--; + } + else if(m == l - 1) + { + // Use lae2 to compute 2x2 eigenvalues. Using rte, rt1, rt2. + T rte, rt1, rt2; + rte = sqrt(E[l - 1]); + host_lae2(D[l], rte, D[l - 1], rt1, rt2); + D[l] = rt1; + D[l - 1] = rt2; + E[l - 1] = 0; + l = l - 2; + } + else + { + if(iters == max_iters) + break; + iters++; + + T sigma, gamma, r, rte, c, s; + + // Form shift. Using rte, r, c, s. + rte = sqrt(E[l - 1]); + sigma = (D[l - 1] - p) / (2 * rte); + if(sigma >= 0) + r = abs(sqrt(1 + sigma * sigma)); + else + r = -abs(sqrt(1 + sigma * sigma)); + sigma = p - (rte / (sigma + r)); + + c = 1; + s = 0; + gamma = D[m] - sigma; + p = gamma * gamma; + + for(int i = m; i <= l - 1; i++) + { + T bb = E[i]; + r = p + bb; + if(i != m) + E[i - 1] = s * r; + + T oldc = c; + c = p / r; + s = bb / r; + T oldgam = gamma; + gamma = c * (D[i + 1] - sigma) - s * oldgam; + D[i] = oldgam + (D[i + 1] - gamma); + if(c != 0) + p = (gamma * gamma) / c; + else + p = oldc * bb; + } + + E[l - 1] = s * p; + D[l] = sigma + gamma; + } + } + } + + // Undo scaling + if(anorm > ssfmax) + host_scale_tridiag(lsv, lendsv, D, E, ssfmax / anorm); + if(anorm < ssfmin) + host_scale_tridiag(lsv, lendsv, D, E, ssfmin / anorm); + } + + // Check for convergence + for(int i = 0; i < n - 1; i++) + if(E[i] != 0) + info++; + + // Sort eigenvalues + /** (TODO: the quick-sort method implemented in lasrt_increasing fails for some cases. + Substituting it here with a simple sorting algorithm. If more performance is required in + the future, lasrt_increasing should be debugged or another quick-sort method + could be implemented) **/ + for(int ii = 1; ii < n; ii++) + { + l = ii - 1; + m = l; + p = D[l]; + for(int j = ii; j < n; j++) + { + if(D[j] < p) + { + m = j; + p = D[j]; + } + } + if(m != l) + { + D[m] = D[l]; + D[l] = p; + } + } +} + /** STERF_KERNEL implements the main loop of the sterf algorithm to compute the eigenvalues of a symmetric tridiagonal matrix given by D and E **/ @@ -58,6 +323,7 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, // Determine submatrix indices if(l1 > 0) E[l1 - 1] = 0; + for(m = l1; m < n - 1; m++) { if(abs(E[m]) <= sqrt(abs(D[m])) * sqrt(abs(D[m + 1])) * eps) @@ -69,12 +335,15 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, lsv = l = l1; lendsv = lend = m; + l1 = m + 1; + if(lend == l) continue; // Scale submatrix anorm = find_max_tridiag(l, lend, D, E); + if(anorm == 0) continue; else if(anorm > ssfmax) @@ -97,8 +366,12 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, { // Find small subdiagonal element for(m = l; m <= lend - 1; m++) + { if(abs(E[m]) <= eps * eps * abs(D[m] * D[m + 1])) + { break; + } + } if(m < lend) E[m] = 0; @@ -173,9 +446,14 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, { // Find small subdiagonal element for(m = l; m >= lend + 1; m--) + { if(abs(E[m - 1]) <= eps * eps * abs(D[m] * D[m - 1])) + { break; + } + } + //printf("Finished check subranges\n"); if(m > lend) E[m - 1] = 0; p = D[l]; @@ -254,6 +532,7 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, if(E[i] != 0) info[bid]++; + /// into another kernel, check time // Sort eigenvalues /** (TODO: the quick-sort method implemented in lasrt_increasing fails for some cases. Substituting it here with a simple sorting algorithm. If more performance is required in @@ -282,6 +561,356 @@ ROCSOLVER_KERNEL void sterf_kernel(const rocblas_int n, } } +template +ROCSOLVER_KERNEL void sterf_find_subranges_gre(const rocblas_int n, + T* DD, + const rocblas_int offsetD, + T* EE, + const rocblas_int offsetE, + const T eps, + rocblas_int* split_ranges) +{ + T* D = DD; + T* E = EE; + + int m = 0, l = 0; + + rocblas_int range_count = 0; + + T Eold = 0; + T Emax = 0; + T GL, GU; + T tnrm = 0; + + for(int i = 0; i < n; ++i) + { + T Eabs = abs(E[i]); + if(Eabs >= Emax) + Emax = Eabs; + + T tmp = Eabs + Eold; + GL = min(GL, D[i] - tmp); + GU = max(GU, D[i] + tmp); + Eold = Eabs; + } + + /// spectral diametr + tnrm = GU - GL; + + while(l < n) + { + if(l > 0) + E[l - 1] = 0; + + for(m = l; m < n - 1; m++) + { + if(abs(E[m]) <= tnrm * eps) + { + E[m] = 0; + break; + } + } + + if(l != m) + { + split_ranges[range_count] = l; + ++range_count; + split_ranges[range_count] = m; + ++range_count; + } + + l = m + 1; + } + + for(int i = range_count; i < n; ++i) + split_ranges[i] = -1; +} + +template +ROCSOLVER_KERNEL void sterf_find_subranges_default(const rocblas_int n, + T* DD, + const rocblas_int offsetD, + T* EE, + const rocblas_int offsetE, + const T eps, + rocblas_int* split_ranges) +{ + T* D = DD; + T* E = EE; + + int m = 0, l = 0; + + rocblas_int range_count = 0; + + while(l < n) + { + if(l > 0) + E[l - 1] = 0; + + for(m = l; m < n - 1; m++) + { + if(abs(E[m]) <= sqrt(abs(D[m])) * sqrt(abs(D[m + 1])) * eps) + { + E[m] = 0; + break; + } + } + + if(l != m) + { + split_ranges[range_count] = l; + ++range_count; + split_ranges[range_count] = m; + ++range_count; + } + + l = m + 1; + } + + for(int i = range_count; i < n; ++i) + split_ranges[i] = -1; +} + +/// default parallel kernel +template +ROCSOLVER_KERNEL void sterf_parallelize(T* D, + T* E, + const rocblas_int max_iter, + const T eps, + const T ssfmax, + const T ssfmin, + rocblas_int* split_ranges, + rocblas_int* info) +{ + rocblas_int m = 0; + rocblas_int count = 0, l = -1, lend = -1; + rocblas_int l0, lend0; + T p, anorm; + + const rocblas_int tid = hipThreadIdx_x; + + l0 = l = split_ranges[2 * tid]; + lend0 = lend = split_ranges[2 * tid + 1]; + + if(l == -1 || lend == -1) + return; + + anorm = find_max_tridiag(l, lend, D, E); + + if(anorm == 0) + return; + else if(anorm > ssfmax) + scale_tridiag(l, lend, D, E, anorm / ssfmax); + else if(anorm < ssfmin) + scale_tridiag(l, lend, D, E, anorm / ssfmin); + sterf_sq_e(l, lend, E); + + // Choose iteration type (QL or QR) + if(abs(D[lend]) < abs(D[l])) + { + lend = l; + l = lend0; + } + + rocblas_int iters = 0; + if(lend >= l) + { + // for QL + while(l <= lend && iters < max_iter) + { + // Find small subdiagonal element (QL) + for(m = l; m <= lend - 1; m++) + { + if(abs(E[m]) <= eps * eps * abs(D[m] * D[m + 1])) + { + break; + } + } + + if(m < lend) + E[m] = 0; + + p = D[l]; + if(m == l) + { + ++l; + continue; + } + else if(m == l + 1) + { + T rte, rt1, rt2; + rte = sqrt(E[l]); + lae2(D[l], rte, D[l + 1], rt1, rt2); + D[l] = rt1; + D[l + 1] = rt2; + E[l] = 0; + l = l + 2; + continue; + } + else + { + if(iters == max_iter) + break; + ++iters; + T sigma, gamma, r, rte, c, s; + + // Form shift + rte = sqrt(E[l]); + sigma = (D[l + 1] - p) / (2 * rte); + + if(sigma >= 0) + r = abs(sqrt(1 + sigma * sigma)); + else + r = -abs(sqrt(1 + sigma * sigma)); + sigma = p - (rte / (sigma + r)); + + c = 1; + s = 0; + gamma = D[m] - sigma; + p = gamma * gamma; + + for(int i = m - 1; i >= l; i--) + { + T bb = E[i]; + r = p + bb; + if(i != m - 1) + E[i + 1] = s * r; + + T oldc = c; + c = p / r; + s = bb / r; + T oldgam = gamma; + gamma = c * (D[i] - sigma) - s * oldgam; + D[i + 1] = oldgam + (D[i] - gamma); + + if(c != 0) + p = (gamma * gamma) / c; + else + p = oldc * bb; + } + + E[l] = s * p; + D[l] = sigma + gamma; + } + } + } + else + { + //for QR + while(l >= lend && iters < max_iter) + { + // Find small subdiagonal element + for(m = l; m >= lend + 1; m--) + { + if(abs(E[m - 1]) <= eps * eps * abs(D[m] * D[m - 1])) + { + break; + } + } + + if(m > lend) + E[m - 1] = 0; + + p = D[l]; + if(m == l) + { + --l; + continue; + } + else if(m == l - 1) + { + T rte, rt1, rt2; + rte = sqrt(E[l - 1]); + lae2(D[l], rte, D[l - 1], rt1, rt2); + D[l] = rt1; + D[l - 1] = rt2; + E[l - 1] = 0; + l = l - 2; + continue; + } + else + { + if(iters == max_iter) + break; + ++iters; + T sigma, gamma, r, rte, c, s; + + // Form shift. Using rte, r, c, s. + rte = sqrt(E[l - 1]); + sigma = (D[l - 1] - p) / (2 * rte); + if(sigma >= 0) + r = abs(sqrt(1 + sigma * sigma)); + else + r = -abs(sqrt(1 + sigma * sigma)); + sigma = p - (rte / (sigma + r)); + + c = 1; + s = 0; + gamma = D[m] - sigma; + p = gamma * gamma; + + for(int i = m; i <= l - 1; i++) + { + T bb = E[i]; + r = p + bb; + if(i != m) + E[i - 1] = s * r; + + T oldc = c; + c = p / r; + s = bb / r; + T oldgam = gamma; + gamma = c * (D[i + 1] - sigma) - s * oldgam; + D[i] = oldgam + (D[i + 1] - gamma); + if(c != 0) + p = (gamma * gamma) / c; + else + p = oldc * bb; + } + + E[l - 1] = s * p; + D[l] = sigma + gamma; + } + } + } + + if(anorm > ssfmax) + scale_tridiag(l, lend, D, E, ssfmax / anorm); + if(anorm < ssfmin) + scale_tridiag(l, lend, D, E, ssfmin / anorm); + + for(int i = l; i <= lend; i++) + if(E[i] != 0) + info[0]++; +} + +template +ROCSOLVER_KERNEL void sterf_sorting(const rocblas_int n, T* D) +{ + rocblas_int l, m; + T p; + + for(int ii = 1; ii < n; ii++) + { + l = ii - 1; + m = l; + p = D[l]; + for(int j = ii; j < n; j++) + { + if(D[j] < p) + { + m = j; + p = D[j]; + } + } + if(m != l) + { + D[m] = D[l]; + D[l] = p; + } + } +} + template void rocsolver_sterf_getMemorySize(const rocblas_int n, const rocblas_int batch_count, @@ -298,6 +927,19 @@ void rocsolver_sterf_getMemorySize(const rocblas_int n, *size_stack = sizeof(rocblas_int) * (2 * 32) * batch_count; } +template +void rocsolver_sterf_parallel_getMemorySize(const rocblas_int n, size_t* size_ranges) +{ + // if quick return no workspace needed + if(n == 0) + { + *size_ranges = 0; + return; + } + + *size_ranges = sizeof(rocblas_int) * (n / 2); +} + template rocblas_status rocsolver_sterf_argCheck(rocblas_handle handle, const rocblas_int n, T D, T E, rocblas_int* info) @@ -360,8 +1002,69 @@ rocblas_status rocsolver_sterf_template(rocblas_handle handle, ssfmin = sqrt(ssfmin) / (eps * eps); ssfmax = sqrt(ssfmax) / T(3.0); +#ifdef HYBRID_CPU + + for(int i = 0; i < batch_count; ++i) + { + T* h_D = new T[n]; + T* h_E = new T[n]; + rocblas_int h_info = 0; + + hipDeviceSynchronize(); + + T* shD = D + i * strideD + shiftD; + T* shE = E + i * strideE + shiftE; + + // copy to CPU + hipMemcpy(h_D, shD, sizeof(T) * n, hipMemcpyDeviceToHost); + hipMemcpy(h_E, shE, sizeof(T) * n, hipMemcpyDeviceToHost); + + sterf_cpu(n, h_D, h_E, h_info, 30 * n, eps, ssfmin, ssfmax); + + hipMemcpy(shD, h_D, sizeof(T) * n, hipMemcpyHostToDevice); + hipMemcpy(shE, h_E, sizeof(T) * n, hipMemcpyHostToDevice); + hipMemcpy(info + i, &h_info, sizeof(rocblas_int), hipMemcpyHostToDevice); + + delete[] h_D; + delete[] h_E; + } + +#elif EXPERIMENTAL + + int max_threads = 1024; + int CU_count = (n - 1) / max_threads + 1; + int thread_count = max_threads; + + dim3 grid(CU_count, 1, 1); + dim3 block(thread_count, 1, 1); + + int offsetD = thread_count; + int offsetE = thread_count; + rocblas_int* split_ranges = stack; + + size_t lmemsize = n * sizeof(int) + sizeof(T); + + /// find ranges for sterf + ROCSOLVER_LAUNCH_KERNEL(sterf_find_subranges_default, dim3(1), dim3(1), 0, stream, n, + D + shiftD, offsetD, E + shiftE, offsetE, eps, ranges_m); + + /// execute parallel sterf + CU_count = (n - 2) / (2 * max_threads) + 1; + thread_count = n > max_threads ? max_threads : n / 2; + + dim3 sterf_grid(CU_count, 1, 1); + dim3 sterf_block(thread_count, 1, 1); + + ROCSOLVER_LAUNCH_KERNEL(sterf_parallelize, dim3(CU_count), dim3(thread_count), 0, stream, D, + E, n, eps, ssfmax, ssfmin, ranges_m, info); + + ROCSOLVER_LAUNCH_KERNEL(sterf_sorting, dim3(1), dim3(1), 0, stream, n, D); + +#else + ROCSOLVER_LAUNCH_KERNEL(sterf_kernel, dim3(batch_count), dim3(1), 0, stream, n, D + shiftD, strideD, E + shiftE, strideE, info, stack, 30 * n, eps, ssfmin, ssfmax); +#endif return rocblas_status_success; } diff --git a/library/src/include/lib_host_helpers.hpp b/library/src/include/lib_host_helpers.hpp index 0bc2ecb89..8967c9890 100644 --- a/library/src/include/lib_host_helpers.hpp +++ b/library/src/include/lib_host_helpers.hpp @@ -60,6 +60,82 @@ inline rocblas_int get_index(rocblas_int* intervals, rocblas_int max, rocblas_in return i; } +/** FIND_MAX_TRIDIAG finds the element with the largest magnitude in the + tridiagonal matrix **/ +template +T host_find_max_tridiag(const rocblas_int start, const rocblas_int end, T* D, T* E) +{ + T anorm = abs(D[end]); + for(int i = start; i < end; i++) + anorm = max(anorm, max(abs(D[i]), abs(E[i]))); + return anorm; +} + +/** SCALE_TRIDIAG scales the elements of the tridiagonal matrix by a given + scale factor **/ +template +void host_scale_tridiag(const rocblas_int start, const rocblas_int end, T* D, T* E, T scale) +{ + D[end] *= scale; + for(int i = start; i < end; i++) + { + D[i] *= scale; + E[i] *= scale; + } +} + +/** LAE2 computes the eigenvalues of a 2x2 symmetric matrix + [ a b ] + [ b c ] **/ +template , int> = 0> +void host_lae2(T& a, T& b, T& c, T& rt1, T& rt2) +{ + T sm = a + c; + T adf = abs(a - c); + T ab = abs(b + b); + + T rt, acmx, acmn; + if(adf > ab) + { + rt = ab / adf; + rt = adf * sqrt(1 + rt * rt); + } + else if(adf < ab) + { + rt = adf / ab; + rt = ab * sqrt(1 + rt * rt); + } + else + rt = ab * sqrt(2); + + // Compute the eigenvalues + if(abs(a) > abs(c)) + { + acmx = a; + acmn = c; + } + else + { + acmx = c; + acmn = a; + } + if(sm < 0) + { + rt1 = T(0.5) * (sm - rt); + rt2 = T((acmx / (double)rt1) * acmn - (b / (double)rt1) * b); + } + else if(sm > 0) + { + rt1 = T(0.5) * (sm + rt); + rt2 = T((acmx / (double)rt1) * acmn - (b / (double)rt1) * b); + } + else + { + rt1 = T(0.5) * rt; + rt2 = T(-0.5) * rt; + } +} + #ifdef ROCSOLVER_VERIFY_ASSUMPTIONS // Ensure __assert_fail is declared. #if !__is_identifier(__assert_fail) diff --git a/library/src/lapack/roclapack_syevd_heevd.hpp b/library/src/lapack/roclapack_syevd_heevd.hpp index dfe738e37..147e07e69 100644 --- a/library/src/lapack/roclapack_syevd_heevd.hpp +++ b/library/src/lapack/roclapack_syevd_heevd.hpp @@ -71,7 +71,11 @@ void rocsolver_syevd_heevd_getMemorySize(const rocblas_evect evect, // extra requirements for computing only the eigenvalues (sterf) rocsolver_sterf_getMemorySize(n, batch_count, &w12); +#ifdef EXPERIMENTAL + *size_work3 = sizeof(rocblas_int) * (n); +#else *size_work3 = 0; +#endif } // size of array for temporary matrix products @@ -154,8 +158,13 @@ rocblas_status rocsolver_syevd_heevd_template(rocblas_handle handle, if(evect != rocblas_evect_original) { // only compute eigenvalues +#ifdef EXPERIMENTAL + rocsolver_sterf_template(handle, n, D, 0, strideD, E, 0, strideE, info, batch_count, + (rocblas_int*)work3); +#else rocsolver_sterf_template(handle, n, D, 0, strideD, E, 0, strideE, info, batch_count, (rocblas_int*)work1); +#endif } else {