Skip to content

Commit

Permalink
Fully deterministic encoder backward kernels for train_gpt2.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
ademeure committed May 21, 2024
1 parent 6c8bc17 commit b5e75dd
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 60 deletions.
2 changes: 1 addition & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 54,7 @@ int main(int argc, char *argv[]) {
// do a training step
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
gpt2_backward(&model, x);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config);
cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings

Expand Down
2 changes: 1 addition & 1 deletion test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 203,7 @@ int main(int argc, char *argv[]) {
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
gpt2_backward(&model, x);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) (end.tv_nsec - start.tv_nsec) / 1e9;

Expand Down
245 changes: 187 additions & 58 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 38,9 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),
#include <stdio.h>
#include <stdarg.h>
#include <string>
#include <vector>
#include <functional>
#include <unordered_map>
// GPU / CUDA related
#include <cuda_runtime.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -532,50 535,108 @@ __global__ void encoder_forward_kernel3(floatX* out,
store128(out_btc, packed_out);
}

template <typename T>
__device__ void atomicStochasticAdd(T* address, float val0, float val1, unsigned int seed) {
static_assert(sizeof(T) == 2, "Only 16-bit atomicStochasticAdd supported.");
float2 val = make_float2(val0, val1);
unsigned int* address_as_uint = (unsigned int*)address;
unsigned int old = *address_as_uint, assumed;
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
do {
assumed = old;
float2 new_fp32 = make_float2((float)(reinterpret_cast<T*>(&old)[0]) val.x,
(float)(reinterpret_cast<T*>(&old)[1]) val.y);
T new_rounded[2];
stochastic_rounding(new_fp32.x, &new_rounded[0], random);
stochastic_rounding(new_fp32.y, &new_rounded[1], random >> 16);
old = atomicCAS(address_as_uint, assumed, *(unsigned int*)&new_rounded);
} while (assumed != old);
}
__device__ void atomicStochasticAdd(float* address, float val0, float val1, unsigned int seed) {
atomicAdd(address, val0);
atomicAdd(address 1, val1);
}

