-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use L3 BLAS in LARFT #799
Use L3 BLAS in LARFT #799
Changes from 6 commits
ec8d369
a41fc7f
5551835
abf57d4
3bf4974
dc49736
c421dde
c20db67
01dc39e
e7555a0
b60c463
3c1cdb3
4d98734
7fa5b6b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,10 @@ | |
* Univ. of Tennessee, Univ. of California Berkeley, | ||
* Univ. of Colorado Denver and NAG Ltd.. | ||
* December 2016 | ||
* and | ||
* Joffrain, Low, Quintana-Ortí, et al. (2006). Accumulating householder | ||
* transformations, revisited. | ||
* ACM Transactions on Mathematical Software 32(2), p. 169–179. | ||
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
* | ||
* Redistribution and use in source and binary forms, with or without | ||
|
@@ -51,7 +55,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
const rocblas_int ldf, | ||
const rocblas_stride strideF, | ||
const rocblas_direct direct, | ||
const rocblas_storev storev) | ||
const rocblas_storev storev, | ||
const bool inc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would suggest renaming There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
{ | ||
const auto b = hipBlockIdx_z; | ||
const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; | ||
|
@@ -71,9 +76,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
if(j < i) | ||
{ | ||
if(storev == rocblas_column_wise) | ||
Fp[j + i * ldf] = -tp[i] * Vp[i + j * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[i + j * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[i + j * ldv]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You've got two writes to global memory here, as well as two reads from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed to reduce redundant reads/writes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to perform the index calculation as "j + i * int64_t(ldf)" and "i + j * int64_t(ldv)" to avoid 32bit integer overflow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed using |
||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[j + i * ldv]; | ||
} | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = 0; | ||
|
@@ -83,9 +108,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
if(j > i) | ||
{ | ||
if(storev == rocblas_column_wise) | ||
Fp[j + i * ldf] = -tp[i] * Vp[(n - k + i) + j * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[(n - k + i) + j * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[(n - k + i) + j * ldv]; | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
} | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = 0; | ||
|
@@ -106,7 +151,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
const rocblas_int ldf, | ||
const rocblas_stride strideF, | ||
const rocblas_direct direct, | ||
const rocblas_storev storev) | ||
const rocblas_storev storev, | ||
const bool inc) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
{ | ||
const auto b = hipBlockIdx_z; | ||
const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x; | ||
|
@@ -126,9 +172,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
if(j < i) | ||
{ | ||
if(storev == rocblas_column_wise) | ||
Fp[j + i * ldf] = -tp[i] * conj(Vp[i + j * ldv]); | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * conj(Vp[i + j * ldv]); | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * conj(Vp[i + j * ldv]); | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[j + i * ldv]; | ||
} | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = 0; | ||
|
@@ -138,9 +204,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n, | |
if(j > i) | ||
{ | ||
if(storev == rocblas_column_wise) | ||
Fp[j + i * ldf] = -tp[i] * conj(Vp[(n - k + i) + j * ldv]); | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * conj(Vp[(n - k + i) + j * ldv]); | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * conj(Vp[(n - k + i) + j * ldv]); | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
{ | ||
if(!inc) | ||
{ | ||
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
} | ||
else | ||
{ | ||
Fp[j + i * ldf] *= -tp[i]; | ||
Fp[j + i * ldf] += -tp[i] * Vp[j + (n - k + i) * ldv]; | ||
} | ||
} | ||
} | ||
else | ||
Fp[j + i * ldf] = 0; | ||
|
@@ -161,6 +247,41 @@ ROCSOLVER_KERNEL void set_tau(const rocblas_int k, T* tau, const rocblas_stride | |
} | ||
} | ||
|
||
template <typename T, typename I, std::enable_if_t<!std::is_same_v<T, rocblas_float_complex>, int> = 0> | ||
bool larft_do_l3(const I dim, | ||
const rocblas_direct direct, | ||
const rocblas_storev storev) | ||
{ | ||
I value[] = {LARFT_L3_DEFAULT}; | ||
I intervals[] = {LARFT_L3_INTERVALS_DEFAULT}; | ||
I max = LARFT_L3_NUM_INTERVALS_DEFAULT; | ||
|
||
return value[get_index(intervals, max, dim)]; | ||
} | ||
|
||
template <typename T, typename I, std::enable_if_t<std::is_same_v<T, rocblas_float_complex>, int> = 0> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A short comment describing why this exception for float complex exists would be appreciated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a comment |
||
bool larft_do_l3(const I dim, | ||
const rocblas_direct direct, | ||
const rocblas_storev storev) | ||
{ | ||
if(storev == rocblas_column_wise) | ||
{ | ||
I value[] = {LARFT_L3_C_COL}; | ||
I intervals[] = {LARFT_L3_INTERVALS_C_COL}; | ||
I max = LARFT_L3_NUM_INTERVALS_C_COL; | ||
|
||
return value[get_index(intervals, max, dim)]; | ||
} | ||
else | ||
{ | ||
I value[] = {LARFT_L3_DEFAULT}; | ||
I intervals[] = {LARFT_L3_INTERVALS_DEFAULT}; | ||
I max = LARFT_L3_NUM_INTERVALS_DEFAULT; | ||
|
||
return value[get_index(intervals, max, dim)]; | ||
} | ||
} | ||
|
||
template <bool BATCHED, typename T> | ||
void rocsolver_larft_getMemorySize(const rocblas_int n, | ||
const rocblas_int k, | ||
|
@@ -269,13 +390,50 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, | |
rocblas_fill uplo; | ||
rocblas_operation trans; | ||
|
||
const bool call_l3 = larft_do_l3<T>(n, direct, storev) && n > k; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of these names. Maybe something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
const rocblas_int u1_n = call_l3 ? k : n; | ||
const rocblas_int u2_n = call_l3 ? n - k : 0; | ||
|
||
if(call_l3) | ||
{ | ||
if(direct == rocblas_forward_direction && storev == rocblas_column_wise) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A short comment explaining what the gemm is doing would be appreciated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a comment |
||
{ | ||
rocblasCall_gemm<T>(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, k, u2_n, scalars + 2, | ||
V, shiftV + idx2D(u1_n, 0, ldv), ldv, strideV, | ||
V, shiftV + idx2D(u1_n, 0, ldv), ldv, strideV, | ||
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); | ||
} | ||
else if(direct == rocblas_backward_direction && storev == rocblas_column_wise) | ||
{ | ||
rocblasCall_gemm<T>(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, k, u2_n, scalars + 2, | ||
V, shiftV + idx2D(0, 0, ldv), ldv, strideV, | ||
V, shiftV + idx2D(0, 0, ldv), ldv, strideV, | ||
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); | ||
} | ||
else if(direct == rocblas_forward_direction && storev == rocblas_row_wise) | ||
{ | ||
rocblasCall_gemm<T>(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, k, u2_n, scalars + 2, | ||
V, shiftV + idx2D(0, u1_n, ldv), ldv, strideV, | ||
V, shiftV + idx2D(0, u1_n, ldv), ldv, strideV, | ||
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); | ||
} | ||
else if(direct == rocblas_backward_direction && storev == rocblas_row_wise) | ||
{ | ||
rocblasCall_gemm<T>(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, k, u2_n, scalars + 2, | ||
V, shiftV + idx2D(0, 0, ldv), ldv, strideV, | ||
V, shiftV + idx2D(0, 0, ldv), ldv, strideV, | ||
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr); | ||
} | ||
} | ||
|
||
// Fix diagonal of T, make zero the not used triangular part, | ||
// setup tau (changing signs) and account for the non-stored 1's on the | ||
// householder vectors | ||
rocblas_int blocks = (k - 1) / 32 + 1; | ||
ROCSOLVER_LAUNCH_KERNEL(set_triangular, dim3(blocks, blocks, batch_count), dim3(32, 32), 0, | ||
stream, n, k, V, shiftV, ldv, strideV, tau, strideT, F, ldf, strideF, | ||
direct, storev); | ||
direct, storev, call_l3); | ||
ROCSOLVER_LAUNCH_KERNEL(set_tau, dim3(blocks, batch_count), dim3(32, 1), 0, stream, k, tau, | ||
strideT); | ||
|
||
|
@@ -294,7 +452,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, | |
if(storev == rocblas_column_wise) | ||
{ | ||
trans = rocblas_operation_conjugate_transpose; | ||
rocblasCall_gemv<T>(handle, trans, n - 1 - i, i, tau + i, strideT, V, | ||
rocblasCall_gemv<T>(handle, trans, u1_n - 1 - i, i, tau + i, strideT, V, | ||
shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V, | ||
shiftV + idx2D(i + 1, i, ldv), 1, strideV, scalars + 2, 0, F, | ||
idx2D(0, i, ldf), 1, strideF, batch_count, workArr); | ||
|
@@ -306,7 +464,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, | |
ldv, strideV, batch_count); | ||
|
||
trans = rocblas_operation_none; | ||
rocblasCall_gemv<T>(handle, trans, i, n - 1 - i, tau + i, strideT, V, | ||
rocblasCall_gemv<T>(handle, trans, i, u1_n - 1 - i, tau + i, strideT, V, | ||
shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V, | ||
shiftV + idx2D(i, i + 1, ldv), ldv, strideV, scalars + 2, 0, F, | ||
idx2D(0, i, ldf), 1, strideF, batch_count, workArr); | ||
|
@@ -337,9 +495,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, | |
if(storev == rocblas_column_wise) | ||
{ | ||
trans = rocblas_operation_conjugate_transpose; | ||
rocblasCall_gemv<T>(handle, trans, n - k + i, k - i - 1, tau + i, strideT, V, | ||
shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V, | ||
shiftV + idx2D(0, i, ldv), 1, strideV, scalars + 2, 0, F, | ||
rocblasCall_gemv<T>(handle, trans, u1_n - k + i, k - i - 1, tau + i, strideT, V, | ||
shiftV + idx2D(u2_n, i + 1, ldv), ldv, strideV, V, | ||
shiftV + idx2D(u2_n, i, ldv), 1, strideV, scalars + 2, 0, F, | ||
idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr); | ||
} | ||
else | ||
|
@@ -349,9 +507,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle, | |
ldv, strideV, batch_count); | ||
|
||
trans = rocblas_operation_none; | ||
rocblasCall_gemv<T>(handle, trans, k - i - 1, n - k + i, tau + i, strideT, V, | ||
shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V, | ||
shiftV + idx2D(i, 0, ldv), ldv, strideV, scalars + 2, 0, F, | ||
rocblasCall_gemv<T>(handle, trans, k - i - 1, u1_n - k + i, tau + i, strideT, V, | ||
shiftV + idx2D(i + 1, u2_n, ldv), ldv, strideV, V, | ||
shiftV + idx2D(i, u2_n, ldv), ldv, strideV, scalars + 2, 0, F, | ||
idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr); | ||
|
||
if(COMPLEX) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if this line were indented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, please check if I did it correctly.