Skip to content

Commit

Permalink
Add micro step guard
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 1, 2024
1 parent 78e7627 commit 74e5ba7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 61,7 @@ int main(int argc, char *argv[]) {
// do a training step
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, y, 1, true);
gpt2_backward_and_reduce(&model, x, y, 1, 0);
float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);
float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, grad_scale, 1, &multi_gpu_config);
Expand Down
6 changes: 3 additions & 3 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 218,7 @@ int main(int argc, char *argv[]) {
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, y, 1, true);
gpt2_backward_and_reduce(&model, x, y, 1, 0);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) (end.tv_nsec - start.tv_nsec) / 1e9;

Expand Down Expand Up @@ -336,7 336,7 @@ int main(int argc, char *argv[]) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step 11, &multi_gpu_config);
losses[step] = model.mean_loss;
tokens[step] = loader.inputs[0];
Expand All @@ -351,7 351,7 @@ int main(int argc, char *argv[]) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, 0);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step 11, &multi_gpu_config);

if(loader.inputs[0] != tokens[step]) {
Expand Down
10 changes: 6 additions & 4 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -751,11 751,13 @@ void gpt2_zero_grad(GPT2 *model) {
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, bool last_step) {
void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) {
NVTX_RANGE_FN();
bool last_step = micro_step == grad_accum_steps - 1;

// init gradients of parameters and activations to zero
gpt2_zero_grad(model);
if (micro_step == 0) {
gpt2_zero_grad(model); // set correct grad state on first micro-step
}

// lazily allocate the memory for gradients of the weights and activations, if needed
if (model->grads_memory == NULL) {
Expand Down Expand Up @@ -1795,7 1797,7 @@ int main(int argc, char *argv[]) {
// forward pass. note that we pass in grad_accum_steps, which scales down the loss
gpt2_forward(&model, train_loader.inputs, B, T);
// backward pass. all model params accumulate gradients with = inside this inner loop
gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step == grad_accum_steps - 1);
gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step);
}
float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score
// fetch the next learning rate
Expand Down

0 comments on commit 74e5ba7

Please sign in to comment.