-
Notifications
You must be signed in to change notification settings - Fork 191
/
train_text_generation.py
91 lines (81 loc) · 2.46 KB
/
train_text_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from argparse import ArgumentParser
import yaml
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.training_utils import (
OnPolicyTrainer,
SupervisedTrainer,
)
def main(
config_path: str,
project_name: str,
experiment_name: str,
base_path_to_store_results: str,
entity_name: str,
log_to_wandb: bool,
):
# load the config file
with open(config_path, "r") as fp:
config = yaml.safe_load(fp)
# load tracker
tracker = Tracker(
base_path_to_store_results,
config,
project_name,
experiment_name,
entity_name,
log_to_wandb,
)
# instantiate the trainer here
if "supervised" in config["alg"]["id"]:
trainer = SupervisedTrainer(
tokenizer_config=config["tokenizer"],
datapool_config=config["datapool"],
alg_config=config["alg"],
train_eval_config=config["train_evaluation"],
tracker=tracker,
)
else:
trainer = OnPolicyTrainer(
tokenizer_config=config["tokenizer"],
datapool_config=config["datapool"],
reward_config=config["reward_fn"],
env_config=config["env"],
on_policy_alg_config=config["alg"],
train_eval_config=config["train_evaluation"],
tracker=tracker,
)
trainer.train_and_eval()
if __name__ == "__main__":
parser = ArgumentParser(description="Fine-tune LM to generate controlled text")
parser.add_argument("--config_path", type=str, help="path to the config file")
parser.add_argument(
"--project_name", type=str, help="WANDB project name", default="rl4lm_exps"
)
parser.add_argument(
"--experiment_name",
type=str,
help="WANDB experiment name",
default="rl4lm_experiment",
)
parser.add_argument(
"--entity_name", type=str, help="WANDB entity name", default=None
)
parser.add_argument(
"--base_path_to_store_results",
type=str,
help="Base path to store experiment results",
default=os.getcwd(),
)
parser.add_argument(
"--log_to_wandb", action="store_true", help="Whether to use wandb logging"
)
args = parser.parse_args()
main(
args.config_path,
args.project_name,
args.experiment_name,
args.base_path_to_store_results,
args.entity_name,
args.log_to_wandb,
)