Help train on tpu v3-32 #74
Replies: 12 comments
-
can you please pull and re-run the script ? |
Beta Was this translation helpful? Give feedback.
-
still same.
|
Beta Was this translation helpful? Give feedback.
-
actually there"s a funny bug in your code that i notice that right now you haven"t use import jax.numpy
from EasyDel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="my_first_model_to_train_using_easydel",
num_train_epochs=3,
configs_to_init_model_class=configs_to_init_model_class ,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=64,
max_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it"s optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in fully FSDP automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)
def ultra_chat_prompting_process(
data_chunk
):
user_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
]
assistant_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
]
prompt = ""
for uc, ac in zip(user_part, assistant_part):
prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"
return {"prompt": prompt}
tokenization_process = lambda data_chunk: tokenizer(
data_chunk["prompt"],
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
tokenization_process,
num_proc=12,
remove_columns=dataset_train.column_names
)
# you can do the same for evaluation process dataset
trainer = CausalLanguageModelTrainer(
train_arguments,
dataset,
checkpoint_path=None
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here"s where your model saved {output.last_save_file_name}") |
Beta Was this translation helpful? Give feedback.
-
the issue before solved but this occurs.
Thanks, |
Beta Was this translation helpful? Give feedback.
-
funny bug in my code again import jax.numpy
from EasyDel import (
TrainArguments,
CausalLanguageModelTrainer,
AutoEasyDelModelForCausalLM,
EasyDelOptimizers,
EasyDelSchedulers,
EasyDelGradientCheckPointers
)
from datasets import load_dataset
import flax
from jax import numpy as jnp
from transformers import AutoTokenizer
huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
max_length = 2048
tokenizer = AutoTokenizer.from_pretrained(
huggingface_repo_id_or_path,
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
configs_to_init_model_class = {
"config": model.config,
"dtype": jnp.bfloat16,
"param_dtype": jnp.bfloat16,
"input_shape": (1, 1)
}
train_arguments = TrainArguments(
model_class=type(model),
model_name="my_first_model_to_train_using_easydel",
num_train_epochs=3,
configs_to_init_model_class=configs_to_init_model_class ,
learning_rate=5e-5,
learning_rate_end=1e-6,
optimizer=EasyDelOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
scheduler=EasyDelSchedulers.LINEAR,
# "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
weight_decay=0.01,
total_batch_size=64,
max_steps=None, # None to let trainer Decide
do_train=True,
do_eval=False, # it"s optional but supported
backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
max_length=max_length, # Note that you have to change this in the model config too
gradient_checkpointing=EasyDelGradientCheckPointers.NOTHING_SAVEABLE,
sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
# everything training will be in fully FSDP automatic and share data between devices
use_pjit_attention_force=False,
remove_ckpt_after_load=True,
gradient_accumulation_steps=8,
loss_re_mat="",
dtype=jnp.bfloat16
)
def ultra_chat_prompting_process(
data_chunk
):
user_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
]
assistant_part = [
chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
]
prompt = ""
for uc, ac in zip(user_part, assistant_part):
prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"
return {"prompt": prompt}
tokenization_process = lambda data_chunk: tokenizer(
data_chunk["prompt"],
add_special_tokens=False,
max_length=max_length,
padding="max_length"
)
dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
dataset_train = dataset_train.map(
tokenization_process,
num_proc=12,
remove_columns=dataset_train.column_names
)
# you can do the same for evaluation process dataset
trainer = CausalLanguageModelTrainer(
train_arguments,
dataset_train,
checkpoint_path=None
)
output = trainer.train(flax.core.FrozenDict({"params": params}))
print(f"Hey ! , here"s where your model saved {output.last_save_file_name}") |
Beta Was this translation helpful? Give feedback.
-
Now it working like this Does this look good? cuz this feels bit slow. Thank you, |
Beta Was this translation helpful? Give feedback.
-
it"s not normal and library by default will only use one wandb and for second tip that I can give you is |
Beta Was this translation helpful? Give feedback.
-
Okay I"ll try increase batch size. |
Beta Was this translation helpful? Give feedback.
-
how you are using EasyDel? |
Beta Was this translation helpful? Give feedback.
-
use GitHub method, that"s fixed after last version
|
Beta Was this translation helpful? Give feedback.
-
ah okay thank you |
Beta Was this translation helpful? Give feedback.
-
#34
I read this issue and tried it. but couldn"t make it work :(
Hi, Thank you for your amazing work.
I"ve been trying few days to make tpu v3-32 to work.
I used tpu VM "tpu-ubuntu2204-base" and tried by following code after installing jax and etc to each tpus
train.py
Then I sent it to tpus by
sudo gcloud compute tpus tpu-vm scp train.py node-1: --worker=all --zone=europe-west4-a
and ran it
sudo gcloud compute tpus tpu-vm ssh node-1 --zone=europe-west4-a --worker=all --command="python3 train.py"
and got error
The issue is
Thank you,
Beta Was this translation helpful? Give feedback.
All reactions