diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/examples/revnet/revnet_test.py')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/revnet_test.py | 27 |
1 files changed, 14 insertions, 13 deletions
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index b2ac4b67c9..26b0847523 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -31,10 +31,11 @@ tfe = tf.contrib.eager def train_one_iter(model, inputs, labels, optimizer, global_step=None): """Train for one iteration.""" - grads, vars_, loss = model.compute_gradients(inputs, labels, training=True) + grads, vars_, logits, loss = model.compute_gradients( + inputs, labels, training=True) optimizer.apply_gradients(zip(grads, vars_), global_step=global_step) - return loss + return logits, loss class RevNetTest(tf.test.TestCase): @@ -42,6 +43,8 @@ class RevNetTest(tf.test.TestCase): def setUp(self): super(RevNetTest, self).setUp() config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 @@ -94,7 +97,7 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients(self): """Test `compute_gradients` function.""" self.model(self.x, training=False) # Initialize model - grads, vars_, loss = self.model.compute_gradients( + grads, vars_, logits, loss = self.model.compute_gradients( inputs=self.x, labels=self.t, training=True, l2_reg=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) @@ -119,7 +122,7 @@ class RevNetTest(tf.test.TestCase): def test_compute_gradients_defun(self): """Test `compute_gradients` function with defun.""" compute_gradients = tfe.defun(self.model.compute_gradients) - grads, vars_, _ = compute_gradients(self.x, self.t, training=True) + grads, vars_, _, _ = compute_gradients(self.x, self.t, training=True) self.assertTrue(isinstance(grads, list)) self.assertTrue(isinstance(vars_, list)) self.assertEqual(len(grads), len(vars_)) @@ -131,6 +134,9 @@ class RevNetTest(tf.test.TestCase): """Test model training in graph mode.""" with tf.Graph().as_default(): config = config_.get_hparams_cifar_38() + config.add_hparam("n_classes", 10) + config.add_hparam("dataset", "cifar-10") + x = tf.random_normal( shape=(self.config.batch_size,) + self.config.input_shape) t = tf.random_uniform( @@ -138,17 +144,12 @@ class RevNetTest(tf.test.TestCase): minval=0, maxval=self.config.n_classes, dtype=tf.int32) - global_step = tfe.Variable(0., trainable=False) + global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) - model(x) - updates = model.get_updates_for(x) - - x_ = tf.identity(x) - grads_all, vars_all, _ = model.compute_gradients(x_, t, training=True) + grads_all, vars_all, _, _ = model.compute_gradients(x, t, training=True) optimizer = tf.train.AdamOptimizer(learning_rate=1e-3) - with tf.control_dependencies(updates): - train_op = optimizer.apply_gradients( - zip(grads_all, vars_all), global_step=global_step) + train_op = optimizer.apply_gradients( + zip(grads_all, vars_all), global_step=global_step) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) |