From f7fbdc4b2aa52986cdab2e5482ba840457c6428f Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 23 Jan 2023 17:21:04 -0800 Subject: [PATCH] Precalculate .safetensors model hashes after training --- library/train_util.py | 45 +++++++++++++++++++++++++++++++++++++++++++ networks/lora.py | 10 ++++++++++ 2 files changed, 55 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 0fdbadc10..bbc68aaea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -12,6 +12,7 @@ import os import random import hashlib +from io import BytesIO from tqdm import tqdm import torch @@ -25,6 +26,7 @@ import cv2 from einops import rearrange from torch import einsum +import safetensors.torch import library.model_util as model_util @@ -790,6 +792,49 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + # flash attention forwards and backwards # https://arxiv.org/abs/2205.14135 diff --git a/networks/lora.py b/networks/lora.py index 9243f1e1b..bbc65164d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -7,6 +7,8 @@ import os import torch +from library import train_util + class LoRAModule(torch.nn.Module): """ @@ -221,6 +223,14 @@ def save_weights(self, file, dtype, metadata): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + save_file(state_dict, file, metadata) else: torch.save(state_dict, file)