From 29a6b1cb465f6fd6f0391852912b51d6479176f9 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 13 Aug 2024 13:48:23 +0300 Subject: [PATCH] initial curand implementation for model init --- Makefile | 2 +- llmc/cuda_utils.cuh | 14 ++++++++++++++ train_gpt2.cu | 31 +++++++++++++++++-------------- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/Makefile b/Makefile index 6fa511db4..8528ca499 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ FORCE_NVCC_O ?= 3 # NVCC flags # -t=0 is short for --threads, 0 = number of CPUs on the machine NVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O) -NVCC_LDFLAGS = -lcublas -lcublasLt +NVCC_LDFLAGS = -lcublas -lcublasLt -lcurand NVCC_INCLUDES = NVCC_LDLIBS = NCLL_INCUDES = diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 0ce728ee1..855b4bf12 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -131,6 +131,11 @@ __device__ float cast_value(__nv_bfloat16 val) { return __bfloat162float(val); } +template<> +__device__ __nv_bfloat16 cast_value<__nv_bfloat16, float>(float val) { + return __float2bfloat16(val); +} + template __global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) { int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -140,6 +145,15 @@ __global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t } } +template +__global__ void fill_kernel(T* dst, T value, size_t n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // need to try grid stride looping for more perf later + if (idx < n) { + dst[idx] = value; + } +} + // ---------------------------------------------------------------------------- // Warp/Block communication primitives diff --git a/train_gpt2.cu b/train_gpt2.cu index 16f801387..b9115008d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -70,6 +70,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. // defines: multi_gpu_get_shard_offset, multi_gpu_async_reduce_gradient #include "llmc/zero.cuh" +#include "curand.h" + // ---------------------------------------------------------------------------- // global vars for I/O char filename_buffer[512]; @@ -577,22 +579,25 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // NOTE: assuming all parameters are of the type floatX, could be relaxed later mt19937_state init_rng; manual_seed(&init_rng, 42); - floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); - memset(params_memory_cpu, 0, model->num_parameters_bytes); // fill in all the weights with random values float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers); // we have to init all these tensors exactly in the order that PyTorch initializes them // so that we can match them up and get correctness and exactly the same initial conditions size_t L = model->config.num_layers; size_t offset = 0; + + curandGenerator_t rng; + curandCreateGenerator(&rng, curandRngType::CURAND_RNG_PSEUDO_MT19937); + curandSetPseudoRandomGeneratorSeed(rng, 42); + curandSetGeneratorOrdering(rng, CURAND_ORDERING_PSEUDO_LEGACY); // less performant; numbers are the same on all GPUs + curandSetStream(rng, main_stream); + for (int l = 0; l < L; l++) { offset = 0; for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { // the layernorm parameters are all initialized to 1 if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once - for (size_t j = 0; j < model->param_elements[i]; j++) { - params_memory_cpu[offset + j] = 1.0f; - } + fill_kernel<<param_elements[i], 512)), 512, 0, main_stream>>>((floatX*)model->params_memory + offset, (floatX)1.0, model->param_elements[i]); } // weights tensors are handled here if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors @@ -613,20 +618,18 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) { // scaled by 1/sqrt(2*L) for training stability float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; // okay let's draw the random numbers and write them - float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); - normal_(fp32_buffer, n, 0.0f, scale, &init_rng); - for (size_t j = 0; j < n; j++) { - params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j]; - } - free(fp32_buffer); + float *fp32_buffer; + cudaCheck(cudaMallocAsync(&fp32_buffer, n*sizeof(float), main_stream)); + curandGenerateNormal(rng, fp32_buffer, n, 0.f, scale); + copy_and_cast_kernel<<>>((floatX*)model->params_memory + offset + layer_offset, + fp32_buffer, n, 0, 0); + cudaCheck(cudaFreeAsync(fp32_buffer, main_stream)); } offset += model->param_elements[i]; } } - // copy them to GPU - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); - free(params_memory_cpu); + cudaCheck(cudaDeviceSynchronize()); } // propagate inputs through the network to produce logits.