aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/revnet_test.py')
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index b2ac4b67c9..26b0847523 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -31,10 +31,11 @@ tfe = tf.contrib.eager
def train_one_iter(model, inputs, labels, optimizer, global_step=None):
"""Train for one iteration."""
- grads, vars_, loss = model.compute_gradients(inputs, labels, training=True)
+ grads, vars_, logits, loss = model.compute_gradients(
+ inputs, labels, training=True)
optimizer.apply_gradients(zip(grads, vars_), global_step=global_step)
- return loss
+ return logits, loss
class RevNetTest(tf.test.TestCase):
@@ -42,6 +43,8 @@ class RevNetTest(tf.test.TestCase):
def setUp(self):
super(RevNetTest, self).setUp()
config = config_.get_hparams_cifar_38()
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("dataset", "cifar-10")
# Reconstruction could cause numerical error, use double precision for tests
config.dtype = tf.float64
config.fused = False # Fused batch norm does not support tf.float64
@@ -94,7 +97,7 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients(self):
"""Test `compute_gradients` function."""
self.model(self.x, training=False) # Initialize model
- grads, vars_, loss = self.model.compute_gradients(
+ grads, vars_, logits, loss = self.model.compute_gradients(
inputs=self.x, labels=self.t, training=True, l2_reg=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
@@ -119,7 +122,7 @@ class RevNetTest(tf.test.TestCase):
def test_compute_gradients_defun(self):
"""Test `compute_gradients` function with defun."""
compute_gradients = tfe.defun(self.model.compute_gradients)
- grads, vars_, _ = compute_gradients(self.x, self.t, training=True)
+ grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True)
self.assertTrue(isinstance(grads, list))
self.assertTrue(isinstance(vars_, list))
self.assertEqual(len(grads), len(vars_))
@@ -131,6 +134,9 @@ class RevNetTest(tf.test.TestCase):
"""Test model training in graph mode."""
with tf.Graph().as_default():
config = config_.get_hparams_cifar_38()
+ config.add_hparam("n_classes", 10)
+ config.add_hparam("dataset", "cifar-10")
+
x = tf.random_normal(
shape=(self.config.batch_size,) + self.config.input_shape)
t = tf.random_uniform(
@@ -138,17 +144,12 @@ class RevNetTest(tf.test.TestCase):
minval=0,
maxval=self.config.n_classes,
dtype=tf.int32)
- global_step = tfe.Variable(0., trainable=False)
+ global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
- model(x)
- updates = model.get_updates_for(x)
-
- x_ = tf.identity(x)
- grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True)
+ grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
- with tf.control_dependencies(updates):
- train_op = optimizer.apply_gradients(
- zip(grads_all, vars_all), global_step=global_step)
+ train_op = optimizer.apply_gradients(
+ zip(grads_all, vars_all), global_step=global_step)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())