Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use L3 BLAS in LARFT #799

Closed
wants to merge 14 commits into from
196 changes: 177 additions & 19 deletions library/src/auxiliary/rocauxiliary_larft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
* Univ. of Tennessee, Univ. of California Berkeley,
* Univ. of Colorado Denver and NAG Ltd..
* December 2016
* and
* Joffrain, Low, Quintana-Ortí, et al. (2006). Accumulating householder
* transformations, revisited.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if this line were indented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, please check if I did it correctly.

* ACM Transactions on Mathematical Software 32(2), p. 169–179.
* Copyright (C) 2019-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -51,7 +55,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
const rocblas_int ldf,
const rocblas_stride strideF,
const rocblas_direct direct,
const rocblas_storev storev)
const rocblas_storev storev,
const bool inc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest renaming inc, since when I see it I think of "increment", which doesn't seem to be what it's doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
const auto b = hipBlockIdx_z;
const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
Expand All @@ -71,9 +76,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
if(j < i)
{
if(storev == rocblas_column_wise)
Fp[j + i * ldf] = -tp[i] * Vp[i + j * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[i + j * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[i + j * ldv];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've got two writes to global memory here, as well as two reads from tp[i]. You may want to cache tp[i] and Fp[j + i * ldf] * -tp[i] in separate local variables, which will hopefully speed up the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to reduce redundant reads/writes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to perform the index calculation as "j + i * int64_t(ldf)" and "i + j * int64_t(ldv)" to avoid 32bit integer overflow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed using idx2D

}
}
else
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[j + i * ldv];
}
}
}
else
Fp[j + i * ldf] = 0;
Expand All @@ -83,9 +108,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
if(j > i)
{
if(storev == rocblas_column_wise)
Fp[j + i * ldf] = -tp[i] * Vp[(n - k + i) + j * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[(n - k + i) + j * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[(n - k + i) + j * ldv];
}
}
else
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[j + (n - k + i) * ldv];
}
}
}
else
Fp[j + i * ldf] = 0;
Expand All @@ -106,7 +151,8 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
const rocblas_int ldf,
const rocblas_stride strideF,
const rocblas_direct direct,
const rocblas_storev storev)
const rocblas_storev storev,
const bool inc)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
const auto b = hipBlockIdx_z;
const auto i = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x;
Expand All @@ -126,9 +172,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
if(j < i)
{
if(storev == rocblas_column_wise)
Fp[j + i * ldf] = -tp[i] * conj(Vp[i + j * ldv]);
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * conj(Vp[i + j * ldv]);
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * conj(Vp[i + j * ldv]);
}
}
else
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[j + i * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[j + i * ldv];
}
}
}
else
Fp[j + i * ldf] = 0;
Expand All @@ -138,9 +204,29 @@ ROCSOLVER_KERNEL void set_triangular(const rocblas_int n,
if(j > i)
{
if(storev == rocblas_column_wise)
Fp[j + i * ldf] = -tp[i] * conj(Vp[(n - k + i) + j * ldv]);
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * conj(Vp[(n - k + i) + j * ldv]);
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * conj(Vp[(n - k + i) + j * ldv]);
}
}
else
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv];
{
if(!inc)
{
Fp[j + i * ldf] = -tp[i] * Vp[j + (n - k + i) * ldv];
}
else
{
Fp[j + i * ldf] *= -tp[i];
Fp[j + i * ldf] += -tp[i] * Vp[j + (n - k + i) * ldv];
}
}
}
else
Fp[j + i * ldf] = 0;
Expand All @@ -161,6 +247,41 @@ ROCSOLVER_KERNEL void set_tau(const rocblas_int k, T* tau, const rocblas_stride
}
}

template <typename T, typename I, std::enable_if_t<!std::is_same_v<T, rocblas_float_complex>, int> = 0>
bool larft_do_l3(const I dim,
const rocblas_direct direct,
const rocblas_storev storev)
{
I value[] = {LARFT_L3_DEFAULT};
I intervals[] = {LARFT_L3_INTERVALS_DEFAULT};
I max = LARFT_L3_NUM_INTERVALS_DEFAULT;

return value[get_index(intervals, max, dim)];
}

