Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.8.6 #1212

Merged
merged 98 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
dfe08f3
support deepspeed
BootsofLagrangian Feb 3, 2024
64873c1
fix offload_optimizer_device typo
BootsofLagrangian Feb 5, 2024
2824312
fix vae type error during training sdxl
BootsofLagrangian Feb 5, 2024
4295f91
fix all trainer about vae
BootsofLagrangian Feb 5, 2024
3970bf4
maybe fix branch to run offloading
BootsofLagrangian Feb 5, 2024
7d2a926
apply offloading method runable for all trainer
BootsofLagrangian Feb 5, 2024
6255661
fix full_fp16 compatible and train_step
BootsofLagrangian Feb 7, 2024
2445a5b
remove test requirements
BootsofLagrangian Feb 7, 2024
a98feca
forgot setting mixed_precision for deepspeed. sorry
BootsofLagrangian Feb 7, 2024
03f0816
the reason not working grad accum steps found. it was becasue of my a…
BootsofLagrangian Feb 9, 2024
4d5186d
refactored codes, some function moved into train_utils.py
BootsofLagrangian Feb 22, 2024
577e991
add some new dataset settings
kohya-ss Feb 26, 2024
f2c727f
add minimal impl for masked loss
kohya-ss Feb 26, 2024
1751936
update readme
kohya-ss Feb 26, 2024
4a5546d
fix typo
kohya-ss Feb 26, 2024
074d32a
Merge branch 'main' into dev
kohya-ss Feb 27, 2024
eefb3cc
Merge branch 'deep-speed' into deepspeed
kohya-ss Feb 27, 2024
0e4a573
Merge pull request #1101 from BootsofLagrangian/deepspeed
kohya-ss Feb 27, 2024
e3ccf8f
make deepspeed_utils
kohya-ss Feb 27, 2024
a9b64ff
support masked loss in sdxl_train ref #589
kohya-ss Feb 27, 2024
14c9372
add doc about Colab/rich issue
kohya-ss Mar 3, 2024
124ec45
Add "encoding='utf-8'"
Horizon1704 Mar 10, 2024
095b803
save state on train end
gesen2egee Mar 10, 2024
d282c45
Update train_network.py
gesen2egee Mar 11, 2024
74c266a
Merge branch 'dev' into masked-loss
kohya-ss Mar 12, 2024
97524f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 12, 2024
948029f
random ip_noise_gamma strength
KohakuBlueleaf Mar 12, 2024
8639940
random noise_offset strength
KohakuBlueleaf Mar 12, 2024
53954a1
use correct settings for parser
KohakuBlueleaf Mar 12, 2024
0a8ec52
Merge branch 'main' into dev
kohya-ss Mar 15, 2024
443f029
fix doc
kohya-ss Mar 15, 2024
0ef4fe7
Merge branch 'dev' into masked-loss
kohya-ss Mar 17, 2024
7081a0c
extension of src image could be different than target image
kohya-ss Mar 17, 2024
3419c3d
common masked loss func, apply to all training script
kohya-ss Mar 17, 2024
86e40fa
Merge branch 'dev' into deep-speed
kohya-ss Mar 17, 2024
a7dff59
Update tag_images_by_wd14_tagger.py
sdbds Mar 18, 2024
5410a8c
Update requirements.txt
sdbds Mar 18, 2024
a71c35c
Update requirements.txt
sdbds Mar 18, 2024
6c51c97
fix typo
sdbds Mar 20, 2024
e281e86
Merge branch 'main' into dev
kohya-ss Mar 20, 2024
7da41be
Merge pull request #1192 from sdbds/main
kohya-ss Mar 20, 2024
80dbbf5
tagger now stores model under repo_id subdir
kohya-ss Mar 20, 2024
cf09c6a
Merge pull request #1177 from KohakuBlueleaf/random-strength-noise
kohya-ss Mar 20, 2024
46331a9
English Translation of config_README-ja.md (#1175)
darkstorm2150 Mar 20, 2024
5f6196e
update readme
kohya-ss Mar 20, 2024
119cc99
Merge pull request #1167 from Horizon1704/patch-1
kohya-ss Mar 20, 2024
3b0db0f
update readme
kohya-ss Mar 20, 2024
bf6cd4b
Merge pull request #1168 from gesen2egee/save_state_on_train_end
kohya-ss Mar 20, 2024
855add0
update option help and readme
kohya-ss Mar 20, 2024
9b6b39f
Merge branch 'dev' into masked-loss
kohya-ss Mar 20, 2024
fbb98f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 20, 2024
d945602
Fix most of ZeRO stage uses optimizer partitioning
BootsofLagrangian Mar 20, 2024
a35e7bd
Merge pull request #1200 from BootsofLagrangian/deep-speed
kohya-ss Mar 20, 2024
d17c0f5
update dataset config doc
kohya-ss Mar 20, 2024
863c7f7
format by black
kohya-ss Mar 23, 2024
f4a4c11
support multiline captions ref #1155
kohya-ss Mar 23, 2024
0c7baea
register reg images with correct subset
feffy380 Mar 23, 2024
79d1c12
disable sample_every_n_xxx if value less than 1 ref #1202
kohya-ss Mar 24, 2024
691f043
update readme
kohya-ss Mar 24, 2024
ad97410
Merge pull request #1205 from feffy380/patch-1
kohya-ss Mar 24, 2024
381c449
update readme and typing hint
kohya-ss Mar 24, 2024
ae97c8b
[Experimental] Add cache mechanism for dataset groups to avoid long w…
KohakuBlueleaf Mar 24, 2024
0253472
refactor metadata caching for DreamBooth dataset
kohya-ss Mar 24, 2024
8d58588
Merge branch 'dev' into masked-loss
kohya-ss Mar 24, 2024
993b2ab
Merge branch 'dev' into deep-speed
kohya-ss Mar 24, 2024
1648ade
format by black
kohya-ss Mar 24, 2024
9bbb28c
update PyTorch version and reorganize dependencies
kohya-ss Mar 24, 2024
9c4492b
fix pytorch version 2.1.1 to 2.1.2
kohya-ss Mar 24, 2024
c24422f
Merge branch 'dev' into deep-speed
kohya-ss Mar 25, 2024
a2b8531
make each script consistent, fix to work w/o DeepSpeed
kohya-ss Mar 25, 2024
ea05e3f
Merge pull request #1139 from kohya-ss/deep-speed
kohya-ss Mar 26, 2024
ab1e389
Merge branch 'dev' into masked-loss
kohya-ss Mar 26, 2024
5a2afb3
Merge pull request #1207 from kohya-ss/masked-loss
kohya-ss Mar 26, 2024
c86e356
Merge branch 'dev' into dataset-cache
kohya-ss Mar 26, 2024
78e0a76
Merge pull request #1206 from kohya-ss/dataset-cache
kohya-ss Mar 26, 2024
6c08e97
update readme
kohya-ss Mar 26, 2024
6f7e93d
Add OpenVINO and ROCm ONNX Runtime for WD14
Disty0 Mar 27, 2024
b86af67
Merge pull request #1213 from Disty0/dev
kohya-ss Mar 27, 2024
dd9763b
Rating support for WD Tagger
Disty0 Mar 27, 2024
954731d
fix typo
Disty0 Mar 27, 2024
4012fd2
IPEX fix pin_memory
Disty0 Mar 28, 2024
bc586ce
Add --use_rating_tags and --character_tags_first for WD Tagger
Disty0 Mar 29, 2024
f1f30ab
fix to work with num_beams>1 closes #1149
kohya-ss Mar 30, 2024
ae3f625
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
kohya-ss Mar 30, 2024
434dc40
update readme
kohya-ss Mar 30, 2024
6ba8428
Merge pull request #1216 from Disty0/dev
kohya-ss Mar 30, 2024
cae5aa0
update wd14 tagger and doc
kohya-ss Mar 30, 2024
f5323e3
update tagger doc
kohya-ss Mar 30, 2024
2c2ca9d
update tagger doc
kohya-ss Mar 30, 2024
059ee04
fix typo
kohya-ss Mar 30, 2024
2258a1b
add save/load hook to remove U-Net/TEs from state
kohya-ss Mar 31, 2024
b748b48
fix attention couple+deep shink cause error in some reso
kohya-ss Apr 3, 2024
cd587ce
verify command line args if wandb is enabled
kohya-ss Apr 4, 2024
921036d
Merge pull request #1240 from kohya-ss/verify-command-line-args
kohya-ss Apr 7, 2024
089727b
update readme
kohya-ss Apr 7, 2024
90b1879
Add option to use Scheduled Huber Loss in all training pipelines to i…
kabachuha Apr 7, 2024
d30ebb2
update readme, add metadata for network module
kohya-ss Apr 7, 2024
dfa3079
update readme
kohya-ss Apr 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make deepspeed_utils
  • Loading branch information
kohya-ss committed Feb 27, 2024
commit e3ccf8fbf73a0f728fc167a20b1e0648a3604f41
35 changes: 17 additions & 18 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from tqdm import tqdm

import torch
from library import deepspeed_utils
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
Expand Down Expand Up @@ -42,6 +44,7 @@
def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

cache_latents = args.cache_latents
Expand Down Expand Up @@ -219,7 +222,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers if not args.deepspeed else 1, # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

Expand All @@ -231,7 +234,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)

