forked from Stability-AI/StableCascade
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
3,617 additions
and
227 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,72 @@ | ||
import torch | ||
import torchvision | ||
import numpy as np | ||
from torchtools.transforms import SmartCrop | ||
import math | ||
|
||
class Bucketeer(): | ||
def __init__(self, dataloader, density=256*256, factor=8, ratios=[1/1, 1/2, 3/4, 3/5, 4/5, 6/9, 9/16], reverse_list=True, randomize_p=0.3, randomize_q=0.2, crop_mode='random', p_random_ratio=0.0, interpolate_nearest=False): | ||
assert crop_mode in ['center', 'random', 'smart'] | ||
self.crop_mode = crop_mode | ||
self.ratios = ratios | ||
if reverse_list: | ||
for r in list(ratios): | ||
if 1/r not in self.ratios: | ||
self.ratios.append(1/r) | ||
self.sizes = [(int(((density/r)**0.5//factor)*factor), int(((density*r)**0.5//factor)*factor)) for r in ratios] | ||
self.batch_size = dataloader.batch_size | ||
self.iterator = iter(dataloader) | ||
self.buckets = {s: [] for s in self.sizes} | ||
self.smartcrop = SmartCrop(int(density**0.5), randomize_p, randomize_q) if self.crop_mode=='smart' else None | ||
self.p_random_ratio = p_random_ratio | ||
self.interpolate_nearest = interpolate_nearest | ||
|
||
def get_available_batch(self): | ||
for b in self.buckets: | ||
if len(self.buckets[b]) >= self.batch_size: | ||
batch = self.buckets[b][:self.batch_size] | ||
self.buckets[b] = self.buckets[b][self.batch_size:] | ||
return batch | ||
return None | ||
|
||
def get_closest_size(self, x): | ||
if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio: | ||
best_size_idx = np.random.randint(len(self.ratios)) | ||
else: | ||
w, h = x.size(-1), x.size(-2) | ||
best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios]) | ||
return self.sizes[best_size_idx] | ||
|
||
def get_resize_size(self, orig_size, tgt_size): | ||
if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0: | ||
alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size))) | ||
resize_size = max(alt_min, min(tgt_size)) | ||
else: | ||
alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size))) | ||
resize_size = max(alt_max, max(tgt_size)) | ||
return resize_size | ||
|
||
def __next__(self): | ||
batch = self.get_available_batch() | ||
while batch is None: | ||
elements = next(self.iterator) | ||
for dct in elements: | ||
img = dct['images'] | ||
size = self.get_closest_size(img) | ||
resize_size = self.get_resize_size(img.shape[-2:], size) | ||
if self.interpolate_nearest: | ||
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.NEAREST) | ||
else: | ||
img = torchvision.transforms.functional.resize(img, resize_size, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True) | ||
if self.crop_mode == 'center': | ||
img = torchvision.transforms.functional.center_crop(img, size) | ||
elif self.crop_mode == 'random': | ||
img = torchvision.transforms.RandomCrop(size)(img) | ||
elif self.crop_mode == 'smart': | ||
self.smartcrop.output_size = size | ||
img = self.smartcrop(img) | ||
self.buckets[size].append({**{'images': img}, **{k:dct[k] for k in dct if k != 'images'}}) | ||
batch = self.get_available_batch() | ||
|
||
out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} | ||
return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} |
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,83 @@ | ||
import torch | ||
import torchvision | ||
from torch import nn | ||
from PIL import Image | ||
import numpy as np | ||
import os | ||
|
||
|
||
# MICRO RESNET | ||
class ResBlock(nn.Module): | ||
def __init__(self, channels): | ||
super(ResBlock, self).__init__() | ||
|
||
self.resblock = nn.Sequential( | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(channels, channels, kernel_size=3), | ||
nn.InstanceNorm2d(channels, affine=True), | ||
nn.ReLU(), | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(channels, channels, kernel_size=3), | ||
nn.InstanceNorm2d(channels, affine=True), | ||
) | ||
|
||
def forward(self, x): | ||
out = self.resblock(x) | ||
return out x | ||
|
||
|
||
class Upsample2d(nn.Module): | ||
def __init__(self, scale_factor): | ||
super(Upsample2d, self).__init__() | ||
|
||
self.interp = nn.functional.interpolate | ||
self.scale_factor = scale_factor | ||
|
||
def forward(self, x): | ||
x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') | ||
return x | ||
|
||
|
||
class MicroResNet(nn.Module): | ||
def __init__(self): | ||
super(MicroResNet, self).__init__() | ||
|
||
self.downsampler = nn.Sequential( | ||
nn.ReflectionPad2d(4), | ||
nn.Conv2d(3, 8, kernel_size=9, stride=4), | ||
nn.InstanceNorm2d(8, affine=True), | ||
nn.ReLU(), | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(8, 16, kernel_size=3, stride=2), | ||
nn.InstanceNorm2d(16, affine=True), | ||
nn.ReLU(), | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(16, 32, kernel_size=3, stride=2), | ||
nn.InstanceNorm2d(32, affine=True), | ||
nn.ReLU(), | ||
) | ||
|
||
self.residual = nn.Sequential( | ||
ResBlock(32), | ||
nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), | ||
ResBlock(64), | ||
) | ||
|
||
self.segmentator = nn.Sequential( | ||
nn.ReflectionPad2d(1), | ||
nn.Conv2d(64, 16, kernel_size=3), | ||
nn.InstanceNorm2d(16, affine=True), | ||
nn.ReLU(), | ||
Upsample2d(scale_factor=2), | ||
nn.ReflectionPad2d(4), | ||
nn.Conv2d(16, 1, kernel_size=9), | ||
nn.Sigmoid() | ||
) | ||
|
||
def forward(self, x): | ||
out = self.downsampler(x) | ||
out = self.residual(out) | ||
out = self.segmentator(out) | ||
return out | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 1,36 @@ | ||
# Pidinet | ||
# https://github.com/hellozhuo/pidinet | ||
|
||
import os | ||
import torch | ||
import numpy as np | ||
from einops import rearrange | ||
from .model import pidinet | ||
from .util import annotator_ckpts_path, safe_step | ||
|
||
|
||
class PidiNetDetector: | ||
def __init__(self, device): | ||
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/table5_pidinet.pth" | ||
modelpath = os.path.join(annotator_ckpts_path, "table5_pidinet.pth") | ||
if not os.path.exists(modelpath): | ||
from basicsr.utils.download_util import load_file_from_url | ||
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) | ||
self.netNetwork = pidinet() | ||
self.netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(modelpath)['state_dict'].items()}) | ||
self.netNetwork.to(device).eval().requires_grad_(False) | ||
|
||
def __call__(self, input_image): # , safe=False): | ||
return self.netNetwork(input_image)[-1] | ||
# assert input_image.ndim == 3 | ||
# input_image = input_image[:, :, ::-1].copy() | ||
# with torch.no_grad(): | ||
# image_pidi = torch.from_numpy(input_image).float().cuda() | ||
# image_pidi = image_pidi / 255.0 | ||
# image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') | ||
# edge = self.netNetwork(image_pidi)[-1] | ||
|
||
# if safe: | ||
# edge = safe_step(edge) | ||
# edge = (edge * 255.0).clip(0, 255).astype(np.uint8) | ||
# return edge[0][0] |
Binary file not shown.
Oops, something went wrong.