template <typename T, typename I, std::enable_if_t<std::is_same_v<T, rocblas_float_complex>, int> = 0>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A short comment describing why this exception for float complex exists would be appreciated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

bool larft_do_l3(const I dim,
const rocblas_direct direct,
const rocblas_storev storev)
{
if(storev == rocblas_column_wise)
{
I value[] = {LARFT_L3_C_COL};
I intervals[] = {LARFT_L3_INTERVALS_C_COL};
I max = LARFT_L3_NUM_INTERVALS_C_COL;

return value[get_index(intervals, max, dim)];
}
else
{
I value[] = {LARFT_L3_DEFAULT};
I intervals[] = {LARFT_L3_INTERVALS_DEFAULT};
I max = LARFT_L3_NUM_INTERVALS_DEFAULT;

return value[get_index(intervals, max, dim)];
}
}

template <bool BATCHED, typename T>
void rocsolver_larft_getMemorySize(const rocblas_int n,
const rocblas_int k,
Expand Down Expand Up @@ -269,13 +390,50 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle,
rocblas_fill uplo;
rocblas_operation trans;

const bool call_l3 = larft_do_l3<T>(n, direct, storev) && n > k;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of these names. Maybe something like use_gemm and larft_use_gemm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


const rocblas_int u1_n = call_l3 ? k : n;
const rocblas_int u2_n = call_l3 ? n - k : 0;

if(call_l3)
{
if(direct == rocblas_forward_direction && storev == rocblas_column_wise)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A short comment explaining what the gemm is doing would be appreciated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

{
rocblasCall_gemm<T>(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, k, u2_n, scalars + 2,
V, shiftV + idx2D(u1_n, 0, ldv), ldv, strideV,
V, shiftV + idx2D(u1_n, 0, ldv), ldv, strideV,
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr);
}
else if(direct == rocblas_backward_direction && storev == rocblas_column_wise)
{
rocblasCall_gemm<T>(handle, rocblas_operation_conjugate_transpose, rocblas_operation_none, k, k, u2_n, scalars + 2,
V, shiftV + idx2D(0, 0, ldv), ldv, strideV,
V, shiftV + idx2D(0, 0, ldv), ldv, strideV,
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr);
}
else if(direct == rocblas_forward_direction && storev == rocblas_row_wise)
{
rocblasCall_gemm<T>(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, k, u2_n, scalars + 2,
V, shiftV + idx2D(0, u1_n, ldv), ldv, strideV,
V, shiftV + idx2D(0, u1_n, ldv), ldv, strideV,
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr);
}
else if(direct == rocblas_backward_direction && storev == rocblas_row_wise)
{
rocblasCall_gemm<T>(handle, rocblas_operation_none, rocblas_operation_conjugate_transpose, k, k, u2_n, scalars + 2,
V, shiftV + idx2D(0, 0, ldv), ldv, strideV,
V, shiftV + idx2D(0, 0, ldv), ldv, strideV,
scalars + 1, F, idx2D(0, 0, ldf), ldf, strideF, batch_count, workArr);
}
}

// Fix diagonal of T, make zero the not used triangular part,
// setup tau (changing signs) and account for the non-stored 1's on the
// householder vectors
rocblas_int blocks = (k - 1) / 32 + 1;
ROCSOLVER_LAUNCH_KERNEL(set_triangular, dim3(blocks, blocks, batch_count), dim3(32, 32), 0,
stream, n, k, V, shiftV, ldv, strideV, tau, strideT, F, ldf, strideF,
direct, storev);
direct, storev, call_l3);
ROCSOLVER_LAUNCH_KERNEL(set_tau, dim3(blocks, batch_count), dim3(32, 1), 0, stream, k, tau,
strideT);

