-
Notifications
You must be signed in to change notification settings - Fork 239
/
models.py
91 lines (73 loc) · 3.51 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
def GeneratorCNN(z, hidden_num, output_num, repeat_num, data_format, reuse):
with tf.variable_scope("G", reuse=reuse) as vs:
num_output = int(np.prod([8, 8, hidden_num]))
x = slim.fully_connected(z, num_output, activation_fn=None)
x = reshape(x, 8, 8, hidden_num, data_format)
for idx in range(repeat_num):
x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
if idx < repeat_num - 1:
x = upscale(x, 2, data_format)
out = slim.conv2d(x, 3, 3, 1, activation_fn=None, data_format=data_format)
variables = tf.contrib.framework.get_variables(vs)
return out, variables
def DiscriminatorCNN(x, input_channel, z_num, repeat_num, hidden_num, data_format):
with tf.variable_scope("D") as vs:
# Encoder
x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
prev_channel_num = hidden_num
for idx in range(repeat_num):
channel_num = hidden_num * (idx 1)
x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
x = slim.conv2d(x, channel_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
if idx < repeat_num - 1:
x = slim.conv2d(x, channel_num, 3, 2, activation_fn=tf.nn.elu, data_format=data_format)
#x = tf.contrib.layers.max_pool2d(x, [2, 2], [2, 2], padding='VALID')
x = tf.reshape(x, [-1, np.prod([8, 8, channel_num])])
z = x = slim.fully_connected(x, z_num, activation_fn=None)
# Decoder
num_output = int(np.prod([8, 8, hidden_num]))
x = slim.fully_connected(x, num_output, activation_fn=None)
x = reshape(x, 8, 8, hidden_num, data_format)
for idx in range(repeat_num):
x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
x = slim.conv2d(x, hidden_num, 3, 1, activation_fn=tf.nn.elu, data_format=data_format)
if idx < repeat_num - 1:
x = upscale(x, 2, data_format)
out = slim.conv2d(x, input_channel, 3, 1, activation_fn=None, data_format=data_format)
variables = tf.contrib.framework.get_variables(vs)
return out, z, variables
def int_shape(tensor):
shape = tensor.get_shape().as_list()
return [num if num is not None else -1 for num in shape]
def get_conv_shape(tensor, data_format):
shape = int_shape(tensor)
# always return [N, H, W, C]
if data_format == 'NCHW':
return [shape[0], shape[2], shape[3], shape[1]]
elif data_format == 'NHWC':
return shape
def nchw_to_nhwc(x):
return tf.transpose(x, [0, 2, 3, 1])
def nhwc_to_nchw(x):
return tf.transpose(x, [0, 3, 1, 2])
def reshape(x, h, w, c, data_format):
if data_format == 'NCHW':
x = tf.reshape(x, [-1, c, h, w])
else:
x = tf.reshape(x, [-1, h, w, c])
return x
def resize_nearest_neighbor(x, new_size, data_format):
if data_format == 'NCHW':
x = nchw_to_nhwc(x)
x = tf.image.resize_nearest_neighbor(x, new_size)
x = nhwc_to_nchw(x)
else:
x = tf.image.resize_nearest_neighbor(x, new_size)
return x
def upscale(x, scale, data_format):
_, h, w, _ = get_conv_shape(x, data_format)
return resize_nearest_neighbor(x, (h*scale, w*scale), data_format)