Skip to content

Commit

Permalink
add a bunch of controlnet code
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloppp committed Jan 19, 2024
1 parent 9a21d5e commit a42a5ee
Show file tree
Hide file tree
Showing 20 changed files with 3,617 additions and 227 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 3,7 @@
dist_file_*
__pycache__/*
*/__pycache__/*
*/**/__pycache__/*
*/**/__pycache__/*
*_latest_output.jpg
*_sample.jpg
jobs/*.sh
510 changes: 510 additions & 0 deletions .ipynb_checkpoints/controlnet test-checkpoint.ipynb

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions bucketeer.py
Original file line number Diff line number Diff line change
@@ -0,0 1,72 @@
import torch
import torchvision
import numpy as np
from torchtools.transforms import SmartCrop
import math

class Bucketeer():
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False):
assert crop_mode in ['center', 'random', 'smart']
self.crop_mode = crop_mode
self.ratios = ratios
if reverse_list:
for r in list(ratios):
if 1/r not in self.ratios:
self.ratios.append(1/r)
self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios]
self.batch_size = dataloader.batch_size
self.iterator = iter(dataloader)
self.buckets = {s: [] for s in self.sizes}
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None
self.p_random_ratio = p_random_ratio
self.interpolate_nearest = interpolate_nearest

def get_available_batch(self):
for b in self.buckets:
if len(self.buckets[b]) >= self.batch_size:
batch = self.buckets[b][:self.batch_size]
self.buckets[b] = self.buckets[b][self.batch_size:]
return batch
return None

def get_closest_size(self, x):
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
best_size_idx = np.random.randint(len(self.ratios))
else:
w, h = x.size(-1), x.size(-2)
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
return self.sizes[best_size_idx]

def get_resize_size(self, orig_size, tgt_size):
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
resize_size = max(alt_min, min(tgt_size))
else:
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
resize_size = max(alt_max, max(tgt_size))
return resize_size

def __next__(self):
batch = self.get_available_batch()
while batch is None:
elements = next(self.iterator)
for dct in elements:
img = dct['images']
size = self.get_closest_size(img)
resize_size = self.get_resize_size(img.shape[-2:], size)
if self.interpolate_nearest:
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST)
else:
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True)
if self.crop_mode == 'center':
img = torchvision.transforms.functional.center_crop(img, size)
elif self.crop_mode == 'random':
img = torchvision.transforms.RandomCrop(size)(img)
elif self.crop_mode == 'smart':
self.smartcrop.output_size = size
img = self.smartcrop(img)
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}})
batch = self.get_available_batch()

out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
Empty file added configs/.gitkeep
Empty file.
510 changes: 510 additions & 0 deletions controlnet test.ipynb

Large diffs are not rendered by default.

Empty file added jobs/.gitkeep
Empty file.
Binary file not shown.
83 changes: 83 additions & 0 deletions modules/cnet_modules/inpainting/saliency_model.py
Original file line number Diff line number Diff line change
@@ -0,0 1,83 @@
import torch
import torchvision
from torch import nn
from PIL import Image
import numpy as np
import os


# MICRO RESNET
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()

self.resblock = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
)

def forward(self, x):
out = self.resblock(x)
return out x


class Upsample2d(nn.Module):
def __init__(self, scale_factor):
super(Upsample2d, self).__init__()

self.interp = nn.functional.interpolate
self.scale_factor = scale_factor

def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
return x


class MicroResNet(nn.Module):
def __init__(self):
super(MicroResNet, self).__init__()

self.downsampler = nn.Sequential(
nn.ReflectionPad2d(4),
nn.Conv2d(3, 8, kernel_size=9, stride=4),
nn.InstanceNorm2d(8, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 16, kernel_size=3, stride=2),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(16, 32, kernel_size=3, stride=2),
nn.InstanceNorm2d(32, affine=True),
nn.ReLU(),
)

self.residual = nn.Sequential(
ResBlock(32),
nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
ResBlock(64),
)

self.segmentator = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(64, 16, kernel_size=3),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
Upsample2d(scale_factor=2),
nn.ReflectionPad2d(4),
nn.Conv2d(16, 1, kernel_size=9),
nn.Sigmoid()
)

def forward(self, x):
out = self.downsampler(x)
out = self.residual(out)
out = self.segmentator(out)
return out


36 changes: 36 additions & 0 deletions modules/cnet_modules/pidinet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 1,36 @@
# Pidinet
# https://github.com/hellozhuo/pidinet

import os
import torch
import numpy as np
from einops import rearrange
from .model import pidinet
from .util import annotator_ckpts_path, safe_step


class PidiNetDetector:
def __init__(self, device):
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth"
modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth")
if not os.path.exists(modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
self.netNetwork = pidinet()
self.netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()})
self.netNetwork.to(device).eval().requires_grad_(False)

def __call__(self, input_image): # , safe=False):
return self.netNetwork(input_image)[-1]
# assert input_image.ndim == 3
# input_image = input_image[:, :, ::-1].copy()
# with torch.no_grad():
# image_pidi = torch.from_numpy(input_image).float().cuda()
# image_pidi = image_pidi / 255.0
# image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w')
# edge = self.netNetwork(image_pidi)[-1]

# if safe:
# edge = safe_step(edge)
# edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
# return edge[0][0]
Binary file not shown.
Loading

0 comments on commit a42a5ee

Please sign in to comment.