Skip to content

Commit

Permalink
Add CompressedSR task
Browse files Browse the repository at this point in the history
  • Loading branch information
Choiuijin1125 committed Sep 29, 2022
1 parent 212e0ef commit 8f9aa49
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions main_test_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 72,11 @@ def main():
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old h_pad, :]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old w_pad]
output = test(img_lq, model, args, window_size)
output = output[..., :h_old * args.scale, :w_old * args.scale]

if args.task == 'compressed_sr':
output = output[0][..., :h_old * args.scale, :w_old * args.scale]
else:
output = output[..., :h_old * args.scale, :w_old * args.scale]

# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
Expand Down Expand Up @@ -140,6 144,12 @@ def define_model(args):
img_range=1., depths=[6, 6, 6, 6], embed_dim=60, num_heads=[6, 6, 6, 6],
mlp_ratio=2, upsampler='pixelshuffledirect', resi_connection='1conv')
param_key_g = 'params'

elif args.task == 'compressed_sr':
model = net(upscale=args.scale, in_chans=3, img_size=args.training_patch_size, window_size=8,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2, upsampler='pixelshuffle_aux', resi_connection='1conv')
param_key_g = 'params'

# 003 real-world image sr
elif args.task == 'real_sr':
Expand Down Expand Up @@ -180,8 190,8 @@ def define_model(args):

def setup(args):
# 001 classical image sr/ 002 lightweight image sr
if args.task in ['classical_sr', 'lightweight_sr']:
save_dir = f'results/swin2sr_{args.task}_x{args.scale}'
if args.task in ['classical_sr', 'lightweight_sr', 'compressed_sr']:
save_dir = f'results/swinir_{args.task}_x{args.scale}'
folder = args.folder_gt
border = args.scale
window_size = 8
Expand Down Expand Up @@ -213,6 223,11 @@ def get_image_pair(args, path):
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img_lq = cv2.imread(f'{args.folder_lq}/{imgname}x{args.scale}{imgext}', cv2.IMREAD_COLOR).astype(
np.float32) / 255.

elif args.task in ['compressed_sr']:
img_gt = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img_lq = cv2.imread(f'{args.folder_lq}/{imgname}.jpg', cv2.IMREAD_COLOR).astype(
np.float32) / 255.

# 003 real-world image sr (load lq image only)
elif args.task in ['real_sr']:
Expand Down

0 comments on commit 8f9aa49

Please sign in to comment.