-
Notifications
You must be signed in to change notification settings - Fork 7
/
trainingv2.py
234 lines (193 loc) · 10.4 KB
/
trainingv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import json
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, LlamaConfig
from safetensors.torch import load_file
from llama_model import LlamaModel
from tqdm import tqdm
from torch.utils.data import DataLoader
import argparse
import numpy as np
import math
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
def download_dataset(dataset_path):
if dataset_path.startswith("https://huggingface.co/datasets/"):
dataset_path = dataset_path.split("/")[-1]
dataset = load_dataset(dataset_path)
return dataset
def preprocess_dataset(file_path, file_format):
"""Preprocesses the dataset based on its format."""
data = []
if file_format == "txt":
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data.append(line.strip())
elif file_format == "json":
with open(file_path, "r", encoding="utf-8") as file:
json_data = json.load(file)
for item in json_data:
data.append(item["text"])
elif file_format == "jsonl":
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
data.append(json.loads(line)["text"])
return data
def loss(model, inputs, targets, lengths):
# Move lengths to the same device as inputs
inputs = inputs.to(device)
lengths = lengths.to(device)
targets = targets.to(device)
# Run model on inputs
attention_mask = torch.arange(inputs.shape[1], device=device)[None, :] < lengths[:, None]
attention_mask = attention_mask.unsqueeze(1).repeat(1, inputs.shape[1], 1)
attention_mask = attention_mask.to(inputs.device)
# Generate cos and sin embeddings
seq_length = inputs.shape[1]
position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs.device)
cos = torch.zeros(seq_length, model.config.hidden_size // model.config.num_attention_heads, device=inputs.device)
sin = torch.zeros(seq_length, model.config.hidden_size // model.config.num_attention_heads, device=inputs.device)
div_term = torch.exp(torch.arange(0, model.config.hidden_size // model.config.num_attention_heads, 2, device=inputs.device) * (-torch.log(torch.tensor(10000.0)) / (model.config.hidden_size // model.config.num_attention_heads)))
cos[:, 0::2] = torch.cos(position_ids[:, None] * div_term)
sin[:, 1::2] = torch.sin(position_ids[:, None] * div_term)
logits = model(inputs, attention_mask=attention_mask, cos=cos, sin=sin)
logits = logits.view(-1, logits.size(-1))
#print("Logits shape:", logits.shape)
#print("Targets shape:", targets.shape)
#print("Logits values:", logits)
#print("Targets values:", targets)
# Reshape logits to match the shape of targets
logits = logits.view(targets.shape[0], targets.shape[1], -1)
# Compute the loss
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)
targets = targets.contiguous().view(-1)
loss = loss_fn(logits.view(-1, logits.size(-1)), targets)
toks = logits.size(0) * logits.size(1)
return loss, torch.tensor(toks, device=device)
def iterate_batches(dset, tokenizer, batch_size, train=False, max_length=4096):
# Shuffle indices
while True:
indices = np.arange(len(dset))
if train:
indices = np.random.permutation(indices)
# Collect batches from dataset
for i in range(0, len(indices) - batch_size 1, batch_size):
# Encode batch
batch = [tokenizer.encode(dset[indices[i j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than max_length tokens
if max(lengths) > max_length:
print(
f"[WARNING] Some sequences are longer than {max_length} tokens. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = torch.tensor(batch_arr).to(device)
lengths = torch.tensor(lengths).to(device)
yield batch[:, :-1].to(device), batch[:, 1:].to(device), lengths.to(device)
if not train:
break
def evaluate(model, dataset, loss_fn, tokenizer, batch_size, num_batches, max_length, device):
model.eval()
all_losses = []
ntokens = 0
with torch.no_grad():
for it, batch in zip(
range(num_batches),
iterate_batches(dataset, tokenizer, batch_size, max_length=max_length),
):
batch = tuple(t.to(device) for t in batch)
inputs, targets, lengths = batch
attention_mask = torch.arange(inputs.shape[1], device=device)[None, :] < lengths[:, None]
attention_mask = attention_mask.to(device)
cos = torch.zeros(inputs.shape[1], model.config.hidden_size // model.config.num_attention_heads, device=device)
sin = torch.zeros(inputs.shape[1], model.config.hidden_size // model.config.num_attention_heads, device=device)
losses, toks = loss_fn(model, inputs, targets, lengths)
all_losses.append((losses * toks).item())
ntokens = toks.item()
return np.sum(all_losses) / ntokens
def train(model, tokenizer, dataset, batch_size, num_epochs, learning_rate, iters, val_batches, steps_per_report, steps_per_eval, max_length, grad_accum_steps):
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
print("Learning rate:", optimizer.param_groups[0]['lr'])
trainable_params = sum(v.numel() for _, v in model.named_parameters() if v.requires_grad) / 10**6
total_params = sum(v.numel() for _, v in model.named_parameters()) / 10**6
print(f"Total parameters: {total_params:.3f}M")
print(f"Trainable parameters: {trainable_params:.3f}M")
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
torch.autograd.set_detect_anomaly(True)
# Define the learning rate scheduler
def lr_scheduler(step):
if step < iters // 2:
return learning_rate
else:
return learning_rate * 0.1
# Define the weight decay scheduler
def wd_scheduler(step):
if step < iters // 2:
return 0.1
else:
return 0.0
step = 0
while step < iters:
for batch_idx, batch in enumerate(tqdm(iterate_batches(dataset, tokenizer, batch_size, train=True, max_length=max_length), desc=f"Iteration {step // steps_per_eval 1}")):
batch = tuple(t.to(device) for t in batch)
loss_value, ntoks = loss(model, *batch)
loss_value = loss_value / grad_accum_steps
loss_value.backward(retain_graph=True)
if (batch_idx 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
optimizer.zero_grad()
# Update the learning rate and weight decay
lr = lr_scheduler(step)
wd = wd_scheduler(step)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
param_group['weight_decay'] = wd
if step % steps_per_report == 0:
print(f"Step {step}: Loss = {loss_value.item():.4f}")
if step % steps_per_eval == 0:
cos = torch.zeros(max_length, model.config.hidden_size // model.config.num_attention_heads, device=device)
sin = torch.zeros(max_length, model.config.hidden_size // model.config.num_attention_heads, device=device)
val_loss = evaluate(model, dataset, loss, tokenizer, batch_size, val_batches, max_length, device)
print(f"Validation Loss at Step {step}: {val_loss:.4f}")
model.train()
step = 1
if step >= iters:
break
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Fine-tuning script.")
parser.add_argument("--dataset", type=str, help="Path to the dataset file.")
parser.add_argument("--model_path", type=str, help="Path to the pre-trained model.")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training.")
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs.")
parser.add_argument("--learning_rate", type=float, default=1e-5, help="Learning rate for the optimizer.")
parser.add_argument("--output_dir", type=str, default="fine_tuned_model", help="Output directory to save the fine-tuned model.")
parser.add_argument("--iters", type=int, default=1000, help="Steps to train for.")
parser.add_argument("--val_batches", type=int, default=25, help="Number of validation batches, -1 uses the entire validation set.")
parser.add_argument("--steps_per_report", type=int, default=10, help="Number of training steps between loss reporting.")
parser.add_argument("--steps_per_eval", type=int, default=20, help="Number of training steps between validations.")
parser.add_argument("--max_length", type=int, default=8192, help="Maximum sequence length for input tokens.")
parser.add_argument("--grad_accum_steps", type=int, default=1, help="Number of steps for gradient accumulation.")
args = parser.parse_args()
file_format = args.dataset.split(".")[-1]
dataset = preprocess_dataset(args.dataset, file_format)
model_path = input("Please enter the path to the pre-trained model you want to fine-tune: ")
# Load the model configuration from the pre-trained model directory
config = LlamaConfig.from_pretrained(model_path)
print(f"Loaded hidden size: {config.hidden_size}")
print(f"Loaded number of attention heads: {config.num_attention_heads}")
# Create a new model instance with the loaded configuration
model = LlamaModel(config)
# Move the model to the appropriate device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = train(model, tokenizer, dataset, args.batch_size, args.num_epochs, args.learning_rate, args.iters, args.val_batches, args.steps_per_report, args.steps_per_eval, args.max_length, args.grad_accum_steps)
model.save_pretrained(args.output_dir)