-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path7_Auto_Encoder_Beta.py
More file actions
109 lines (84 loc) · 3.09 KB
/
7_Auto_Encoder_Beta.py
File metadata and controls
109 lines (84 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
from layers import layers
import config
model_name = config.MODEL_NAME
train_dir = config.TRAIN_DIR
test_dir = config.TEST_DIR
checkpoints_dir = config.CHECKPOINTS_DIR
n_trains = config.N_TRAINS ** 2
batch_size = config.BATCH_SIZE
width = config.WIDTH
height = config.HEIGHT
channels = config.CHANNELS
flat = config.FLAT
n_classes = config.N_CLASSES
k = 100
num_imgs = 3
mnist = input_data.read_data_sets('data', one_hot=True)
def get_dict(train=True, batch=True):
if train:
if batch:
batch_x, _ = mnist.train.next_batch(batch_size)
return {x: batch_x}
else:
return {x: mnist.train.images}
else:
if batch:
batch_x, _ = mnist.test.next_batch(batch_size)
return {x: batch_x}
else:
return {x: mnist.test.images}
with tf.name_scope('InputLayer'):
x = tf.placeholder(tf.float32, shape=[None, flat], name='x')
with tf.name_scope('NetworkModel'):
with tf.name_scope('Encoder'):
y1 = layers.ae_layer(x, flat, k)
with tf.name_scope('Decoder'):
y = layers.ae_layer(y1, k, flat)
with tf.name_scope('Train'):
loss = tf.reduce_mean(tf.pow(y - x, 2), name='loss')
train = tf.train.AdamOptimizer().minimize(loss)
with tf.name_scope('Accuracy'):
accuracy = 1 - loss
# Add image summaries
x_img = tf.reshape(x, [-1, height, width, channels]) # input
y_img = tf.reshape(y, [-1, height, width, channels]) # Reconstruct
tf.summary.image('InputImage', x_img, max_outputs=num_imgs)
tf.summary.image('OutputImage', y_img, max_outputs=num_imgs)
# Add scalar summaries
tf.summary.scalar('Loss', loss)
tf.summary.scalar('Accuracy', accuracy)
init_op = tf.global_variables_initializer()
summary_op = tf.summary.merge_all()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
def init():
'''
WARNING! This will override the trained model checkpoints.
'''
with tf.Session() as sess:
# Open protocol for writing files
train_writer = tf.summary.FileWriter(train_dir)
train_writer.add_graph(sess.graph)
test_writer = tf.summary.FileWriter(test_dir)
if not os.path.exists(checkpoints_dir):
os.mkdir(checkpoints_dir)
sess.run(init_op)
for n_train in range(1, n_trains + 1):
print("Training {}...".format(n_train))
_ = sess.run([train], feed_dict=get_dict(train=True, batch=True))
if n_train % 100 == 0:
saver.save(sess, os.path.join(checkpoints_dir, model_name),
global_step=n_train)
# Train
s = sess.run(summary_op, feed_dict=get_dict(train=True, batch=False))
train_writer.add_summary(s, n_train)
# Test
s = sess.run(summary_op, feed_dict=get_dict(train=False, batch=False))
test_writer.add_summary(s, n_train)
def load(sess):
if os.path.exists(checkpoints_dir):
saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
init()