Expand All @@ -248,21 +251,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder.to(weight_dtype)

if args.deepspeed:
training_models_dict = {}
training_models_dict["unet"] = unet
if args.train_text_encoder: training_models_dict["text_encoder"] = text_encoder

ds_model = train_util.prepare_deepspeed_model(args, **training_models_dict)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(ds_model, optimizer, train_dataloader, lr_scheduler)

training_models = []
unet = ds_model.models["unet"]
training_models.append(unet)
if args.train_text_encoder:
text_encoder = ds_model.models["text_encoder"]
training_models.append(text_encoder)

else: # acceleratorがなんかよろしくやってくれるらしい
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
Expand Down Expand Up @@ -327,13 +325,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
with accelerator.accumulate(*training_models):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) # .to(dtype=weight_dtype)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(weight_dtype)
latents = latents * 0.18215
b_size = latents.shape[0]

Expand Down Expand Up @@ -493,6 +491,7 @@ def setup_parser() -> argparse.ArgumentParser:
train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
Expand Down
139 changes: 139 additions & 0 deletions library/deepspeed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import argparse
import torch
from accelerate import DeepSpeedPlugin, Accelerator

from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


def add_deepspeed_arguments(parser: argparse.ArgumentParser):
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
parser.add_argument(
"--offload_optimizer_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_device",
type=str,
default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--offload_param_nvme_path",
type=str,
default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
)


