aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training')
-rwxr-xr-xtensorflow/python/training/__init__.py0
-rw-r--r--tensorflow/python/training/adagrad.py58
-rw-r--r--tensorflow/python/training/adagrad_test.py144
-rw-r--r--tensorflow/python/training/adam.py142
-rw-r--r--tensorflow/python/training/adam_test.py174
-rw-r--r--tensorflow/python/training/checkpoint_state.proto18
-rw-r--r--tensorflow/python/training/coordinator.py186
-rw-r--r--tensorflow/python/training/coordinator_test.py98
-rw-r--r--tensorflow/python/training/ftrl.py283
-rw-r--r--tensorflow/python/training/ftrl_test.py234
-rw-r--r--tensorflow/python/training/gradient_descent.py44
-rw-r--r--tensorflow/python/training/gradient_descent_test.py105
-rw-r--r--tensorflow/python/training/input.py501
-rw-r--r--tensorflow/python/training/input_test.py477
-rw-r--r--tensorflow/python/training/learning_rate_decay.py65
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py60
-rw-r--r--tensorflow/python/training/momentum.py51
-rw-r--r--tensorflow/python/training/momentum_test.py258
-rw-r--r--tensorflow/python/training/moving_averages.py247
-rw-r--r--tensorflow/python/training/moving_averages_test.py130
-rw-r--r--tensorflow/python/training/optimizer.py426
-rw-r--r--tensorflow/python/training/queue_runner.py233
-rw-r--r--tensorflow/python/training/queue_runner_test.py186
-rw-r--r--tensorflow/python/training/rmsprop.py81
-rw-r--r--tensorflow/python/training/rmsprop_test.py158
-rw-r--r--tensorflow/python/training/saver.proto30
-rw-r--r--tensorflow/python/training/saver.py887
-rw-r--r--tensorflow/python/training/saver_test.py563
-rw-r--r--tensorflow/python/training/summary_io.py226
-rw-r--r--tensorflow/python/training/summary_writer_test.py151
-rw-r--r--tensorflow/python/training/training.py138
-rw-r--r--tensorflow/python/training/training_ops.py115
-rw-r--r--tensorflow/python/training/training_ops_test.py159
-rw-r--r--tensorflow/python/training/training_util.py57
34 files changed, 6685 insertions, 0 deletions
diff --git a/tensorflow/python/training/__init__.py b/tensorflow/python/training/__init__.py
new file mode 100755
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/training/__init__.py
diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py
new file mode 100644
index 0000000000..41cf2e00f4
--- /dev/null
+++ b/tensorflow/python/training/adagrad.py
@@ -0,0 +1,58 @@
+"""Adagrad for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class AdagradOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Adagrad algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, initial_accumulator_value=0.1,
+ use_locking=False, name="Adagrad"):
+ """Construct a new Adagrad optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ initial_accumulator_value: A floating point value.
+ Starting value for the accumulators, must be positive.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Adagrad".
+
+ Raises:
+ ValueError: If the initial_accumulator_value is invalid.
+ """
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(AdagradOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._initial_accumulator_value = initial_accumulator_value
+ # Created in Initialize.
+ self._learning_rate_tensor = None
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ val = constant_op.constant(self._initial_accumulator_value,
+ shape=v.get_shape())
+ self._get_or_make_slot(v, val, "accumulator", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+
+ def _apply_dense(self, grad, var):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var, acc, self._learning_rate_tensor, grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var, acc, self._learning_rate_tensor, grad.values, grad.indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py
new file mode 100644
index 0000000000..ee83791eb5
--- /dev/null
+++ b/tensorflow/python/training/adagrad_test.py
@@ -0,0 +1,144 @@
+"""Functional tests for aggregate operations."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class AdagradOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ ada_opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllClose(np.array([-1.6026098728179932, -0.6026098728179932]),
+ var0.eval())
+ self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]),
+ var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testSparseBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([[1.0], [2.0]])
+ var1 = tf.Variable([[3.0], [4.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ ada_opt = tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllClose([[-1.6026098728179932], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [3.715679168701172]], var1.eval())
+
+ def testSparseStability(self):
+ with self.test_session():
+ shape = [1, 6]
+ var0 = tf.Variable([[0.00872496, -0.106952, 0.110467, 0.226505,
+ -0.0147257, -0.0105945]])
+ grads0 = tf.IndexedSlices(
+ tf.constant(
+ [[-5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05]],
+ shape=shape),
+ tf.constant([0]),
+ tf.constant(shape))
+ ada_opt = tf.train.AdagradOptimizer(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = tf.initialize_all_variables()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllClose([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], slot0.eval())
+ self.assertAllClose(
+ [[0.00891194, -0.10712013, 0.11047515, 0.22636929,
+ -0.0144573, -0.01029443]], var0.eval())
+
+ def testSharing(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ ada_opt = tf.train.AdagradOptimizer(3.0)
+ # Apply the optimizer twice. Both applications will use the same accums.
+ ada_update1 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllClose(np.array([-1.6026098728179932, -0.6026098728179932]),
+ var0.eval())
+ self.assertAllClose(np.array([2.715679168701172, 3.715679168701172]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
new file mode 100644
index 0000000000..266430bb13
--- /dev/null
+++ b/tensorflow/python/training/adam.py
@@ -0,0 +1,142 @@
+"""Adam for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class AdamOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Adam algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
+ use_locking=False, name="Adam"):
+ """Construct a new Adam optimizer.
+
+ Implementation is based on: http://arxiv.org/pdf/1412.6980v7.pdf
+
+ Initialization:
+
+ ```
+ m_0 <- 0 (Initialize initial 1st moment vector)
+ v_0 <- 0 (Initialize initial 2nd moment vector)
+ t <- 0 (Initialize timestep)
+ ```
+
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ ```
+ t <- t + 1
+ lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
+
+ m_t <- beta1 * m_{t-1} + (1 - beta1) * g
+ v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
+ variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
+ ```
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ beta1: A float value or a constant float tensor.
+ The exponential decay rate for the 1st moment estimates.
+ beta2: A float value or a constant float tensor.
+ The exponential decay rate for the 2st moment estimates.
+ epsilon: A small constant for numerical stability.
+ use_locking: If True use locks for update operation.s
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+ super(AdamOptimizer, self).__init__(use_locking, name)
+ self._lr = learning_rate
+ self._beta1 = beta1
+ self._beta2 = beta2
+ self._epsilon = epsilon
+
+ # Tensor versions of the constructor arguments, created in _prepare().
+ self._lr_t = None
+ self._beta1_t = None
+ self._beta2_t = None
+ self._epsilon_t = None
+
+ # Variables to accumulate the powers of the beta parameters.
+ # Created in _create_slots when we know the variables to optimize.
+ self._beta1_power = None
+ self._beta2_power = None
+
+ # Created in SparseApply if needed.
+ self._updated_lr = None
+
+ def _get_beta_accumulators(self):
+ return self._beta1_power, self._beta2_power
+
+ def _create_slots(self, var_list):
+ # Create the beta1 and beta2 accumulators on the same device as the first
+ # variable.
+ if self._beta1_power is None:
+ with ops.device(var_list[0].device):
+ self._beta1_power = variables.Variable(self._beta1, name="beta1_power")
+ self._beta2_power = variables.Variable(self._beta2, name="beta2_power")
+ # Create slots for the first and second moments.
+ for v in var_list:
+ self._zeros_slot(v, "m", self._name)
+ self._zeros_slot(v, "v", self._name)
+
+ def _prepare(self):
+ self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
+ self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
+ self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
+ self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
+
+ def _apply_dense(self, grad, var):
+ m = self.get_slot(var, "m")
+ v = self.get_slot(var, "v")
+ return training_ops.apply_adam(
+ var, m, v, self._beta1_power, self._beta2_power,
+ self._lr_t, self._beta1_t, self._beta2_t,
+ self._epsilon_t, grad, use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ lr = (self._lr_t *
+ math_ops.sqrt(1 - self._beta2_power)
+ / (1 - self._beta1_power))
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, "m")
+ m_scaled_g_values = grad.values * (1 - self._beta1_t)
+ m_t = state_ops.assign(m, m * self._beta1_t,
+ use_locking=self._use_locking)
+ m_t = state_ops.scatter_add(m_t, grad.indices, m_scaled_g_values,
+ use_locking=self._use_locking)
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, "v")
+ v_scaled_g_values = (grad.values * grad.values) * (1 - self._beta2_t)
+ v_t = state_ops.assign(v, v * self._beta2_t, use_locking=self._use_locking)
+ v_t = state_ops.scatter_add(v_t, grad.indices, v_scaled_g_values,
+ use_locking=self._use_locking)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + self._epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _finish(self, update_ops, name_scope):
+ # Update the power accumulators.
+ with ops.control_dependencies(update_ops):
+ with ops.device(self._beta1_power.device):
+ update_beta1 = self._beta1_power.assign(
+ self._beta1_power * self._beta1_t,
+ use_locking=self._use_locking)
+ update_beta2 = self._beta2_power.assign(
+ self._beta2_power * self._beta2_t,
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
+ name=name_scope)
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
new file mode 100644
index 0000000000..f92728d0c7
--- /dev/null
+++ b/tensorflow/python/training/adam_test.py
@@ -0,0 +1,174 @@
+"""Tests for Adam."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+def adam_update_numpy(param, g_t, t, m, v, alpha=0.001, beta1=0.9, beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(tf.test.TestCase):
+
+ def testSparse(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = tf.IndexedSlices(tf.constant(grads0_np),
+ tf.constant(grads0_np_indices),
+ tf.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = tf.IndexedSlices(tf.constant(grads1_np),
+ tf.constant(grads1_np_indices),
+ tf.constant([2]))
+ opt = tf.train.AdamOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testBasic(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0 = tf.constant(grads0_np)
+ grads1 = tf.constant(grads1_np)
+ opt = tf.train.AdamOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.AdamOptimizer()
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testSharing(self):
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=np.float32)
+ grads0_np = np.array([0.1, 0.1], dtype=np.float32)
+ var1_np = np.array([3.0, 4.0], dtype=np.float32)
+ grads1_np = np.array([0.01, 0.01], dtype=np.float32)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0 = tf.constant(grads0_np)
+ grads1 = tf.constant(grads1_np)
+ opt = tf.train.AdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllClose(0.9 ** t, beta1_power.eval())
+ self.assertAllClose(0.999 ** t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/checkpoint_state.proto b/tensorflow/python/training/checkpoint_state.proto
new file mode 100644
index 0000000000..1f521341f1
--- /dev/null
+++ b/tensorflow/python/training/checkpoint_state.proto
@@ -0,0 +1,18 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Protocol buffer representing the checkpoint state.
+//
+// TODO(mdevin): Add other attributes as needed.
+message CheckpointState {
+ // Path to the most-recent model checkpoint.
+ string model_checkpoint_path = 1;
+
+ // Paths to all not-yet-deleted model checkpoints, sorted from oldest to
+ // newest.
+ // Note that the value of model_checkpoint_path should be the last item in
+ // this list.
+ repeated string all_model_checkpoint_paths = 2;
+}
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
new file mode 100644
index 0000000000..f090e6d222
--- /dev/null
+++ b/tensorflow/python/training/coordinator.py
@@ -0,0 +1,186 @@
+"""Coordinator to help multiple threads stop when requested."""
+import sys
+import threading
+import time
+
+from tensorflow.python.platform import logging
+
+
+class Coordinator(object):
+ """A coordinator for threads.
+
+ This class implements a simple mechanism to coordinate the termination of a
+ set of threads.
+
+ #### Usage:
+
+ ```python
+ # Create a coordinator.
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ ```
+
+ Any of the threads can call `coord.request_stop()` to ask for all the threads
+ to stop. To cooperate with the requests, each thread must check for
+ `coord.should_stop()` on a regular basis. `coord.should_stop()` returns
+ `True` as soon as `coord.request_stop()` has been called.
+
+ A typical thread running with a Coordinator will do something like:
+
+ ```python
+ while not coord.should_stop():
+ ...do some work...
+ ```
+
+ #### Exception handling:
+
+ A thread can report an exception to the Coordinator as part of the
+ `should_stop()` call. The exception will be re-raised from the
+ `coord.join()` call.
+
+ Thread code:
+
+ ```python
+ try:
+ while not coord.should_stop():
+ ...do some work...
+ except Exception, e:
+ coord.request_stop(e)
+ ```
+
+ Main code:
+
+ ```python
+ try:
+ ...
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate.
+ coord.join(threads)
+ except Exception, e:
+ ...exception that was passed to coord.request_stop()
+ ```
+
+ #### Grace period for stopping:
+
+ After a thread has called `coord.request_stop()` the other threads have a
+ fixed time to stop, this is called the 'stop grace period' and defaults to 2
+ minutes. If any of the threads is still alive after the grace period expires
+ `coord.join()` raises a RuntimeException reporting the laggards.
+
+ ```
+ try:
+ ...
+ coord = Coordinator()
+ # Start a number of threads, passing the coordinator to each of them.
+ ...start thread 1...(coord, ...)
+ ...start thread N...(coord, ...)
+ # Wait for all the threads to terminate, give them 10s grace period
+ coord.join(threads, stop_grace_period_secs=10)
+ except RuntimeException:
+ ...one of the threads took more than 10s to stop after request_stop()
+ ...was called.
+ except Exception:
+ ...exception that was passed to coord.request_stop()
+ ```
+ """
+
+ def __init__(self):
+ """Create a new Coordinator."""
+ # Protects all attributes.
+ self._lock = threading.Lock()
+ # Event set when threads must stop.
+ self._stop_event = threading.Event()
+ # Python exc_info to report.
+ self._exc_info_to_raise = None
+
+ def request_stop(self, ex=None):
+ """Request that the threads stop.
+
+ After this is called, calls to should_stop() will return True.
+
+ Args:
+ ex: Optional Exception, or Python 'exc_info' tuple as returned by
+ sys.exc_info(). If this is the first call to request_stop() the
+ corresponding exception is recorded and re-raised from join().
+ """
+ with self._lock:
+ if not self._stop_event.is_set():
+ if ex and self._exc_info_to_raise is None:
+ if isinstance(ex, tuple):
+ logging.info("Error reported to Coordinator: %s", str(ex[1]))
+ self._exc_info_to_raise = ex
+ else:
+ logging.info("Error reported to Coordinator: %s", str(ex))
+ self._exc_info_to_raise = sys.exc_info()
+ self._stop_event.set()
+
+ def should_stop(self):
+ """Check if stop was requested.
+
+ Returns:
+ True if a stop was requested.
+ """
+ return self._stop_event.is_set()
+
+ def wait_for_stop(self, timeout=None):
+ """Wait till the Coordinator is told to stop.
+
+ Args:
+ timeout: float. Sleep for up to that many seconds waiting for
+ should_stop() to become True.
+
+ Returns:
+ True if the Coordinator is told stop, False if the timeout expired.
+ """
+ return self._stop_event.wait(timeout)
+
+ def join(self, threads, stop_grace_period_secs=120):
+ """Wait for threads to terminate.
+
+ Blocks until all 'threads' have terminated or request_stop() is called.
+
+ After the threads stop, if an 'exc_info' was passed to request_stop, that
+ exception is re-reaised.
+
+ Grace period handling: When request_stop() is called, threads are given
+ 'stop_grace_period_secs' seconds to terminate. If any of them is still
+ alive after that period expires, a RuntimeError is raised. Note that if
+ an 'exc_info' was passed to request_stop() then it is raised instead of
+ that RuntimeError.
+
+ Args:
+ threads: List threading.Threads. The started threads to join.
+ stop_grace_period_secs: Number of seconds given to threads to stop after
+ request_stop() has been called.
+
+ Raises:
+ RuntimeError: If any thread is still alive after request_stop()
+ is called and the grace period expires.
+ """
+ # Wait for all threads to stop or for request_stop() to be called.
+ while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0):
+ pass
+
+ # If any thread is still alive, wait for the grace period to expire.
+ while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0:
+ stop_grace_period_secs -= 1.0
+ time.sleep(1.0)
+
+ # List the threads still alive after the grace period.
+ stragglers = [t.name for t in threads if t.is_alive()]
+
+ # Terminate with an exception if appropriate.
+ with self._lock:
+ if self._exc_info_to_raise:
+ exc_info = self._exc_info_to_raise
+ raise exc_info[0], exc_info[1], exc_info[2]
+ elif stragglers:
+ raise RuntimeError("Coordinator stopped with threads still running: %s",
+ " ".join(stragglers))
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
new file mode 100644
index 0000000000..ce9126caf4
--- /dev/null
+++ b/tensorflow/python/training/coordinator_test.py
@@ -0,0 +1,98 @@
+"""Tests for Coordinator."""
+import sys
+import threading
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+def StopInN(coord, n_secs):
+ time.sleep(n_secs)
+ coord.request_stop()
+
+
+def RaiseInN(coord, n_secs, ex, report_exception):
+ try:
+ time.sleep(n_secs)
+ raise ex
+ except RuntimeError, e:
+ if report_exception:
+ coord.request_stop(e)
+ else:
+ coord.request_stop(sys.exc_info())
+
+
+def SleepABit(n_secs):
+ time.sleep(n_secs)
+
+
+class CoordinatorTest(tf.test.TestCase):
+
+ def testStopAPI(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ coord.request_stop()
+ self.assertTrue(coord.should_stop())
+ self.assertTrue(coord.wait_for_stop(0.01))
+
+ def testStopAsync(self):
+ coord = tf.train.Coordinator()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.1))
+ threading.Thread(target=StopInN, args=(coord, 0.02)).start()
+ self.assertFalse(coord.should_stop())
+ self.assertFalse(coord.wait_for_stop(0.01))
+ self.assertTrue(coord.wait_for_stop(0.03))
+ self.assertTrue(coord.should_stop())
+
+ def testJoin(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=SleepABit, args=(0.01,)),
+ threading.Thread(target=SleepABit, args=(0.02,)),
+ threading.Thread(target=SleepABit, args=(0.01,))]
+ for t in threads:
+ t.start()
+ coord.join(threads)
+
+ def testJoinGraceExpires(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=StopInN, args=(coord, 0.01)),
+ threading.Thread(target=SleepABit, args=(10.0,))]
+ for t in threads:
+ t.daemon = True
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "threads still running"):
+ coord.join(threads, stop_grace_period_secs=0.02)
+
+ def testJoinRaiseReportExcInfo(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), False)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), False))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+ def testJoinRaiseReportException(self):
+ coord = tf.train.Coordinator()
+ threads = [
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.01, RuntimeError("First"), True)),
+ threading.Thread(target=RaiseInN,
+ args=(coord, 0.02, RuntimeError("Too late"), True))]
+ for t in threads:
+ t.start()
+ with self.assertRaisesRegexp(RuntimeError, "First"):
+ coord.join(threads)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/ftrl.py b/tensorflow/python/training/ftrl.py
new file mode 100644
index 0000000000..6b9471a5ed
--- /dev/null
+++ b/tensorflow/python/training/ftrl.py
@@ -0,0 +1,283 @@
+"""FTRL-Proximal for Tensor Flow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer
+
+
+def _Solve(a, b, c):
+ """Return solution of a quadratic minimization.
+
+ The optimization equation is:
+ f(a, b, c) = argmin_w{1/2 * a * w^2 + b * w + c * |w|}
+ we get optimal solution w*:
+ w* = -(b - sign(b)*c)/a if |b| > c else w* = 0
+
+ REQUIRES: Dimensionality of a and b must be same
+
+ Args:
+ a: A Tensor
+ b: A Tensor
+ c: A Tensor with one element.
+
+ Returns:
+ A Tensor w, which is solution for the equation
+ """
+ with ops.name_scope("solve_" + b.op.name):
+ c = ops.convert_to_tensor(c)
+ k = array_ops.fill(array_ops.shape(b), c)
+ zero_t = array_ops.zeros(array_ops.shape(b), dtype=b.dtype)
+ w = (c * math_ops.sign(b) - b) / a
+ w = math_ops.select(math_ops.less(math_ops.abs(b), k), zero_t, w)
+ return w
+
+
+def _Compute(accum, linear, base_lr, lr_power, l1, l2):
+ """Compute "variable" given current "accum" and "linear".
+
+ REQUIRES: Dimensionality of accum and linear must be same.
+
+ Args:
+ accum: A Tensor which is accumulated gradient square.
+ linear: A Tensor with same size of accum.
+ base_lr: A Tensor which is base learning rate
+ lr_power: A Tensor which is learning rate power
+ l1: A Tensor which is l1_regularization strength
+ l2: A Tensor which is l2_regularization strength
+ Returns:
+ A Tensor which is "variable" after update
+ """
+ with ops.name_scope("compute_" + accum.op.name):
+ one_t = constant_op.constant(1.0, dtype=types.float32)
+ two_t = constant_op.constant(2.0, dtype=types.float32)
+ learning_rate = math_ops.pow(accum, lr_power) * base_lr
+ quadratic = one_t / learning_rate + two_t * l2
+ w = _Solve(quadratic, linear, l1)
+ return w
+
+
+def _Update(variable, gradients, accum, linear, base_lr, lr_power, l1, l2):
+ """Update "variable", "accum", "linear" based on "gradients".
+
+ Some notations here: "variable" as W, "accum" as N, "linear" as Z,
+ "gradients" as G, N(t) means "accum" at t-step.
+ Assuming lr_power = -0.5 which means using adagrad learning rate.
+ "accum" updates as: N = N + G^2
+ "linear" updates as: Z = Z + G - W * (sqrt(N(t)) - sqrt(N(t-1)))/base_lr
+ REQUIRES: Dimensionality of variable, gradients, accum and linear
+ must be same.
+
+ Args:
+ variable: A Variable.
+ gradients: A Tensor of same shape as 'variable'.
+ accum: A Variable containing the sum of the squares of gradients.
+ linear: A Variable containing approximation info.
+ base_lr: A constant represents base learning rate.
+ lr_power: A constant is used to adjust learning rate.
+ l1: A constant represents l1 regularization strength.
+ l2: A constant represents l2 regularization strength.
+
+ Returns:
+ A group op including three Assign ops:
+ 1. Assign for "accum"
+ 2. Assign for "linear"
+ 3. Assign for "variable"
+ """
+ dtype = variable.dtype.base_dtype
+ base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
+ lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
+ l1 = ops.convert_to_tensor(l1, dtype=dtype)
+ l2 = ops.convert_to_tensor(l2, dtype=dtype)
+ # Compute the new accumulator
+ sqr_grad = math_ops.square(gradients)
+ accum_updated = sqr_grad + accum
+ # Compute the new linear
+ neg_lr_power = math_ops.neg(lr_power)
+ sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
+ accum, neg_lr_power)
+ sigma /= base_lr
+ proximal_adjust = sigma * variable
+ linear_updated = linear + gradients - proximal_adjust
+ # Compute the "variable"
+ variable_updated = _Compute(accum_updated, linear_updated, base_lr,
+ lr_power, l1, l2)
+
+ with ops.control_dependencies([sigma]):
+ accum_update_op = state_ops.assign(accum, accum_updated)
+ linear_update_op = state_ops.assign(linear, linear_updated)
+ variable_update_op = state_ops.assign(variable, variable_updated)
+ group_op = control_flow_ops.group(linear_update_op, accum_update_op,
+ variable_update_op)
+ return group_op
+
+
+# TODO(xbing): Refactor code to make _SparseUpdate and _Update share
+# common routines.
+def _SparseUpdate(variable, gradients, accum, linear, base_lr,
+ lr_power, l1, l2):
+ """Sparse Update "variable", "accum", "linear" based on sparse "gradients".
+
+ See the description in _Update.
+
+ Args:
+ variable: A Variable.
+ gradients: A Sparse Tensor
+ accum: A Variable containing the sum of the squares of gradients.
+ linear: A Variable containing approximation info.
+ base_lr: A constant represents base learning rate.
+ lr_power: A constant is used to adjust learning rate.
+ l1: A constant represents l1 regularization strength.
+ l2: A constant represents l2 regularization strength.
+
+ Returns:
+ A group op including three ScatterUpdate ops:
+ 1. ScatterUpdate for "accum"
+ 2. ScatterUpdate for "linear"
+ 3. ScatterUpdate for "variable"
+ """
+ assert isinstance(gradients, ops.IndexedSlices)
+ with ops.name_scope("sparse_update_" + variable.op.name) as scope:
+ dtype = variable.dtype.base_dtype
+ base_lr = ops.convert_to_tensor(base_lr, dtype=dtype)
+ lr_power = ops.convert_to_tensor(lr_power, dtype=dtype)
+ l1 = ops.convert_to_tensor(l1, dtype=dtype)
+ l2 = ops.convert_to_tensor(l2, dtype=dtype)
+
+ # Compute the new value for the accumulator
+ previous_accum = array_ops.gather(accum, gradients.indices)
+ sqr_grad = gradients.values * gradients.values
+ accum_updated = sqr_grad + previous_accum
+
+ # Compute the new linear
+ neg_lr_power = math_ops.neg(lr_power)
+ sigma = math_ops.pow(accum_updated, neg_lr_power) - math_ops.pow(
+ previous_accum, neg_lr_power)
+ sigma /= base_lr
+ variable_slice = array_ops.gather(variable, gradients.indices)
+ proximal_adjust = sigma * variable_slice
+ linear_slice = array_ops.gather(linear, gradients.indices)
+ linear_updated = linear_slice + gradients.values - proximal_adjust
+
+ # Compute the new "variable"
+ variable_updated = _Compute(accum_updated, linear_updated, base_lr,
+ lr_power, l1, l2)
+
+ with ops.control_dependencies([sigma]):
+ accum_update_op = state_ops.scatter_update(accum, gradients.indices,
+ accum_updated)
+ linear_update_op = state_ops.scatter_update(linear, gradients.indices,
+ linear_updated)
+ variable_update_op = state_ops.scatter_update(variable, gradients.indices,
+ variable_updated)
+ group_op = control_flow_ops.group(linear_update_op, accum_update_op,
+ variable_update_op, name=scope)
+ return group_op
+
+
+class FtrlOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the FTRL algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate,
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0,
+ use_locking=False, name="Ftrl"):
+ """Construct a new FTRL optimizer.
+
+ The Ftrl-proximal algorithm, abbreviated for Follow-the-regularized-leader,
+ is described in the paper [Ad Click Prediction: a View from the Trenches](
+ https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf).
+
+ It can give a good performance vs. sparsity tradeoff.
+
+ Ftrl-proximal uses its own global base learning rate and can behave like
+ Adagrad with `learning_rate_power=-0.5`, or like gradient descent with
+ `learning_rate_power=0.0`.
+
+ The effective learning rate is adjusted per parameter, relative to this
+ base learning rate as:
+
+ ```
+ effective_learning_rate_i = (learning_rate /
+ pow(k + summed_squared_gradients_for_i, learning_rate_power));
+ ```
+
+ where k is the small constant `initial_accumulator_value`.
+
+ Note that the real regularization coefficient of `|w|^2` for objective
+ function is `1 / lambda_2` if specifying `l2 = lambda_2` as argument when
+ using this function.
+
+ Args:
+ learning_rate: A float value or a constant float `Tensor`.
+ learning_rate_power: A float value, must be less or equal to zero.
+ initial_accumulator_value: The starting value for accumulators.
+ Only positive values are allowed.
+ l1_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ l2_regularization_strength: A float value, must be greater than or
+ equal to zero.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Ftrl".
+
+ Raises:
+ ValueError: if one of the arguments is invalid.
+ """
+ super(FtrlOptimizer, self).__init__(use_locking, name)
+
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value %f needs to be positive" %
+ initial_accumulator_value)
+ if learning_rate_power > 0.0:
+ raise ValueError("learning_rate_power %f needs to be negative or zero" %
+ learning_rate_power)
+ if l1_regularization_strength < 0.0:
+ raise ValueError(
+ "l1_regularization_strength %f needs to be positive or zero" %
+ l1_regularization_strength)
+ if l2_regularization_strength < 0.0:
+ raise ValueError(
+ "l2_regularization_strength %f needs to be positive or zero" %
+ l2_regularization_strength)
+
+ self._learning_rate = learning_rate
+ self._learning_rate_power = learning_rate_power
+ self._initial_accumulator_value = initial_accumulator_value
+ self._l1_regularization_strength = l1_regularization_strength
+ self._l2_regularization_strength = l2_regularization_strength
+
+ def _create_slots(self, var_list):
+ # Create the "accum" and "linear" slots.
+ for v in var_list:
+ self._get_or_make_slot(
+ v,
+ constant_op.constant(self._initial_accumulator_value,
+ dtype=v.dtype, shape=v.get_shape()),
+ "accum",
+ self._name)
+ self._zeros_slot(v, "linear", self._name)
+
+ def _apply_dense(self, grad, var):
+ accum = self.get_slot(var, "accum")
+ linear = self.get_slot(var, "linear")
+ return _Update(var, grad, accum, linear,
+ self._learning_rate, self._learning_rate_power,
+ self._l1_regularization_strength,
+ self._l2_regularization_strength)
+
+ def _apply_sparse(self, grad, var):
+ accum = self.get_slot(var, "accum")
+ linear = self.get_slot(var, "linear")
+ return _SparseUpdate(var, grad, accum, linear,
+ self._learning_rate, self._learning_rate_power,
+ self._l1_regularization_strength,
+ self._l2_regularization_strength)
diff --git a/tensorflow/python/training/ftrl_test.py b/tensorflow/python/training/ftrl_test.py
new file mode 100644
index 0000000000..eb581048f1
--- /dev/null
+++ b/tensorflow/python/training/ftrl_test.py
@@ -0,0 +1,234 @@
+"""Functional tests for Ftrl operations."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class FtrlOptimizerTest(tf.test.TestCase):
+
+ def testFtrlwithoutRegularization(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([0.0, 0.0])
+ var1 = tf.Variable([0.0, 0.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([0.0, 0.0], v0_val)
+ self.assertAllClose([0.0, 0.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-2.60260963, -4.29698515]),
+ v0_val)
+ self.assertAllClose(np.array([-0.28432083, -0.56694895]),
+ v1_val)
+
+ def testFtrlwithoutRegularization2(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 3 steps FTRL
+ for _ in range(3):
+ update.run()
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-2.55607247, -3.98729396]),
+ v0_val)
+ self.assertAllClose(np.array([-0.28232238, -0.56096673]),
+ v1_val)
+
+ def testFtrlWithL1(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=0.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-7.66718769, -10.91273689]),
+ v0_val)
+ self.assertAllClose(np.array([-0.93460727, -1.86147261]),
+ v1_val)
+
+ def testFtrlWithL1_L2(self):
+ with self.test_session() as sess:
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([4.0, 3.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ opt = tf.train.FtrlOptimizer(3.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.001,
+ l2_regularization_strength=2.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose([1.0, 2.0], v0_val)
+ self.assertAllClose([4.0, 3.0], v1_val)
+
+ # Run 10 steps FTRL
+ for _ in range(10):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ self.assertAllClose(np.array([-0.24059935, -0.46829352]),
+ v0_val)
+ self.assertAllClose(np.array([-0.02406147, -0.04830509]),
+ v1_val)
+
+ def applyOptimizer(self, opt, steps=5, is_sparse=False):
+ if is_sparse:
+ var0 = tf.Variable([[0.0], [0.0]])
+ var1 = tf.Variable([[0.0], [0.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.02], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ else:
+ var0 = tf.Variable([0.0, 0.0])
+ var1 = tf.Variable([0.0, 0.0])
+ grads0 = tf.constant([0.1, 0.2])
+ grads1 = tf.constant([0.01, 0.02])
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ sess = tf.get_default_session()
+ v0_val, v1_val = sess.run([var0, var1])
+ if is_sparse:
+ self.assertAllClose([[0.0], [0.0]], v0_val)
+ self.assertAllClose([[0.0], [0.0]], v1_val)
+ else:
+ self.assertAllClose([0.0, 0.0], v0_val)
+ self.assertAllClose([0.0, 0.0], v1_val)
+
+ # Run Ftrl for a few steps
+ for _ in range(steps):
+ update.run()
+
+ v0_val, v1_val = sess.run([var0, var1])
+ return v0_val, v1_val
+
+ # When variables are intialized with Zero, FTRL-Proximal has two properties:
+ # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical
+ # with GradientDescent.
+ # 2. Without L1&L2 but with adaptive learning rate, FTRL-Proximal is identical
+ # with Adagrad.
+ # So, basing on these two properties, we test if our implementation of
+ # FTRL-Proximal performs same updates as Adagrad or GradientDescent.
+ def testEquivAdagradwithoutRegularization(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivSparseAdagradwithoutRegularization(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Adagrad learning rate
+ learning_rate_power=-0.5,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ is_sparse=True)
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.AdagradOptimizer(3.0, initial_accumulator_value=0.1),
+ is_sparse=True)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivSparseGradientDescentwithoutRegularizaion(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0),
+ is_sparse=True)
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.GradientDescentOptimizer(3.0), is_sparse=True)
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+ def testEquivGradientDescentwithoutRegularizaion(self):
+ with self.test_session():
+ val0, val1 = self.applyOptimizer(
+ tf.train.FtrlOptimizer(3.0,
+ # Fixed learning rate
+ learning_rate_power=-0.0,
+ initial_accumulator_value=0.1,
+ l1_regularization_strength=0.0,
+ l2_regularization_strength=0.0))
+
+ with self.test_session():
+ val2, val3 = self.applyOptimizer(
+ tf.train.GradientDescentOptimizer(3.0))
+
+ self.assertAllClose(val0, val2)
+ self.assertAllClose(val1, val3)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/gradient_descent.py b/tensorflow/python/training/gradient_descent.py
new file mode 100644
index 0000000000..21247aacf3
--- /dev/null
+++ b/tensorflow/python/training/gradient_descent.py
@@ -0,0 +1,44 @@
+"""GradientDescent for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+# pylint: disable=unused-import
+from tensorflow.python.ops import math_ops
+# pylint: enable=unused-import
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class GradientDescentOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the gradient descent algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
+ """Construct a new gradient descent optimizer.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning
+ rate to use.
+ use_locking: If True use locks for update operation.s
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "GradientDescent".
+ """
+ super(GradientDescentOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+
+ def _apply_dense(self, grad, var):
+ return training_ops.apply_gradient_descent(
+ var,
+ self._learning_rate_tensor,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ delta = ops.IndexedSlices(grad.values * self._learning_rate_tensor,
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
diff --git a/tensorflow/python/training/gradient_descent_test.py b/tensorflow/python/training/gradient_descent_test.py
new file mode 100644
index 0000000000..d5b0cae401
--- /dev/null
+++ b/tensorflow/python/training/gradient_descent_test.py
@@ -0,0 +1,105 @@
+"""Functional test for GradientDescent."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class GradientDescentOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.GradientDescentOptimizer(3.0)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def testWithGlobalStep(self):
+ with self.test_session():
+ global_step = tf.Variable(0, trainable=False)
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllClose([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1], var0.eval())
+ self.assertAllClose([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], var1.eval())
+ self.assertAllClose(1, global_step.eval())
+
+ def testSparseBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([[1.0], [2.0]])
+ var1 = tf.Variable([[3.0], [4.0]])
+ grads0 = tf.IndexedSlices(tf.constant([0.1], shape=[1, 1]),
+ tf.constant([0]),
+ tf.constant([2, 1]))
+ grads1 = tf.IndexedSlices(tf.constant([0.01], shape=[1, 1]),
+ tf.constant([1]),
+ tf.constant([2, 1]))
+ sgd_op = tf.train.GradientDescentOptimizer(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllClose([[1.0 - 3.0 * 0.1], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0 - 3.0 * 0.01]], var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
new file mode 100644
index 0000000000..413fc044f7
--- /dev/null
+++ b/tensorflow/python/training/input.py
@@ -0,0 +1,501 @@
+"""## Input pipeline
+
+TensorFlow functions for setting up an input-prefetching pipeline.
+Please see the [reading data how-to](../../how_tos/reading_data.md)
+for context.
+
+### Beginning of an input pipeline
+
+The "producer" functions add a queue to the graph and a corresponding
+QueueRunner for running the subgraph that fills that queue.
+
+@@match_filenames_once
+@@limit_epochs
+@@range_input_producer
+@@slice_input_producer
+@@string_input_producer
+
+### Batching at the end of an input pipeline
+
+These functions add a queue to the graph to assemble a batch of
+examples, with possible shuffling. They also add a QueueRunner for
+running the subgraph that fills that queue.
+
+Use [batch](#batch) or [batch_join](#batch_join) for batching examples that have
+already been well shuffled. Use [shuffle_batch](#shuffle_batch) or
+[shuffle_batch_join](#shuffle_batch_join) for examples that
+would benefit from additional shuffling.
+
+Use [batch](#batch) or [shuffle_batch](#shuffle_batch) if you want a
+single thread producing examples to batch, or if you have a
+single subgraph producing examples but you want to run it in N threads
+(where you increase N until it can keep the queue full). Use
+[batch_join](#batch_join) or [shuffle_batch_join](#shuffle_batch_join)
+if you have N different subgraphs producing examples to batch and you
+want them run by N threads.
+
+@@batch
+@@batch_join
+@@shuffle_batch
+@@shuffle_batch_join
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import summary_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import queue_runner
+
+
+def match_filenames_once(pattern, name=None):
+ """Save the list of files matching pattern, so it is only computed once.
+
+ Args:
+ pattern: A file pattern (glob).
+ name: A name for the operations (optional).
+
+ Returns:
+ A variable that is initialized to the list of files matching pattern.
+ """
+ with ops.op_scope([pattern], name, "matching_filenames") as name:
+ return variables.Variable(io_ops.matching_files(pattern), trainable=False,
+ name=name, validate_shape=False)
+
+
+def limit_epochs(tensor, num_epochs=None, name=None):
+ """Returns tensor num_epochs times and then raises an OutOfRange error.
+
+ Args:
+ tensor: Any Tensor.
+ num_epochs: An integer (optional). If specified, limits the number
+ of steps the output tensor may be evaluated.
+ name: A name for the operations (optional).
+
+ Returns:
+ tensor or OutOfRange.
+ """
+ if num_epochs is None:
+ return tensor
+ if num_epochs <= 0:
+ raise ValueError("num_epochs must be > 0 not %d." % num_epochs)
+ with ops.op_scope([tensor], name, "limit_epochs") as name:
+ zero64 = constant_op.constant(0, dtype=types.int64)
+ epochs = variables.Variable(zero64, name="epochs")
+ counter = epochs.count_up_to(num_epochs)
+ with ops.control_dependencies([counter]):
+ return array_ops.identity(tensor, name=name)
+
+
+def _input_producer(input_tensor, dtype, num_epochs, shuffle, seed, capacity,
+ name, summary_name):
+ if shuffle:
+ input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
+ input_tensor = limit_epochs(input_tensor, num_epochs)
+
+ q = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=[dtype], shapes=[[]],
+ name=name)
+ enq = q.enqueue_many([input_tensor])
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq]))
+ summary_ops.scalar_summary("queue/%s/%s" % (q.name, summary_name),
+ math_ops.cast(q.size(), types.float32) *
+ (1. / capacity))
+ return q
+
+
+def string_input_producer(string_tensor, num_epochs=None, shuffle=True,
+ seed=None, capacity=32, name=None):
+ """Output strings (e.g. filenames) to a queue for an input pipeline.
+
+ Args:
+ string_tensor: A 1-D string tensor with the strings to produce.
+ num_epochs: An integer (optional). If specified, `string_input_producer`
+ produces each string from `string_tensor` `num_epochs` times before
+ generating an OutOfRange error. If not specified, `string_input_producer`
+ can cycle through the strings in `string_tensor` an unlimited number of
+ times.
+ shuffle: Boolean. If true, the strings are randomly shuffled within each
+ epoch.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A queue with the output strings. A QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+ """
+ with ops.op_scope([string_tensor], name, "input_producer") as name:
+ return _input_producer(
+ string_tensor, types.string, num_epochs, shuffle, seed, capacity, name,
+ "fraction_of_%d_full" % capacity)
+
+
+def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
+ capacity=32, name=None):
+ """Produces the integers from 0 to limit-1 in a queue.
+
+ Args:
+ limit: An int32 scalar tensor.
+ num_epochs: An integer (optional). If specified, `range_input_producer`
+ produces each integer `num_epochs` times before generating an
+ OutOfRange error. If not specified, `range_input_producer` can cycle
+ through the integers an unlimited number of times.
+ shuffle: Boolean. If true, the integers are randomly shuffled within each
+ epoch.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A Queue with the output integers. A QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+ """
+ with ops.op_scope([limit], name, "input_producer") as name:
+ range_tensor = math_ops.range(0, limit)
+ return _input_producer(
+ range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name,
+ "fraction_of_%d_full" % capacity)
+
+
+def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None,
+ capacity=32, name=None):
+ """Produces a slice of each Tensor in tensor_list.
+
+ Implemented using a Queue -- a QueueRunner for the Queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list: A list of Tensors. Every Tensor in tensor_list must
+ have the same size in the first dimension.
+ num_epochs: An integer (optional). If specified, `slice_input_producer`
+ produces each slice `num_epochs` times before generating
+ an OutOfRange error. If not specified, `slice_input_producer` can cycle
+ through the slices an unlimited number of times.
+ seed: An integer (optional). Seed used if shuffle == True.
+ capacity: An integer. Sets the queue capacity.
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors, one for each element of tensor_list. If the tensor
+ in tensor_list has shape [N, a, b, .., z], then the corresponding output
+ tensor will have shape [a, b, ..., z].
+ """
+ with ops.op_scope(tensor_list, name, "input_producer"):
+ tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
+ if not tensor_list:
+ raise ValueError(
+ "Expected at least one tensor in slice_input_producer().")
+ range_size = array_ops.shape(tensor_list[0])[0]
+ # TODO(josh11b): Add an assertion that the first dimension of
+ # everything in TensorList matches. Maybe just check the inferred shapes?
+ queue = range_input_producer(range_size, num_epochs=num_epochs,
+ shuffle=shuffle, seed=seed, capacity=capacity)
+ index = queue.dequeue()
+ output = [array_ops.gather(t, index) for t in tensor_list]
+ return output
+
+
+# Helpers for the batching functions ------------------------------------------
+
+def _flatten(tensor_list_list):
+ return [tensor for tensor_list in tensor_list_list for tensor in tensor_list]
+
+
+def _validate(tensor_list):
+ tensor_list = ops.convert_n_to_tensor_or_indexed_slices(tensor_list)
+ if not tensor_list:
+ raise ValueError("Expected at least one tensor in batch().")
+ return tensor_list
+
+
+def _validate_join(tensor_list_list):
+ tensor_list_list = [ops.convert_n_to_tensor_or_indexed_slices(tl)
+ for tl in tensor_list_list]
+ if not tensor_list_list:
+ raise ValueError("Expected at least one input in batch_join().")
+ return tensor_list_list
+
+
+def _dtypes(tensor_list_list):
+ all_dtypes = [[t.dtype for t in tl] for tl in tensor_list_list]
+ dtypes = all_dtypes[0]
+ for other_dtypes in all_dtypes[1:]:
+ if other_dtypes != dtypes:
+ raise TypeError("Expected types to be consistent: %s vs. %s." %
+ ", ".join(x.name for x in dtypes),
+ ", ".join(x.name for x in other_dtypes))
+ return dtypes
+
+
+def _merge_shapes(shape_list, enqueue_many):
+ shape_list = [tensor_shape.as_shape(s) for s in shape_list]
+ if enqueue_many:
+ # We want the shapes without the leading batch dimension.
+ shape_list = [s.WithRankAtLeast(1)[1:] for s in shape_list]
+ merged_shape = shape_list[0]
+ for s in shape_list[1:]:
+ merged_shape.merge_with(s)
+ return merged_shape.as_list()
+
+
+def _shapes(tensor_list_list, shapes, enqueue_many):
+ if shapes is None:
+ l = len(tensor_list_list[0])
+ shapes = [_merge_shapes([tl[i].get_shape().as_list()
+ for tl in tensor_list_list],
+ enqueue_many) for i in range(l)]
+ return shapes
+
+
+def _enqueue_join(queue, tensor_list_list, enqueue_many):
+ if enqueue_many:
+ enqueue_ops = [queue.enqueue_many(tl) for tl in tensor_list_list]
+ else:
+ enqueue_ops = [queue.enqueue(tl) for tl in tensor_list_list]
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
+
+
+def _enqueue(queue, tensor_list, threads, enqueue_many):
+ if enqueue_many:
+ enqueue_ops = [queue.enqueue_many(tensor_list)] * threads
+ else:
+ enqueue_ops = [queue.enqueue(tensor_list)] * threads
+ queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
+
+
+# Batching functions ----------------------------------------------------------
+
+def batch(tensor_list, batch_size, num_threads=1, capacity=32,
+ enqueue_many=False, shapes=None, name=None):
+ """Run tensor_list to fill a queue to create batches.
+
+ Implemented using a queue -- a QueueRunner for the queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list: The list of tensors to enqueue.
+ batch_size: The new batch size pulled from the queue.
+ num_threads: The number of threads enqueuing tensor_list.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ enqueue_many: If False, tensor_list is assumed to represent a
+ single example. If True, tensor_list is assumed to represent
+ a batch of examples, where the first dimension is indexed by
+ example, and all members of tensor_list should have the same
+ size in the first dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list (leaving off the first dimension
+ if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as tensor_list.
+ If enqueue_many is false, then an input tensor with shape
+ `[x, y, z]` will be output as a tensor with shape
+ `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(tensor_list, name, "batch") as name:
+ tensor_list = _validate(tensor_list)
+ dtypes = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = data_flow_ops.FIFOQueue(
+ capacity=capacity, dtypes=dtypes, shapes=shapes)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many)
+ summary_ops.scalar_summary(
+ "queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), types.float32) * (1. / capacity))
+ return queue.dequeue_many(batch_size, name=name)
+
+
+# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
+# a multiple of len(tensor_list_list)?) parameter, to address the use
+# case where you want more parallelism than you can support different
+# readers (either because you don't have that many files or can't
+# read that many files in parallel due to the number of seeks required).
+# Once this is done, batch() can be written as a call to batch_join().
+def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
+ shapes=None, name=None):
+ """Run a list of tensors to fill a queue to create batches of examples.
+
+ This version enqueues a different list of tensors in different threads.
+ Implemented using a queue -- a QueueRunner for the queue
+ is added to the current Graph's QUEUE_RUNNER collection.
+
+ Args:
+ tensor_list_list: A list of tuples of tensors to enqueue.
+ len(tensor_list_list) threads will be started, with the i-th
+ thread enqueuing the tensors from tensor_list[i].
+ tensor_list[i1][j] must match tensor_list[i2][j] in type and
+ shape (except in the first dimension if enqueue_many is true).
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ enqueue_many: If False, each tensor_list_list[i] is assumed to
+ represent a single example. If True, tensor_list_list[i] is
+ assumed to represent a batch of examples, where the first
+ dimension is indexed by example, and all members of
+ tensor_list_list[i] should have the same size in the first
+ dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list_list[i] (which must match, after
+ leaving off the first dimension if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as
+ tensor_list_list[i]. If enqueue_many is false, then an input
+ tensor with shape `[x, y, z]` will be output as a tensor with
+ shape `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(_flatten(tensor_list_list), name, "batch_join") as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ dtypes = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = data_flow_ops.FIFOQueue(
+ capacity=capacity, dtypes=dtypes, shapes=shapes)
+ _enqueue_join(queue, tensor_list_list, enqueue_many)
+ summary_ops.scalar_summary(
+ "queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), types.float32) * (1. / capacity))
+ return queue.dequeue_many(batch_size, name=name)
+
+
+def shuffle_batch(tensor_list, batch_size, capacity, min_after_dequeue,
+ num_threads=1, seed=None, enqueue_many=False, shapes=None,
+ name=None):
+ """Create batches by randomly shuffling tensors.
+
+ This adds:
+
+ * a shuffling queue into which tensors from tensor_list are enqueued.
+ * a dequeue many operation to create batches from the queue,
+ * and a QueueRunner is added to the current Graph's QUEUE_RUNNER collection,
+ to enqueue the tensors from tensor_list.
+
+ Args:
+ tensor_list: The list of tensors to enqueue.
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ num_threads: The number of threads enqueuing tensor_list.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: If False, tensor_list is assumed to represent a
+ single example. If True, tensor_list is assumed to represent
+ a batch of examples, where the first dimension is indexed by
+ example, and all members of tensor_list should have the same
+ size in the first dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list (leaving off the first dimension
+ if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as tensor_list.
+ If enqueue_many is false, then an input tensor with shape
+ `[x, y, z]` will be output as a tensor with shape
+ `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(tensor_list, name, "shuffle_batch") as name:
+ tensor_list = _validate(tensor_list)
+ dtypes = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=dtypes, shapes=shapes)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many)
+ full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary_ops.scalar_summary(summary_name, full)
+
+ return queue.dequeue_many(batch_size, name=name)
+
+
+def shuffle_batch_join(tensor_list_list, batch_size, capacity,
+ min_after_dequeue, seed=None, enqueue_many=False,
+ shapes=None, name=None):
+ """Create batches by randomly shuffling tensors.
+
+ This version enqueues a different list of tensors in different threads.
+ It adds:
+
+ * a shuffling queue into which tensors from tensor_list_list are enqueued.
+ * a dequeue many operation to create batches from the queue,
+ * and a QueueRunner is added to the current Graph's QUEUE_RUNNER collection,
+ to enqueue the tensors from tensor_list_list.
+
+ Args:
+ tensor_list_list: A list of tuples of tensors to enqueue.
+ len(tensor_list_list) threads will be started, with the i-th
+ thread enqueuing the tensors from tensor_list[i].
+ tensor_list[i1][j] must match tensor_list[i2][j] in type and
+ shape (except in the first dimension if enqueue_many is true).
+ batch_size: The new batch size pulled from the queue.
+ capacity: Maximum number of elements in the queue, controls the
+ how far ahead the prefetching allowed is allowed to get and
+ memory usage.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: If False, each tensor_list_list[i] is assumed to
+ represent a single example. If True, tensor_list_list[i] is
+ assumed to represent a batch of examples, where the first
+ dimension is indexed by example, and all members of
+ tensor_list_list[i] should have the same size in the first
+ dimension.
+ shapes: Optional. The shapes for each example. Defaults to the
+ inferred shapes for tensor_list_list[i] (which must match, after
+ leaving off the first dimension if enqueue_many is True).
+ name: A name for the operations (optional).
+
+ Returns:
+ A list of tensors with the same number and types as
+ tensor_list_list[i]. If enqueue_many is false, then an input
+ tensor with shape `[x, y, z]` will be output as a tensor with
+ shape `[batch_size, x, y, z]`. If enqueue_many is True, and an
+ input tensor has shape `[*, x, y, z]`, the the output will have
+ shape `[batch_size, x, y, z]`.
+ """
+ with ops.op_scope(
+ _flatten(tensor_list_list), name, "shuffle_batch_join") as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ dtypes = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=dtypes, shapes=shapes)
+ _enqueue_join(queue, tensor_list_list, enqueue_many)
+ full = (math_ops.cast(queue.size() - min_after_dequeue, types.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary_ops.scalar_summary(summary_name, full)
+ return queue.dequeue_many(batch_size, name=name)
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
new file mode 100644
index 0000000000..fe8c195e77
--- /dev/null
+++ b/tensorflow/python/training/input_test.py
@@ -0,0 +1,477 @@
+"""Tests for training.input."""
+
+import os
+import itertools
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class MatchFilenamesOnceTest(tf.test.TestCase):
+
+ def test(self):
+ temp_dir = self.get_temp_dir()
+ filenames = [os.path.join(temp_dir, n) for n in os.listdir(temp_dir)]
+ additional = [os.path.join(self.get_temp_dir(), "match_filenames.%d" % i)
+ for i in range(3)]
+ for name in additional:
+ open(name, "w").write("Some contents")
+ filenames += additional
+ with self.test_session():
+ star = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "*"))
+ question = tf.train.match_filenames_once(
+ os.path.join(self.get_temp_dir(), "match_filenames.?"))
+ one = tf.train.match_filenames_once(additional[1])
+ tf.initialize_all_variables().run()
+ self.assertItemsEqual(filenames, star.eval())
+ self.assertItemsEqual(additional, question.eval())
+ self.assertItemsEqual([additional[1]], one.eval())
+
+
+class LimitEpochsTest(tf.test.TestCase):
+
+ def testNoLimit(self):
+ with self.test_session():
+ seven = tf.constant(7)
+ seven_forever = tf.train.limit_epochs(seven)
+ tf.initialize_all_variables().run()
+ for i in range(100):
+ self.assertEqual(7, seven_forever.eval())
+
+ def testLimit(self):
+ with self.test_session():
+ love_me = tf.constant("Love Me")
+ love_me_two_times = tf.train.limit_epochs(love_me, num_epochs=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ self.assertEqual("Love Me", love_me_two_times.eval())
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ love_me_two_times.eval()
+
+
+class StringInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ strings = ["to", "be", "or", "not", "to", "be"]
+ num_epochs = 3
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(len(strings) * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(strings * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ strings = ["a", "b", "c"]
+ num_epochs = 600
+ queue = tf.train.string_input_producer(
+ strings, num_epochs=num_epochs, shuffle=True, seed=271828)
+ dequeue_many = queue.dequeue_many(len(strings))
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the strings within an epoch and
+ # count how often each possible order appears.
+ expected = ["abc", "acb", "bac", "bca", "cab", "cba"]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = "".join(output)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class RangeInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session():
+ num_epochs = 3
+ range_size = 5
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=False)
+ dequeue_many = queue.dequeue_many(range_size * num_epochs)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ output = dequeue_many.eval()
+ self.assertAllEqual(range(range_size) * num_epochs, output)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session():
+ num_epochs = 200
+ range_size = 2
+ queue = tf.train.range_input_producer(
+ range_size, num_epochs=num_epochs, shuffle=True, seed=314159)
+ dequeue_many = queue.dequeue_many(range_size)
+ dequeue = queue.dequeue()
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [12, 21]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = dequeue_many.eval()
+ key = 10 * (output[0] + 1) + (output[1] + 1)
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ dequeue.eval()
+ for thread in threads:
+ thread.join()
+
+
+class SliceInputProducerTest(tf.test.TestCase):
+
+ def testNoShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 3
+ source_strings = ["Alpha", "Beta", "Delta", "Gamma"]
+ source_ints = [2, 3, 5, 7]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=False)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # No randomness, so just see repeated copies of the input.
+ num_items = len(source_strings) * num_epochs
+ output = [sess.run(slices) for _ in range(num_items)]
+ out_strings, out_ints = zip(*output)
+ self.assertAllEqual(source_strings * num_epochs, out_strings)
+ self.assertAllEqual(source_ints * num_epochs, out_ints)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+ def testShuffle(self):
+ with self.test_session() as sess:
+ num_epochs = 1200
+ source_strings = ["A", "B", "D", "G"]
+ source_ints = [7, 3, 5, 2]
+ slices = tf.train.slice_input_producer(
+ [source_strings, source_ints], num_epochs=num_epochs, shuffle=True,
+ seed=161803)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Validate that we only shuffle the integers within an epoch and
+ # count how often each possible order appears.
+ expected = [",".join(x) for x in
+ itertools.permutations(["A7", "B3", "D5", "G2"])]
+ frequency = {}
+ for e in expected:
+ frequency[e] = 0
+ for _ in range(num_epochs):
+ output = [sess.run(slices) for _ in range(len(source_strings))]
+ key = ",".join([s + str(i) for s, i in output])
+ self.assertIn(key, expected)
+ frequency[key] += 1
+
+ # Expect an approximately even distribution over all possible orders.
+ expected_frequency = num_epochs / len(expected)
+ margin = expected_frequency * 0.4
+ tf.logging.info("Observed counts: %s", frequency)
+ for key in expected:
+ value = frequency[key]
+ self.assertGreater(value, expected_frequency - margin)
+ self.assertLess(value, expected_frequency + margin)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(slices)
+ for thread in threads:
+ thread.join()
+
+
+class BatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual(results[0],
+ range(i * batch_size, (i + 1) * batch_size))
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.batch([counter, "string"], batch_size=batch_size,
+ num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class BatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..34, "a").
+ num_a = 35
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 45 times and then stops.
+ num_b = 45
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.batch_join([[counter, "a"], [ninety_nine, "b"]],
+ batch_size=batch_size)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Verify the order of results from "a" were preserved.
+ self.assertAllEqual(all_a, range(num_a))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchTest(tf.test.TestCase):
+
+ def testOneThread(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=141421)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testManyThreads(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ batched = tf.train.shuffle_batch(
+ [counter, "string"], batch_size=batch_size, capacity=32,
+ min_after_dequeue=16, seed=173205, num_threads=4)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ all_counts = []
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ all_counts.extend(results[0])
+ self.assertAllEqual(results[1], ["string"] * batch_size)
+ # Results scrambled, but include all the expected numbers.
+ deltas = [all_counts[i + 1] - all_counts[i]
+ for i in range(len(all_counts) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertItemsEqual(all_counts, range(num_batches * batch_size))
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+class ShuffleBatchJoinTest(tf.test.TestCase):
+
+ def testTwoThreads(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..24, "a").
+ num_a = 25
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, "b") 35 times and then stops.
+ num_b = 35
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ batched = tf.train.shuffle_batch_join(
+ [[counter, "a"], [ninety_nine, "b"]], batch_size=batch_size,
+ capacity=32, min_after_dequeue=16, seed=223607)
+
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) / batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ which_a = [i for i, s in enumerate(results[1]) if s == "a"]
+ which_b = [i for i, s in enumerate(results[1]) if s == "b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Saw all the items from "a", but scrambled.
+ self.assertItemsEqual(all_a, range(num_a))
+ deltas = [all_a[i + 1] - all_a[i]
+ for i in range(len(all_a) - 1)]
+ self.assertFalse(all(d == deltas[0] for d in deltas))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
new file mode 100644
index 0000000000..cafcb26d01
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -0,0 +1,65 @@
+"""Various learning rate decay functions."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
+ staircase=False, name=None):
+ """Applies exponential decay to the learning rate.
+
+ When training a model, it is often recommended to lower the learning rate as
+ the training progresses. This function applies an exponential decay function
+ to a provided initial learning rate. It requires a `global_step` value to
+ compute the decayed learning rate. You can just pass a TensorFlow variable
+ that you increment at each training step.
+
+ The function returns the decayed learning rate. It is computed as:
+
+ ```python
+ decayed_learning_rate = learning_rate *
+ decay_rate ^ (global_step / decay_steps)
+ ```
+
+ If the argument `staircase` is `True`, then `global_step /decay_steps` is an
+ integer division and the decayed learning rate follows a staircase function.
+
+ Example: decay every 100000 steps with a base of 0.96:
+
+ ```python
+ ...
+ global_step = tf.Variable(0, trainable=False)
+ starter_learning_rate = 0.1
+ learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
+ 100000, 0.96, staircase=True)
+ optimizer = tf.GradientDescent(learning_rate)
+ # Passing global_step to minimize() will increment it at each step.
+ optimizer.minimize(...my loss..., global_step=global_step)
+ ```
+
+ Args:
+ learning_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The initial learning rate.
+ global_step: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Global step to use for the decay computation. Must not be negative.
+ decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number.
+ Must be positive. See the decay computation above.
+ decay_rate: A scalar `float32` or `float64` `Tensor` or a
+ Python number. The decay rate.
+ staircase: Boolean. It `True` decay the learning rate at discrete intervals.
+ name: string. Optional name of the operation. Defaults to 'ExponentialDecay'
+
+ Returns:
+ A scalar `Tensor` of the same type as `learning_rate`. The decayed
+ learning rate.
+ """
+ with ops.op_scope([learning_rate, global_step, decay_steps, decay_rate],
+ name, "ExponentialDecay") as name:
+ learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate")
+ dtype = learning_rate.dtype
+ global_step = math_ops.cast(global_step, dtype)
+ decay_steps = math_ops.cast(decay_steps, dtype)
+ decay_rate = math_ops.cast(decay_rate, dtype)
+ p = global_step / decay_steps
+ if staircase:
+ p = math_ops.floor(p)
+ return math_ops.mul(learning_rate, math_ops.pow(decay_rate, p), name=name)
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
new file mode 100644
index 0000000000..b85d58cae7
--- /dev/null
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -0,0 +1,60 @@
+"""Functional test for learning rate decay."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import learning_rate_decay
+
+
+class LRDecayTest(test_util.TensorFlowTestCase):
+
+ def testContinuous(self):
+ with self.test_session():
+ step = 5
+ decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
+ expected = .05 * 0.96 ** (5.0 / 10.0)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+ def testStaircase(self):
+ with self.test_session():
+ step = state_ops.variable_op([], types.int32)
+ assign_100 = state_ops.assign(step, 100)
+ assign_1 = state_ops.assign(step, 1)
+ assign_2 = state_ops.assign(step, 2)
+ decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 / 3)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+ def testVariables(self):
+ with self.test_session():
+ step = variables.Variable(1)
+ assign_1 = step.assign(1)
+ assign_2 = step.assign(2)
+ assign_100 = step.assign(100)
+ decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
+ staircase=True)
+ variables.initialize_all_variables().run()
+ # No change to learning rate
+ assign_1.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ assign_2.op.run()
+ self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
+ # Decayed learning rate
+ assign_100.op.run()
+ expected = .1 * 0.96 ** (100 / 3)
+ self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/training/momentum.py b/tensorflow/python/training/momentum.py
new file mode 100644
index 0000000000..fdd434359f
--- /dev/null
+++ b/tensorflow/python/training/momentum.py
@@ -0,0 +1,51 @@
+"""Momentum for TensorFlow."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class MomentumOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the Momentum algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, momentum,
+ use_locking=False, name="Momentum"):
+ """Construct a new Momentum optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. The learning rate.
+ momentum: A `Tensor` or a floating point value. The momentum.
+ use_locking: If `True` use locks for update operations.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "Momentum".
+ """
+ super(MomentumOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._momentum = momentum
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._zeros_slot(v, "momentum", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+ self._momentum_tensor = ops.convert_to_tensor(self._momentum,
+ name="momentum")
+
+ def _apply_dense(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad, self._momentum_tensor,
+ use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ mom = self.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var, mom,
+ self._learning_rate_tensor, grad.values, grad.indices,
+ self._momentum_tensor, use_locking=self._use_locking).op
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
new file mode 100644
index 0000000000..2cf86d97c9
--- /dev/null
+++ b/tensorflow/python/training/momentum_test.py
@@ -0,0 +1,258 @@
+"""Tests for Momentum."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class MomentumOptimizerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in tf.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in tf.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllClose(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0),
+ 2.0 - (0.1 * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0),
+ 4.0 - (0.01 * 2.0)]),
+ var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ slot0.eval())
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval())
+
+ def testFloat64(self):
+ with self.test_session():
+ opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+
+ # compute_gradients.
+ values = [1.0, 3.0]
+ good_vars = [tf.Variable([v]) for v in values]
+ bad_loss = tf.constant(2.0, tf.float64, name="bad_loss")
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_loss.*expected.*float32",
+ opt.compute_gradients, bad_loss, good_vars)
+ bad_vars = [
+ tf.Variable(np.array([v], np.float64), name="bad_var")
+ for v in values]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.compute_gradients, tf.cast(bad_vars[0] + bad_vars[1], tf.float32),
+ bad_vars)
+ opt.compute_gradients(good_vars[0] + good_vars[1], good_vars)
+
+ # apply_gradients.
+ bad_grads = [
+ tf.constant([0.1], dtype=np.float64, name="bad_grad"),
+ tf.constant([0.01])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_grad.*expected.*float32",
+ opt.apply_gradients, zip(bad_grads, good_vars))
+ good_grads = [tf.constant([0.01]), tf.constant([0.02])]
+ self.assertRaisesRegexp(
+ ValueError, r"Invalid type.*float64.*bad_var.*expected.*float32",
+ opt.apply_gradients, zip(good_grads, bad_vars))
+ opt.apply_gradients(zip(good_grads, good_vars))
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momemntum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018, 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615]
+ db_out[0] = [-9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018, -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618]
+ db_grad[1] = [0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378, 0.5513742, 0.94687688, 0.16012503, 0.22159521]
+ db_out[1] = [-0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884, -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544]
+ db_grad[2] = [0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965, 0.31168157, 0.43203235, 0.16792089, 0.24644311]
+ db_out[2] = [-0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978, -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189]
+ db_grad[3] = [0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098, 0.81454384, 0.03848977, 0.89759839, 0.93665648]
+ db_out[3] = [-0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105, -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303]
+ db_grad[4] = [0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359, 0.69107032, 0.81897682, 0.5433259, 0.67860287]
+ db_out[4] = [-0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165, -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544]
+ db_grad[5] = [0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563, 0.84163809, 0.41172323, 0.83259648, 0.44941229]
+ db_out[5] = [-0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094, -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717]
+ db_grad[6] = [0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221, 0.73577434, 0.16014607, 0.57500273, 0.071136251]
+ db_out[6] = [-0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685, -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997]
+ db_grad[7] = [0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646, 0.74053431, 0.16033, 0.66625422, 0.73515922]
+ db_out[7] = [-0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838, -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418]
+ db_grad[8] = [0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039, 0.55561525, 0.22567581, 0.93331909, 0.29438227]
+ db_out[8] = [-0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527, -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781]
+ db_grad[9] = [0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893, 0.68593478, 0.50580865, 0.12602448, 0.093537711]
+ db_out[9] = [-0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302, -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.test_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = tf.Variable([0.0] * num_samples)
+ grads0 = tf.constant([0.0] * num_samples)
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ tf.initialize_all_variables().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ with self.test_session():
+ var0 = tf.Variable(tf.zeros([4, 2]))
+ var1 = tf.Variable(
+ tf.constant(1.0, tf.float32, [4, 2]))
+ grads0 = tf.IndexedSlices(tf.constant([[.1, .1]]),
+ tf.constant([1]),
+ tf.constant([4, 2]))
+ grads1 = tf.IndexedSlices(tf.constant([[.01, .01], [.01, .01]]),
+ tf.constant([2, 3]),
+ tf.constant([4, 2]))
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllClose(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllClose(np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllClose(np.array([- (0.1 * 2.0),
+ - (0.1 * 2.0)]),
+ var0.eval()[1])
+ self.assertAllClose(np.array([1.0 - (0.01 * 2.0),
+ 1.0 - (0.01 * 2.0)]),
+ var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1),
+ (0.9 * 0.1 + 0.1)]),
+ slot0.eval()[1])
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01),
+ (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllClose(
+ np.array([- (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval()[1])
+ self.assertAllClose(np.array([0.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval()[2])
+
+ def testSharing(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ mom_opt = tf.train.MomentumOptimizer(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllClose(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0),
+ 2.0 - (0.1 * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0),
+ 4.0 - (0.01 * 2.0)]),
+ var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ slot0.eval())
+ self.assertAllClose(np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)]),
+ var0.eval())
+ self.assertAllClose(np.array([2.98 - ((0.9 * 0.01 + 0.01) * 2.0),
+ 3.98 - ((0.9 * 0.01 + 0.01) * 2.0)]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
new file mode 100644
index 0000000000..becc71dfa2
--- /dev/null
+++ b/tensorflow/python/training/moving_averages.py
@@ -0,0 +1,247 @@
+"""Maintain moving averages of parameters."""
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+
+
+# TODO(mdevin): switch to variables.Variable.
+def assign_moving_average(variable, value, decay, name=None):
+ """Compute the moving average of a variable.
+
+ The moving average of 'variable' updated with 'value' is:
+ variable * decay + value * (1 - decay)
+
+ The returned Operation sets 'variable' to the newly computed moving average.
+
+ The new value of 'variable' can be set with the 'AssignSub' op as:
+ variable -= (1 - decay) * (variable - value)
+
+ Args:
+ variable: A Variable.
+ value: A tensor with the same shape as 'variable'
+ decay: A float Tensor or float value. The moving average decay.
+ name: Optional name of the returned operation.
+
+ Returns:
+ An Operation that updates 'variable' with the newly computed
+ moving average.
+ """
+ with ops.op_scope([variable, value, decay], name, "AssignMovingAvg") as name:
+ with ops.device(variable.device):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != variable.dtype.base_dtype:
+ decay = math_ops.cast(decay, variable.dtype.base_dtype)
+ return state_ops.assign_sub(variable, (variable - value) * decay,
+ name=name)
+
+
+class ExponentialMovingAverage(object):
+ """Maintains moving averages of variables by employing and exponential decay.
+
+ When training a model, it is often beneficial to maintain moving averages of
+ the trained parameters. Evaluations that use averaged parameters sometimes
+ produce significantly better results than the final trained values.
+
+ The `apply()` method adds shadow copies of trained variables and add ops that
+ maintain a moving average of the trained variables in their shadow copies.
+ It is used when building the training model. The ops that maintain moving
+ averages are typically run after each training step.
+ The `average()` and `average_name()` methods give access to the shadow
+ variables and their names. They are useful when building an evaluation
+ model, or when restoring a model from a checkpoint file. They help use the
+ moving averages in place of the last trained values for evaluations.
+
+ The moving averages are computed using exponential decay. You specify the
+ decay value when creating the `ExponentialMovingAverage` object. The shadow
+ variables are initialized with the same initial values as the trained
+ variables. When you run the ops to maintain the moving averages, each
+ shadow variable is updated with the formula:
+
+ `shadow_variable -= (1 - decay) * (shadow_variable - variable)`
+
+ This is mathematically equivalent to the classic formula below, but the use
+ of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless
+ updates to the variables:
+
+ `shadow_variable = decay * shadow_variable + (1 - decay) * variable`
+
+ Reasonable values for `decay` are close to 1.0, typically in the
+ multiple-nines range: 0.999, 0.9999, etc.
+
+ Example usage when creating a training model:
+
+ ```python
+ # Create variables.
+ var0 = tf.Variable(...)
+ var1 = tf.Variable(...)
+ # ... use the variables to build a training model...
+ ...
+ # Create an op that applies the optimizer. This is what we usually
+ # would use as a training op.
+ opt_op = opt.minimize(my_loss, [var0, var1])
+
+ # Create an ExponentialMovingAverage object
+ ema = tf.train.ExponentialMovingAverage(decay=0.9999)
+
+ # Create the shadow variables, and add ops to maintain moving averages
+ # of var0 and var1.
+ maintain_averages_op = ema.apply([var0, var1])
+
+ # Create an op that will update the moving averages after each training
+ # step. This is what we will use in place of the usuall trainig op.
+ with tf.control_dependencies([opt_op]):
+ training_op = tf.group(maintain_averages_op)
+
+ ...train the model by running training_op...
+ ```
+
+ There are two ways to use the moving averages for evaluations:
+
+ * Build a model that uses the shadow variables instead of the variables.
+ For this, use the `average()` method which returns the shadow variable
+ for a given variable.
+ * Build a model normally but load the checkpoint files to evaluate by using
+ the shadow variable names. For this use the `average_name()` method. See
+ the [Saver class](train.md#Saver) for more information on restoring saved
+ variables.
+
+ Example of restoring the shadow variable values:
+
+ ```python
+ # Create a Saver that loads variables from their saved shadow values.
+ shadow_var0_name = ema.average_name(var0)
+ shadow_var1_name = ema.average_name(var1)
+ saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
+ saver.restore(...checkpoint filename...)
+ # var0 and var1 now hold the moving average values
+ ```
+
+ @@__init__
+ @@apply
+ @@average_name
+ @@average
+ """
+
+ def __init__(self, decay, num_updates=None,
+ name="ExponentialMovingAverage"):
+ """Creates a new ExponentialMovingAverage object.
+
+ The `Apply()` method has to be called to create shadow variables and add
+ ops to maintain moving averages.
+
+ The optional `num_updates` parameter allows one to tweak the decay rate
+ dynamically. . It is typical to pass the count of training steps, usually
+ kept in a variable that is incremented at each step, in which case the
+ decay rate is lower at the start of training. This makes moving averages
+ move faster. If passed, the actual decay rate used is:
+
+ `min(decay, (1 + num_updates) / (10 + num_updates))`
+
+ Args:
+ decay: Float. The decay to use.
+ num_updates: Optional count of number of updates applied to variables.
+ name: String. Optional prefix name to use for the name of ops added in
+ `Apply()`.
+ """
+ self._decay = decay
+ self._num_updates = num_updates
+ self._name = name
+ self._averages = {}
+
+ def apply(self, var_list=None):
+ """Maintains moving averages of variables.
+
+ `var_list` must be a list of `Variable` or `Tensor` objects. This method
+ creates shadow variables for all elements of `var_list`. Shadow variables
+ for `Variable` objects are initialized to the variable's initial value.
+ For `Tensor` objects, the shadow variables are initialized to 0.
+
+ shadow variables are created with `trainable=False` and added to the
+ `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
+ `tf.all_variables()`.
+
+ Returns an op that updates all shadow variables as described above.
+
+ Note that `apply()` can be called multiple times with different lists of
+ variables.
+
+ Args:
+ var_list: A list of Variable or Tensor objects. The variables
+ and Tensors must be of types float32 or float64.
+
+ Returns:
+ An Operation that updates the moving averages.
+
+ Raises:
+ TypeError: If the arguments are not all float32 or float64.
+ ValueError: If the moving average of one of the variables is already
+ being computed.
+ """
+ # TODO(mdevin): op_scope
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ for var in var_list:
+ if var.dtype.base_dtype not in [types.float32, types.float64]:
+ raise TypeError("The variables must be float or double: %s" % var)
+ if var in self._averages:
+ raise ValueError("Moving average already computed for: %s" % var)
+ with ops.name_scope(var.op.name + "/" + self._name) as scope:
+ with ops.device(var.device):
+ if isinstance(var, variables.Variable):
+ initial_value = var.initialized_value()
+ else:
+ initial_value = array_ops.zeros(var.get_shape().as_list())
+ avg = variables.Variable(initial_value, name=scope, trainable=False)
+ self._averages[var] = avg
+ with ops.name_scope(self._name) as scope:
+ decay = ops.convert_to_tensor(self._decay, name="decay")
+ if self._num_updates is not None:
+ num_updates = math_ops.cast(self._num_updates, types.float32,
+ name="num_updates")
+ decay = math_ops.minimum(decay,
+ (1.0 + num_updates) / (10.0 + num_updates))
+ updates = []
+ for var in var_list:
+ updates.append(assign_moving_average(self._averages[var], var, decay))
+ return control_flow_ops.group(*updates, name=scope)
+
+ def average(self, var):
+ """Returns the `Variable` holding the average of `var`.
+
+ Args:
+ var: A `Variable` object.
+
+ Returns:
+ A `Variable` object or `None` if the moving average of `var`
+ is not maintained..
+ """
+ return self._averages.get(var, None)
+
+ def average_name(self, var):
+ """Returns the name of the `Variable` holding the average for `var`.
+
+ The typical scenario for `ExponentialMovingAverage` is to compute moving
+ averages of variables during training, and restore the variables from the
+ computed moving averages during evaluations.
+
+ To restore variables, you have to know the name of the shadow variables.
+ That name and the original variable can then be passed to a `Saver()` object
+ to restore the variable from the moving average value with:
+ `saver = tf.train.Saver({ema.average_name(var): var})`
+
+ `average_name()` can be called whether or not `apply()` has been called.
+
+ Args:
+ var: A `Variable` object.
+
+ Returns:
+ A string: the name of the variable that will be used or was used
+ by the `ExponentialMovingAverage class` to hold the moving average of
+ `var`.
+ """
+ return var.op.name + "/" + self._name
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
new file mode 100644
index 0000000000..73ee94b400
--- /dev/null
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -0,0 +1,130 @@
+"""Functional test for moving_averages.py."""
+import tensorflow.python.platform
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import moving_averages
+
+
+class MovingAveragesTest(test_util.TensorFlowTestCase):
+
+ def testAssignMovingAverage(self):
+ with self.test_session():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0], types.float32)
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ variables.initialize_all_variables().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ assign.op.run()
+ self.assertAllClose([10.0 * 0.25 + 1.0 * (1.0 - 0.25),
+ 11.0 * 0.25 + 2.0 * (1.0 - 0.25)],
+ var.eval())
+
+def _Repeat(value, dim):
+ if dim == 1:
+ return value
+ return [value for _ in xrange(dim)]
+
+class ExponentialMovingAverageTest(test_util.TensorFlowTestCase):
+
+ def _CheckDecay(self, ema, actual_decay, dim):
+ tens = _Repeat(10.0, dim)
+ thirties = _Repeat(30.0, dim)
+ var0 = variables.Variable(tens, name="v0")
+ var1 = variables.Variable(thirties, name="v1")
+ variables.initialize_all_variables().run()
+ # Note that tensor2 is not a Variable but just a plain Tensor resulting
+ # from the sum operation.
+ tensor2 = var0 + var1
+ update = ema.apply([var0, var1, tensor2])
+ avg0 = ema.average(var0)
+ avg1 = ema.average(var1)
+ avg2 = ema.average(tensor2)
+
+ self.assertFalse(avg0 in variables.trainable_variables())
+ self.assertFalse(avg1 in variables.trainable_variables())
+ self.assertFalse(avg2 in variables.trainable_variables())
+ variables.initialize_all_variables().run()
+
+ self.assertEqual("v0/ExponentialMovingAverage:0", avg0.name)
+ self.assertEqual("v1/ExponentialMovingAverage:0", avg1.name)
+ self.assertEqual("add/ExponentialMovingAverage:0", avg2.name)
+
+ # Check initial values.
+ self.assertAllClose(tens, var0.eval())
+ self.assertAllClose(thirties, var1.eval())
+ self.assertAllClose(_Repeat(10.0 + 30.0, dim), tensor2.eval())
+
+ # Check that averages are initialized correctly.
+ self.assertAllClose(tens, avg0.eval())
+ self.assertAllClose(thirties, avg1.eval())
+ # Note that averages of Tensor's initialize to zeros_like since no value
+ # of the Tensor is known because the Op has not been run (yet).
+ self.assertAllClose(_Repeat(0.0, dim), avg2.eval())
+
+ # Update the averages and check.
+ update.run()
+ dk = actual_decay
+
+ expected = _Repeat(10.0 * dk + 10.0 * (1 - dk), dim)
+ self.assertAllClose(expected, avg0.eval())
+ expected = _Repeat(30.0 * dk + 30.0 * (1 - dk), dim)
+ self.assertAllClose(expected, avg1.eval())
+ expected = _Repeat(0.0 * dk + (10.0 + 30.0) * (1 - dk), dim)
+ self.assertAllClose(expected, avg2.eval())
+
+ # Again, update the averages and check.
+ update.run()
+ expected = _Repeat((10.0 * dk + 10.0 * (1 - dk)) * dk + 10.0 * (1 - dk),
+ dim)
+ self.assertAllClose(expected, avg0.eval())
+ expected = _Repeat((30.0 * dk + 30.0 * (1 - dk)) * dk + 30.0 * (1 - dk),
+ dim)
+ self.assertAllClose(expected, avg1.eval())
+ expected = _Repeat(((0.0 * dk + (10.0 + 30.0) * (1 - dk)) * dk +
+ (10.0 + 30.0) * (1 - dk)),
+ dim)
+ self.assertAllClose(expected, avg2.eval())
+
+ def testAverageVariablesNoNumUpdates_Scalar(self):
+ with self.test_session():
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=1)
+
+ def testAverageVariablesNoNumUpdates_Vector(self):
+ with self.test_session():
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ self._CheckDecay(ema, actual_decay=0.25, dim=5)
+
+ def testAverageVariablesNumUpdates_Scalar(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=1)
+
+ def testAverageVariablesNumUpdates_Vector(self):
+ with self.test_session():
+ # With num_updates 1, the decay applied is 0.1818
+ ema = moving_averages.ExponentialMovingAverage(0.25, num_updates=1)
+ self._CheckDecay(ema, actual_decay=0.181818, dim=5)
+
+ def testAverageVariablesNames(self):
+ v0 = variables.Variable(10.0, name="v0")
+ v1 = variables.Variable(30.0, name="v1")
+ tensor2 = v0 + v1
+ ema = moving_averages.ExponentialMovingAverage(0.25, name="foo_avg")
+ self.assertEqual("v0/foo_avg", ema.average_name(v0))
+ self.assertEqual("v1/foo_avg", ema.average_name(v1))
+ self.assertEqual("add/foo_avg", ema.average_name(tensor2))
+ ema.apply([v0, v1, tensor2])
+ self.assertEqual(ema.average_name(v0), ema.average(v0).op.name)
+ self.assertEqual(ema.average_name(v1), ema.average(v1).op.name)
+ self.assertEqual(ema.average_name(tensor2), ema.average(tensor2).op.name)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
new file mode 100644
index 0000000000..1186117169
--- /dev/null
+++ b/tensorflow/python/training/optimizer.py
@@ -0,0 +1,426 @@
+"""Base class for optimizers."""
+# pylint: disable=g-bad-name
+import types
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types as tf_types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+
+
+class Optimizer(object):
+ """Base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains variables.Variable
+ # objects.
+ opt_op = opt.minimize(cost, <list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ @@__init__
+
+ @@minimize
+ @@compute_gradients
+ @@apply_gradients
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
+ that controls the degree of parallelism during the application of the
+ gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>GATE_NONE</b>: Compute and apply gradients in parallel. This provides the
+ maximum parallelism in execution, at the cost of some non-reproducibility in
+ the results. For example the two gradients of MatMul depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>GATE_OP</b>: For each Op, make sure all gradients are computed before they
+ are used. This prevents race conditions for Ops that generate gradients for
+ multiple inputs where the gradients depend on the inputs.
+
+ <b>GATE_GRAPH</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ @@get_slot_names
+ @@get_slot
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, use_locking, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+
+ Args:
+ use_locking: Bool. If True apply use locks to prevent concurrent updates
+ to variables.
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: if name is malformed.
+ """
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+ self._use_locking = use_locking
+ self._name = name
+ # Dictionary of slots.
+ # {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... }
+ self._slots = {}
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, name=None):
+ """Add operations to minimize 'loss' by updating 'var_list'.
+
+ This method simply combines calls compute_gradients() and
+ apply_gradients(). If you want to process the gradient before applying them
+ call compute_gradients() and apply_gradients() explicitly instead of using
+ this function.
+
+ Args:
+ loss: A Tensor containing the value to minimize.
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ var_list: Optional list of variables.Variable to update to minimize
+ 'loss'. Defaults to the list of variables collected in the graph
+ under the key GraphKeys.TRAINABLE_VARIABLES.
+ gate_gradients: How to gate the computation of gradients. Can be
+ GATE_NONE, GATE_OP, or GATE_GRAPH.
+ name: Optional name for the returned operation.
+
+ Returns:
+ An Operation that updates the variables in 'var_list'. If 'global_step'
+ was not None, that operation also increments global_step.
+
+ Raises:
+ ValueError: if some of the variables are not variables.Variable objects.
+ """
+ grads_and_vars = self.compute_gradients(loss, var_list=var_list,
+ gate_gradients=gate_gradients)
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP):
+ """Compute gradients of "loss" for the variables in "var_list".
+
+ This is the first part of minimize(). It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a Tensor, a
+ IndexedSlices, or None if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize.
+ var_list: Optional list of variables.Variable to update to minimize
+ "loss". Defaults to the list of variables collected in the graph
+ under the key GraphKey.TRAINABLE_VARIABLES.
+ gate_gradients: How to gate the computation of gradients. Can be
+ GATE_NONE, GATE_OP, or GATE_GRAPH.
+
+ Returns:
+ A list of (gradient, variable) pairs.
+
+ Raises:
+ TypeError: If var_list contains anything else than variables.Variable.
+ ValueError: If some arguments are invalid.
+ """
+ if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
+ Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if var_list is None:
+ var_list = variables.trainable_variables()
+ for var in var_list:
+ if not isinstance(var, variables.Variable):
+ raise TypeError("Argument is not a variables.Variable: %s" % var)
+ grads = gradients.gradients(
+ loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
+ if gate_gradients == Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = zip(grads, var_list)
+ self._assert_valid_dtypes([v for g, v in grads_and_vars if g is not None])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of minimize(). It returns an Operation that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ compute_gradients().
+ global_step: Optional Variable to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the Optimizer constructor.
+
+ Returns:
+ An Operation that applies the specified gradients. If 'global_step'
+ was not None, that operation also increments global_step.
+
+ Raises:
+ TypeError: if grads_and_vars is malformed.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
+ for g, v in grads_and_vars:
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices, types.NoneType)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ if not isinstance(v, variables.Variable):
+ raise TypeError(
+ "Variable must be a variables.Variable: %s" % v)
+ if g is not None:
+ self._assert_valid_dtypes([g, v])
+ self._create_slots([v for g, v in grads_and_vars if g is not None])
+ update_ops = []
+ with ops.op_scope([], name, self._name) as name:
+ self._prepare()
+ for grad, var in grads_and_vars:
+ if not grad:
+ continue
+ with ops.name_scope("update_" + var.op.name), ops.device(var.device):
+ if isinstance(grad, ops.Tensor):
+ update_ops.append(self._apply_dense(grad, var))
+ else:
+ update_ops.append(self._apply_sparse(grad, var))
+ if global_step is None:
+ return self._finish(update_ops, name)
+ else:
+ with ops.control_dependencies([self._finish(update_ops, "update")]):
+ with ops.device(global_step.device):
+ return state_ops.assign_add(global_step, 1, name=name).op
+
+ def get_slot(self, var, name):
+ """Return a slot named "name" created for "var" by the Optimizer.
+
+ Some Optimizer subclasses use additional variables. For example
+ Momentum and Adagrad use variables to accumulate updates. This method
+ gives access to these Variables if for some reason you need them.
+
+ Use get_slot_names() to get the list of slot names created by the Optimizer.
+
+ Args:
+ var: A variable passed to minimize() or apply_gradients().
+ name: A string.
+
+ Returns:
+ The Variable for the slot if it was created, None otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(var, None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the Optimizer.
+
+ See get_slot().
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def _assert_valid_dtypes(self, tensors):
+ """Asserts tensors are all valid types (see _valid_dtypes).
+
+ Args:
+ tensors: tensors to check.
+ Raises:
+ ValueError: if any tensor is not a valid type.
+ """
+ valid_dtypes = self._valid_dtypes()
+ for t in tensors:
+ dtype = t.dtype.base_dtype
+ if dtype not in valid_dtypes:
+ raise ValueError(
+ "Invalid type %s for %s, expected: %s." % (
+ dtype, t.name, [v for v in valid_dtypes]))
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _valid_dtypes(self):
+ """Valid types for loss, variables and gradients.
+
+ Defaults to float32. Subclasses should override to allow other types.
+
+ Returns:
+ Valid types for loss, variables and gradients.
+ """
+ return set([tf_types.float32])
+
+ def _create_slots(self, var_list):
+ """Create all slots needed by the variables.
+
+ Args:
+ var_list: A list of variables.Variable.
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self):
+ """Create all needed tensors before applying gradients.
+
+ This is called with the name_scope using the "name" that
+ users have chosen for the application of gradients.
+ """
+ pass
+
+ def _apply_dense(self, grad, var):
+ """Add ops to apply dense gradients to "var".
+
+ Args:
+ grad: A Tensor.
+ var: A variables.Variable.
+
+ Return:
+ An Operation.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse(self, grad, var):
+ """Add ops to apply sparse gradients to "var".
+
+ Args:
+ grad: IndexedSlices.
+ var: A variables.Variable.
+
+ Return:
+ An Operation.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, update_ops, name_scope):
+ """Do what is needed to finish the update.
+
+ This is called with the name_scope using the "name" that
+ users have chosen for the application of gradients.
+
+ Args:
+ update_ops: List of Operations to update variables. This list contains
+ the values returned by the _apply_dense() and _apply_sparse() calls.
+ name_scope: string. Name to use for the returned operation.
+
+ Returns:
+ The operation to apply updates.
+ """
+ return control_flow_ops.group(*update_ops, name=name_scope)
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A variables.Variable.
+ val: A Tensor. The initial value of the slot.
+ slot_name: Name for the slot.
+ op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A variables.Variable.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ slot = named_slots.get(var, None)
+ if slot is None:
+ # Scope the slot name in the namespace of the Variable and
+ # create the slot on the same device as the variable.
+ with ops.name_scope(var.op.name + "/" + op_name) as scope:
+ with ops.device(var.device):
+ slot = variables.Variable(val, name=scope, trainable=False)
+ named_slots[var] = slot
+ return slot
+
+ def _zeros_slot(self, var, slot_name, op_name):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A variables.Variable.
+ slot_name: Name for the slot.
+ op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A variables.Variable.
+ """
+ val = array_ops.zeros(var.get_shape().as_list(), dtype=var.dtype)
+ return self._get_or_make_slot(var, val, slot_name, op_name)
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
new file mode 100644
index 0000000000..fcf9927c79
--- /dev/null
+++ b/tensorflow/python/training/queue_runner.py
@@ -0,0 +1,233 @@
+"""Create threads to run multiple enqueue ops."""
+import threading
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import logging
+
+
+class QueueRunner(object):
+ """Holds a list of enqueue operations for a queue, each to be run in a thread.
+
+ Queues are a convenient TensorFlow mechanism to compute tensors
+ asynchronously using multiple threads. For example in the canonical 'Input
+ Reader' setup one set of threads generates filenames in a queue; a second set
+ of threads read records from the files, processes them, and enqueues tensors
+ on a second queue; a third set of threads dequeues these input records to
+ construct batches and runs them through training operations.
+
+ There are several delicate issues when running multiple threads that way:
+ closing the queues in sequence as the input is exhausted, correctly catching
+ and reporting exceptions, etc.
+
+ The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
+ """
+
+ def __init__(self, queue, enqueue_ops):
+ """Create a QueueRunner.
+
+ On construction the `QueueRunner` adds an op to close the queue. That op
+ will be run if the enqueue ops raise exceptions.
+
+ When you later call the `create_threads()` method, the `QueueRunner` will
+ create one thread for each op in `enqueue_ops`. Each thread will run its
+ enqueue op in parallel with the other threads. The enqueue ops do not have
+ to all be the same op, but it is expected that they all enqueue tensors in
+ `queue`.
+
+ Args:
+ queue: A `Queue`.
+ enqueue_ops: List of enqueue ops to run in threads later.
+ """
+ self._queue = queue
+ self._enqueue_ops = enqueue_ops
+ # Close when no more will be produced, but pending enqueues should be
+ # preserved.
+ self._close_op = self._queue.close()
+ # Close and cancel pending enqueues since there was an error and we want
+ # to unblock everything so we can cleanly exit.
+ self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
+ # Protect the count of runs to wait for.
+ self._lock = threading.Lock()
+ self._runs = 0
+ # List of exceptions raised by the running threads.
+ self._exceptions_raised = []
+
+ @property
+ def exceptions_raised(self):
+ """Exceptions raised but not handled by the `QueueRunner` threads.
+
+ Exceptions raised in queue runner threads are handled in one of two ways
+ depending on whether or not a `Coordinator` was passed to
+ `create_threads()`:
+
+ * With a `Coordinator`, exceptions are reported to the coordinator and
+ forgotten by the `QueueRunner`.
+ * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
+ made available in this `exceptions_raised` property.
+
+ Returns:
+ A list of Python `Exception` objects. The list is empty if no exception
+ was captured. (No exceptions are captured when using a Coordinator.)
+ """
+ return self._exceptions_raised
+
+ # pylint: disable=broad-except
+ def _run(self, sess, enqueue_op, coord=None):
+ """Execute the enqueue op in a loop, close the queue in case of error.
+
+ Args:
+ sess: A Session.
+ enqueue_op: The Operation to run.
+ coord: Optional Coordinator object for reporting errors and checking
+ for stop conditions.
+ """
+ decremented = False
+ try:
+ while True:
+ if coord and coord.should_stop():
+ break
+ try:
+ sess.run(enqueue_op)
+ except errors.OutOfRangeError:
+ # This exception indicates that a queue was closed.
+ with self._lock:
+ self._runs -= 1
+ decremented = True
+ if self._runs == 0:
+ try:
+ sess.run(self._close_op)
+ except Exception, e:
+ # Intentionally ignore errors from close_op.
+ logging.vlog(1, "Ignored exception: %s", str(e))
+ return
+ except Exception, e:
+ # This catches all other exceptions.
+ if coord:
+ coord.request_stop(e)
+ else:
+ logging.error("Exception in QueueRunner: %s", str(e))
+ with self._lock:
+ self._exceptions_raised.append(e)
+ raise
+ finally:
+ # Make sure we account for all terminations: normal or errors.
+ if not decremented:
+ with self._lock:
+ self._runs -= 1
+
+ def _close_on_stop(self, sess, cancel_op, coord):
+ """Close the queue when the Coordinator requests stop.
+
+ Args:
+ sess: A Session.
+ cancel_op: The Operation to run.
+ coord: Coordinator.
+ """
+ coord.wait_for_stop()
+ try:
+ sess.run(cancel_op)
+ except Exception, e:
+ # Intentionally ignore errors from cancel_op.
+ logging.vlog(1, "Ignored exception: %s", str(e))
+ # pylint: enable=broad-except
+
+ def create_threads(self, sess, coord=None, daemon=False, start=False):
+ """Create threads to run the enqueue ops.
+
+ This method requires a session in which the graph was launched. It creates
+ a list of threads, optionally starting them. There is one thread for each
+ op passed in `enqueue_ops`.
+
+ The `coord` argument is an optional coordinator, that the threads will use
+ to terminate together and report exceptions. If a coordinator is given,
+ this method starts an additional thread to close the queue when the
+ coordinator requests a stop.
+
+ This method may be called again as long as all threads from a previous call
+ have stopped.
+
+ Args:
+ sess: A `Session`.
+ coord: Optional `Coordinator` object for reporting errors and checking
+ stop conditions.
+ daemon: Boolean. If `True` make the threads daemon threads.
+ start: Boolean. If `True` starts the threads. If `False` the
+ caller must call the `start()` method of the returned threads.
+
+ Returns:
+ A list of threads.
+
+ Raises:
+ RuntimeError: If threads from a previous call to `create_threads()` are
+ still running.
+ """
+ with self._lock:
+ if self._runs > 0:
+ raise RuntimeError(
+ "Threads are already running from a previous call to Threads() "
+ "for this queue runner.")
+ self._runs = len(self._enqueue_ops)
+ self._exceptions_raised = []
+
+ ret_threads = [threading.Thread(target=self._run, args=(sess, op, coord))
+ for op in self._enqueue_ops]
+ if coord:
+ ret_threads.append(threading.Thread(target=self._close_on_stop,
+ args=(sess, self._cancel_op, coord)))
+ for t in ret_threads:
+ if daemon:
+ t.daemon = True
+ if start:
+ t.start()
+ return ret_threads
+
+
+def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
+ """Adds a `QueueRunner` to a collection in the graph.
+
+ When building a complex model that uses many queues it is often difficult to
+ gather all the queue runners that need to be run. This convenience function
+ allows you to add a queue runner to a well known collection in the graph.
+
+ The companion method `start_queue_runners()` can be used to start threads for
+ all the collected queue runners.
+
+ Args:
+ qr: A `QueueRunner`.
+ collection: A `GraphKey` specifying the graph collection to add
+ the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+ """
+ ops.add_to_collection(collection, qr)
+
+
+def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
+ collection=ops.GraphKeys.QUEUE_RUNNERS):
+ """Starts all queue runners collected in the graph.
+
+ This is a companion method to `add_queue_runner()`. It just starts
+ threads for all queue runners collected in the graph. It returns
+ the list of all threads.
+
+ Args:
+ sess: `Session` used to run the queue ops. Defaults to the
+ default session.
+ coord: Optional `Coordinator` for coordinating the started threads.
+ daemon: Whether the threads should be marked as `daemons`, meaning
+ they don't block program exit.
+ start: Set to `False` to only create the threads, not start them.
+ collection: A `GraphKey` specifying the graph collection to
+ get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`.
+
+ Returns:
+ A list of threads.
+ """
+ if sess is None:
+ sess = ops.get_default_session()
+ threads = []
+ for qr in ops.get_collection(collection):
+ threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
+ start=start))
+ return threads
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
new file mode 100644
index 0000000000..c94c02da66
--- /dev/null
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -0,0 +1,186 @@
+"""Tests for QueueRunner."""
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class QueueRunnerTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 3.
+ self.assertEqual(3, var.eval())
+
+ def testTwoOps(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var0 = tf.Variable(zero64)
+ count_up_to_3 = var0.count_up_to(3)
+ var1 = tf.Variable(zero64)
+ count_up_to_30 = var1.count_up_to(30)
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, [count_up_to_3, count_up_to_30])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ self.assertEqual(0, len(qr.exceptions_raised))
+ self.assertEqual(3, var0.eval())
+ self.assertEqual(30, var1.eval())
+
+ def testExceptionsCaptured(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["i fail", "so fail"])
+ threads = qr.create_threads(sess)
+ tf.initialize_all_variables().run()
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(2, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+ self.assertTrue("Operation not in the graph" in str(exceptions[1]))
+
+ def testRealDequeueEnqueue(self):
+ with self.test_session() as sess:
+ q0 = tf.FIFOQueue(3, tf.float32)
+ enqueue0 = q0.enqueue((10.0,))
+ close0 = q0.close()
+ q1 = tf.FIFOQueue(30, tf.float32)
+ enqueue1 = q1.enqueue((q0.dequeue(),))
+ dequeue1 = q1.dequeue()
+ qr = tf.train.QueueRunner(q1, [enqueue1])
+ threads = qr.create_threads(sess)
+ for t in threads:
+ t.start()
+ # Enqueue 2 values, then close queue0.
+ enqueue0.run()
+ enqueue0.run()
+ close0.run()
+ # Wait for the queue runner to terminate.
+ for t in threads:
+ t.join()
+ # It should have terminated cleanly.
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The 2 values should be in queue1.
+ self.assertEqual(10.0, dequeue1.eval())
+ self.assertEqual(10.0, dequeue1.eval())
+ # And queue1 should now be closed.
+ with self.assertRaisesRegexp(tf.errors.OutOfRangeError, "is closed"):
+ dequeue1.eval()
+
+ def testRespectCoordShouldStop(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ # As the coordinator to stop. The queue runner should
+ # finish immediately.
+ coord = tf.train.Coordinator()
+ coord.request_stop()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ coord.join(threads)
+ self.assertEqual(0, len(qr.exceptions_raised))
+ # The variable should be 0.
+ self.assertEqual(0, var.eval())
+
+ def testRequestStopOnException(self):
+ with self.test_session() as sess:
+ queue = tf.FIFOQueue(10, tf.float32)
+ qr = tf.train.QueueRunner(queue, ["not an op"])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord)
+ for t in threads:
+ t.start()
+ # The exception should be re-raised when joining.
+ with self.assertRaisesRegexp(ValueError, "Operation not in the graph"):
+ coord.join(threads)
+
+ def testGracePeriod(self):
+ with self.test_session() as sess:
+ # The enqueue will quickly block.
+ queue = tf.FIFOQueue(2, tf.float32)
+ enqueue = queue.enqueue((10.0,))
+ dequeue = queue.dequeue()
+ qr = tf.train.QueueRunner(queue, [enqueue])
+ coord = tf.train.Coordinator()
+ threads = qr.create_threads(sess, coord, start=True)
+ # Dequeue one element and then request stop.
+ dequeue.op.run()
+ time.sleep(0.02)
+ coord.request_stop()
+ # We should be able to join because the RequestStop() will cause
+ # the queue to be closed and the enqueue to terminate.
+ coord.join(threads, stop_grace_period_secs=0.05)
+
+ def testNoMultiThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ coord = tf.train.Coordinator()
+ qr = tf.train.QueueRunner(queue, [count_up_to])
+ threads = []
+ threads.extend(qr.create_threads(sess, coord=coord))
+ with self.assertRaisesRegexp(
+ RuntimeError,
+ "Threads are already running"):
+ threads.extend(qr.create_threads(sess, coord=coord))
+ coord.request_stop()
+ coord.join(threads, stop_grace_period_secs=0.5)
+
+ def testThreads(self):
+ with self.test_session() as sess:
+ # CountUpTo will raise OUT_OF_RANGE when it reaches the count.
+ zero64 = tf.constant(0, dtype=tf.int64)
+ var = tf.Variable(zero64)
+ count_up_to = var.count_up_to(3)
+ queue = tf.FIFOQueue(10, tf.float32)
+ tf.initialize_all_variables().run()
+ qr = tf.train.QueueRunner(queue, [count_up_to, "bad op"])
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+ threads = qr.create_threads(sess, start=True)
+ for t in threads:
+ t.join()
+ exceptions = qr.exceptions_raised
+ self.assertEqual(1, len(exceptions))
+ self.assertTrue("Operation not in the graph" in str(exceptions[0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/rmsprop.py b/tensorflow/python/training/rmsprop.py
new file mode 100644
index 0000000000..6dc0ce11ea
--- /dev/null
+++ b/tensorflow/python/training/rmsprop.py
@@ -0,0 +1,81 @@
+"""One-line documentation for rmsprop module.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = decay * mean_square{t-1} + (1-decay) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square + epsilon)
+delta = - mom
+
+"""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import training_ops
+
+
+class RMSPropOptimizer(optimizer.Optimizer):
+ """Optimizer that implements the RMSProp algorithm.
+
+ @@__init__
+ """
+
+ def __init__(self, learning_rate, decay, momentum=0.0, epsilon=1e-10,
+ use_locking=False, name="RMSProp"):
+ """Construct a new RMSProp optimizer.
+
+ Args:
+ learning_rate: A Tensor or a floating point value. The learning rate.
+ decay: discounting factor for the history/coming gradient
+ momentum: a scalar tensor.
+ epsilon: small value to avoid zero denominator.
+ use_locking: If True use locks for update operation.
+ name: Optional name prefic for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+ super(RMSPropOptimizer, self).__init__(use_locking, name)
+ self._learning_rate = learning_rate
+ self._decay = decay
+ self._momentum = momentum
+ self._epsilon = epsilon
+
+ # Tensors for learning rate and momentum. Created in _prepare.
+ self._learning_rate_tensor = None
+ self._decay_tensor = None
+ self._momentum_tensor = None
+ self._epsilon_tensor = None
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self._get_or_make_slot(
+ v, constant_op.constant(1.0, dtype=v.dtype, shape=v.get_shape()),
+ "rms", self._name)
+ self._zeros_slot(v, "momentum", self._name)
+
+ def _prepare(self):
+ self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
+ name="learning_rate")
+ self._decay_tensor = ops.convert_to_tensor(self._decay, name="decay")
+ self._momentum_tensor = ops.convert_to_tensor(self._momentum,
+ name="momentum")
+ self._epsilon_tensor = ops.convert_to_tensor(self._epsilon,
+ name="epsilon")
+
+ def _apply_dense(self, grad, var):
+ rms = self.get_slot(var, "rms")
+ mom = self.get_slot(var, "momentum")
+ return training_ops.apply_rms_prop(
+ var, rms, mom,
+ self._learning_rate_tensor,
+ self._decay_tensor,
+ self._momentum_tensor,
+ self._epsilon_tensor,
+ grad, use_locking=self._use_locking).op
+
+ def _apply_sparse(self, grad, var):
+ raise NotImplementedError()
diff --git a/tensorflow/python/training/rmsprop_test.py b/tensorflow/python/training/rmsprop_test.py
new file mode 100644
index 0000000000..520df73ca8
--- /dev/null
+++ b/tensorflow/python/training/rmsprop_test.py
@@ -0,0 +1,158 @@
+"""Tests for rmsprop."""
+import math
+
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class RMSPropOptimizerTest(tf.test.TestCase):
+
+ def testWithoutMomentum(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0
+ / math.sqrt(0.90001+1.0)),
+ 4.0 - (0.01 * 2.0
+ / math.sqrt(0.90001+1.0))]),
+ var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
+ rms0.eval())
+ self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
+ rms1.eval())
+ # Check the parameters.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
+ - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1.0))
+ - (0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1.0))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
+ - (0.01 * 2.0 /
+ math.sqrt(0.90001*0.9+1e-5+1.0)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1.0))
+ - (0.01 * 2.0 /
+ math.sqrt(0.90001*0.9+1e-5+1.0))]),
+ var1.eval())
+
+ def testWithMomentum(self):
+ with self.test_session():
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([0.01, 0.01])
+
+ opt = tf.train.RMSPropOptimizer(learning_rate=2.0, decay=0.9,
+ momentum=0.5, epsilon=1e-5)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ tf.initialize_all_variables().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertTrue(rms0 is not None)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertTrue(rms1 is not None)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertTrue(mom0 is not None)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertTrue(mom1 is not None)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllClose(np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllClose(np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllClose(np.array([(0.1 * 2.0 / math.sqrt(0.901+1e-5)),
+ (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
+ mom0.eval())
+ self.assertAllClose(np.array([(0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
+ (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]),
+ mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllClose(np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5))]),
+ var0.eval())
+ self.assertAllClose(np.array([3.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5)),
+ 4.0 - (0.01 * 2.0/ math.sqrt(0.90001+1e-5))]
+ ),
+ var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllClose(np.array([0.901*0.9+0.001, 0.901*0.9+0.001]),
+ rms0.eval())
+ self.assertAllClose(np.array([0.90001*0.9+1e-5, 0.90001*0.9+1e-5]),
+ rms1.eval())
+ self.assertAllClose(np.array([0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
+ (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901+1e-5)) +
+ (0.1*2.0/math.sqrt(0.901*0.9+0.001+1e-5))
+ ]), mom0.eval())
+ self.assertAllClose(np.array([0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
+ (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5)),
+ 0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5))+
+ (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllClose(
+ np.array([1.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
+ 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
+ 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901+1e-5)) - (0.5 * (
+ 0.1 * 2.0 / math.sqrt(0.901+1e-5)) +(
+ 0.1 * 2.0 / math.sqrt(0.901*0.9+0.001+1e-5)))
+ ]), var0.eval())
+
+ self.assertAllClose(
+ np.array([3.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
+ - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
+ (0.01 * 2.0 /math.sqrt(0.90001*0.9+2e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001+1e-5))
+ - (0.5 *(0.01 * 2.0/ math.sqrt(0.90001+1e-5)) +
+ (0.01 * 2.0 / math.sqrt(0.90001*0.9+2e-5)))]),
+ var1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/saver.proto b/tensorflow/python/training/saver.proto
new file mode 100644
index 0000000000..b9ba9f7e3c
--- /dev/null
+++ b/tensorflow/python/training/saver.proto
@@ -0,0 +1,30 @@
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+
+// Protocol buffer representing the configuration of a SaveRestoreHelper.
+message SaverDef {
+ // The name of the tensor in which to specify the filename when saving or
+ // restoring a model checkpoint.
+ string filename_tensor_name = 1;
+
+ // The operation to run when saving a model checkpoint.
+ string save_tensor_name = 2;
+
+ // The operation to run when restoring a model checkpoint.
+ string restore_op_name = 3;
+
+ // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted.
+ int32 max_to_keep = 4;
+
+ // Shard the save files, one per device that has Parameters nodes.
+ bool sharded = 5;
+
+ // How often to keep an additional checkpoint. If not specified, only the last
+ // "max_to_keep" checkpoints are kept; if specified, in addition to keeping
+ // the
+ // last "max_to_keep" checkpoints, an additional checkpoint will be kept for
+ // every n hours of training.
+ float keep_checkpoint_every_n_hours = 6;
+}
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
new file mode 100644
index 0000000000..505bbad4c6
--- /dev/null
+++ b/tensorflow/python/training/saver.py
@@ -0,0 +1,887 @@
+# pylint: disable=invalid-name
+"""Save and restore variables."""
+import collections
+import numbers
+import os.path
+import time
+
+from google.protobuf import text_format
+
+from tensorflow.python.client import graph_util
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_io_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from tensorflow.python.training import saver_pb2
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
+
+
+class BaseSaverBuilder(object):
+ """Base class for Savers.
+
+ Can be extended to create different Ops.
+ """
+
+ class VarToSave(object):
+ """Class used to describe variable slices that need to be saved."""
+
+ def __init__(self, var, slice_spec, name):
+ self.var = var
+ self.slice_spec = slice_spec
+ self.name = name
+
+ def __init__(self):
+ pass
+
+ def save_op(self, filename_tensor, vars_to_save):
+ """Create an Op to save 'vars_to_save'.
+
+ This is intended to be overridden by subclasses that want to generate
+ different Ops.
+
+ Args:
+ filename_tensor: String Tensor.
+ vars_to_save: a list of BaseSaverBuilder.VarToSave objects.
+
+ Returns:
+ An Operation that save the variables.
+ """
+ return io_ops._save(
+ filename=filename_tensor,
+ tensor_names=[vs.name for vs in vars_to_save],
+ tensors=[vs.var for vs in vars_to_save],
+ tensor_slices=[vs.slice_spec for vs in vars_to_save])
+
+ def restore_op(self, filename_tensor, var_to_save, preferred_shard):
+ """Create an Op to read the variable 'var_to_save'.
+
+ This is intended to be overridden by subclasses that want to generate
+ different Ops.
+
+ Args:
+ filename_tensor: String Tensor.
+ var_to_save: a BaseSaverBuilder.VarToSave object.
+ preferred_shard: Int. Shard to open first when loading a sharded file.
+
+ Returns:
+ A Tensor resulting from reading 'var_to_save' from 'filename'.
+ """
+ return io_ops._restore_slice(
+ filename_tensor,
+ var_to_save.name,
+ var_to_save.slice_spec,
+ var_to_save.var.dtype,
+ preferred_shard=preferred_shard)
+
+ def sharded_filename(self, filename_tensor, shard, num_shards):
+ """Append sharding information to a filename.
+
+ Args:
+ filename_tensor: a string tensor.
+ shard: integer. The shard for the filename.
+ num_shards: an int Tensor for the number of shards.
+
+ Returns:
+ A string tensor.
+ """
+ return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards)
+
+ def _AddSaveOps(self, filename_tensor, vars_to_save):
+ """Add ops to save variables that are on the same shard.
+
+ Args:
+ filename_tensor: String Tensor.
+ vars_to_save: a list of _VarToSave objects.
+
+ Returns:
+ A tensor with the filename used to save.
+ """
+ save = self.save_op(filename_tensor, vars_to_save)
+ return control_flow_ops.with_dependencies([save], filename_tensor)
+
+ def _AddShardedSaveOps(self, filename_tensor, per_device):
+ """Add ops to save the params per shard.
+
+ Args:
+ filename_tensor: String Tensor.
+ per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
+ returned by _GroupByDevices().
+
+ Returns:
+ An op to save the variables.
+ """
+ num_shards = len(per_device)
+ sharded_saves = []
+ num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
+ for shard, (device, vars_to_save) in enumerate(per_device):
+ with ops.device(device):
+ sharded_filename = self.sharded_filename(
+ filename_tensor, shard, num_shards_tensor)
+ sharded_saves.append(self._AddSaveOps(sharded_filename, vars_to_save))
+ # Return the sharded name for the save path.
+ with ops.control_dependencies([x.op for x in sharded_saves]):
+ return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)
+
+ def _AddRestoreOps(self,
+ filename_tensor,
+ vars_to_save,
+ restore_sequentially,
+ reshape,
+ preferred_shard=-1,
+ name="restore_all"):
+ """Add operations to restore vars_to_save.
+
+ Args:
+ filename_tensor: Tensor for the path of the file to load.
+ vars_to_save: a list of _VarToSave objects.
+ restore_sequentially: True if we want to restore variables sequentially
+ within a shard.
+ reshape: True if we want to reshape loaded tensors to the shape of
+ the corresponding variable.
+ preferred_shard: Shard to open first when loading a sharded file.
+ name: Name for the returned op.
+
+ Returns:
+ An Operation that restores the variables.
+ """
+ assign_ops = []
+ for vs in vars_to_save:
+ v = vs.var
+ restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
+ # Load and optionally reshape on the CPU, as string tensors are not
+ # available on the GPU.
+ # TODO(mdevin): Re-enable restore on GPU when we can support annotating
+ # string tensors as "HostMemory" inputs.
+ with ops.device(graph_util.set_cpu0(v.device) if v.device else None):
+ with ops.control_dependencies(restore_control_inputs):
+ values = self.restore_op(filename_tensor, vs, preferred_shard)
+ if reshape:
+ shape = v.get_shape()
+ if not shape.is_fully_defined():
+ shape = array_ops.shape(v)
+ values = array_ops.reshape(values, shape)
+
+ # Assign on the same device as the variable.
+ with ops.device(v.device):
+ assign_ops.append(state_ops.assign(v,
+ values,
+ validate_shape=not reshape))
+
+ # Create a Noop that has control dependencies from all the updates.
+ return control_flow_ops.group(*assign_ops, name=name)
+
+ def _AddShardedRestoreOps(self, filename_tensor, per_device,
+ restore_sequentially, reshape):
+ """Add Ops to save variables from multiple devices.
+
+ Args:
+ filename_tensor: Tensor for the path of the file to load.
+ per_device: A list of (device, _VarToSave) pairs, as
+ returned by _GroupByDevices().
+ restore_sequentially: True if we want to restore variables sequentially
+ within a shard.
+ reshape: True if we want to reshape loaded tensors to the shape of
+ the corresponding variable.
+
+ Returns:
+ An Operation that restores the variables.
+ """
+ sharded_restores = []
+ for shard, (device, vars_to_save) in enumerate(per_device):
+ with ops.device(device):
+ sharded_restores.append(self._AddRestoreOps(
+ filename_tensor,
+ vars_to_save,
+ restore_sequentially,
+ reshape,
+ preferred_shard=shard,
+ name="restore_shard"))
+ return control_flow_ops.group(*sharded_restores, name="restore_all")
+
+ def _IsVariable(self, v):
+ return isinstance(v, ops.Tensor) and (
+ v.op.type == "Variable" or v.op.type == "AutoReloadVariable")
+
+ def _GroupByDevices(self, vars_to_save):
+ """Group Variable tensor slices per device.
+
+ TODO(mdevin): Make sure that all the devices found are on different
+ job/replica/task/cpu|gpu. It would be bad if 2 were on the same device.
+ It can happen if the devices as unspecified.
+
+ Args:
+ vars_to_save: a list of BaseSaverBuilder.VarToSave objects.
+
+ Returns:
+ A list of tuples: (device_name, BaseSaverBuilder.VarToSave) tuples.
+ The list is sorted by ascending device_name.
+ """
+ per_device = collections.defaultdict(lambda: [])
+ for var_to_save in vars_to_save:
+ per_device[var_to_save.var.device].append(var_to_save)
+ return sorted([(dev, tup) for dev, tup in per_device.iteritems()],
+ key=lambda t: t[0])
+
+ def _VarListToDict(self, var_list):
+ """Create a dictionary of names to variable lists.
+
+ Args:
+ var_list: A list, tuple, or set of Variables.
+
+ Returns:
+ A dictionary of variable names to the variables that must be saved under
+ that name. Variables with save_slice_info are grouped together under the
+ same key in no particular order.
+
+ Raises:
+ TypeError: If the type of var_list or its elements is not supported.
+ ValueError: If at least two variables share the same name.
+ """
+ if not isinstance(var_list, (list, tuple, set)):
+ raise TypeError("Variables to save should be passed in a dict or a "
+ "list: %s" % var_list)
+ var_list = set(var_list)
+ names_to_variables = {}
+ for var in var_list:
+ # pylint: disable=protected-access
+ if isinstance(var, variables.Variable) and var._save_slice_info:
+ name = var._save_slice_info.name
+ if name in names_to_variables:
+ if not isinstance(names_to_variables[name], list):
+ raise ValueError("Mixing slices and non-slices with the same name: "
+ "%s" % name)
+ names_to_variables[name].append(var)
+ else:
+ names_to_variables[name] = [var]
+ else:
+ var = ops.convert_to_tensor(var)
+ if not self._IsVariable(var):
+ raise TypeError("Variable to save is not a Variable: %s" % var)
+ name = var.op.name
+ if name in names_to_variables:
+ raise ValueError("At least two variables have the same name: %s" %
+ name)
+ names_to_variables[name] = var
+ # pylint: enable=protected-access
+ return names_to_variables
+
+ def _ValidateAndSliceInputs(self, names_to_variables):
+ """Returns the variables and names that will be used for a Saver.
+
+ Args:
+ names_to_variables: A dict (k, v) where k is the name of a variable and v
+ is a Variable to save or a BaseSaverBuilder.Saver.
+
+ Returns:
+ A list of BaseSaverBuilder.VarToSave objects.
+
+ Raises:
+ TypeError: if any of the keys are not strings or any of the
+ values are not one of Tensor or Variable.
+ ValueError: if the same variable is given in more than one value
+ (this also applies to slices of SlicedVariables).
+ """
+ if not isinstance(names_to_variables, dict):
+ names_to_variables = self._VarListToDict(names_to_variables)
+
+ vars_to_save = []
+ seen_variables = set()
+ for name in sorted(names_to_variables.iterkeys()):
+ if not isinstance(name, basestring):
+ raise TypeError("names_to_variables must be a dict mapping string "
+ "names to variable Tensors. Name is not a string: %s" %
+ name)
+ v = names_to_variables[name]
+ if isinstance(v, (list, tuple)):
+ # A set of slices.
+ slice_name = None
+ # pylint: disable=protected-access
+ for variable in v:
+ if not isinstance(variable, variables.Variable):
+ raise ValueError("Slices must all be Variables: %s" % variable)
+ if not variable._save_slice_info:
+ raise ValueError("Slices must all be slices: %s" % variable)
+ if slice_name is None:
+ slice_name = variable._save_slice_info.name
+ elif slice_name != variable._save_slice_info.name:
+ raise variable("Slices must all be from the same tensor: %s != %s"
+ % (slice_name, variable._save_slice_info.name))
+ self._AddVarToSave(vars_to_save, seen_variables,
+ variable, variable._save_slice_info.spec, name)
+ # pylint: enable=protected-access
+ else:
+ # A variable or tensor.
+ variable = ops.convert_to_tensor(v)
+ if not self._IsVariable(variable):
+ raise TypeError("names_to_variables must be a dict mapping string "
+ "names to Tensors/Variables. Not a variable: %s" %
+ variable)
+ self._AddVarToSave(vars_to_save, seen_variables, variable, "", name)
+ return vars_to_save
+
+ def _AddVarToSave(self, vars_to_save, seen_variables, variable, slice_spec,
+ name):
+ """Create a VarToSave and add it to the vars_to_save list.
+
+ Args:
+ vars_to_save: List to append the new VarToSave to.
+ seen_variables: Set of variables already processed. Used to check
+ that each variable is only saved once.
+ variable: Variable to save.
+ slice_spec: String. Slice spec for the variable.
+ name: Name to use to save the variable.
+
+ Raises:
+ ValueError: If the variable has already been processed.
+ """
+ if variable in seen_variables:
+ raise ValueError("The same variable will be restored with two names: %s",
+ variable)
+ vars_to_save.append(BaseSaverBuilder.VarToSave(variable, slice_spec, name))
+ seen_variables.add(variable)
+
+ def build(self,
+ names_to_variables,
+ reshape=False,
+ sharded=False,
+ max_to_keep=5,
+ keep_checkpoint_every_n_hours=10000.0,
+ name=None,
+ restore_sequentially=False):
+ """Adds save/restore nodes to the graph and creates a SaverDef proto.
+
+ Args:
+ names_to_variables: A dictionary mapping name to a Variable.
+ Each name will be associated with the
+ corresponding variable in the checkpoint.
+ reshape: If True, allow restoring parameters from a checkpoint
+ that where the parameters have a different shape. This is
+ only needed when you try to restore from a Dist-Belief checkpoint,
+ and only some times.
+ sharded: If True, shard the checkpoints, one per device that has
+ Parameters nodes.
+ max_to_keep: maximum number of checkpoints to keep. As new checkpoints
+ are created, old ones are deleted. If None or 0, no checkpoints are
+ deleted. Presently the number is only roughly enforced. For example
+ in case of restarts more than max_to_keep checkpoints may be kept.
+ keep_checkpoint_every_n_hours: How often checkpoints should be kept.
+ Defaults to 10,000 hours.
+ name: string. Optional name to use as a prefix when adding operations.
+ restore_sequentially: A Bool, which if true, causes restore of different
+ variables to happen sequentially within each device.
+
+ Returns:
+ A SaverDef proto.
+
+ Raises:
+ TypeError: If 'names_to_variables' is not a dictionary mapping string
+ keys to variable Tensors.
+ ValueError: If any of the keys or values in 'names_to_variables' is not
+ unique.
+ """
+ vars_to_save = self._ValidateAndSliceInputs(names_to_variables)
+ if max_to_keep is None:
+ max_to_keep = 0
+
+ with ops.op_scope([vs.var for vs in vars_to_save], name, "save") as name:
+ # Add the Constant string tensor for the filename.
+ filename_tensor = constant_op.constant("model")
+
+ # Add the save ops.
+ if sharded:
+ per_device = self._GroupByDevices(vars_to_save)
+ save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
+ restore_op = self._AddShardedRestoreOps(
+ filename_tensor, per_device, restore_sequentially, reshape)
+ else:
+ save_tensor = self._AddSaveOps(filename_tensor, vars_to_save)
+ restore_op = self._AddRestoreOps(
+ filename_tensor, vars_to_save, restore_sequentially, reshape)
+
+ assert restore_op.name.endswith("restore_all"), restore_op.name
+
+ return saver_pb2.SaverDef(
+ filename_tensor_name=filename_tensor.name,
+ save_tensor_name=save_tensor.name,
+ restore_op_name=restore_op.name,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ sharded=sharded)
+
+def _GetCheckpointFilename(save_dir, latest_filename):
+ """Returns a filename for storing the CheckpointState.
+
+ Args:
+ save_dir: The directory for saving and restoring checkpoints.
+ latest_filename: Name of the file in 'save_dir' that is used
+ to store the CheckpointState.
+
+ Returns:
+ The path of the file that contains the CheckpointState proto.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ return os.path.join(save_dir, latest_filename)
+
+
+def update_checkpoint_state(save_dir,
+ model_checkpoint_path,
+ all_model_checkpoint_paths=None,
+ latest_filename=None):
+ """Updates the content of the 'checkpoint' file.
+
+ This updates the checkpoint file containing a CheckpointState
+ proto.
+
+ Args:
+ save_dir: Directory where the model was saved.
+ model_checkpoint_path: The checkpoint file.
+ all_model_checkpoint_paths: list of strings. Paths to all not-yet-deleted
+ checkpoints, sorted from oldest to newest. If this is a non-empty list,
+ the last element must be equal to model_checkpoint_path. These paths
+ are also saved in the CheckpointState proto.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Raises:
+ RuntimeError: If the save paths conflict.
+ """
+ if all_model_checkpoint_paths is None:
+ all_model_checkpoint_paths = []
+ elif all_model_checkpoint_paths[-1] != model_checkpoint_path:
+ logging.warning(
+ "%s is not in all_model_checkpoint_paths! Manually adding it.",
+ model_checkpoint_path)
+ all_model_checkpoint_paths.append(model_checkpoint_path)
+ # Writes the "checkpoint" file for the coordinator for later restoration.
+ coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
+ if coord_checkpoint_filename == model_checkpoint_path:
+ raise RuntimeError("Save path '%s' conflicts with path used for "
+ "checkpoint state. Please use a different save path." %
+ model_checkpoint_path)
+ coord_checkpoint_proto = CheckpointState(
+ model_checkpoint_path=model_checkpoint_path,
+ all_model_checkpoint_paths=all_model_checkpoint_paths)
+ f = gfile.FastGFile(coord_checkpoint_filename, mode="w")
+ f.write(text_format.MessageToString(coord_checkpoint_proto))
+ f.close()
+
+
+def get_checkpoint_state(checkpoint_dir, latest_filename=None):
+ """Returns CheckpointState proto from the "checkpoint" file.
+
+ If the "checkpoint" file contains a valid CheckpointState
+ proto, returns it.
+
+ Args:
+ checkpoint_dir: The directory of checkpoints.
+ latest_filename: Optional name of the checkpoint file. Default to
+ 'checkpoint'.
+
+ Returns:
+ A CheckpointState if the state was available, None
+ otherwise.
+ """
+ ckpt = None
+ coord_checkpoint_filename = _GetCheckpointFilename(
+ checkpoint_dir, latest_filename)
+ f = None
+ try:
+ # Check that the file exists before opeining it to avoid
+ # many lines of errors from colossus in the logs.
+ if gfile.Exists(coord_checkpoint_filename):
+ f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
+ ckpt = CheckpointState()
+ text_format.Merge(f.read(), ckpt)
+ except gfile.FileError:
+ # It's ok if the file cannot be read
+ return None
+ except text_format.ParseError, e:
+ logging.warning(str(e))
+ logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
+ return None
+ finally:
+ if f:
+ f.close()
+ return ckpt
+
+
+class Saver(object):
+ """Saves and restores variables.
+
+ See [Variables](../../how_tos/variables/index.md)
+ for an overview of variables, saving and restoring.
+
+ The `Saver` class adds ops to save and restore variables to and from
+ *checkpoints*. It also provides convenience methods to run these ops.
+
+ Checkpoints are binary files in a proprietary format which map variable names
+ to tensor values. The best way to examine the contents of a checkpoint is to
+ load it using a `Saver`.
+
+ Savers can automatically number checkpoint filenames with a provided counter.
+ This lets you keep multiple checkpoints at different steps while training a
+ model. For example you can number the checkpoint filenames with the training
+ step number. To avoid filling up disks, savers manage checkpoint files
+ automatically. For example, they can keep only the N most recent files, or
+ one checkpoint for every N hours of training.
+
+ You number checkpoint filenames by passing a value to the optional
+ `global_step` argument to `save()`:
+
+ ```python
+ saver.save('my-model', global_step=0) ==> filename: 'my-model-0'
+ ...
+ saver.save('my-model', global_step=1000) ==> filename: 'my-model-1000'
+ ```
+
+ Additionally, optional arguments to the `Saver()` constructor let you control
+ the proliferation of checkpoint files on disk:
+
+ * `max_to_keep` indicates the maximum number of recent checkpoint files to
+ keep. As new files are created, older files are deleted. If None or 0,
+ all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
+ checkpoint files are kept.)
+
+ * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
+ `max_to_keep` checkpoint files, you might want to keep one checkpoint file
+ for every N hours of training. This can be useful if you want to later
+ analyze how a model progressed during a long training session. For
+ example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
+ one checkpoint file for every 2 hours of training. The default value of
+ 10,000 hours effectively disables the feature.
+
+ Note that you still have to call the `save()` method to save the model.
+ Passing these arguments to the constructor will not save variables
+ automatically for you.
+
+ A training program that saves regularly looks like:
+
+ ```python
+ ...
+ # Create a saver.
+ saver = tf.train.Saver(...variables...)
+ # Launch the graph and train, saving the model every 1,000 steps.
+ sess = tf.Session()
+ for step in xrange(1000000):
+ sess.run(..training_op..)
+ if step % 1000 == 0:
+ # Append the step number to the checkpoint name:
+ saver.save(sess, 'my-model', global_step=step)
+ ```
+
+ In addition to checkpoint files, savers keep a protocol buffer on disk with
+ the list of recent checkpoints. This is used to manage numbered checkpoint
+ files and by `latest_checkpoint()`, which makes it easy to discover the path
+ to the most recent checkpoint. That protocol buffer is stored in a file named
+ 'checkpoint' next to the checkpoint files.
+
+ If you create several savers, you can specify a different filename for the
+ protocol buffer file in the call to `save()`.
+
+ @@__init__
+ @@save
+ @@restore
+
+ Other utility methods.
+
+ @@last_checkpoints
+ @@set_last_checkpoints
+ @@as_saver_def
+ """
+
+ def __init__(self,
+ var_list=None,
+ reshape=False,
+ sharded=False,
+ max_to_keep=5,
+ keep_checkpoint_every_n_hours=10000.0,
+ name=None,
+ restore_sequentially=False,
+ saver_def=None,
+ builder=None):
+ """Creates a `Saver`.
+
+ The constructor adds ops to save and restore variables.
+
+ `var_list` specifies the variables that will be saved and restored. It can
+ be passed as a `dict` or a list:
+
+ * A `dict` of names to variables: The keys are the names that will be
+ used to save or restore the variables in the checkpoint files.
+ * A list of variables: The variables will be keyed with their op name in
+ the checkpoint files.
+
+ For example:
+
+ ```python
+ v1 = tf.Variable(..., name='v1')
+ v2 = tf.Variable(..., name='v2')
+
+ # Pass the variables as a dict:
+ saver = tf.train.Saver({'v1': v1, 'v2': v2})
+
+ # Or pass them as a list.
+ saver = tf.train.Saver([v1, v2])
+ # Passing a list is equivalent to passing a dict with the variable op names
+ # as keys:
+ saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
+ ```
+
+ The optional `reshape` argument, if True, allows restoring a variable from
+ a save file where the variable had a different shape, but the same number
+ of elements and type. This is useful if you have reshaped a variable and
+ want to reload it from an older checkpoint.
+
+ The optional `sharded` argument, if True, instructs the saver to shard
+ checkpoints per device.
+
+ Args:
+ var_list: A list of Variables or a dictionary mapping names to
+ Variables. If None, defaults to the list of all variables.
+ reshape: If True, allows restoring parameters from a checkpoint
+ where the variables have a different shape.
+ sharded: If True, shard the checkpoints, one per device.
+ max_to_keep: maximum number of recent checkpoints to keep.
+ Defaults to 10,000 hours.
+ keep_checkpoint_every_n_hours: How often to keep checkpoints.
+ Defaults to 10,000 hours.
+ name: string. Optional name to use as a prefix when adding operations.
+ restore_sequentially: A Bool, which if true, causes restore of different
+ variables to happen sequentially within each device. This can lower
+ memory usage when restoring very large models.
+ saver_def: Optional SaverDef proto to use instead of running the builder.
+ This is only useful for specialty code that wants to recreate a Saver
+ object for a previously built Graph that had a Saver. The saver_def
+ proto should be the one returned by the as_saver_def() call of the
+ Saver that was created for that Graph.
+ builder: Optional SaverBuilder to use if a saver_def was not provided.
+ Defaults to BaseSaverBuilder().
+
+ Raises:
+ TypeError: If `var_list` is invalid.
+ ValueError: If any of the keys or values in `var_list` is not unique.
+ """
+ if saver_def is None:
+ if builder is None:
+ builder = BaseSaverBuilder()
+ if var_list is None:
+ var_list = variables.all_variables()
+ if not var_list:
+ raise ValueError("No variables to save")
+ saver_def = builder.build(
+ var_list,
+ reshape=reshape,
+ sharded=sharded,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ name=name,
+ restore_sequentially=restore_sequentially)
+ if not isinstance(saver_def, saver_pb2.SaverDef):
+ raise ValueError("saver_def must if a saver_pb2.SaverDef: %s" % saver_def)
+ if not saver_def.save_tensor_name:
+ raise ValueError("saver_def must specify the save_tensor_name: %s"
+ % str(saver_def))
+ if not saver_def.restore_op_name:
+ raise ValueError("saver_def must specify the restore_op_name: %s"
+ % str(saver_def))
+ self._filename_tensor_name = saver_def.filename_tensor_name
+ self._save_tensor_name = saver_def.save_tensor_name
+ self._restore_op_name = saver_def.restore_op_name
+ self._max_to_keep = saver_def.max_to_keep
+ # If keep_checkpoint_every_n_hours is not set, set it to 10000 hours.
+ self._keep_checkpoint_every_n_hours = (
+ saver_def.keep_checkpoint_every_n_hours if
+ saver_def.keep_checkpoint_every_n_hours else 10000)
+ self._next_checkpoint_time = (
+ time.time() + self._keep_checkpoint_every_n_hours * 3600)
+ self._sharded = saver_def.sharded
+ self._last_checkpoints = []
+
+ def _CheckpointFilename(self, p):
+ """Returns the checkpoint file name.
+
+ If p is (filename, time) pair, return p[0]; else return p.
+
+ Args:
+ p: (filename, time) pair or just checkpoint filename.
+
+ Returns:
+ Checkpoint file name.
+ """
+ return p[0] if isinstance(p, tuple) else p
+
+ def _MaybeDeleteOldCheckpoints(self, latest_save_path):
+ """Deletes old checkpoints if necessary.
+
+ Always keep the last max_to_keep checkpoints. If
+ keep_checkpoint_every_n_hours was specified, keep an additional checkpoint
+ every N hours. For example, if N is 0.5, an additional checkpoint is kept
+ for every 0.5 hours of training; if N is 10, an additional checkpoint is
+ kept for every 10 hours of training.
+
+ Args:
+ latest_save_path: Name including path of checkpoint file to save.
+ """
+ if not self._max_to_keep:
+ return
+ # Remove first from list if the same name was used before.
+ for p in self._last_checkpoints:
+ if latest_save_path == self._CheckpointFilename(p):
+ self._last_checkpoints.remove(p)
+ # Append new path to list
+ self._last_checkpoints.append((latest_save_path, time.time()))
+ # If more than max_to_keep, remove oldest.
+ if len(self._last_checkpoints) > self._max_to_keep:
+ p = self._last_checkpoints.pop(0)
+ # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
+ # have reached N hours of training.
+ should_keep = p[1] > self._next_checkpoint_time
+ if should_keep:
+ self._next_checkpoint_time += (
+ self._keep_checkpoint_every_n_hours * 3600)
+ return
+ # Otherwise delete the files.
+ for f in gfile.Glob(self._CheckpointFilename(p)):
+ try:
+ gfile.Remove(f)
+ except gfile.GOSError, e:
+ logging.warning("Ignoring: %s", str(e))
+
+ def as_saver_def(self):
+ """Generates a `SaverDef` representation of this saver.
+
+ Returns:
+ A `SaverDef` proto.
+ """
+ return saver_pb2.SaverDef(
+ filename_tensor_name=self._filename_tensor_name,
+ save_tensor_name=self._save_tensor_name,
+ restore_op_name=self._restore_op_name,
+ max_to_keep=self._max_to_keep,
+ keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
+ sharded=self._sharded)
+
+ @property
+ def last_checkpoints(self):
+ """List of not-yet-deleted checkpoint filenames.
+
+ You can pass any of the returned values to `restore()`.
+
+ Returns:
+ A list of checkpoint filenames, sorted from oldest to newest.
+ """
+ return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
+
+ def set_last_checkpoints(self, last_checkpoints):
+ """Sets the list of not-yet-deleted checkpoint filenames.
+
+ Args:
+ last_checkpoints: a list of checkpoint filenames.
+
+ Raises:
+ AssertionError: if the list of checkpoint filenames has already been set.
+ """
+ assert not self._last_checkpoints
+ assert isinstance(last_checkpoints, list)
+ self._last_checkpoints = list(last_checkpoints)
+
+ def save(self, sess, save_path, global_step=None, latest_filename=None):
+ """Saves variables.
+
+ This method runs the ops added by the constructor for saving variables.
+ It requires a session in which the graph was launched. The variables to
+ save must also have been initialized.
+
+ The method returns the path of the newly created checkpoint file. This
+ path can be passed directly to a call to `restore()`.
+
+ Args:
+ sess: A Session to use to save the variables..
+ save_path: string. Path to the checkpoint filename. If the saver is
+ `sharded`, this is the prefix of the sharded checkpoint filename.
+ global_step: If provided the global step number is appended to
+ `save_path` to create the checkpoint filename. The optional argument
+ can be a Tensor, a Tensor name or an integer.
+ latest_filename: Optional name for the protocol buffer file that will
+ contains the list of most recent checkpoint filenames. That file,
+ kept in the same directory as the checkpoint files, is automatically
+ managed by the saver to keep track of recent checkpoints. Defaults to
+ 'checkpoint'.
+
+ Returns:
+ A string: path at which the variables were saved. If the saver is
+ sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
+ is the number of shards created.
+
+ Raises:
+ TypeError: If `sess` is not a Session.
+ """
+ if latest_filename is None:
+ latest_filename = "checkpoint"
+ if global_step is not None:
+ if not isinstance(global_step, numbers.Number):
+ global_step = training_util.global_step(sess, global_step)
+ checkpoint_file = "%s-%d" % (save_path, global_step)
+ else:
+ checkpoint_file = save_path
+ save_path = os.path.dirname(save_path)
+ if not isinstance(sess, session.SessionInterface):
+ raise TypeError("'sess' must be a Session; %s" % sess)
+
+ model_checkpoint_path = sess.run(
+ self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
+ model_checkpoint_path = str(model_checkpoint_path)
+ self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
+ update_checkpoint_state(save_path, model_checkpoint_path,
+ self.last_checkpoints, latest_filename)
+ return model_checkpoint_path
+
+ def restore(self, sess, save_path):
+ """Restores previously saved variables.
+
+ This method runs the ops added by the constructor for restoring variables.
+ It requires a session in which the graph was launched. The variables to
+ restore do not have to have been initialized, as restoring is itself a way
+ to initialize variables.
+
+ The `save_path` argument is typically a value previously returned from a
+ `save()` call, or a call to `latest_checkpoint()`.
+
+ Args:
+ sess: A Session to use to restore the parameters.
+ save_path: Path where parameters were previously saved.
+ """
+ sess.run([self._restore_op_name], {self._filename_tensor_name: save_path})
+
+
+def latest_checkpoint(checkpoint_dir, latest_filename=None):
+ """Finds the filename of latest saved checkpoint file.
+
+ Args:
+ checkpoint_dir: Directory where the variables were saved.
+ latest_filename: Optional name for the protocol buffer file that
+ contains the list of most recent checkpoint filenames.
+ See the corresponding argument to `Saver.save()`.
+
+ Returns:
+ The full path to the latest checkpoint or None if no checkpoint was found.
+ """
+ # Pick the latest checkpoint based on checkpoint state.
+ ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
+ if ckpt and ckpt.model_checkpoint_path:
+ checkpoint_full_path = os.path.join(
+ checkpoint_dir, ckpt.model_checkpoint_path)
+ if gfile.Exists(checkpoint_full_path):
+ return checkpoint_full_path
+
+ return None
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
new file mode 100644
index 0000000000..db378e9637
--- /dev/null
+++ b/tensorflow/python/training/saver_test.py
@@ -0,0 +1,563 @@
+"""Tests for tensorflow.ops.io_ops."""
+import os.path
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+import numpy as np
+
+from tensorflow.python.platform import gfile
+
+
+class SaverTest(tf.test.TestCase):
+
+ def testBasics(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics")
+
+ with self.test_session() as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Start a second session. In that session the parameter nodes
+ # have not been initialized either.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1})
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v0" in e.message):
+ sess.run(v0)
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v1" in e.message):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Build another graph with 2 nodes, initialized
+ # differently, and a Restore node for them.
+ with self.test_session() as sess:
+ v0_2 = tf.Variable(1000.0, name="v0")
+ v1_2 = tf.Variable(2000.0, name="v1")
+ save2 = tf.train.Saver({"v0": v0_2, "v1": v1_2})
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(1000.0, v0_2.eval())
+ self.assertEqual(2000.0, v1_2.eval())
+ # Restore the values saved earlier in the parameter nodes.
+ save2.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0_2.eval())
+ self.assertEqual(20.0, v1_2.eval())
+
+ def testInt64(self):
+ save_path = os.path.join(self.get_temp_dir(), "int64")
+
+ with self.test_session() as sess:
+ # Build a graph with 1 node, and save and restore for them.
+ v = tf.Variable(np.int64(15), name="v")
+ save = tf.train.Saver({"v": v}, restore_sequentially=True)
+ tf.initialize_all_variables().run()
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ with self.test_session() as sess:
+ v = tf.Variable(np.int64(-1), name="v")
+ save = tf.train.Saver({"v": v})
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v" in e.message):
+ sess.run(v)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(np.int64(15), v.eval())
+
+ def testSomeErrors(self):
+ with tf.Graph().as_default():
+ v0 = tf.Variable([10.0], name="v0")
+ v1 = tf.Variable([20.0], name="v1")
+ v2 = tf.Variable([20.0], name="v2")
+ v2._set_save_slice_info(tf.Variable.SaveSliceInfo("v1", ""))
+
+ # By default the name used for "v2" will be "v1" and raise an error.
+ with self.assertRaisesRegexp(ValueError, "same name: v1"):
+ tf.train.Saver([v0, v1, v2])
+
+ # The names are different and will work.
+ tf.train.Saver({"vee1": v1, "other": [v2]})
+
+ def testBasicsWithListOfVariables(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics_with_list")
+
+ with self.test_session(graph=tf.Graph()) as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver([v0, v1])
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Start a second session. In that session the variables
+ # have not been initialized either.
+ with self.test_session(graph=tf.Graph()) as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+ save = tf.train.Saver([v0, v1])
+
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v0" in e.message):
+ sess.run(v0)
+ with self.assertRaisesWithPredicateMatch(
+ tf.OpError, lambda e: "uninitialized value v1" in e.message):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Build another graph with 2 nodes, initialized
+ # differently, and a Restore node for them.
+ with self.test_session(graph=tf.Graph()) as sess:
+ v0_2 = tf.Variable(1000.0, name="v0")
+ v1_2 = tf.Variable(2000.0, name="v1")
+ save2 = tf.train.Saver([v0_2, v1_2])
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(1000.0, v0_2.eval())
+ self.assertEqual(2000.0, v1_2.eval())
+ # Restore the values saved earlier in the parameter nodes.
+ save2.restore(sess, save_path)
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0_2.eval())
+ self.assertEqual(20.0, v1_2.eval())
+
+ def _SaveAndLoad(self, var_name, var_value, other_value, save_path):
+ with self.test_session() as sess:
+ var = tf.Variable(var_value, name=var_name)
+ save = tf.train.Saver({var_name: var})
+ var.initializer.run()
+ val = save.save(sess, save_path)
+ self.assertEqual(save_path, val)
+ with self.test_session() as sess:
+ var = tf.Variable(other_value, name=var_name)
+ save = tf.train.Saver({var_name: var})
+ save.restore(sess, save_path)
+ self.assertAllClose(var_value, var.eval())
+
+ def testCacheRereadsFile(self):
+ save_path = os.path.join(self.get_temp_dir(), "cache_rereads")
+ # Save and reload one Variable named "var0".
+ self._SaveAndLoad("var0", 0.0, 1.0, save_path)
+ # Save and reload one Variable named "var1" in the same file.
+ # The cached readers should know to re-read the file.
+ self._SaveAndLoad("var1", 1.1, 2.2, save_path)
+
+ def testGPU(self):
+ if not tf.test.IsBuiltWithCuda():
+ return
+ save_path = os.path.join(self.get_temp_dir(), "gpu")
+ with tf.Session("", graph=tf.Graph()) as sess:
+ with sess.graph.device("/gpu:0"):
+ v0_1 = tf.Variable(123.45)
+ save = tf.train.Saver({"v0": v0_1})
+ tf.initialize_all_variables().run()
+ save.save(sess, save_path)
+
+ with tf.Session("", graph=tf.Graph()) as sess:
+ with sess.graph.device("/gpu:0"):
+ v0_2 = tf.Variable(543.21)
+ save = tf.train.Saver({"v0": v0_2})
+ tf.initialize_all_variables().run()
+ self.assertAllClose(543.21, v0_2.eval())
+ save.restore(sess, save_path)
+ self.assertAllClose(123.45, v0_2.eval())
+
+ def testVariables(self):
+ save_path = os.path.join(self.get_temp_dir(), "variables")
+ with tf.Session("", graph=tf.Graph()) as sess:
+ one = tf.Variable(1.0)
+ twos = tf.Variable([2.0, 2.0, 2.0])
+ init = tf.initialize_all_variables()
+ save = tf.train.Saver(tf.all_variables())
+ init.run()
+ save.save(sess, save_path)
+
+ with tf.Session("", graph=tf.Graph()) as sess:
+ one = tf.Variable(0.0)
+ twos = tf.Variable([0.0, 0.0, 0.0])
+ # Saver with no arg, defaults to 'all variables'.
+ save = tf.train.Saver()
+ save.restore(sess, save_path)
+ self.assertAllClose(1.0, one.eval())
+ self.assertAllClose([2.0, 2.0, 2.0], twos.eval())
+
+ def testSaveWithGlobalStep(self):
+ save_path = os.path.join(self.get_temp_dir(), "ckpt_with_global_step")
+ global_step_int = 5
+ # Save and reload one Variable named "var0".
+ self._SaveAndLoad("var0", 0.0, 1.0, save_path)
+ for use_tensor in [True, False]:
+ with self.test_session() as sess:
+ var = tf.Variable(1.0, name="var0")
+ save = tf.train.Saver({var.op.name: var})
+ var.initializer.run()
+ if use_tensor:
+ global_step = tf.constant(global_step_int)
+ val = save.save(sess, save_path, global_step=global_step)
+ else:
+ val = save.save(sess, save_path, global_step=global_step_int)
+ expected_save_path = "%s-%d" % (save_path, global_step_int)
+ self.assertEqual(expected_save_path, val)
+
+
+class SaveRestoreShardedTest(tf.test.TestCase):
+
+ def testBasics(self):
+ save_path = os.path.join(self.get_temp_dir(), "sharded")
+
+ # Build a graph with 2 parameter nodes on different devices.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(10, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(20, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ val = save.save(sess, save_path)
+ self.assertEqual(save_path + "-?????-of-00002", val)
+
+ # Restore a different "v0" from shard 0 of the saved files.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ save = tf.train.Saver({"v0": v0}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(111, v0.eval())
+ save.restore(sess, save_path + "-00000-of-00002")
+ self.assertEqual(10, v0.eval())
+
+ # Restore a different "v1" from shard 1 of the saved files.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v1 = tf.Variable(222)
+ save = tf.train.Saver({"v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(222, v1.eval())
+ save.restore(sess, save_path + "-00001-of-00002")
+ self.assertEqual(20, v1.eval())
+
+ # Now try a restore with the sharded filename.
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(222, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True)
+ tf.initialize_all_variables().run()
+ self.assertEqual(111, v0.eval())
+ self.assertEqual(222, v1.eval())
+ save_path = os.path.join(self.get_temp_dir(), "sharded")
+ save.restore(sess, save_path + "-?????-of-?????")
+ self.assertEqual(10, v0.eval())
+ self.assertEqual(20, v1.eval())
+
+ def testSaverDef(self):
+ with self.test_session():
+ v0 = tf.Variable(123, name="v0")
+ save = tf.train.Saver({"v0": v0}, sharded=True)
+ sd = save.as_saver_def()
+ self.assertTrue(sd.sharded)
+
+
+class MaxToKeepTest(tf.test.TestCase):
+
+ def testNonSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with self.test_session() as sess:
+ v = tf.Variable(10.0, name="v")
+ save = tf.train.Saver({"v": v}, max_to_keep=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual([], save.last_checkpoints)
+
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+ self.assertTrue(gfile.Exists(s1))
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+ self.assertTrue(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s2))
+
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s3))
+
+ # Create a second helper, identical to the first.
+ save2 = tf.train.Saver(saver_def=save.as_saver_def())
+ save2.set_last_checkpoints(save.last_checkpoints)
+
+ # Create a third helper, with the same configuration but no knowledge of
+ # previous checkpoints.
+ save3 = tf.train.Saver(saver_def=save.as_saver_def())
+
+ # Exercise the first helper.
+
+ # Adding s2 again (old s2 is removed first, then new s2 appended)
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s3, s2], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s1))
+ self.assertTrue(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should now be deleted as oldest in list)
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ # Exercise the second helper.
+
+ # Adding s2 again (old s2 is removed first, then new s2 appended)
+ s2 = save2.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s3, s2], save2.last_checkpoints)
+ # Created by the first helper.
+ self.assertTrue(gfile.Exists(s1))
+ # Deleted by the first helper.
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should now be deleted as oldest in list)
+ s1 = save2.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save2.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ # Exercise the third helper.
+
+ # Adding s2 again (but helper is unaware of previous s2)
+ s2 = save3.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s2], save3.last_checkpoints)
+ # Created by the first helper.
+ self.assertTrue(gfile.Exists(s1))
+ # Deleted by the first helper.
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+
+ # Adding s1 (s3 should not be deleted because helper is unaware of it)
+ s1 = save3.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s2, s1], save3.last_checkpoints)
+ self.assertFalse(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s1))
+
+ def testSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with tf.Session(
+ target="",
+ config=tf.ConfigProto(device_count={"CPU": 2})) as sess:
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(111, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(222, name="v1")
+ save = tf.train.Saver({"v0": v0, "v1": v1}, sharded=True, max_to_keep=2)
+ tf.initialize_all_variables().run()
+ self.assertEqual([], save.last_checkpoints)
+
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+ self.assertEquals(2, len(gfile.Glob(s1)))
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+ self.assertEquals(2, len(gfile.Glob(s1)))
+ self.assertEquals(2, len(gfile.Glob(s2)))
+
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+ self.assertEquals(0, len(gfile.Glob(s1)))
+ self.assertEquals(2, len(gfile.Glob(s2)))
+ self.assertEquals(2, len(gfile.Glob(s3)))
+
+
+class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
+
+ def testNonSharded(self):
+ save_dir = os.path.join(self.get_temp_dir(),
+ "keep_checkpoint_every_n_hours")
+ try:
+ gfile.DeleteRecursively(save_dir)
+ except gfile.GOSError, _:
+ pass # Ignore
+ gfile.MakeDirs(save_dir)
+
+ with self.test_session() as sess:
+ v = tf.Variable([10.0], name="v")
+ # Run the initializer NOW to avoid the 0.5s overhead of the first Run()
+ # call, which throws the test timing off in fastbuild mode.
+ tf.initialize_all_variables().run()
+ # Create a saver that will keep the last 2 checkpoints plus one every 0.7
+ # seconds.
+ start_time = time.time()
+ save = tf.train.Saver({"v": v}, max_to_keep=2,
+ keep_checkpoint_every_n_hours=0.7 / 3600)
+ self.assertEqual([], save.last_checkpoints)
+
+ # Wait till 0.7 second have elapsed so s1 will be old enough to keep.
+ time.sleep((time.time() + 0.7) - start_time)
+ s1 = save.save(sess, os.path.join(save_dir, "s1"))
+ self.assertEqual([s1], save.last_checkpoints)
+
+ s2 = save.save(sess, os.path.join(save_dir, "s2"))
+ self.assertEqual([s1, s2], save.last_checkpoints)
+
+ # We now have 2 'last_checkpoints': [s1, s2]. The next call to Save(),
+ # would normally delete s1, because max_to_keep is 2. However, s1 is
+ # older than 0.7s so we must keep it.
+ s3 = save.save(sess, os.path.join(save_dir, "s3"))
+ self.assertEqual([s2, s3], save.last_checkpoints)
+
+ # s1 should still be here, we are Not checking now to reduce time
+ # variance in the test.
+
+ # We now have 2 'last_checkpoints': [s2, s3], and s1 on disk. The next
+ # call to Save(), will delete s2, because max_to_keep is 2, and because
+ # we already kept the old s1. s2 is very close in time to s1 so it gets
+ # deleted.
+ s4 = save.save(sess, os.path.join(save_dir, "s4"))
+ self.assertEqual([s3, s4], save.last_checkpoints)
+
+ # Check that s1 is still here, but s2 is gone.
+ self.assertTrue(gfile.Exists(s1))
+ self.assertFalse(gfile.Exists(s2))
+ self.assertTrue(gfile.Exists(s3))
+ self.assertTrue(gfile.Exists(s4))
+
+
+class SaveRestoreWithVariableNameMap(tf.test.TestCase):
+
+ def testNonReshape(self):
+ save_path = os.path.join(self.get_temp_dir(), "basics")
+
+ with self.test_session() as sess:
+ # Build a graph with 2 parameter nodes, and Save and
+ # Restore nodes for them.
+ v0 = tf.Variable(10.0, name="v0")
+ v1 = tf.Variable(20.0, name="v1")
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ tf.initialize_all_variables().run()
+
+ # Check that the parameter nodes have been initialized.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Save the initialized values in the file at "save_path"
+ # Use a variable name map to set the saved tensor names
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, basestring))
+ self.assertEqual(save_path, val)
+
+ # Verify that the original names are not in the Saved file
+ save = tf.train.Saver({"v0": v0, "v1": v1})
+ with self.assertRaisesOpError("not found in checkpoint"):
+ save.restore(sess, save_path)
+
+ # Verify that the mapped names are present in the Saved file and can be
+ # Restored using remapped names.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="v0")
+ v1 = tf.Variable(-1.0, name="v1")
+
+ with self.assertRaisesOpError("uninitialized value v0"):
+ sess.run(v0)
+ with self.assertRaisesOpError("uninitialized value v1"):
+ sess.run(v1)
+
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ save.restore(sess, save_path)
+
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+ # Add a prefix to the node names in the current graph and Restore using
+ # remapped names.
+ with self.test_session() as sess:
+ v0 = tf.Variable(-1.0, name="restore_prefix/v0")
+ v1 = tf.Variable(-1.0, name="restore_prefix/v1")
+
+ with self.assertRaisesOpError("uninitialized value restore_prefix/v0"):
+ sess.run(v0)
+ with self.assertRaisesOpError("uninitialized value restore_prefix/v1"):
+ sess.run(v1)
+
+ # Restore the saved values in the parameter nodes.
+ save = tf.train.Saver({"save_prefix/v0": v0, "save_prefix/v1": v1})
+ save.restore(sess, save_path)
+
+ # Check that the parameter nodes have been restored.
+ self.assertEqual(10.0, v0.eval())
+ self.assertEqual(20.0, v1.eval())
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
new file mode 100644
index 0000000000..dd994c5311
--- /dev/null
+++ b/tensorflow/python/training/summary_io.py
@@ -0,0 +1,226 @@
+"""Reads Summaries from and writes Summaries to event files."""
+
+import os.path
+import Queue
+import threading
+import time
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.platform import gfile
+
+
+class SummaryWriter(object):
+ """Writes `Summary` protocol buffers to event files.
+
+ The `SummaryWriter` class provides a mechanism to create an event file in a
+ given directory and add summaries and events to it. The class updates the
+ file contents asynchronously. This allows a training program to call methods
+ to add data to the file directly from the training loop, without slowing down
+ training.
+
+ @@__init__
+
+ @@add_summary
+ @@add_event
+ @@add_graph
+
+ @@flush
+ @@close
+ """
+
+ def __init__(self, logdir, graph_def=None, max_queue=10, flush_secs=120):
+ """Creates a `SummaryWriter` and an event file.
+
+ On construction the summary writer creates a new event file in `logdir`.
+ This event file will contain `Event` protocol buffers constructed when you
+ call one of the following functions: `add_summary()`, `add_event()`, or
+ `add_graph()`.
+
+ If you pass a `graph_def` protocol buffer to the constructor it is added to
+ the event file. (This is equivalent to calling `add_graph()` later).
+
+ TensorBoard will pick the graph from the file and display it graphically so
+ you can interactively explore the graph you built. You will usually pass
+ the graph from the session in which you launched it:
+
+ ```python
+ ...create a graph...
+ # Launch the graph in a session.
+ sess = tf.Session()
+ # Create a summary writer, add the 'graph_def' to the event file.
+ writer = tf.train.SummaryWriter(<some-directory>, sess.graph_def)
+ ```
+
+ The other arguments to the constructor control the asynchronous writes to
+ the event file:
+
+ * `flush_secs`: How often, in seconds, to flush the added summaries
+ and events to disk.
+ * `max_queue`: Maximum number of summaries or events pending to be
+ written to disk before one of the 'add' calls block.
+
+ Args:
+ logdir: A string. Directory where event file will be written.
+ graph_def: A `GraphDef` protocol buffer.
+ max_queue: Integer. Size of the queue for pending events and summaries.
+ flush_secs: Number. How often, in seconds, to flush the
+ pending events and summaries to disk.
+ """
+ self._logdir = logdir
+ if not gfile.IsDirectory(self._logdir):
+ gfile.MakeDirs(self._logdir)
+ self._event_queue = Queue.Queue(max_queue)
+ self._ev_writer = pywrap_tensorflow.EventsWriter(
+ os.path.join(self._logdir, "events"))
+ self._worker = _EventLoggerThread(self._event_queue, self._ev_writer,
+ flush_secs)
+ self._worker.start()
+ if graph_def is not None:
+ self.add_graph(graph_def)
+
+ def add_summary(self, summary, global_step=None):
+ """Adds a `Summary` protocol buffer to the event file.
+
+ This method wraps the provided summary in an `Event` procotol buffer
+ and adds it to the event file.
+
+ You can pass the output of any summary op, as-is, to this function. You
+ can also pass a `Summary` procotol buffer that you manufacture with your
+ own data. This is commonly done to report evaluation results in event
+ files.
+
+ Args:
+ summary: A `Summary` protocol buffer, optionally serialized as a string.
+ global_step: Number. Optional global step value to record with the
+ summary.
+ """
+ if isinstance(summary, basestring):
+ summ = summary_pb2.Summary()
+ summ.ParseFromString(summary)
+ summary = summ
+ event = event_pb2.Event(wall_time=time.time(), summary=summary)
+ if global_step is not None:
+ event.step = long(global_step)
+ self.add_event(event)
+
+ def add_event(self, event):
+ """Adds an event to the event file.
+
+ Args:
+ event: An `Event` protocol buffer.
+ """
+ self._event_queue.put(event)
+
+ def add_graph(self, graph_def, global_step=None):
+ """Adds a `GraphDef` protocol buffer to the event file.
+
+ The graph described by the protocol buffer will be displayed by
+ TensorBoard. Most users pass a graph in the constructor instead.
+
+ Args:
+ graph_def: A `GraphDef` protocol buffer.
+ global_step: Number. Optional global step counter to record with the
+ graph.
+ """
+ event = event_pb2.Event(wall_time=time.time(), graph_def=graph_def)
+ if global_step is not None:
+ event.step = long(global_step)
+ self._event_queue.put(event)
+
+ def flush(self):
+ """Flushes the event file to disk.
+
+ Call this method to make sure that all pending events have been written to
+ disk.
+ """
+ self._event_queue.join()
+ self._ev_writer.Flush()
+
+ def close(self):
+ """Flushes the event file to disk and close the file.
+
+ Call this method when you do not need the summary writer anymore.
+ """
+ self.flush()
+ self._ev_writer.Close()
+
+
+class _EventLoggerThread(threading.Thread):
+ """Thread that logs events."""
+
+ def __init__(self, queue, ev_writer, flush_secs):
+ """Creates an _EventLoggerThread.
+
+ Args:
+ queue: a Queue from which to dequeue events.
+ ev_writer: an event writer. Used to log brain events for
+ the visualizer.
+ flush_secs: How often, in seconds, to flush the
+ pending file to disk.
+ """
+ threading.Thread.__init__(self)
+ self.daemon = True
+ self._queue = queue
+ self._ev_writer = ev_writer
+ self._flush_secs = flush_secs
+ # The first event will be flushed immediately.
+ self._next_event_flush_time = 0
+
+ def run(self):
+ while True:
+ event = self._queue.get()
+ try:
+ self._ev_writer.WriteEvent(event)
+ # Flush the event writer every so often.
+ now = time.time()
+ if now > self._next_event_flush_time:
+ self._ev_writer.Flush()
+ # Do it again in two minutes.
+ self._next_event_flush_time = now + self._flush_secs
+ finally:
+ self._queue.task_done()
+
+
+def summary_iterator(path):
+ """An iterator for reading `Event` protocol buffers from an event file.
+
+ You can use this function to read events written to an event file. It returns
+ a Python iterator that yields `Event` protocol buffers.
+
+ Example: Print the contents of an events file.
+
+ ```python
+ for e in tf.summary_iterator(path to events file):
+ print e
+ ```
+
+ Example: Print selected summary values.
+
+ ```python
+ # This example supposes that the events file contains summaries with a
+ # summary value tag 'loss'. These could have been added by calling
+ # `add_summary()`, passing the output of a scalar summary op created with
+ # with: `tf.scalar_summary(['loss'], loss_tensor)`.
+ for e in tf.summary_iterator(path to events file):
+ for v in e.summary.value:
+ if v.tag == 'loss':
+ print v.simple_value
+ ```
+
+ See the protocol buffer definitions of
+ [Event](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/util/event.proto)
+ and
+ [Summary](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+ for more information about their attributes.
+
+ Args:
+ path: The path to an event file created by a `SummaryWriter`.
+
+ Yields:
+ `Event` protocol buffers.
+ """
+ for r in tf_record.tf_record_iterator(path):
+ yield event_pb2.Event.FromString(r)
diff --git a/tensorflow/python/training/summary_writer_test.py b/tensorflow/python/training/summary_writer_test.py
new file mode 100644
index 0000000000..2ec416f68f
--- /dev/null
+++ b/tensorflow/python/training/summary_writer_test.py
@@ -0,0 +1,151 @@
+"""Tests for training_coordinator.py."""
+import glob
+import os.path
+import shutil
+import time
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class SummaryWriterTestCase(tf.test.TestCase):
+
+ def _TestDir(self, test_name):
+ test_dir = os.path.join(self.get_temp_dir(), test_name)
+ return test_dir
+
+ def _CleanTestDir(self, test_name):
+ test_dir = self._TestDir(test_name)
+ if os.path.exists(test_dir):
+ shutil.rmtree(test_dir)
+ return test_dir
+
+ def _EventsReader(self, test_dir):
+ event_paths = glob.glob(os.path.join(test_dir, "event*"))
+ # If the tests runs multiple time in the same directory we can have
+ # more than one matching event file. We only want to read the last one.
+ self.assertTrue(event_paths)
+ return tf.train.summary_iterator(event_paths[-1])
+
+ def _assertRecent(self, t):
+ self.assertTrue(abs(t - time.time()) < 5)
+
+ def testBasics(self):
+ test_dir = self._CleanTestDir("basics")
+ sw = tf.train.SummaryWriter(test_dir)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="mee",
+ simple_value=10.0)]),
+ 10)
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="boo",
+ simple_value=20.0)]),
+ 20)
+ with tf.Graph().as_default() as g:
+ tf.constant([0], name="zero")
+ gd = g.as_graph_def()
+ sw.add_graph(gd, global_step=30)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the value 'mee=10.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(10, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'mee' simple_value: 10.0 }
+ """, ev.summary)
+
+ # The next event should have the value 'boo=20.0'.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(20, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'boo' simple_value: 20.0 }
+ """, ev.summary)
+
+ # The next event should have the graph_def.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(30, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ def testConstructWithGraph(self):
+ test_dir = self._CleanTestDir("basics_with_graph")
+ with tf.Graph().as_default() as g:
+ tf.constant([12], name="douze")
+ gd = g.as_graph_def()
+ sw = tf.train.SummaryWriter(test_dir, graph_def=gd)
+ sw.close()
+ rr = self._EventsReader(test_dir)
+
+ # The first event should list the file_version.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # The next event should have the graph.
+ ev = next(rr)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(0, ev.step)
+ self.assertProtoEquals(gd, ev.graph_def)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+ # Checks that values returned from session Run() calls are added correctly to
+ # summaries. These are numpy types so we need to check they fit in the
+ # protocol buffers correctly.
+ def testSummariesAndStopFromSessionRunCalls(self):
+ test_dir = self._CleanTestDir("global_step")
+ sw = tf.train.SummaryWriter(test_dir)
+ with self.test_session():
+ i = tf.constant(1, dtype=tf.int32, shape=[])
+ l = tf.constant(2, dtype=tf.int64, shape=[])
+ # Test the summary can be passed serialized.
+ summ = tf.Summary(value=[tf.Summary.Value(tag="i", simple_value=1.0)])
+ sw.add_summary(summ.SerializeToString(), i.eval())
+ sw.add_summary(tf.Summary(value=[tf.Summary.Value(tag="l",
+ simple_value=2.0)]),
+ l.eval())
+ sw.close()
+
+ rr = self._EventsReader(test_dir)
+
+ # File_version.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals("brain.Event:1", ev.file_version)
+
+ # Summary passed serialized.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(1, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'i' simple_value: 1.0 }
+ """, ev.summary)
+
+ # Summary passed as SummaryObject.
+ ev = next(rr)
+ self.assertTrue(ev)
+ self._assertRecent(ev.wall_time)
+ self.assertEquals(2, ev.step)
+ self.assertProtoEquals("""
+ value { tag: 'l' simple_value: 2.0 }
+ """, ev.summary)
+
+ # We should be done.
+ self.assertRaises(StopIteration, lambda: next(rr))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
new file mode 100644
index 0000000000..a400e9fa7d
--- /dev/null
+++ b/tensorflow/python/training/training.py
@@ -0,0 +1,138 @@
+# pylint: disable=wildcard-import,unused-import,g-bad-import-order,line-too-long
+"""This library provides a set of classes and functions that helps train models.
+
+## Optimizers.
+
+The Optimizer base class provides methods to compute gradients for a loss and
+apply gradients to variables. A collection of subclasses implement classic
+optimization algorithms such as GradientDescent and Adagrad.
+
+You never instantiate the Optimizer class itself, but instead instantiate one
+of the subclasses.
+
+@@Optimizer
+
+@@GradientDescentOptimizer
+@@AdagradOptimizer
+@@MomentumOptimizer
+@@AdamOptimizer
+@@FtrlOptimizer
+@@RMSPropOptimizer
+
+## Gradient Computation.
+
+TensorFlow provides functions to compute the derivatives for a given
+TensorFlow computation graph, adding operations to the graph. The
+optimizer classes automatically compute derivatives on your graph, but
+creators of new Optimizers or expert users can call the lower-level
+functions below.
+
+@@gradients
+@@AggregationMethod
+
+@@stop_gradient
+
+
+## Gradient Clipping
+
+TensorFlow provides several operations that you can use to add clipping
+functions to your graph. You can use these functions to perform general data
+clipping, but they're particularly useful for handling exploding or vanishing
+gradients.
+
+@@clip_by_value
+@@clip_by_norm
+@@clip_by_average_norm
+@@clip_by_global_norm
+@@global_norm
+
+## Decaying the learning rate.
+@@exponential_decay
+
+## Moving Averages.
+
+Some training algorithms, such as GradientDescent and Momentum often benefit
+from maintaining a moving average of variables during optimization. Using the
+moving averages for evaluations often improve results significantly.
+
+@@ExponentialMovingAverage
+
+## Coordinator and QueueRunner.
+
+See [Threading and Queues](../../how_tos/threading_and_queues/index.md)
+for how to use threads and queues. For documentation on the Queue API,
+see [Queues](../../api_docs/python/io_ops.md#queues).
+
+@@Coordinator
+@@QueueRunner
+@@add_queue_runner
+@@start_queue_runners
+
+## Summary Operations.
+
+The following ops output
+[`Summary`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/summary.proto)
+protocol buffers as serialized string tensors.
+
+You can fetch the output of a summary op in a session, and pass it to a
+[SummaryWriter](train.md#SummaryWriter) to append it to an event file. You can
+then use TensorBoard to visualize the contents of the event files. See
+[TensorBoard and Summaries](../../how_tos/summaries_and_tensorboard/index.md)
+for more details.
+
+@@scalar_summary
+@@image_summary
+@@histogram_summary
+@@zero_fraction
+
+@@merge_summary
+@@merge_all_summaries
+
+## Adding Summaries to Event Files.
+
+See [Summaries and
+TensorBoard](../../how_tos/summaries_and_tensorboard/index.md) for an
+overview of summaries, event files, and visualization in TensorBoard.
+
+@@SummaryWriter
+@@summary_iterator
+
+## Training utilities.
+
+@@global_step
+@@write_graph
+"""
+
+# Optimizers.
+from tensorflow.python.training.adagrad import AdagradOptimizer
+from tensorflow.python.training.adam import AdamOptimizer
+from tensorflow.python.training.ftrl import FtrlOptimizer
+from tensorflow.python.training.momentum import MomentumOptimizer
+from tensorflow.python.training.moving_averages import ExponentialMovingAverage
+from tensorflow.python.training.optimizer import Optimizer
+from tensorflow.python.training.rmsprop import RMSPropOptimizer
+from tensorflow.python.training.gradient_descent import GradientDescentOptimizer
+
+# Utility classes for training.
+from tensorflow.python.training.coordinator import Coordinator
+from tensorflow.python.training.queue_runner import *
+
+# For the module level doc.
+from tensorflow.python.training import input as _input
+from tensorflow.python.training.input import *
+
+from tensorflow.python.training.saver import get_checkpoint_state
+from tensorflow.python.training.saver import latest_checkpoint
+from tensorflow.python.training.saver import Saver
+from tensorflow.python.training.saver import update_checkpoint_state
+from tensorflow.python.training.summary_io import summary_iterator
+from tensorflow.python.training.summary_io import SummaryWriter
+from tensorflow.python.training.training_util import write_graph
+from tensorflow.python.training.training_util import global_step
+
+# Training data protos.
+from tensorflow.core.example.example_pb2 import *
+from tensorflow.core.example.feature_pb2 import *
+
+# Utility op. Open Source. TODO(mdevin): move to nn?
+from tensorflow.python.training.learning_rate_decay import exponential_decay
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
new file mode 100644
index 0000000000..410b23e04d
--- /dev/null
+++ b/tensorflow/python/training/training_ops.py
@@ -0,0 +1,115 @@
+"""Python wrappers for training ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.training import gen_training_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.training.gen_training_ops import *
+# pylint: enable=wildcard-import
+
+
+# Shape functions for fused training ops
+# --------------------------------------
+#
+# The fused training ops all have the same basic structure: they take
+# one or more variables with the same shape, and emit a reference to
+# the original variable (which has the same shape as the first
+# input). In addition, they take one or more scalar tensors containing
+# hyperparameters.
+#
+# The sparse ops take the gradients as a Python IndexedSlices, which
+# means that the indices are a vector of length N, and the gradient
+# values are a tensor whose size is the same as the original variable,
+# except for the 0th dimension, which has size N.
+
+
+def _AssertInputIsScalar(op, index):
+ """Raises ValueError if `op.inputs[index]` is not scalar."""
+ op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+
+@ops.RegisterShape("ApplyAdagrad")
+def _ApplyAdagradShape(op):
+ """Shape function for the ApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyAdam")
+def _ApplyAdamShape(op):
+ """Shape function for the ApplyAdam op."""
+ var_shape = op.inputs[0].get_shape()
+ m_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ v_shape = op.inputs[2].get_shape().merge_with(m_shape)
+ _AssertInputIsScalar(op, 3) # beta1_power
+ _AssertInputIsScalar(op, 4) # beta2_power
+ _AssertInputIsScalar(op, 5) # lr
+ _AssertInputIsScalar(op, 6) # beta1
+ _AssertInputIsScalar(op, 7) # beta2
+ _AssertInputIsScalar(op, 8) # epsilon
+ grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyMomentum")
+def _ApplyMomentumShape(op):
+ """Shape function for the ApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ _AssertInputIsScalar(op, 4) # momentum
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyRMSProp")
+def _ApplyRMSPropShape(op):
+ """Shape function for the ApplyRMSProp op."""
+ var_shape = op.inputs[0].get_shape()
+ ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
+ _AssertInputIsScalar(op, 3) # lr
+ _AssertInputIsScalar(op, 4) # rho
+ _AssertInputIsScalar(op, 5) # momentum
+ _AssertInputIsScalar(op, 6) # epsilon
+ grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyGradientDescent")
+def _ApplyGradientDescentShape(op):
+ """Shape function for the ApplyGradientDescent op."""
+ var_shape = op.inputs[0].get_shape()
+ _AssertInputIsScalar(op, 1) # alpha
+ delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
+ return [delta_shape]
+
+
+@ops.RegisterShape("SparseApplyAdagrad")
+def _SparseApplyAdagradShape(op):
+ """Shape function for the SparseApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ return [accum_shape]
+
+
+@ops.RegisterShape("SparseApplyMomentum")
+def _SparseApplyMomentumShape(op):
+ """Shape function for the SparseApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ _AssertInputIsScalar(op, 5) # momentum
+ return [accum_shape]
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
new file mode 100644
index 0000000000..902b9b0d78
--- /dev/null
+++ b/tensorflow/python/training/training_ops_test.py
@@ -0,0 +1,159 @@
+"""Tests for tensorflow.learning.training_ops."""
+
+import itertools
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.python.framework import types
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import training_ops
+
+
+class TrainingOpsTest(TensorFlowTestCase):
+
+ def _toType(self, dtype):
+ if dtype == np.float32:
+ return types.float32
+ elif dtype == np.float64:
+ return types.float64
+ elif dtype == np.int32:
+ return types.int32
+ elif dtype == np.int64:
+ return types.int64
+ else:
+ assert False, (dtype)
+
+ def _testTypes(self, x, alpha, delta, use_gpu=None):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var = variables.Variable(x)
+ variables.initialize_all_variables().run()
+ self.assertAllEqual(x, var.eval())
+ apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
+ out = apply_sgd.eval()
+ self.assertShapeEqual(out, apply_sgd)
+ self.assertAllEqual(x - alpha * delta, out)
+
+ def testApplyGradientDescent(self):
+ for (dtype, use_gpu) in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ x = np.arange(100).astype(dtype)
+ alpha = np.array(2.0).astype(dtype)
+ delta = np.arange(100).astype(dtype)
+ self._testTypes(x, alpha, delta, use_gpu)
+
+ def _testTypesForAdagrad(self, x, y, lr, grad, use_gpu=None):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var = variables.Variable(x)
+ accum = variables.Variable(y)
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(x, var.eval())
+ apply_adagrad = training_ops.apply_adagrad(var, accum, lr, grad)
+ out = apply_adagrad.eval()
+ self.assertShapeEqual(out, apply_adagrad)
+ self.assertAllClose(
+ x - lr * grad * (y + grad * grad) ** (-0.5), out)
+ self.assertAllEqual(y + grad * grad, accum.eval())
+
+ def testApplyAdagrad(self):
+ for (dtype, use_gpu) in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ x = np.arange(100).astype(dtype)
+ y = np.arange(1, 101).astype(dtype)
+ lr = np.array(2.0).astype(dtype)
+ grad = np.arange(100).astype(dtype)
+ self._testTypesForAdagrad(x, y, lr, grad, use_gpu)
+
+ def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
+ self.setUp()
+ with self.test_session(use_gpu=False):
+ var = variables.Variable(x)
+ accum = variables.Variable(y)
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(x, var.eval())
+ sparse_apply_adagrad = training_ops.sparse_apply_adagrad(
+ var, accum, lr, grad,
+ constant_op.constant(indices, self._toType(indices.dtype)))
+ out = sparse_apply_adagrad.eval()
+ self.assertShapeEqual(out, sparse_apply_adagrad)
+
+ for (i, index) in enumerate(indices):
+ self.assertAllClose(
+ x[index] - lr * grad[i] * (y[index] + grad[i] * grad[i]) ** (-0.5),
+ var.eval()[index])
+ self.assertAllEqual(y[index] + grad[i] * grad[i], accum.eval()[index])
+
+ def testSparseApplyAdagrad(self):
+ for (dtype, index_type) in itertools.product(
+ [np.float32, np.float64], [np.int32, np.int64]):
+ x_val = [range(10), range(10, 20), range(20, 30)]
+ y_val = [range(1, 11), range(11, 21), range(21, 31)]
+ x = np.array(x_val).astype(dtype)
+ y = np.array(y_val).astype(dtype)
+ lr = np.array(2.0).astype(dtype)
+ grad_val = [range(10), range(10)]
+ grad = np.array(grad_val).astype(dtype)
+ indices = np.array([0, 2]).astype(index_type)
+ self._testTypesForSparseAdagrad(x, y, lr, grad, indices)
+
+ def testApplyAdam(self):
+ for dtype, use_gpu in itertools.product(
+ [np.float32, np.float64], [False, True]):
+ var = np.arange(100).astype(dtype)
+ m = np.arange(1, 101).astype(dtype)
+ v = np.arange(101, 201).astype(dtype)
+ grad = np.arange(100).astype(dtype)
+ self._testTypesForAdam(var, m, v, grad, use_gpu)
+
+ def _testTypesForAdam(self, var, m, v, grad, use_gpu):
+ self.setUp()
+ with self.test_session(use_gpu=use_gpu):
+ var_t = variables.Variable(var)
+ m_t = variables.Variable(m)
+ v_t = variables.Variable(v)
+
+ t = 1
+ beta1 = np.array(0.9, dtype=var.dtype)
+ beta2 = np.array(0.999, dtype=var.dtype)
+ beta1_power = beta1**t
+ beta2_power = beta2**t
+ lr = np.array(0.001, dtype=var.dtype)
+ epsilon = np.array(1e-8, dtype=var.dtype)
+ beta1_t = constant_op.constant(beta1, self._toType(var.dtype), [])
+ beta2_t = constant_op.constant(beta2, self._toType(var.dtype), [])
+ beta1_power_t = variables.Variable(beta1_power)
+ beta2_power_t = variables.Variable(beta2_power)
+ lr_t = constant_op.constant(lr, self._toType(var.dtype), [])
+ epsilon_t = constant_op.constant(epsilon, self._toType(var.dtype), [])
+ variables.initialize_all_variables().run()
+
+ self.assertAllEqual(var, var_t.eval())
+ new_var, _, _ = self._adamUpdateNumpy(var, grad, t, m, v,
+ lr, beta1, beta2, epsilon)
+ apply_adam = training_ops.apply_adam(var_t, m_t, v_t, beta1_power_t,
+ beta2_power_t, lr_t,
+ beta1_t, beta2_t, epsilon_t, grad)
+ out = apply_adam.eval()
+ self.assertShapeEqual(out, apply_adam)
+ self.assertAllClose(new_var, out)
+
+ def _adamUpdateNumpy(self, param, g_t, t, m, v, alpha, beta1,
+ beta2, epsilon):
+ alpha_t = alpha * np.sqrt(1 - beta2 ** t) / (1 - beta1 ** t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
new file mode 100644
index 0000000000..14166e25c6
--- /dev/null
+++ b/tensorflow/python/training/training_util.py
@@ -0,0 +1,57 @@
+"""Utility functions for training."""
+import os.path
+
+from tensorflow.python.platform import gfile
+
+
+def global_step(sess, global_step_tensor):
+ """Small helper to get the global step.
+
+ ```python
+ # Creates a variable to hold the global_step.
+ global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
+ # Creates a session.
+ sess = tf.Session()
+ # Initializes the variable.
+ sess.run(global_step_tensor.initializer)
+ print 'global_step:', tf.train.global_step(sess, global_step_tensor)
+
+ global_step: 10
+ ```
+
+ Args:
+ sess: A brain `Session` object.
+ global_step_tensor: `Tensor` or the `name` of the operation that contains
+ the global step.
+
+ Returns:
+ The global step value.
+ """
+ return int(sess.run(global_step_tensor))
+
+
+def write_graph(graph_def, logdir, name, as_text=True):
+ """Writes a graph proto on disk.
+
+ The graph is written as a binary proto unless as_text is `True`.
+
+ ```python
+ v = tf.Variable(0, name='my_variable')
+ sess = tf.Session()
+ tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
+ ```
+
+ Args:
+ graph_def: A `GraphDef` protocol buffer.
+ logdir: Directory where to write the graph.
+ name: Filename for the graph.
+ as_text: If `True`, writes the graph as an ASCII proto.
+ """
+ path = os.path.join(logdir, name)
+ gfile.MakeDirs(os.path.dirname(path))
+ f = gfile.FastGFile(path, "w")
+ if as_text:
+ f.write(str(graph_def))
+ else:
+ f.write(graph_def.SerializeToString())
+ f.close()