Skip to content

Commit

Permalink
Hotfix adopt hipDataType and hipblasComputeType_t (#462)
Browse files Browse the repository at this point in the history
* Adopt hipDataType and deprecate hipblasltDatatype_t

* replace hipblasLtComputeType_t with hipblasComputeType_t

* update changelog

---------

Co-authored-by: Jeff Daily <[email protected]>
  • Loading branch information
jichangjichang and jeffdaily authored Nov 30, 2023
1 parent a9c5cc7 commit 592518e
Show file tree
Hide file tree
Showing 54 changed files with 1,496 additions and 1,502 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
- Support fp8/bf8 datatype (only for gfx94x platform)
- Support Scalar A,B,C,D for fp8/bf8 datatype
### Changed
- Replace hipblasDatatype_t with hipblasltDatatype_t
- Replace hipblasDatatype_t with hipDataType
- Replace hipblasLtComputeType_t with hipblasComputeType_t
- Deprecate HIPBLASLT_MATMUL_DESC_D_SCALE_VECTOR_POINTER

## (Unreleased) hipBLASLt 0.3.0
Expand Down
62 changes: 31 additions & 31 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,7 @@ void run_function(const func_map& map, const Arguments& arg, const std::string&
auto match = map.find(arg.function);
if(match == map.end())
throw std::invalid_argument("Invalid combination --function "s + arg.function
+ " --a_type "s + hipblaslt_datatype_to_string(arg.a_type)
+ msg);
+ " --a_type "s + hip_datatype_to_string(arg.a_type) + msg);
match->second(arg);
}

