Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] initial curand implementation for model init #741

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ FORCE_NVCC_O ?= 3
# NVCC flags
# -t=0 is short for --threads, 0 = number of CPUs on the machine
NVCC_FLAGS = --threads=0 -t=0 --use_fast_math -std=c++17 -O$(FORCE_NVCC_O)
NVCC_LDFLAGS = -lcublas -lcublasLt
NVCC_LDFLAGS = -lcublas -lcublasLt -lcurand
NVCC_INCLUDES =
NVCC_LDLIBS =
NCLL_INCUDES =
Expand Down
14 changes: 14 additions & 0 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ __device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}

template<>
__device__ __nv_bfloat16 cast_value<__nv_bfloat16, float>(float val) {
return __float2bfloat16(val);
}

template<typename Td, typename Ts>
__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t stride_dst, ptrdiff_t stride_src) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -140,6 +145,15 @@ __global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n, ptrdiff_t
}
}

template<class T>
__global__ void fill_kernel(T* dst, T value, size_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx] = value;
}
}

// ----------------------------------------------------------------------------
// Warp/Block communication primitives

Expand Down
31 changes: 17 additions & 14 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
// defines: multi_gpu_get_shard_offset, multi_gpu_async_reduce_gradient
#include "llmc/zero.cuh"

#include "curand.h"

// ----------------------------------------------------------------------------
// global vars for I/O
char filename_buffer[512];
Expand Down Expand Up @@ -577,22 +579,25 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {
// NOTE: assuming all parameters are of the type floatX, could be relaxed later
mt19937_state init_rng;
manual_seed(&init_rng, 42);
floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes);
memset(params_memory_cpu, 0, model->num_parameters_bytes);
// fill in all the weights with random values
float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers);
// we have to init all these tensors exactly in the order that PyTorch initializes them
// so that we can match them up and get correctness and exactly the same initial conditions
size_t L = model->config.num_layers;
size_t offset = 0;

curandGenerator_t rng;
curandCreateGenerator(&rng, curandRngType::CURAND_RNG_PSEUDO_MT19937);
curandSetPseudoRandomGeneratorSeed(rng, 42);
curandSetGeneratorOrdering(rng, CURAND_ORDERING_PSEUDO_LEGACY); // less performant; numbers are the same on all GPUs
curandSetStream(rng, main_stream);

for (int l = 0; l < L; l++) {
offset = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// the layernorm parameters are all initialized to 1
if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once
for (size_t j = 0; j < model->param_elements[i]; j++) {
params_memory_cpu[offset + j] = 1.0f;
}
fill_kernel<<<dim3(CEIL_DIV(model->param_elements[i], 512)), 512, 0, main_stream>>>((floatX*)model->params_memory + offset, (floatX)1.0, model->param_elements[i]);
}
// weights tensors are handled here
if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors
Expand All @@ -613,20 +618,18 @@ void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {
// scaled by 1/sqrt(2*L) for training stability
float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;
// okay let's draw the random numbers and write them
float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));
normal_(fp32_buffer, n, 0.0f, scale, &init_rng);
for (size_t j = 0; j < n; j++) {
params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j];
}
free(fp32_buffer);
float *fp32_buffer;
cudaCheck(cudaMallocAsync(&fp32_buffer, n*sizeof(float), main_stream));
curandGenerateNormal(rng, fp32_buffer, n, 0.f, scale);
copy_and_cast_kernel<<<dim3(CEIL_DIV(n, 512)), 512, 0, main_stream>>>((floatX*)model->params_memory + offset + layer_offset,
fp32_buffer, n, 0, 0);
cudaCheck(cudaFreeAsync(fp32_buffer, main_stream));
}
offset += model->param_elements[i];
}
}

// copy them to GPU
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);
cudaCheck(cudaDeviceSynchronize());
}

// propagate inputs through the network to produce logits.
Expand Down
Loading