Skip to content

Commit

Permalink
Precalculate .safetensors model hashes after training
Browse files Browse the repository at this point in the history
  • Loading branch information
space-nuko committed Jan 24, 2023
1 parent 93df55d commit f7fbdc4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
45 changes: 45 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 12,7 @@
import os
import random
import hashlib
from io import BytesIO

from tqdm import tqdm
import torch
Expand All @@ -25,6 26,7 @@
import cv2
from einops import rearrange
from torch import einsum
import safetensors.torch

import library.model_util as model_util

Expand Down Expand Up @@ -790,6 792,49 @@ def calculate_sha256(filename):
return hash_sha256.hexdigest()


def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""

# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}

bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)

model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash


def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()

b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]


def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024

b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")

offset = n 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)

return hash_sha256.hexdigest()


# flash attention forwards and backwards

# https://arxiv.org/abs/2205.14135
Expand Down
10 changes: 10 additions & 0 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 7,8 @@
import os
import torch

from library import train_util


class LoRAModule(torch.nn.Module):
"""
Expand Down Expand Up @@ -221,6 223,14 @@ def save_weights(self, file, dtype, metadata):

if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file

# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash

save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)

0 comments on commit f7fbdc4

Please sign in to comment.