def prepare_deepspeed_args(args: argparse.Namespace):
if not args.deepspeed:
return

# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
args.max_data_loader_n_workers = 1


def prepare_deepspeed_plugin(args: argparse.Namespace):
if not args.deepspeed:
return None

try:
import deepspeed
except ImportError as e:
logger.error(
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
)
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device,
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device,
offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag,
zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
logger.info("[DeepSpeed] full fp16 enable.")
else:
logger.info(
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
)

if args.offload_optimizer_device is not None:
logger.info("[DeepSpeed] start to manually build cpu_adam.")
deepspeed.ops.op_builder.CPUAdamBuilder().load()
logger.info("[DeepSpeed] building cpu_adam done.")

return deepspeed_plugin


# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
def prepare_deepspeed_model(args: argparse.Namespace, **models):
# remove None from models
models = {k: v for k, v in models.items() if v is not None}

class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model
110 changes: 6 additions & 104 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import DeepSpeedPlugin
import glob
import math
import os
Expand Down Expand Up @@ -70,6 +69,7 @@
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -3243,52 +3243,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)

# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument(
"--zero_stage",
type=int, default=2,
choices=[0, 1, 2, 3],
help="Possible options are 0,1,2,3."
)
parser.add_argument(
"--offload_optimizer_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--fp16_master_weights_and_gradients",
action="store_true",
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32."
)

def verify_training_args(args: argparse.Namespace):
r"""
Expand Down Expand Up @@ -4090,6 +4044,10 @@ def load_tokenizer(args: argparse.Namespace):


def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
"""

if args.logging_dir is None:
logging_dir = None
else:
Expand Down Expand Up @@ -4135,7 +4093,7 @@ def prepare_accelerator(args: argparse.Namespace):
),
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = prepare_deepspeed_plugin(args)
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand All @@ -4149,62 +4107,6 @@ def prepare_accelerator(args: argparse.Namespace):
print("accelerator device:", accelerator.device)
return accelerator

def prepare_deepspeed_plugin(args: argparse.Namespace):
if args.deepspeed is None: return None
try:
import deepspeed
except ImportError as e:
print("deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed")
exit(1)

deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
offload_optimizer_device=args.offload_optimizer_device, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
deepspeed_plugin.deepspeed_config['train_batch_size'] = \
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ['WORLD_SIZE'])
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config['fp16']['initial_scale_power'] = 0 # preventing overflow.
if args.full_fp16 or args.fp16_master_weights_and_gradients:
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
deepspeed_plugin.deepspeed_config['fp16']['fp16_master_weights_and_grads'] = True
print("[DeepSpeed] full fp16 enable.")
else:
print("[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage.")

if args.offload_optimizer_device is not None:
print('[DeepSpeed] start to manually build cpu_adam.')
deepspeed.ops.op_builder.CPUAdamBuilder().load()
print('[DeepSpeed] building cpu_adam done.')

return deepspeed_plugin

def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
assert isinstance(model, torch.nn.Module), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(
torch.nn.ModuleDict(
{key: model}
)
)

def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model

def prepare_dtype(args: argparse.Namespace):
weight_dtype = torch.float32
Expand Down
Loading