Skip to content

Commit

Permalink
Remove local args for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 1, 2024
1 parent de8454d commit 78e7627
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1404,16 1404,16 @@ void error_usage() {
// main training loop
int main(int argc, char *argv[]) {
// read in the (optional) command line arguments
const char* train_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_train_*.bin";
const char* val_data_pattern = "/hdd/llmc/fineweb/bin/fineweb_val_*.bin";
const char* load_filename = "d12"; // bf16 weights of the model
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
const char* lr_scheduler_type = "cosine";
const char* output_log_dir = "/hdd/llmc/log_gpt2_124M/";
const char* output_log_dir = NULL;
int checkpoint_every = 0; // write checkpoints every how many steps?
int checkpoints_keep = 0; // how long checkpoint history do we keep? (in units of checkpoints)
int major_checkpoint_every = 0; // major checkpoints never get deleted when maintaining history
int resume = 0; // resume the optimization, if one is found inside output_log_dir?
int B = 16; // batch size
int B = 4; // batch size
int T = 1024; // sequence length max
int total_batch_size = -1; // will be calculated down below later, if not provided
float learning_rate = 3e-4f;
Expand All @@ -1422,15 1422,15 @@ int main(int argc, char *argv[]) {
float weight_decay = 0.0f;
float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore
float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore
int val_loss_every = 50; // every how many steps do we eval validation loss?
int val_loss_every = 20; // every how many steps do we eval validation loss?
int val_max_steps = 20; // how many batches max do we eval for validation loss?
int sample_every = 25000; // every how many steps to do inference?
int sample_every = 20; // every how many steps to do inference?
int genT = 64; // number of steps of inference we will do
int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once
int max_steps = -1;
int override_enable_tf32 = 1;
int use_master_weights = 1;
int recompute = 0; // recompute during backward setting, 0 = none, 1 = recompute gelu
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;
// multi-node settings
Expand Down

0 comments on commit 78e7627

Please sign in to comment.