Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version 0.8.6 #1212

Merged
merged 98 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
dfe08f3
support deepspeed
BootsofLagrangian Feb 3, 2024
64873c1
fix offload_optimizer_device typo
BootsofLagrangian Feb 5, 2024
2824312
fix vae type error during training sdxl
BootsofLagrangian Feb 5, 2024
4295f91
fix all trainer about vae
BootsofLagrangian Feb 5, 2024
3970bf4
maybe fix branch to run offloading
BootsofLagrangian Feb 5, 2024
7d2a926
apply offloading method runable for all trainer
BootsofLagrangian Feb 5, 2024
6255661
fix full_fp16 compatible and train_step
BootsofLagrangian Feb 7, 2024
2445a5b
remove test requirements
BootsofLagrangian Feb 7, 2024
a98feca
forgot setting mixed_precision for deepspeed. sorry
BootsofLagrangian Feb 7, 2024
03f0816
the reason not working grad accum steps found. it was becasue of my a…
BootsofLagrangian Feb 9, 2024
4d5186d
refactored codes, some function moved into train_utils.py
BootsofLagrangian Feb 22, 2024
577e991
add some new dataset settings
kohya-ss Feb 26, 2024
f2c727f
add minimal impl for masked loss
kohya-ss Feb 26, 2024
1751936
update readme
kohya-ss Feb 26, 2024
4a5546d
fix typo
kohya-ss Feb 26, 2024
074d32a
Merge branch 'main' into dev
kohya-ss Feb 27, 2024
eefb3cc
Merge branch 'deep-speed' into deepspeed
kohya-ss Feb 27, 2024
0e4a573
Merge pull request #1101 from BootsofLagrangian/deepspeed
kohya-ss Feb 27, 2024
e3ccf8f
make deepspeed_utils
kohya-ss Feb 27, 2024
a9b64ff
support masked loss in sdxl_train ref #589
kohya-ss Feb 27, 2024
14c9372
add doc about Colab/rich issue
kohya-ss Mar 3, 2024
124ec45
Add "encoding='utf-8'"
Horizon1704 Mar 10, 2024
095b803
save state on train end
gesen2egee Mar 10, 2024
d282c45
Update train_network.py
gesen2egee Mar 11, 2024
74c266a
Merge branch 'dev' into masked-loss
kohya-ss Mar 12, 2024
97524f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 12, 2024
948029f
random ip_noise_gamma strength
KohakuBlueleaf Mar 12, 2024
8639940
random noise_offset strength
KohakuBlueleaf Mar 12, 2024
53954a1
use correct settings for parser
KohakuBlueleaf Mar 12, 2024
0a8ec52
Merge branch 'main' into dev
kohya-ss Mar 15, 2024
443f029
fix doc
kohya-ss Mar 15, 2024
0ef4fe7
Merge branch 'dev' into masked-loss
kohya-ss Mar 17, 2024
7081a0c
extension of src image could be different than target image
kohya-ss Mar 17, 2024
3419c3d
common masked loss func, apply to all training script
kohya-ss Mar 17, 2024
86e40fa
Merge branch 'dev' into deep-speed
kohya-ss Mar 17, 2024
a7dff59
Update tag_images_by_wd14_tagger.py
sdbds Mar 18, 2024
5410a8c
Update requirements.txt
sdbds Mar 18, 2024
a71c35c
Update requirements.txt
sdbds Mar 18, 2024
6c51c97
fix typo
sdbds Mar 20, 2024
e281e86
Merge branch 'main' into dev
kohya-ss Mar 20, 2024
7da41be
Merge pull request #1192 from sdbds/main
kohya-ss Mar 20, 2024
80dbbf5
tagger now stores model under repo_id subdir
kohya-ss Mar 20, 2024
cf09c6a
Merge pull request #1177 from KohakuBlueleaf/random-strength-noise
kohya-ss Mar 20, 2024
46331a9
English Translation of config_README-ja.md (#1175)
darkstorm2150 Mar 20, 2024
5f6196e
update readme
kohya-ss Mar 20, 2024
119cc99
Merge pull request #1167 from Horizon1704/patch-1
kohya-ss Mar 20, 2024
3b0db0f
update readme
kohya-ss Mar 20, 2024
bf6cd4b
Merge pull request #1168 from gesen2egee/save_state_on_train_end
kohya-ss Mar 20, 2024
855add0
update option help and readme
kohya-ss Mar 20, 2024
9b6b39f
Merge branch 'dev' into masked-loss
kohya-ss Mar 20, 2024
fbb98f1
Merge branch 'dev' into deep-speed
kohya-ss Mar 20, 2024
d945602
Fix most of ZeRO stage uses optimizer partitioning
BootsofLagrangian Mar 20, 2024
a35e7bd
Merge pull request #1200 from BootsofLagrangian/deep-speed
kohya-ss Mar 20, 2024
d17c0f5
update dataset config doc
kohya-ss Mar 20, 2024
863c7f7
format by black
kohya-ss Mar 23, 2024
f4a4c11
support multiline captions ref #1155
kohya-ss Mar 23, 2024
0c7baea
register reg images with correct subset
feffy380 Mar 23, 2024
79d1c12
disable sample_every_n_xxx if value less than 1 ref #1202
kohya-ss Mar 24, 2024
691f043
update readme
kohya-ss Mar 24, 2024
ad97410
Merge pull request #1205 from feffy380/patch-1
kohya-ss Mar 24, 2024
381c449
update readme and typing hint
kohya-ss Mar 24, 2024
ae97c8b
[Experimental] Add cache mechanism for dataset groups to avoid long w…
KohakuBlueleaf Mar 24, 2024
0253472
refactor metadata caching for DreamBooth dataset
kohya-ss Mar 24, 2024
8d58588
Merge branch 'dev' into masked-loss
kohya-ss Mar 24, 2024
993b2ab
Merge branch 'dev' into deep-speed
kohya-ss Mar 24, 2024
1648ade
format by black
kohya-ss Mar 24, 2024
9bbb28c
update PyTorch version and reorganize dependencies
kohya-ss Mar 24, 2024
9c4492b
fix pytorch version 2.1.1 to 2.1.2
kohya-ss Mar 24, 2024
c24422f
Merge branch 'dev' into deep-speed
kohya-ss Mar 25, 2024
a2b8531
make each script consistent, fix to work w/o DeepSpeed
kohya-ss Mar 25, 2024
ea05e3f
Merge pull request #1139 from kohya-ss/deep-speed
kohya-ss Mar 26, 2024
ab1e389
Merge branch 'dev' into masked-loss
kohya-ss Mar 26, 2024
5a2afb3
Merge pull request #1207 from kohya-ss/masked-loss
kohya-ss Mar 26, 2024
c86e356
Merge branch 'dev' into dataset-cache
kohya-ss Mar 26, 2024
78e0a76
Merge pull request #1206 from kohya-ss/dataset-cache
kohya-ss Mar 26, 2024
6c08e97
update readme
kohya-ss Mar 26, 2024
6f7e93d
Add OpenVINO and ROCm ONNX Runtime for WD14
Disty0 Mar 27, 2024
b86af67
Merge pull request #1213 from Disty0/dev
kohya-ss Mar 27, 2024
dd9763b
Rating support for WD Tagger
Disty0 Mar 27, 2024
954731d
fix typo
Disty0 Mar 27, 2024
4012fd2
IPEX fix pin_memory
Disty0 Mar 28, 2024
bc586ce
Add --use_rating_tags and --character_tags_first for WD Tagger
Disty0 Mar 29, 2024
f1f30ab
fix to work with num_beams>1 closes #1149
kohya-ss Mar 30, 2024
ae3f625
Merge branch 'dev' of https://github.com/kohya-ss/sd-scripts into dev
kohya-ss Mar 30, 2024
434dc40
update readme
kohya-ss Mar 30, 2024
6ba8428
Merge pull request #1216 from Disty0/dev
kohya-ss Mar 30, 2024
cae5aa0
update wd14 tagger and doc
kohya-ss Mar 30, 2024
f5323e3
update tagger doc
kohya-ss Mar 30, 2024
2c2ca9d
update tagger doc
kohya-ss Mar 30, 2024
059ee04
fix typo
kohya-ss Mar 30, 2024
2258a1b
add save/load hook to remove U-Net/TEs from state
kohya-ss Mar 31, 2024
b748b48
fix attention couple+deep shink cause error in some reso
kohya-ss Apr 3, 2024
cd587ce
verify command line args if wandb is enabled
kohya-ss Apr 4, 2024
921036d
Merge pull request #1240 from kohya-ss/verify-command-line-args
kohya-ss Apr 7, 2024
089727b
update readme
kohya-ss Apr 7, 2024
90b1879
Add option to use Scheduled Huber Loss in all training pipelines to i…
kabachuha Apr 7, 2024
d30ebb2
update readme, add metadata for network module
kohya-ss Apr 7, 2024
dfa3079
update readme
kohya-ss Apr 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add option to use Scheduled Huber Loss in all training pipelines to i…
…mprove resilience to data corruption (#1228)

* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler"s num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <[email protected]>
  • Loading branch information
kabachuha and kohya-ss authored Apr 7, 2024
commit 90b18795fce516cb00735dc43a6ee76ecae8ec83
6 changes: 3 additions & 3 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +368,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +380,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down
79 changes: 75 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
)
parser.add_argument(
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -4842,6 +4862,38 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device):

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way

if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
timestep = timesteps.item()

if args.huber_schedule == "exponential":
alpha = - math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
huber_c = (1 - args.huber_c) / (1 + sigmas)**2 + args.huber_c
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')

timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'l2':
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
else:
raise NotImplementedError(f'Unknown loss type {args.loss_type}')
timesteps = timesteps.long()

return timesteps, huber_c

def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
# Sample noise that we'll add to the latents
Expand All @@ -4862,8 +4914,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand All @@ -4876,8 +4927,28 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps

return noise, noisy_latents, timesteps, huber_c

# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str="mean", loss_type:str="l2", huber_c:float=0.1):

if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber':
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1':
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
return loss

def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
names = []
Expand Down
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -600,7 +600,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand All @@ -616,7 +616,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -458,7 +458,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
4 changes: 2 additions & 2 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -426,7 +426,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
11 changes: 3 additions & 8 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,8 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(b_size,),
device=latents.device,
)
timesteps = timesteps.long()
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
Expand Down Expand Up @@ -457,7 +452,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand Down
4 changes: 2 additions & 2 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def train(args):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -358,7 +358,7 @@ def train(args):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand Down Expand Up @@ -873,7 +873,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

Expand All @@ -588,7 +588,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down
4 changes: 2 additions & 2 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -473,7 +473,7 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand Down