__global__ void encoder_backward_kernel(floatX* dwte, floatX* dwpe,
const floatX* dout, const int* inp,
int B, int T, int C, unsigned int seed) {
int idx = blockIdx.x * blockDim.x threadIdx.x;
int N = B * T * C;
idx *= 2; // 2 elements per thread
if (idx >= N) { return; }
template <int BLOCK_SIZE=256>
__global__ void wte_backward_kernel(floatX* dwte,
const int4* bucket_info, const int* workload_indices, const floatX* dout, const int* inp,
unsigned int seed, int B, int T, int C) {
// In order to be deterministic, we preprocess the inputs on the cpu into "buckets"
// Each bucket corresponds to (WARP_SIZE * x128::size) channels for a single vocabulary token
// Each thread handles x128::size channels, e.g. 256 per warp for BF16
// Each block handles (BLOCK_SIZE / WARP_SIZE) elements in a single bucket in parallel
// If a bucket has less than 8 elements, some warps will return immediately
// If a bucket has more than 8 elements, we will loop over all of them
// The buckets are sorted on the CPU so the largest buckets start 1st
int bucket = blockIdx.x;
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
int c_per_warp = WARP_SIZE * x128::size;

int bucket_start_idx = bucket_info[bucket].x;
int bucket_size = bucket_info[bucket].y;
int bucket_ix = bucket_info[bucket].z;
int c = bucket_info[bucket].w * c_per_warp (lane_id * x128::size);

// Each thread handles "x128::size" channels, so at fp8, each warp would handle 512 channels
// If C is not a multiple of this (e.g. 768), some buckets/c_groups cannot use the entire warp
if (c >= C) { return; }
// Exit early if this is a small bucket and this warp doesn't have any items to process
if (warp_id >= bucket_size) { return; }

float accum[x128::size] = {0.0f};
__shared__ float accum_shared[x128::size * BLOCK_SIZE];

for(int item = warp_id; item < bucket_size; item = BLOCK_SIZE/WARP_SIZE) {
int bt = workload_indices[bucket_start_idx item];
int b = bt / T;
int t = bt % T;

const floatX* dout_btc = dout b * T * C t * C c;
x128 packed_inp1 = load128cs(dout_btc);
for (int k = 0; k < packed_inp1.size; k ) {
accum[k] = (float)packed_inp1[k];
}
}

int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c = idx % C;
if (warp_id != 0) {
// we accumulate into warp 0, so only the other warps need to write to shared memory
for (int k = 0; k < x128::size; k ) {
accum_shared[threadIdx.x k * BLOCK_SIZE] = accum[k];
}
return; // only warp 0 is needed after writing to shared memory
}

int ix = inp[b * T t];
// Read dwte for warp 0 even if other warps are not finished yet to maximise latency tolerance
floatX* dwte_ix = dwte bucket_ix * C c;
x128 packed_in_out = load128(dwte_ix);

const floatX* dout_btc = dout b * T * C t * C c;
floatX* dwte_ix = dwte ix * C c;
floatX* dwpe_tc = dwpe t * C c;
// note: threads which have returned are considered synchronised by CUDA so no risk of deadlock
__syncthreads();

float2 dout_data = make_float2(dout_btc[0], dout_btc[1]);
atomicStochasticAdd(dwte_ix, dout_data.x, dout_data.y, seed);
atomicStochasticAdd(dwpe_tc, dout_data.x, dout_data.y, seed ^ 0xFFFFFFFF);
// Accumulate into warp 0's registers by reading the values of the other warps in shared memory
for (int i = threadIdx.x WARP_SIZE; i < min(BLOCK_SIZE, bucket_size*WARP_SIZE); i = WARP_SIZE) {
for (int k = 0; k < x128::size; k ) {
accum[k] = accum_shared[i k * BLOCK_SIZE];
}
}

// Add the result to dwte and write back to global memory (read-modify-write)
for (unsigned int k = 0; k < x128::size; k ) {
// We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic
stochastic_rounding(accum[k] (float)packed_in_out[k], &packed_in_out[k], seed k);
}
store128(dwte_ix, packed_in_out);
}

__global__ void wpe_backward_kernel(floatX* dwpe,
const floatX* dout, const int* inp,
int B, int T, int C, unsigned int seed) {
// Each thread handles x128::size "channel positions", e.g. 256 per warp for BF16
// For gpt2-124M BF16, C=768 and T=1024, so 3 warps per channel and 3072 warps in total
// For each "channel position" we sum the gradients for every batch at that C/T element
// This way each dwte element is only updated once, and the kernel is fully deterministic!
// The previous kernel was not deterministic, as batches were aggregated with atomicAdd
int idx = (blockIdx.x * blockDim.x threadIdx.x) * x128::size;
if (idx >= T * C) { return; }

// if C is not a multiple of WARP_SIZE*x128::size, it's OK for some warps to handle multiple t
int t = idx / C;
int c = idx % C;
float accum[x128::size] = {0.0f};

for (int b = 0; b < B; b ) {
x128 packed_dout = load128cs(dout (b * T * C) (t * C) c); // will never be read again
for (int k = 0; k < x128::size; k ) {
accum[k] = (float)packed_dout[k];
}
}

floatX* dwpe_tc = dwpe (t * C) c;
x128 packed_dwpe = load128(dwpe_tc);
for (unsigned int k = 0; k < x128::size; k ) {
// We use stochastic rounding to go from FP32 to BF16 but the seed should be deterministic
stochastic_rounding(accum[k] (float)packed_dwpe[k], &packed_dwpe[k], seed k);
}
store128(dwpe_tc, packed_dwpe);
}

__global__ void layernorm_forward_kernel3(floatX* __restrict__ out, floatX* __restrict__ mean, floatX* __restrict__ rstd,
Expand Down Expand Up @@ -783,10 844,9 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
// directly autoregressive, so we only compute the lower triangular part
// uses the online softmax algorithm
assert(T % 4 == 0);
const int warp_size = 32;
int lane_id = threadIdx.x % warp_size;
int warp_id = threadIdx.x / warp_size;
int num_warps = blockDim.x / warp_size;
int lane_id = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int num_warps = blockDim.x / WARP_SIZE;

// micro-optimization: we iterate backwards so that
// after the softmax backward operation completes, the cache retains the
Expand All @@ -809,7 869,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
float sumval = 0.0f;

const floatX* x_aligned = reinterpret_cast<const floatX*>(__builtin_assume_aligned(x, 16));
for (int i = lane_id; i < pos_by_4; i = warp_size) {
for (int i = lane_id; i < pos_by_4; i = WARP_SIZE) {
float regarray[4];
for (int k = 0; k < 4; k) {
regarray[k] = (float)x_aligned[4*i k];
Expand Down Expand Up @@ -838,7 898,7 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
float norm = 1.f / sum;

// divide the whole row by the sum
for (int i = lane_id; i <= own_pos; i = warp_size) {
for (int i = lane_id; i <= own_pos; i = WARP_SIZE) {
// recalculation is faster than doing the round-trip through memory.
float ev = expf(inv_temperature * ((float)__ldcs(x i) - global_maxval));
__stcs(out idx * T i, (floatX)(ev * norm));
Expand Down Expand Up @@ -1354,14 1414,70 @@ void encoder_forward(floatX* out,
cudaCheck(cudaGetLastError());
}

void encoder_backward(floatX* dwte, floatX* dwpe,
const floatX* dout, const int* inp,
int B, int T, int C, unsigned int seed) {
// Fully deterministic (see comments in wte_backward_kernel and wpe_backward_kernel for more details)
void encoder_backward(floatX* dwte, floatX* dwpe, floatX* scratch, // gpu outputs & scratch
int* workload_indices, int4* bucket_info, // cpu scratch buffers
const floatX* dout, const int* inp, const int* inputs_cpu, // cpu/gpu inputs
int B, int T, int C, unsigned int seed) {
NVTX_RANGE_FN();
const int N = B * T * C;

// Launch wpe kernel first (so it runs on the GPU in parallel with the CPU pre-processing for wte)
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size * 2); // each thread handles 2 elements
encoder_backward_kernel<<<grid_size, block_size>>>(dwte, dwpe, dout, inp, B, T, C, seed);
const int N = T * C / x128::size;
const int grid_size = CEIL_DIV(N, block_size);
wpe_backward_kernel<<<grid_size, block_size, 0>>>(dwpe, dout, inp, B, T, C, seed);

// check the GPU scratch buffer is large enough to hold the bucket info and workload indices
// todo - this is trivially true given hardcoded scratch buffer size here, is this useful?
int num_c_groups = CEIL_DIV(C, x128::size * WARP_SIZE);
assert(B*T*num_c_groups * (sizeof(int4) sizeof(int)) <= B*T*3*C * sizeof(floatX));

// Step 1: Sort inputs into buckets
int total_items = 0;
std::unordered_map<uint64_t, std::vector<uint64_t>> buckets;
for (uint64_t bt = 0; bt < B * T; bt ) {
for (uint64_t c_group = 0; c_group < num_c_groups; c_group ) {
// todo - passing c_group/inputs_cpu[bt] in data to avoid a second hash lookup is a bit hacky
uint64_t data = bt (c_group<<32ULL) ((uint64_t)inputs_cpu[bt]<<42ULL);
buckets[c_group num_c_groups * inputs_cpu[bt]].push_back(data);
total_items ;
}
}

// Step 2: Sort buckets by size in descending order
// this is so the largest buckets are processed first by the GPU
// otherwise, if they started late, they would still be running with the rest of the GPU idle
std::vector<std::pair<uint64_t, std::vector<uint64_t>>> sortedBuckets(buckets.begin(), buckets.end());
std::sort(sortedBuckets.begin(), sortedBuckets.end(), // ugly because we don't have a typedef for the std::pair
[](const std::pair<uint64_t, std::vector<uint64_t>>& a, const std::pair<uint64_t, std::vector<uint64_t>>& b) {
return a.second.size() > b.second.size();
});

int num_buckets = buckets.size();
int bucket_index = 0;
int workload_index = 0;
for (const auto& bucket : sortedBuckets) {
bucket_info[bucket_index].x = workload_index; // bucket start
bucket_info[bucket_index].y = bucket.second.size(); // bucket size
bucket_info[bucket_index].z = (bucket.second[0] >> 42ULL) & ((1ULL<<20ULL)-1); // bucket ix
bucket_info[bucket_index].w = (bucket.second[0] >> 32ULL) & ((1ULL<<10ULL)-1); // bucket c

for (uint64_t idx : bucket.second) {
workload_indices[workload_index ] = (int)(idx & ((1ULL<<31ULL)-1ULL));
}
bucket_index ;
}

// Step 3: Copy data from host to device (async until the last one to avoid synchronising CPU/GPU twice)
// todo - could use CUDA events (even without streams) to avoid CPU/GPU synchronisation completely
int4* d_bucket_info = (int4*)scratch;
int* d_workload_indices = (int*)(scratch B*T*num_c_groups * sizeof(int4));
cudaMemcpyAsync(d_bucket_info, bucket_info, num_buckets * sizeof(int4), cudaMemcpyHostToDevice);
cudaMemcpy(d_workload_indices, workload_indices, total_items * sizeof(int), cudaMemcpyHostToDevice);

// Launch wte kernel
// todo - profile block sizes on more content (depends on number of buckets and on GPU?)
wte_backward_kernel<256><<<num_buckets, 256>>>(dwte, d_bucket_info, d_workload_indices, dout, inp, seed, B, T, C);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -1947,6 2063,9 @@ typedef struct {
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights;
int recompute;
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
} GPT2;

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -2022,6 2141,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->inputs = NULL;
model->targets = NULL;
model->cpu_losses = NULL;
model->workload_indices = NULL;
model->bucket_info = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
Expand Down Expand Up @@ -2195,7 2316,7 @@ void gpt2_zero_grad(GPT2 *model) {
}
}

void gpt2_backward(GPT2 *model) {
void gpt2_backward(GPT2 *model, int* inputs) {
NVTX_RANGE_FN();
// double check we forwarded previously, with targets
if (model->mean_loss == -1.0f) {
Expand All @@ -2221,6 2342,11 @@ void gpt2_backward(GPT2 *model) {
model->grads_acts_memory = malloc_and_point_backward(&model->grads_acts, bw_act_sizes);
// init gradients of parameters and activations to zero
gpt2_zero_grad(model);
// initialise cpu scratch buffers for encoder backward
size_t num_c_groups = model->config.channels / (WARP_SIZE * x128::size);
assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?)
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);
}

// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
Expand All @@ -2241,7 2367,8 @@ void gpt2_backward(GPT2 *model) {
cudaCheck(cudaMemset(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX)));

// re-use the output buffer of the forward pass as a scratchpad during backward pass
float* scratchF = (float*)acts.output;
float* scratchF = (float*)acts.output;
floatX* scratchX = (floatX*)acts.output;

// we kick off the chain rule by filling in dlosses with 1.0f/(B*T)
// this was done in the fused classifier kernel as last step of forward pass
Expand Down Expand Up @@ -2323,7 2450,6 @@ void gpt2_backward(GPT2 *model) {
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch; // this is B x T x 4C, so even larger than what we need
floatX* dl_preatt = (floatX*)grads_acts.preatt; // dedicated scratchpad allocation
floatX* scratchX = (floatX*)acts.output;
attention_backward(dl_bt4c, buffer_b, dl_preatt, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH);
#endif

Expand All @@ -2332,7 2458,8 @@ void gpt2_backward(GPT2 *model) {
// layernorm backward does = to dresidual, so it correctly accumulates gradient for the Attention block above
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);
}
encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C, random_u32(&model->rng_state));
encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,
dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state));
}

// Compute a mean of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
Expand Down Expand Up @@ -2448,6 2575,8 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->inputs));
cudaCheck(cudaFree(model->targets));
cudaFreeHost(model->cpu_losses);
free(model->workload_indices);
free(model->bucket_info);
}

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -2477,7 2606,7 @@ void common_free(GPT2 &model) {
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
cublasCheck(cublasLtDestroy(cublaslt_handle));
create_cudnn();
destroy_cudnn();
}

#ifndef TESTING
Expand Down Expand Up @@ -2880,7 3009,7 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps);
lossf = model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward
// backward pass. all model params accumulate gradients with = inside this inner loop
gpt2_backward(&model);
gpt2_backward(&model, train_loader.inputs);
}
// override the mean loss, accounting for the gradient accumulation loop
// this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced
Expand Down

0 comments on commit b5e75dd

Please sign in to comment.