aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-08-07 10:39:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 10:42:49 -0700
commitff3de98d531c23a574a3531d13c47ec3c27543e1 (patch)
tree5fd69464a6000e9fbc7a51dbfd4766704a99a2a6
parentb8886649c75ae864f2532bca044e2f44fb138c95 (diff)
Fix bug causing TFOptimizer iterations variable to increment twice when gradients are applied.
PiperOrigin-RevId: 207740606
-rw-r--r--tensorflow/python/keras/optimizers.py5
-rw-r--r--tensorflow/python/keras/optimizers_test.py30
2 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/python/keras/optimizers.py b/tensorflow/python/keras/optimizers.py
index 0b440185ca..4f97442e82 100644
--- a/tensorflow/python/keras/optimizers.py
+++ b/tensorflow/python/keras/optimizers.py
@@ -718,10 +718,13 @@ class TFOptimizer(Optimizer, checkpointable.CheckpointableBase):
global_step = training_util.get_global_step()
opt_update = self.optimizer.apply_gradients(grads, global_step)
else:
- self.updates = [state_ops.assign_add(self.iterations, 1)]
if not params:
+ self.updates = [state_ops.assign_add(self.iterations, 1)]
return self.updates
+ # Updates list starts out empty because the iterations variable is
+ # incremented in optimizer.apply_gradients()
+ self.updates = []
grads = self.optimizer.compute_gradients(loss, params)
opt_update = self.optimizer.apply_gradients(
grads, global_step=self.iterations)
diff --git a/tensorflow/python/keras/optimizers_test.py b/tensorflow/python/keras/optimizers_test.py
index 55fc3fdcf4..4d295351f5 100644
--- a/tensorflow/python/keras/optimizers_test.py
+++ b/tensorflow/python/keras/optimizers_test.py
@@ -46,7 +46,11 @@ def _test_optimizer(optimizer, target=0.75):
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
+ np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations),
+ 0)
history = model.fit(x_train, y_train, epochs=2, batch_size=16, verbose=0)
+ np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations),
+ 126) # 63 steps per epoch
assert history.history['acc'][-1] >= target
config = keras.optimizers.serialize(optimizer)
optim = keras.optimizers.deserialize(config)
@@ -66,7 +70,11 @@ def _test_optimizer(optimizer, target=0.75):
model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
+ np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations),
+ 126) # Using same optimizer from before
model.train_on_batch(x_train[:10], y_train[:10])
+ np.testing.assert_equal(keras.backend.get_value(model.optimizer.iterations),
+ 127)
kernel, bias = dense.get_weights()
np.testing.assert_allclose(kernel, 1., atol=1e-3)
np.testing.assert_allclose(bias, 2., atol=1e-3)
@@ -145,6 +153,28 @@ class KerasOptimizersTest(test.TestCase):
with self.assertRaises(NotImplementedError):
optimizer.from_config(None)
+ def test_tfoptimizer_iterations(self):
+ with self.test_session():
+ optimizer = keras.optimizers.TFOptimizer(AdamOptimizer(0.01))
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(
+ 2, input_shape=(3,), kernel_constraint=keras.constraints.MaxNorm(1)))
+ model.compile(loss='mean_squared_error', optimizer=optimizer)
+ self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 0)
+
+ model.fit(np.random.random((55, 3)),
+ np.random.random((55, 2)),
+ epochs=1,
+ batch_size=5,
+ verbose=0)
+ self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 11)
+
+ model.fit(np.random.random((20, 3)),
+ np.random.random((20, 2)),
+ steps_per_epoch=8,
+ verbose=0)
+ self.assertEqual(keras.backend.get_value(model.optimizer.iterations), 19)
+
def test_negative_clipvalue_or_clipnorm(self):
with self.assertRaises(ValueError):
_ = keras.optimizers.SGD(lr=0.01, clipvalue=-0.5)