-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
make_cifar10_gcn_whitened.py
70 lines (51 loc) · 2.29 KB
/
make_cifar10_gcn_whitened.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
This script makes a dataset of 32x32 contrast normalized, approximately
whitened CIFAR-10 images.
"""
from __future__ import print_function
from pylearn2.utils import serial
from pylearn2.datasets import preprocessing
from pylearn2.utils import string_utils
from pylearn2.datasets.cifar10 import CIFAR10
import textwrap
def main():
data_dir = string_utils.preprocess('${PYLEARN2_DATA_PATH}/cifar10')
print('Loading CIFAR-10 train dataset...')
train = CIFAR10(which_set='train', gcn=55.)
print("Preparing output directory...")
output_dir = data_dir '/pylearn2_gcn_whitened'
serial.mkdir(output_dir)
README = open(output_dir '/README', 'w')
README.write(textwrap.dedent("""
The .pkl files in this directory may be opened in python using cPickle,
pickle, or pylearn2.serial.load.
train.pkl, and test.pkl each contain a pylearn2 Dataset object defining a
labeled dataset of a 32x32 contrast normalized, approximately whitened
version of the CIFAR-10 dataset. train.pkl contains labeled train examples.
test.pkl contains labeled test examples.
preprocessor.pkl contains a pylearn2 ZCA object that was used to
approximately whiten the images. You may want to use this object later to
preprocess other images.
They were created with the pylearn2 script make_cifar10_gcn_whitened.py.
All other files in this directory, including this README, were created
by the same script and are necessary for the other files to function
correctly.
"""))
README.close()
print("Learning the preprocessor and \
preprocessing the unsupervised train data...")
preprocessor = preprocessing.ZCA()
train.apply_preprocessor(preprocessor=preprocessor, can_fit=True)
print('Saving the unsupervised data')
train.use_design_loc(output_dir '/train.npy')
serial.save(output_dir '/train.pkl', train)
print("Loading the test data")
test = CIFAR10(which_set='test', gcn=55.)
print("Preprocessing the test data")
test.apply_preprocessor(preprocessor=preprocessor, can_fit=False)
print("Saving the test data")
test.use_design_loc(output_dir '/test.npy')
serial.save(output_dir '/test.pkl', test)
serial.save(output_dir '/preprocessor.pkl', preprocessor)
if __name__ == "__main__":
main()