-
Notifications
You must be signed in to change notification settings - Fork 47
/
train.py
46 lines (37 loc) · 1.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import tensorflow as tf
import pprint
import random
import numpy as np
from models.adversarial_learner import AdversarialLearner
import os
import gflags
import sys
from common_flags import FLAGS
#####################################
# THIS FILE SHOULD REMAIN UNCHANGED #
#####################################
def _main():
# Set random seed for training
seed = 8964
tf.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)
pp = pprint.PrettyPrinter()
print_flags_dict = {}
for key in FLAGS.__flags.keys():
print_flags_dict[key] = getattr(FLAGS, key)
pp.pprint(print_flags_dict)
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
trl = AdversarialLearner()
trl.train(FLAGS)
def main(argv):
# Utility main to load flags
try:
argv = FLAGS(argv) # parse flags
except gflags.FlagsError:
print ('Usage: %s ARGS\\n%s' % (sys.argv[0], FLAGS))
sys.exit(1)
_main()
if __name__ == "__main__":
main(sys.argv)