diff --git a/CHANGELOG.md b/CHANGELOG.md index ffb715859..cbe62c27d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,9 @@ Full documentation for rocSOLVER is available at the [rocSOLVER documentation](h ## (Unreleased) rocSOLVER ### Added ### Optimized + +* Improved the performance of LARFT and downstream functions such as GEQR2 and GEQRF + ### Changed ### Deprecated ### Removed diff --git a/library/src/auxiliary/rocauxiliary_larft.hpp b/library/src/auxiliary/rocauxiliary_larft.hpp index 0b6bdaaec..ba3f7dbff 100644 --- a/library/src/auxiliary/rocauxiliary_larft.hpp +++ b/library/src/auxiliary/rocauxiliary_larft.hpp @@ -4,6 +4,10 @@ * Univ. of Tennessee, Univ. of California Berkeley, * Univ. of Colorado Denver and NAG Ltd.. * December 2016 + * and + * Joffrain, Low, Quintana-Orti, et al. (2006). Accumulating householder + * transformations, revisited. + * ACM Transactions on Mathematical Software 32(2), p. 169-179. * Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. * * Redistribution and use in source and binary forms, with or without @@ -35,6 +39,7 @@ #include "rocauxiliary_lacgv.hpp" #include "rocblas.hpp" #include "rocsolver/rocsolver.h" +#include "rocsolver_run_specialized_kernels.hpp" ROCSOLVER_BEGIN_NAMESPACE @@ -51,7 +56,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, const rocblas_int ldf, const rocblas_stride strideF, const rocblas_direct direct, - const rocblas_storev storev) + const rocblas_storev storev, + const bool add_fp) { const auto b = hipBlockIdx_z; const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; @@ -65,30 +71,68 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, Fp = F + b * strideF; if(j == i) - Fp[j + i * ldf] = tp[i]; + Fp[idx2D(j, i, ldf)] = tp[i]; else if(direct == rocblas_forward_direction) { if(j < i) { if(storev == rocblas_column_wise) - Fp[j + i * ldf] = -tp[i] * Vp[i + j * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D(i, j, ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D(i, j, ldv)]); + } + } else - Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D(j, i, ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D(j, i, ldv)]); + } + } } else - Fp[j + i * ldf] = 0; + Fp[idx2D(j, i, ldf)] = 0; } else { if(j > i) { if(storev == rocblas_column_wise) - Fp[j + i * ldf] = -tp[i] * Vp[(n - k + i) + j * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D((n - k + i), j, ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] + = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D((n - k + i), j, ldv)]); + } + } else - Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D(j, (n - k + i), ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] + = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D(j, (n - k + i), ldv)]); + } + } } else - Fp[j + i * ldf] = 0; + Fp[idx2D(j, i, ldf)] = 0; } } } @@ -106,7 +150,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, const rocblas_int ldf, const rocblas_stride strideF, const rocblas_direct direct, - const rocblas_storev storev) + const rocblas_storev storev, + const bool add_fp) { const auto b = hipBlockIdx_z; const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; @@ -120,30 +165,69 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, Fp = F + b * strideF; if(j == i) - Fp[j + i * ldf] = tp[i]; + Fp[idx2D(j, i, ldf)] = tp[i]; else if(direct == rocblas_forward_direction) { if(j < i) { if(storev == rocblas_column_wise) - Fp[j + i * ldf] = -tp[i] * conj(Vp[i + j * ldv]); + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * conj(Vp[idx2D(i, j, ldv)]); + } + else + { + Fp[idx2D(j, i, ldf)] + = -tp[i] * (Fp[idx2D(j, i, ldf)] + conj(Vp[idx2D(i, j, ldv)])); + } + } else - Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D(j, i, ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D(j, i, ldv)]); + } + } } else - Fp[j + i * ldf] = 0; + Fp[idx2D(j, i, ldf)] = 0; } else { if(j > i) { if(storev == rocblas_column_wise) - Fp[j + i * ldf] = -tp[i] * conj(Vp[(n - k + i) + j * ldv]); + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * conj(Vp[idx2D((n - k + i), j, ldv)]); + } + else + { + Fp[idx2D(j, i, ldf)] + = -tp[i] * (Fp[idx2D(j, i, ldf)] + conj(Vp[idx2D((n - k + i), j, ldv)])); + } + } else - Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; + { + if(!add_fp) + { + Fp[idx2D(j, i, ldf)] = -tp[i] * Vp[idx2D(j, (n - k + i), ldv)]; + } + else + { + Fp[idx2D(j, i, ldf)] + = -tp[i] * (Fp[idx2D(j, i, ldf)] + Vp[idx2D(j, (n - k + i), ldv)]); + } + } } else - Fp[j + i * ldf] = 0; + Fp[idx2D(j, i, ldf)] = 0; } } } @@ -161,6 +245,40 @@ ROCSOLVER_KERNEL void set_tau(const rocblas_int k, T* tau, const rocblas_stride } } +template , int> = 0> +bool larft_use_gemm(const I dim, const rocblas_direct direct, const rocblas_storev storev) +{ + I value[] = {LARFT_L3_DEFAULT}; + I intervals[] = {LARFT_L3_INTERVALS_DEFAULT}; + I max = LARFT_L3_NUM_INTERVALS_DEFAULT; + + return value[get_index(intervals, max, dim)]; +} + +/** In most cases, LARFT finds more performance when a subset of the computation is done using GEMM. + The configuration of rocblas_float_complex and rocblas_column_wise is unique in that there's + only a narrow band where using GEMM is more performant. **/ +template , int> = 0> +bool larft_use_gemm(const I dim, const rocblas_direct direct, const rocblas_storev storev) +{ + if(storev == rocblas_column_wise) + { + I value[] = {LARFT_L3_C_COL}; + I intervals[] = {LARFT_L3_INTERVALS_C_COL}; + I max = LARFT_L3_NUM_INTERVALS_C_COL; + + return value[get_index(intervals, max, dim)]; + } + else + { + I value[] = {LARFT_L3_DEFAULT}; + I intervals[] = {LARFT_L3_INTERVALS_DEFAULT}; + I max = LARFT_L3_NUM_INTERVALS_DEFAULT; + + return value[get_index(intervals, max, dim)]; + } +} + template void rocsolver_larft_getMemorySize(const rocblas_int n, const rocblas_int k, @@ -269,13 +387,52 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, rocblas_fill uplo; rocblas_operation trans; + const bool use_gemm = larft_use_gemm(n, direct, storev) && n > k; + + const rocblas_int u1_n = use_gemm ? k : n; + const rocblas_int u2_n = use_gemm ? n - k : 0; + + // Compute T=V2'*V2 or V2*V2' (V'=[V1' V2'] where V1 is triangular and V is trapezoidal) + // SYRK/HERK can be used alternatively, but GEMM is currently more performant. + if(use_gemm) + { + if(direct == rocblas_forward_direction && storev == rocblas_column_wise) + { + rocsolver_gemm(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, + k, u2_n, scalars + 2, V, shiftV + idx2D(u1_n, 0, ldv), ldv, strideV, V, + shiftV + idx2D(u1_n, 0, ldv), ldv, strideV, scalars + 1, F, + idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); + } + else if(direct == rocblas_backward_direction && storev == rocblas_column_wise) + { + rocsolver_gemm(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, + k, u2_n, scalars + 2, V, shiftV + idx2D(0, 0, ldv), ldv, strideV, V, + shiftV + idx2D(0, 0, ldv), ldv, strideV, scalars + 1, F, + idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); + } + else if(direct == rocblas_forward_direction && storev == rocblas_row_wise) + { + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, + k, u2_n, scalars + 2, V, shiftV + idx2D(0, u1_n, ldv), ldv, strideV, V, + shiftV + idx2D(0, u1_n, ldv), ldv, strideV, scalars + 1, F, + idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); + } + else if(direct == rocblas_backward_direction && storev == rocblas_row_wise) + { + rocsolver_gemm(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, + k, u2_n, scalars + 2, V, shiftV + idx2D(0, 0, ldv), ldv, strideV, V, + shiftV + idx2D(0, 0, ldv), ldv, strideV, scalars + 1, F, + idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); + } + } + // Fix diagonal of T, make zero the not used triangular part, // setup tau (changing signs) and account for the non-stored 1's on the // householder vectors rocblas_int blocks = (k - 1) / 32 + 1; ROCSOLVER_LAUNCH_KERNEL(set_triangular, dim3(blocks, blocks, batch_count), dim3(32, 32), 0, stream, n, k, V, shiftV, ldv, strideV, tau, strideT, F, ldf, strideF, - direct, storev); + direct, storev, use_gemm); ROCSOLVER_LAUNCH_KERNEL(set_tau, dim3(blocks, batch_count), dim3(32, 1), 0, stream, k, tau, strideT); @@ -294,7 +451,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, if(storev == rocblas_column_wise) { trans = rocblas_operation_conjugate_transpose; - rocblasCall_gemv(handle, trans, n - 1 - i, i, tau + i, strideT, V, + rocblasCall_gemv(handle, trans, u1_n - 1 - i, i, tau + i, strideT, V, shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V, shiftV + idx2D(i + 1, i, ldv), 1, strideV, scalars + 2, 0, F, idx2D(0, i, ldf), 1, strideF, batch_count, workArr); @@ -306,7 +463,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, ldv, strideV, batch_count); trans = rocblas_operation_none; - rocblasCall_gemv(handle, trans, i, n - 1 - i, tau + i, strideT, V, + rocblasCall_gemv(handle, trans, i, u1_n - 1 - i, tau + i, strideT, V, shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V, shiftV + idx2D(i, i + 1, ldv), ldv, strideV, scalars + 2, 0, F, idx2D(0, i, ldf), 1, strideF, batch_count, workArr); @@ -337,9 +494,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, if(storev == rocblas_column_wise) { trans = rocblas_operation_conjugate_transpose; - rocblasCall_gemv(handle, trans, n - k + i, k - i - 1, tau + i, strideT, V, - shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V, - shiftV + idx2D(0, i, ldv), 1, strideV, scalars + 2, 0, F, + rocblasCall_gemv(handle, trans, u1_n - k + i, k - i - 1, tau + i, strideT, V, + shiftV + idx2D(u2_n, i + 1, ldv), ldv, strideV, V, + shiftV + idx2D(u2_n, i, ldv), 1, strideV, scalars + 2, 0, F, idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr); } else @@ -349,9 +506,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, ldv, strideV, batch_count); trans = rocblas_operation_none; - rocblasCall_gemv(handle, trans, k - i - 1, n - k + i, tau + i, strideT, V, - shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V, - shiftV + idx2D(i, 0, ldv), ldv, strideV, scalars + 2, 0, F, + rocblasCall_gemv(handle, trans, k - i - 1, u1_n - k + i, tau + i, strideT, V, + shiftV + idx2D(i + 1, u2_n, ldv), ldv, strideV, V, + shiftV + idx2D(i, u2_n, ldv), ldv, strideV, scalars + 2, 0, F, idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr); if(COMPLEX) diff --git a/library/src/include/ideal_sizes.hpp b/library/src/include/ideal_sizes.hpp index 4220f10e0..33a88d06e 100644 --- a/library/src/include/ideal_sizes.hpp +++ b/library/src/include/ideal_sizes.hpp @@ -501,3 +501,24 @@ #ifndef SPLITLU_SWITCH_SIZE #define SPLITLU_SWITCH_SIZE 64 #endif + +/******************************* larft **************************************** +*******************************************************************************/ +#ifndef LARFT_L3_NUM_INTERVALS_DEFAULT +#define LARFT_L3_NUM_INTERVALS_DEFAULT 0 +#endif +#ifndef LARFT_L3_INTERVALS_DEFAULT +#define LARFT_L3_INTERVALS_DEFAULT 0 +#endif +#ifndef LARFT_L3_DEFAULT +#define LARFT_L3_DEFAULT 1 +#endif +#ifndef LARFT_L3_NUM_INTERVALS_C_COL +#define LARFT_L3_NUM_INTERVALS_C_COL 2 +#endif +#ifndef LARFT_L3_INTERVALS_C_COL +#define LARFT_L3_INTERVALS_C_COL 1176, 2144 +#endif +#ifndef LARFT_L3_C_COL +#define LARFT_L3_C_COL 0, 1, 0 +#endif diff --git a/library/src/include/lib_host_helpers.hpp b/library/src/include/lib_host_helpers.hpp index c0fd206be..bca4aae41 100644 --- a/library/src/include/lib_host_helpers.hpp +++ b/library/src/include/lib_host_helpers.hpp @@ -43,12 +43,13 @@ ROCSOLVER_BEGIN_NAMESPACE * =========================================================================== */ -inline int64_t idx2D(const int64_t i, const int64_t j, const int64_t lda) +__device__ __host__ inline int64_t idx2D(const int64_t i, const int64_t j, const int64_t lda) { return j * lda + i; } -inline int64_t idx2D(const int64_t i, const int64_t j, const int64_t inca, const int64_t lda) +__device__ __host__ inline int64_t + idx2D(const int64_t i, const int64_t j, const int64_t inca, const int64_t lda) { return j * lda + i * inca; }