Expand All @@ -294,7 +452,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle,
if(storev == rocblas_column_wise)
{
trans = rocblas_operation_conjugate_transpose;
rocblasCall_gemv<T>(handle, trans, n - 1 - i, i, tau + i, strideT, V,
rocblasCall_gemv<T>(handle, trans, u1_n - 1 - i, i, tau + i, strideT, V,
shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V,
shiftV + idx2D(i + 1, i, ldv), 1, strideV, scalars + 2, 0, F,
idx2D(0, i, ldf), 1, strideF, batch_count, workArr);
Expand All @@ -306,7 +464,7 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle,
ldv, strideV, batch_count);

trans = rocblas_operation_none;
rocblasCall_gemv<T>(handle, trans, i, n - 1 - i, tau + i, strideT, V,
rocblasCall_gemv<T>(handle, trans, i, u1_n - 1 - i, tau + i, strideT, V,
shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V,
shiftV + idx2D(i, i + 1, ldv), ldv, strideV, scalars + 2, 0, F,
idx2D(0, i, ldf), 1, strideF, batch_count, workArr);
Expand Down Expand Up @@ -337,9 +495,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle,
if(storev == rocblas_column_wise)
{
trans = rocblas_operation_conjugate_transpose;
rocblasCall_gemv<T>(handle, trans, n - k + i, k - i - 1, tau + i, strideT, V,
shiftV + idx2D(0, i + 1, ldv), ldv, strideV, V,
shiftV + idx2D(0, i, ldv), 1, strideV, scalars + 2, 0, F,
rocblasCall_gemv<T>(handle, trans, u1_n - k + i, k - i - 1, tau + i, strideT, V,
shiftV + idx2D(u2_n, i + 1, ldv), ldv, strideV, V,
shiftV + idx2D(u2_n, i, ldv), 1, strideV, scalars + 2, 0, F,
idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr);
}
else
Expand All @@ -349,9 +507,9 @@ rocblas_status rocsolver_larft_template(rocblas_handle handle,
ldv, strideV, batch_count);

trans = rocblas_operation_none;
rocblasCall_gemv<T>(handle, trans, k - i - 1, n - k + i, tau + i, strideT, V,
shiftV + idx2D(i + 1, 0, ldv), ldv, strideV, V,
shiftV + idx2D(i, 0, ldv), ldv, strideV, scalars + 2, 0, F,
rocblasCall_gemv<T>(handle, trans, k - i - 1, u1_n - k + i, tau + i, strideT, V,
shiftV + idx2D(i + 1, u2_n, ldv), ldv, strideV, V,
shiftV + idx2D(i, u2_n, ldv), ldv, strideV, scalars + 2, 0, F,
idx2D(i + 1, i, ldf), 1, strideF, batch_count, workArr);

if(COMPLEX)
Expand Down
21 changes: 21 additions & 0 deletions library/src/include/ideal_sizes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,24 @@
#ifndef SPLITLU_SWITCH_SIZE
#define SPLITLU_SWITCH_SIZE 64
#endif

/******************************* larft ****************************************
*******************************************************************************/
#ifndef LARFT_L3_NUM_INTERVALS_DEFAULT
#define LARFT_L3_NUM_INTERVALS_DEFAULT 0
#endif
#ifndef LARFT_L3_INTERVALS_DEFAULT
#define LARFT_L3_INTERVALS_DEFAULT 0
#endif
#ifndef LARFT_L3_DEFAULT
#define LARFT_L3_DEFAULT 1
#endif
#ifndef LARFT_L3_NUM_INTERVALS_C_COL
#define LARFT_L3_NUM_INTERVALS_C_COL 2
#endif
#ifndef LARFT_L3_INTERVALS_C_COL
#define LARFT_L3_INTERVALS_C_COL 1176, 2144
#endif
#ifndef LARFT_L3_C_COL
#define LARFT_L3_C_COL 0, 1, 0
#endif