How to train in mult-node? #67
Replies: 36 comments
-
hello |
Beta Was this translation helpful? Give feedback.
-
GPU servers and slurm cluster |
Beta Was this translation helpful? Give feedback.
-
can you please test this code import jax
from EasyDel import TrainArguments, CausalLMTrainer
num_processes = 6
process_id = 0 # number between 0 and num_processes-1 that says which node is current node
coordinator_address = "ip:port" # for example 192.168.1.12:8600 (make sure this port is not closed by firewall)
jax.distributed.initialize(coordinator_address=coordinator_address,
num_processes=num_processes,
process_id=process_id)
train_args = TrainArguments(
backend="gpu",
sharding_array=(num_processes, -1, 1),
use_wandb=True,
use_pjit_attention_force=False
)
trainer = CausalLMTrainer(
arguments=train_args,
dataset_train=..., # To Be passed
ckpt_path=... # To Be passed path to ckpt or None
)
parameters = None # if you want to finetune a model you can pass parameters to trainer and they should be like frozen({"params":...})
trainer.train(model_parameters=parameters or None) |
Beta Was this translation helpful? Give feedback.
-
I encountered some issues when I tried to run EasyDeL/examples/training/causal-lm/llama.py. I use llama-13b. I first convert hf-llama to flax. Use this code: model = AutoModelForCausalLM.from_pretrained(path) Is my convert code correct? then I run llama.py, here is the error: |
Beta Was this translation helpful? Give feedback.
-
when you trying to run pass fully_fsdp=False in config.get_partition_rules that will fix this problem |
Beta Was this translation helpful? Give feedback.
-
This seems to disable fully_fsdp. But I want use fully_fsdp. |
Beta Was this translation helpful? Give feedback.
-
which model you trying to use |
Beta Was this translation helpful? Give feedback.
-
llama-13b v2. But the model"s vocabulary size has changed. Does this have any impact? |
Beta Was this translation helpful? Give feedback.
-
yes that"s the reason that you get this error can you tell me your vocab size? like is it like EOS,BOS added version that have 32002 tokens? |
Beta Was this translation helpful? Give feedback.
-
use this from jax.sharding import PartitionSpec as PS
from EasyDel import TrainArguments
partition_rules = (
("transformer/wte/embedding", PS("dp", "fsdp")),
("attention/(wq|wk|wv)/kernel", PS("fsdp")),
("attention/wo/kernel", PS("fsdp")),
("feed_forward/w1/kernel", PS("fsdp")),
("feed_forward/w2/kernel", PS("fsdp")),
("feed_forward/w3/kernel", PS("fsdp")),
("attention_norm/kernel", PS("fsdp")),
("ffn_norm/kernel", PS("fsdp")),
("transformer/ln_f/kernel", PS("fsdp")),
("lm_head/kernel", PS("fsdp", "dp")),
(".*", PS("fsdp")),
)
train_args = TrainArguments(
custom_rule=partition_rules,
...
) this one have to work fine if you just have changed the vocab size |
Beta Was this translation helpful? Give feedback.
-
yes, it worked! |
Beta Was this translation helpful? Give feedback.
-
if you have any other issue please let me know <3 |
Beta Was this translation helpful? Give feedback.
-
The calculation of losses has encountered an issue. The error is: my vocab size is 55296, sequence_length is 1024. |
Beta Was this translation helpful? Give feedback.
-
set loss_remat to "" train_args = TrainArguments(
custom_rule=partition_rules,
loss_remat=""
) this will work, the current error that you takin is because you trying to use blockwise crossentropy and your vocab size (55296) is not visible by 1024 so you can either change loss_remat to "" or change your loss_chunk |
Beta Was this translation helpful? Give feedback.
-
What is the difference between blockwise_cross and cross_entropy_loss_and_accuracy? What is the difference between blockwise crossentropy and crossentropy? Are there any advantages to using blockwise_cross? |
Beta Was this translation helpful? Give feedback.
-
sorry i explained a part of it wrong from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40) use this instead from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params,sep=".")
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40) |
Beta Was this translation helpful? Give feedback.
-
Docs Are available at https://erfanzar.github.io/EasyDeL/docs/ |
Beta Was this translation helpful? Give feedback.
-
thanks, I will keep testing when I have time. |
Beta Was this translation helpful? Give feedback.
-
import error Traceback (most recent call last): |
Beta Was this translation helpful? Give feedback.
-
fixed im sorry for such error :) |
Beta Was this translation helpful? Give feedback.
-
if train use mult-host, does the dataset need any additional processing? |
Beta Was this translation helpful? Give feedback.
-
yes for using easydel you should preprocess you dataset you should pass the tokenized dataset that contains input_ids and attention mask and for batch size you pass the batch size for each step being multiplied to number of gradient accumulation steps for example imagine that you have passed batch size of 8 to trainer with gradient accumulation 8 the total batch size for data loader become 64 and if you have 2 hosts this will become 32 batch size for each host and if you have 8 GPUs per each machine this will become 4 batch per-each GPU |
Beta Was this translation helpful? Give feedback.
-
warning: Linking two modules of different target triples: "LLVMDialectModule" is "nvptx64-nvidia-gpulibs" whereas "" is "nvptx64-nvidia-cuda" what is use_pjit_attention_force ? |
Beta Was this translation helpful? Give feedback.
-
and use_flash_attention seems not work, I found that their speed is the same whether set to true or false. |
Beta Was this translation helpful? Give feedback.
-
I set max_sequence_length = 10240, had this error: And when use large sequence_length occurs loss=nan. |
Beta Was this translation helpful? Give feedback.
-
Yes you are right you should change you model max length |
Beta Was this translation helpful? Give feedback.
-
use_flash_attention seems not work, I found that their speed is the same whether set to true or false. And what is use_pjit_attention_force ? |
Beta Was this translation helpful? Give feedback.
-
I"m really looking forward to Mojo version. Is Mojo a replacement for Jax, or can they work together? |
Beta Was this translation helpful? Give feedback.
-
Actually mojo is more native at least the version that im creating and its works without any imported libraries in mojo since mojo is fast, native and compiled language i coded everything unique and the only library in use from python is os library only to read the size of check points (mojo don"t have any built in I/O library) |
Beta Was this translation helpful? Give feedback.
-
that was awesome. you want make a framework like tensorflow or pytorch based on mojo. This requires a significant amount of work. |
Beta Was this translation helpful? Give feedback.
-
I dont konw how to train use mult-host. Can you gave a example? thank you.
Beta Was this translation helpful? Give feedback.
All reactions