Expand All @@ -92,16 +91,17 @@ struct perf_matmul<
To,
Tc,
Tci,
std::enable_if_t<(std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hip_bfloat16>{} && std::is_same<TiB, hip_bfloat16>{})
|| (std::is_same<TiA, float>{} && std::is_same<TiB, float>{})
|| (std::is_same<TiA, hipblaslt_f8>{} && std::is_same<TiB, hipblaslt_f8>{})
|| (std::is_same<TiA, hipblaslt_f8>{} && std::is_same<TiB, hipblaslt_bf8>{})
|| (std::is_same<TiA, hipblaslt_bf8>{} && std::is_same<TiB, hipblaslt_f8>{})
|| (std::is_same<TiA, double>{} && std::is_same<TiB, double>{})
|| (std::is_same<TiA, hipblasLtInt8>{} && std::is_same<TiB, hipblasLtInt8>{})
|| (std::is_same<TiA, hipblaslt_f8>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblaslt_f8>{})>>
std::enable_if_t<
(std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hip_bfloat16>{} && std::is_same<TiB, hip_bfloat16>{})
|| (std::is_same<TiA, float>{} && std::is_same<TiB, float>{})
|| (std::is_same<TiA, hipblaslt_f8_fnuz>{} && std::is_same<TiB, hipblaslt_f8_fnuz>{})
|| (std::is_same<TiA, hipblaslt_f8_fnuz>{} && std::is_same<TiB, hipblaslt_bf8_fnuz>{})
|| (std::is_same<TiA, hipblaslt_bf8_fnuz>{} && std::is_same<TiB, hipblaslt_f8_fnuz>{})
|| (std::is_same<TiA, double>{} && std::is_same<TiB, double>{})
|| (std::is_same<TiA, hipblasLtInt8>{} && std::is_same<TiB, hipblasLtInt8>{})
|| (std::is_same<TiA, hipblaslt_f8_fnuz>{} && std::is_same<TiB, hipblasLtHalf>{})
|| (std::is_same<TiA, hipblasLtHalf>{} && std::is_same<TiB, hipblaslt_f8_fnuz>{})>>
: hipblaslt_test_valid
{
void operator()(const Arguments& arg)
Expand Down Expand Up @@ -247,7 +247,7 @@ void fix_batch(int argc, char* argv[])
void hipblaslt_print_version(void)
{
int version;
char git_version[128];
char git_version[128];
hipblaslt_local_handle handle;
hipblasLtGetVersion(handle, &version);
hipblasLtGetGitRevision(handle, &git_version[0]);
Expand Down Expand Up @@ -606,38 +606,38 @@ try
}

std::transform(precision.begin(), precision.end(), precision.begin(), ::tolower);
auto prec = string_to_hipblaslt_datatype(precision);
if(prec == static_cast<hipblasltDatatype_t>(0))
auto prec = string_to_hip_datatype(precision);
if(prec == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --precision " + precision);

arg.a_type = a_type == "" ? prec : string_to_hipblaslt_datatype(a_type);
if(arg.a_type == static_cast<hipblasltDatatype_t>(0))
arg.a_type = a_type == "" ? prec : string_to_hip_datatype(a_type);
if(arg.a_type == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --a_type " + a_type);

arg.b_type = b_type == "" ? prec : string_to_hipblaslt_datatype(b_type);
if(arg.b_type == static_cast<hipblasltDatatype_t>(0))
arg.b_type = b_type == "" ? prec : string_to_hip_datatype(b_type);
if(arg.b_type == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --b_type " + b_type);

arg.c_type = c_type == "" ? prec : string_to_hipblaslt_datatype(c_type);
if(arg.c_type == static_cast<hipblasltDatatype_t>(0))
arg.c_type = c_type == "" ? prec : string_to_hip_datatype(c_type);
if(arg.c_type == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --c_type " + c_type);

arg.d_type = d_type == "" ? prec : string_to_hipblaslt_datatype(d_type);
if(arg.d_type == static_cast<hipblasltDatatype_t>(0))
arg.d_type = d_type == "" ? prec : string_to_hip_datatype(d_type);
if(arg.d_type == HIPBLASLT_DATATYPE_INVALID)
throw std::invalid_argument("Invalid value for --d_type " + d_type);

bool is_f16 = arg.a_type == HIPBLASLT_R_16F || arg.a_type == HIPBLASLT_R_16B;
bool is_f32 = arg.a_type == HIPBLASLT_R_32F;
arg.compute_type = compute_type == "" ? (HIPBLASLT_COMPUTE_F32)
: string_to_hipblaslt_computetype(compute_type);
if(arg.compute_type == static_cast<hipblasLtComputeType_t>(0))
bool is_f16 = arg.a_type == HIP_R_16F || arg.a_type == HIP_R_16BF;
bool is_f32 = arg.a_type == HIP_R_32F;
arg.compute_type
= compute_type == "" ? (HIPBLAS_COMPUTE_32F) : string_to_hipblas_computetype(compute_type);
if(arg.compute_type == static_cast<hipblasComputeType_t>(0))
throw std::invalid_argument("Invalid value for --compute_type " + compute_type);

if(string_to_hipblaslt_datatype(bias_type) == static_cast<hipblasltDatatype_t>(0)
&& bias_type != "" && bias_type != "default")
if(string_to_hip_datatype(bias_type) == HIPBLASLT_DATATYPE_INVALID && bias_type != ""
&& bias_type != "default")
throw std::invalid_argument("Invalid value for --bias_type " + bias_type);
else
arg.bias_type = string_to_hipblaslt_datatype(bias_type);
arg.bias_type = string_to_hip_datatype(bias_type);

arg.initialization = string2hipblaslt_initialization(initialization);
if(arg.initialization == static_cast<hipblaslt_initialization>(0))
Expand Down
15 changes: 3 additions & 12 deletions clients/benchmarks/client_extop_layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,8 @@ int main(int argc, char** argv)
hipStream_t stream{};
hipErr = hipStreamCreate(&stream);
//warmup
auto hipblasltErr = hipblasltExtLayerNorm(HIPBLASLT_R_32F,
gpuOutput,
gpuMean,
gpuInvvar,
gpuInput,
m,
n,
1e-05,
gpuGamma,
gpuBeta,
stream);
auto hipblasltErr = hipblasltExtLayerNorm(
HIP_R_32F, gpuOutput, gpuMean, gpuInvvar, gpuInput, m, n, 1e-05, gpuGamma, gpuBeta, stream);

hipErr = hipMemcpyDtoH(cpuOutput.data(), gpuOutput, numElements * elementNumBytes);
hipErr = hipMemcpyDtoH(cpuMean.data(), gpuMean, m * elementNumBytes);
Expand Down Expand Up @@ -274,7 +265,7 @@ int main(int argc, char** argv)

for(int i = 0; i < numRuns; ++i)
{
hipblasltErr = hipblasltExtLayerNorm(HIPBLASLT_R_32F,
hipblasltErr = hipblasltExtLayerNorm(HIP_R_32F,
gpuOutput,
gpuMean,
gpuInvvar,
Expand Down
105 changes: 52 additions & 53 deletions clients/benchmarks/client_extop_matrixtransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,73 +96,72 @@ struct TypedMatrixTransformIO : public MatrixTransformIO
};

using MatrixTransformIOPtr = std::unique_ptr<MatrixTransformIO>;
MatrixTransformIOPtr
makeMatrixTransformIOPtr(hipblasltDatatype_t datatype, int64_t m, int64_t n, int64_t b)
MatrixTransformIOPtr makeMatrixTransformIOPtr(hipDataType datatype, int64_t m, int64_t n, int64_t b)
{
if(datatype == HIPBLASLT_R_32F)
if(datatype == HIP_R_32F)
{
return std::make_unique<TypedMatrixTransformIO<hipblasLtFloat>>(m, n, b);
}
else if(datatype == HIPBLASLT_R_16F)
else if(datatype == HIP_R_16F)
{
return std::make_unique<TypedMatrixTransformIO<hipblasLtHalf>>(m, n, b);
}
else if(datatype == HIPBLASLT_R_16B)
else if(datatype == HIP_R_16BF)
{
return std::make_unique<TypedMatrixTransformIO<hipblasLtBfloat16>>(m, n, b);
}
else if(datatype == HIPBLASLT_R_8I)
else if(datatype == HIP_R_8I)
{
return std::make_unique<TypedMatrixTransformIO<int8_t>>(m, n, b);
}
return nullptr;
}

hipblasltDatatype_t str2Datatype(const std::string& typeStr)
hipDataType str2Datatype(const std::string& typeStr)
{
if(typeStr == "fp32")
{
return HIPBLASLT_R_32F;
return HIP_R_32F;
}
else if(typeStr == "fp16")
{
return HIPBLASLT_R_16F;
return HIP_R_16F;
}
else if(typeStr == "bf16")
{
return HIPBLASLT_R_16B;
return HIP_R_16BF;
}
else if(typeStr == "i8")
{
return HIPBLASLT_R_8I;
return HIP_R_8I;
}
else if(typeStr == "i32")
{
return HIPBLASLT_R_32I;
return HIP_R_32I;
}

return HIPBLASLT_DATATYPE_INVALID;
}

static int parseArguments(int argc,
char* argv[],
hipblasltDatatype_t& datatype,
hipblasltDatatype_t& scaleDatatype,
int64_t& m,
int64_t& n,
float& alpha,
float& beta,
bool& transA,
bool& transB,
uint32_t& ldA,
uint32_t& ldB,
uint32_t& ldC,
bool& rowMajA,
bool& rowMajB,
bool& rowMajC,
int32_t& batchSize,
int64_t& batchStride,
bool& runValidation)
static int parseArguments(int argc,
char* argv[],
hipDataType& datatype,
hipDataType& scaleDatatype,
int64_t& m,
int64_t& n,
float& alpha,
float& beta,
bool& transA,
bool& transB,
uint32_t& ldA,
uint32_t& ldB,
uint32_t& ldC,
bool& rowMajA,
bool& rowMajB,
bool& rowMajC,
int32_t& batchSize,
int64_t& batchStride,
bool& runValidation)
{
if(argc >= 2)
{
Expand Down Expand Up @@ -495,23 +494,23 @@ void validation(void* c,

int main(int argc, char** argv)
{
int64_t m = 2048;
int64_t n = 2048;
int32_t batchSize = 1;
float alpha = 1;
float beta = 1;
auto transA = false;
auto transB = false;
auto rowMajA = false;
auto rowMajB = false;
auto rowMajC = false;
int64_t batchStride{};
uint32_t ldA{};
uint32_t ldB{};
uint32_t ldC{};
bool runValidation{};
hipblasltDatatype_t datatype{HIPBLASLT_R_32F};
hipblasltDatatype_t scaleDatatype{HIPBLASLT_R_32F};
int64_t m = 2048;
int64_t n = 2048;
int32_t batchSize = 1;
float alpha = 1;
float beta = 1;
auto transA = false;
auto transB = false;
auto rowMajA = false;
auto rowMajB = false;
auto rowMajC = false;
int64_t batchStride{};
uint32_t ldA{};
uint32_t ldB{};
uint32_t ldC{};
bool runValidation{};
hipDataType datatype{HIP_R_32F};
hipDataType scaleDatatype{HIP_R_32F};
parseArguments(argc,
argv,
datatype,
Expand Down Expand Up @@ -671,7 +670,7 @@ int main(int argc, char** argv)

if(runValidation)
{
if(datatype == HIPBLASLT_R_32F)
if(datatype == HIP_R_32F)
{
validation<float>(dC,
dA,
Expand All @@ -691,7 +690,7 @@ int main(int argc, char** argv)
transA,
transB);
}
else if(datatype == HIPBLASLT_R_16F)
else if(datatype == HIP_R_16F)
{
validation<hipblasLtHalf>(dC,
dA,
Expand All @@ -711,7 +710,7 @@ int main(int argc, char** argv)
transA,
transB);
}
else if(datatype == HIPBLASLT_R_16B)
else if(datatype == HIP_R_16BF)
{
validation<hipblasLtBfloat16>(dC,
dA,
Expand All @@ -731,7 +730,7 @@ int main(int argc, char** argv)
transA,
transB);
}
else if(datatype == HIPBLASLT_R_8I)
else if(datatype == HIP_R_8I)
{
validation<int8_t>(dC,
dA,
Expand All @@ -751,7 +750,7 @@ int main(int argc, char** argv)
transA,
transB);
}
else if(datatype == HIPBLASLT_R_32I)
else if(datatype == HIP_R_32I)
{
validation<int32_t>(dC,
dA,
Expand Down
4 changes: 2 additions & 2 deletions clients/benchmarks/client_extop_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ int main(int argc, char** argv)
hipStream_t stream{};
hipErr = hipStreamCreate(&stream);
//warmup
auto hipblasltErr = hipblasltExtSoftmax(HIPBLASLT_R_32F, m, n, 1, output, input, stream);
auto hipblasltErr = hipblasltExtSoftmax(HIP_R_32F, m, n, 1, output, input, stream);

if(hipblasltErr)
{
Expand All @@ -95,7 +95,7 @@ int main(int argc, char** argv)

for(int i = 0; i < numRuns; ++i)
{
hipblasltErr = hipblasltExtSoftmax(HIPBLASLT_R_32F, m, n, 1, output, input, stream);
hipblasltErr = hipblasltExtSoftmax(HIP_R_32F, m, n, 1, output, input, stream);
}

hipErr = hipEventRecord(end, stream);
Expand Down
Loading

0 comments on commit 592518e

Please sign in to comment.