aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-05 12:44:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 12:49:14 -0700
commitef838969b95de39353a3ba495c335cbb14a0c9b5 (patch)
tree800857a506c3d3695a7b3da2fd269a9fec85d93b /tensorflow/contrib
parent6919ab5787e6384d709adf051dc1ce99236b76bc (diff)
Brings V2 Optimizers into Keras w/ Keras signatures
PiperOrigin-RevId: 215950207
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py16
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py5
-rw-r--r--tensorflow/contrib/optimizer_v2/BUILD11
-rw-r--r--tensorflow/contrib/optimizer_v2/adadelta.py75
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py79
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad_test.py3
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py129
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py68
-rw-r--r--tensorflow/contrib/optimizer_v2/gradient_descent.py40
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum.py69
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py1205
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py154
12 files changed, 120 insertions, 1734 deletions
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index cff4b0a463..63a163e76c 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -349,26 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
required_gpus=2)
-adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+adam_optimizer_v1_fn = NamedObject("AdamV1",
+ lambda: adam.AdamOptimizer(0.001, epsilon=1))
rmsprop_optimizer_v1_fn = NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
- adagrad_optimizer_v1_fn]
-adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
+
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
adagrad_optimizer_v2_fn = NamedObject(
"AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
- adagrad_optimizer_v2_fn]
+adam_optimizer_v2_fn = NamedObject(
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+
+optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..60e134055f 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -179,11 +179,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def get_expected_variables(optimizer_fn, num_parameter_devices):
variables_map = {
"GradientDescent": ["dense/kernel", "dense/bias"],
- "Adam": [
- "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
- "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
- "dense/bias/Adam_1"
- ],
"Adagrad": [
"dense/kernel/Adagrad", "dense/kernel",
"dense/bias/Adagrad", "dense/bias"
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD
index 3ba3ee29ec..2cf445a85e 100644
--- a/tensorflow/contrib/optimizer_v2/BUILD
+++ b/tensorflow/contrib/optimizer_v2/BUILD
@@ -47,15 +47,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:optimizer_v2",
],
)
diff --git a/tensorflow/contrib/optimizer_v2/adadelta.py b/tensorflow/contrib/optimizer_v2/adadelta.py
index b206f9f61b..9d73bddd1c 100644
--- a/tensorflow/contrib/optimizer_v2/adadelta.py
+++ b/tensorflow/contrib/optimizer_v2/adadelta.py
@@ -18,17 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.util import deprecation
-class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
+class AdadeltaOptimizer(adadelta.Adadelta):
"""Optimizer that implements the Adadelta algorithm.
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,
use_locking=False, name="Adadelta"):
"""Construct a new Adadelta optimizer.
@@ -48,66 +52,5 @@ class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
"""
- super(AdadeltaOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("rho", rho)
- self._set_hyper("epsilon", epsilon)
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "accum")
- state.zeros_slot(v, "accum_update")
-
- def _apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.sparse_apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_sparse_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdadeltaOptimizer, self).__init__(
+ learning_rate=learning_rate, rho=rho, epsilon=epsilon, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index dab1e02716..716361e29c 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.util import deprecation
-class AdagradOptimizer(optimizer_v2.OptimizerV2):
+class AdagradOptimizer(adagrad.Adagrad):
"""Optimizer that implements the Adagrad algorithm.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
@@ -34,6 +30,10 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, initial_accumulator_value=0.1,
use_locking=False, name="Adagrad"):
"""Construct a new Adagrad optimizer.
@@ -54,64 +54,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
"""
- if initial_accumulator_value <= 0.0:
- raise ValueError("initial_accumulator_value must be positive: %s" %
- initial_accumulator_value)
- super(AdagradOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- self._initial_accumulator_value = initial_accumulator_value
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- def init(v=v, dtype=dtype):
- # Use a Tensor instead of initializer if variable does not have
- # static shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- return math_ops.cast(init_constant, dtype)
- state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
- "accumulator")
-
- def _apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.sparse_apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_sparse_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdagradOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ initial_accumulator_value=initial_accumulator_value,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py
index debaaaeeba..320e41567f 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad_test.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py
@@ -68,9 +68,6 @@ class AdagradOptimizerTest(test.TestCase):
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
- def testBasicLocked(self):
- self.doTestBasic(use_locking=True)
-
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 04b1552b61..363e020757 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -18,22 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.util import deprecation
-class AdamOptimizer(optimizer_v2.OptimizerV2):
+class AdamOptimizer(adam.Adam):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
@@ -87,111 +86,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
"""
- super(AdamOptimizer, self).__init__(use_locking, name)
-
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("beta1", beta1)
- self._set_hyper("beta2", beta2)
- self._set_hyper("epsilon", epsilon)
-
- def _get_beta_accumulators(self, state=None):
- if state is None:
- state = self._get_per_graph_state()
- return (state.get_non_slot("beta1_power"),
- state.get_non_slot("beta2_power"))
-
- def _create_vars(self, var_list, state):
- # Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
- name="beta1_power")
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
- name="beta2_power")
-
- # Create slots for the first and second moments.
- for v in var_list:
- state.zeros_slot(v, "m")
- state.zeros_slot(v, "v")
-
- def _apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.apply_adam(
- var, m, v,
- math_ops.cast(beta1_power, var.dtype.base_dtype),
- math_ops.cast(beta2_power, var.dtype.base_dtype),
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("beta1", var.dtype.base_dtype),
- state.get_hyper("beta2", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad, use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.resource_apply_adam(
- var.handle, m.handle, v.handle,
- math_ops.cast(beta1_power, grad.dtype.base_dtype),
- math_ops.cast(beta2_power, grad.dtype.base_dtype),
- state.get_hyper("learning_rate", grad.dtype.base_dtype),
- state.get_hyper("beta1", grad.dtype.base_dtype),
- state.get_hyper("beta2", grad.dtype.base_dtype),
- state.get_hyper("epsilon", grad.dtype.base_dtype),
- grad, use_locking=self._use_locking)
-
- def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
- lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
- beta1_t = state.get_hyper("beta1", var.dtype.base_dtype)
- beta2_t = state.get_hyper("beta2", var.dtype.base_dtype)
- epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
- lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
- # m_t = beta1 * m + (1 - beta1) * g_t
- m = state.get_slot(var, "m")
- m_scaled_g_values = grad * (1 - beta1_t)
- m_t = state_ops.assign(m, m * beta1_t,
- use_locking=self._use_locking)
- with ops.control_dependencies([m_t]):
- m_t = scatter_add(m, indices, m_scaled_g_values)
- # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
- v = state.get_slot(var, "v")
- v_scaled_g_values = (grad * grad) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- with ops.control_dependencies([v_t]):
- v_t = scatter_add(v, indices, v_scaled_g_values)
- v_sqrt = math_ops.sqrt(v_t)
- var_update = state_ops.assign_sub(var,
- lr * m_t / (v_sqrt + epsilon_t),
- use_locking=self._use_locking)
- return control_flow_ops.group(*[var_update, m_t, v_t])
-
- def _apply_sparse(self, grad, var, state):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
- x, i, v, use_locking=self._use_locking),
- state)
-
- def _resource_scatter_add(self, x, i, v):
- with ops.control_dependencies(
- [resource_variable_ops.resource_scatter_add(
- x.handle, i, v)]):
- return x.value()
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- return self._apply_sparse_shared(
- grad, var, indices, self._resource_scatter_add, state)
-
- def _finish(self, state):
- # Update the power accumulators.
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- update_beta1 = beta1_power.assign(
- beta1_power * state.get_hyper("beta1"),
- use_locking=self._use_locking)
- update_beta2 = beta2_power.assign(
- beta2_power * state.get_hyper("beta2"),
- use_locking=self._use_locking)
- return control_flow_ops.group(update_beta1, update_beta2)
+ super(AdamOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta1,
+ beta_2=beta2,
+ epsilon=epsilon,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index e13b82d1d2..3c68ef995a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -130,8 +130,8 @@ class CheckpointingTests(test.TestCase):
# non-Layer dependency of the model
"model/_non_layer/a_variable",
# The optimizer creates two non-slot variables
- "optimizer/beta1_power",
- "optimizer/beta2_power",
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
# Slot variables
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
@@ -161,21 +161,20 @@ class CheckpointingTests(test.TestCase):
"my_model/dense/kernel",
named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power",
- named_variables["optimizer/beta1_power" + suffix].full_name)
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power",
- named_variables["optimizer/beta2_power" + suffix].full_name)
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
1].node_id]
- self.assertEqual("beta1_power",
- optimizer_node.children[0].local_name)
- self.assertEqual("beta1_power",
- serialized_graph.nodes[optimizer_node.children[0].node_id]
- .attributes[0].full_name)
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
self.assertEqual(
"my_model/dense/kernel",
serialized_graph.nodes[optimizer_node.slot_variables[0]
@@ -241,9 +240,10 @@ class CheckpointingTests(test.TestCase):
on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(
0.001,
- # Preserve beta1_power and beta2_power when appying gradients so we can
- # test that they've been restored correctly.
- beta1=1.0, beta2=1.0)
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta1=1.0,
+ beta2=1.0)
on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
@@ -263,9 +263,9 @@ class CheckpointingTests(test.TestCase):
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
status.assert_consumed()
- beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
- self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
- self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
# TODO(allenl): Debug garbage created by this test in python3.
def testDeferredRestorationUsageEager(self):
@@ -477,7 +477,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
- with self.assertRaisesRegexp(AssertionError, "beta1_power"):
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.executing_eagerly():
@@ -556,8 +556,8 @@ class CheckpointingTests(test.TestCase):
self.evaluate(first_variable.assign([1.]))
self.evaluate(optimizer.get_slot(
var=first_variable, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
# Save and load in a second graph
second_graph = ops.Graph()
@@ -571,29 +571,29 @@ class CheckpointingTests(test.TestCase):
self.evaluate(second_variable.assign([4.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([5.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(6.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
save_path = second_root_checkpointable.save(checkpoint_prefix)
self.evaluate(second_variable.assign([7.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([8.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
status = second_root_checkpointable.restore(save_path)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([4.], self.evaluate(second_variable))
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
var=second_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
# Check that the first graph is unmolested
with first_graph.as_default(), first_session.as_default():
self.assertAllEqual([1.], self.evaluate(first_variable))
self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
var=first_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
class TemplateTests(test.TestCase):
@@ -659,8 +659,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.evaluate(model._named_dense.bias.assign([1.]))
self.evaluate(optimizer.get_slot(
var=model._named_dense.bias, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
return root_checkpointable
def _set_sentinels(self, root_checkpointable):
@@ -669,8 +669,8 @@ class CheckpointCompatibilityTests(test.TestCase):
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
.assign([102.]))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(103.))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
self.assertAllEqual(
@@ -678,8 +678,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
def _write_name_based_checkpoint(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent.py b/tensorflow/contrib/optimizer_v2/gradient_descent.py
index 945c8de559..8bdf408217 100644
--- a/tensorflow/contrib/optimizer_v2/gradient_descent.py
+++ b/tensorflow/contrib/optimizer_v2/gradient_descent.py
@@ -18,15 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
+class GradientDescentOptimizer(sgd.SGD):
"""Optimizer that implements the gradient descent algorithm."""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
"""Construct a new gradient descent optimizer.
@@ -41,29 +43,5 @@ class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
"""
- super(GradientDescentOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- def _apply_dense(self, grad, var, state):
- return training_ops.apply_gradient_descent(
- var,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, handle, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return training_ops.resource_apply_gradient_descent(
- handle.handle, lr, grad, use_locking=self._use_locking)
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return resource_variable_ops.resource_scatter_add(
- handle.handle, indices, -grad * lr)
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- delta = ops.IndexedSlices(
- grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.indices, grad.dense_shape)
- return var.scatter_sub(delta, use_locking=self._use_locking)
+ super(GradientDescentOptimizer, self).__init__(
+ learning_rate=learning_rate, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py
index 0a5aadc2d1..0636f7e356 100644
--- a/tensorflow/contrib/optimizer_v2/momentum.py
+++ b/tensorflow/contrib/optimizer_v2/momentum.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class MomentumOptimizer(optimizer_v2.OptimizerV2):
+class MomentumOptimizer(sgd.SGD):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
@@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
when that part of the variable was used in the forward pass.
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, momentum,
use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
@@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
optimizer functions.
@end_compatibility
"""
- super(MomentumOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("momentum", momentum)
- self._use_nesterov = use_nesterov
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
-
- def _apply_sparse(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.sparse_apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_sparse_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
+ super(MomentumOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ name=name,
+ nesterov=use_nesterov)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 53e27c08c4..9c98dd93b4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -20,462 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import abc
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.util import deprecation
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.training import optimizer as optimizer_v1
-from tensorflow.python.training import slot_creator
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.util import nest
-
-class _OptimizableVariable(object):
- """Interface for abstracting over variables in the optimizers."""
-
- @abc.abstractmethod
- def target(self):
- """Returns the optimization target for this variable."""
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def update_op(self, optimizer, g, *args):
- """Returns the update ops for updating the variable."""
- raise NotImplementedError("Calling an abstract method.")
-
-
-class _RefVariableProcessor(_OptimizableVariable):
- """Processor for Variable."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v._ref() # pylint: disable=protected-access
-
- def update_op(self, optimizer, g, *args):
- if isinstance(g, ops.Tensor):
- update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
- else:
- assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
- "tensor nor IndexedSlices.")
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- # pylint: disable=protected-access
- return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
-
-
-class _DenseReadResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _DenseResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- if isinstance(g, ops.IndexedSlices):
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- return optimizer._resource_apply_sparse_duplicate_indices(
- g.values, self._v, g.indices, *args)
- update_op = optimizer._resource_apply_dense(g, self._v, *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _TensorProcessor(_OptimizableVariable):
- """Processor for ordinary Tensors.
-
- Even though a Tensor can't really be updated, sometimes it is useful to
- compute the gradients with respect to a Tensor using the optimizer. Updating
- the Tensor is, of course, unsupported.
- """
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- raise NotImplementedError("Trying to update a Tensor ", self._v)
-
-
-def _get_processor(v):
- """The processor of v."""
- if context.executing_eagerly():
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- else:
- return _DenseResourceVariableProcessor(v)
- if v.op.type == "VarHandleOp":
- return _DenseResourceVariableProcessor(v)
- if isinstance(v, variables.Variable):
- return _RefVariableProcessor(v)
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- raise NotImplementedError("Trying to optimize unsupported type ", v)
-
-
-def _var_key_v2(var):
- """Key for representing a primary variable, for looking up slots."""
- # pylint: disable=protected-access
- if hasattr(var, "_distributed_container"):
- distributed_container = var._distributed_container()
- assert distributed_container is not None
- if context.executing_eagerly():
- return distributed_container._unique_id
- return distributed_container._shared_name
- if context.executing_eagerly():
- return var._unique_id
- return var.op.name
-
-
-def _resolve(value, name):
- if callable(value):
- value = value()
- return ops.convert_to_tensor(value, name=name)
-
-
-def _is_dynamic(value):
- """Returns true if __init__ arg `value` should be re-evaluated each step."""
- if callable(value): return True
- # Don't need to do anything special in graph mode, since dynamic values
- # will propagate correctly automatically.
- # TODO(josh11b): Add per-device caching across steps using variables for
- # truly static values once we add distributed support.
- if context.executing_eagerly() and isinstance(
- value, resource_variable_ops.ResourceVariable):
- return True
- return False
-
-
-class _OptimizerV2State(object):
- """Holds per-graph and per-step optimizer state.
-
- Use _init_with_static_hyper() to create the state for a graph, and then
- _copy_with_dynamic_hyper() to convert that to state for a particular step.
- The difference between the two is that the former only has hyper
- parameter values that are static and the latter also has values that
- can change every step (according to _is_dynamic()).
- """
-
- def __init__(self, op_name):
- self._op_name = op_name
-
- def _init_with_static_hyper(self, hyper):
- """Initialize a fresh state object from hyper dict."""
- # self._hyper contains a dict from name to a dict with the Tensor values.
- # This dict starts with a single item with key "None" with the hyper
- # parameter value converted to a Tensor. Other items have dtype keys
- # with that Tensor cast to that dtype.
- with ops.init_scope():
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if not dynamic}
- self._slots = {}
- self._non_slot_dict = {}
- # Extra state to help Optimizers implement Checkpointable. Holds information
- # about variables which will be restored as soon as they're created.
- self._deferred_dependencies = {} # Non-slot variables
- self._deferred_slot_restorations = {} # Slot variables
-
- def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
- """Create a new state object for a particular step."""
- ret = _OptimizerV2State(self._op_name)
- # pylint: disable=protected-access
- ret._slots = self._slots
- ret._non_slot_dict = self._non_slot_dict
- ret._deferred_dependencies = self._deferred_dependencies
- ret._deferred_slot_restorations = self._deferred_slot_restorations
- ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if dynamic}
- ret._hyper.update(self._hyper)
- ret._non_slot_devices = non_slot_devices
- ret._distribution = distribution
- return ret
-
- def _variables(self):
- """Returns a list of all variables held by self."""
- optimizer_variables = list(self._non_slot_dict.values())
- for variable_dict in self._slots.values():
- for slot_for_variable in variable_dict.values():
- optimizer_variables.append(slot_for_variable)
- # Sort variables by name so that the return is deterministic.
- return sorted(optimizer_variables, key=lambda v: v.name)
-
- def _slot_dict(self, slot_name):
- """Returns a dict for caching slots created under the given name.
-
- Args:
- slot_name: Name for the slot.
-
- Returns:
- A dict that maps primary `Variable` objects to the slot created
- for that variable, under the given slot name.
- """
- named_slots = self._slots.get(slot_name, None)
- if named_slots is None:
- named_slots = {}
- self._slots[slot_name] = named_slots
- return named_slots
-
- def create_slot(self, var, val, slot_name, optional_op_name=None):
- """Find or create a slot for a variable.
-
- Args:
- var: A `Variable` object.
- val: A `Tensor`. The initial value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot(
- var, val, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def create_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, optional_op_name=None):
- """Find or create a slot for a variable, using an Initializer.
-
- Args:
- var: A `Variable` object.
- initializer: An `Initializer`. The initial value of the slot.
- shape: Shape of the initial value of the slot.
- dtype: Type of the value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot_with_initializer(
- var, initializer, shape, dtype, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def zeros_slot(self, var, slot_name, optional_op_name=None):
- """Find or create a slot initialized with 0.0.
-
- Args:
- var: A `Variable` object.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_zeros_slot(
- var, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable,
- optional_op_name=None):
- """Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored. When executing eagerly, we create the slot variable with a
- restoring initializer.
-
- No new variables are created when graph building. Instead,
- _restore_slot_variable catches these after normal creation and adds restore
- ops to the graph. This method is nonetheless important when graph building
- for the case when a slot variable has already been created but `variable`
- has just been added to a dependency graph (causing us to realize that the
- slot variable needs to be restored).
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
- """
- slot_variable = self.get_slot(var=variable, name=slot_name)
- if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()
- # Defer slot variable creation if there is an active variable creator
- # scope. Generally we'd like to eagerly create/restore slot variables
- # when possible, but this may mean that scopes intended to catch
- # `variable` also catch its eagerly created slot variable
- # unintentionally (specifically make_template would add a dependency on
- # a slot variable if not for this case). Deferring is mostly harmless
- # (aside from double initialization), and makes variable creator scopes
- # behave the same way they do when graph building.
- and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
- initializer = checkpointable.CheckpointInitialValue(
- checkpoint_position=slot_variable_position)
- slot_variable = self.create_slot(
- var=variable,
- val=initializer,
- slot_name=slot_name,
- optional_op_name=optional_op_name)
- # Optimizers do not have unconditional dependencies on their slot
- # variables (nor do any other objects). They are only saved if the
- # variables they were created for are also saved.
- if slot_variable is not None:
- # If we've either made this slot variable, or if we've pulled out an
- # existing slot variable, we should restore it.
- slot_variable_position.restore(slot_variable)
- else:
- # We didn't make the slot variable. Defer restoring until it gets created
- # normally. We keep a list rather than the one with the highest restore
- # UID in case slot variables have their own dependencies, in which case
- # those could differ between restores.
- variable_key = _var_key_v2(variable)
- self._deferred_slot_restorations.setdefault(
- slot_name, {}).setdefault(variable_key, []).append(
- slot_variable_position)
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- named_slots = self._slots.get(name, None)
- if not named_slots:
- return None
- return named_slots.get(_var_key_v2(var), None)
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- return sorted(self._slots.keys())
-
- def create_non_slot(self, initial_value, name, colocate_with=None):
- """Add an extra variable, not associated with a slot."""
- v = self._non_slot_dict.get(name, None)
- if v is None:
- if colocate_with is None: colocate_with = self._non_slot_devices
- with self._distribution.colocate_vars_with(colocate_with):
- # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
- v = variable_scope.variable(initial_value, name=name, trainable=False)
- self._non_slot_dict[name] = v
- deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
- for checkpoint_position in sorted(
- deferred_dependencies_list,
- key=lambda restore: restore.checkpoint.restore_uid,
- reverse=True):
- checkpoint_position.restore(v)
- return v
-
- def _restore_slot_variable(self, slot_name, variable, slot_variable):
- """Restore a newly created slot variable's value."""
- variable_key = _var_key_v2(variable)
- deferred_restorations = self._deferred_slot_restorations.get(
- slot_name, {}).pop(variable_key, [])
- # Iterate over restores, highest restore UID first to minimize the number
- # of assignments.
- deferred_restorations.sort(key=lambda position: position.restore_uid,
- reverse=True)
- for checkpoint_position in deferred_restorations:
- checkpoint_position.restore(slot_variable)
-
- def get_non_slot(self, name):
- """Returns the non-slot variable identified by `name`."""
- return self._non_slot_dict.get(name, None)
-
- def get_hyper(self, name, dtype=None):
- """Returns the `name` hyper parameter, optionally cast to `dtype`."""
- dtype_dict = self._hyper[name]
- # Do we have the value cast to dtype already cached? This should always
- # succeed when dtype is None.
- if dtype in dtype_dict:
- return dtype_dict[dtype]
- # Not cached, cast to dtype and save the result in the cache.
- result = math_ops.cast(dtype_dict[None], dtype)
- dtype_dict[dtype] = result
- return result
-
-
-class OptimizerV2(optimizer_v1.Optimizer):
+class OptimizerV2(optimizer_v2.OptimizerV2):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@@ -586,6 +135,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
GATE_OP = 1
GATE_GRAPH = 2
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, use_locking, name):
"""Create a new Optimizer.
@@ -606,746 +159,4 @@ class OptimizerV2(optimizer_v1.Optimizer):
RuntimeError: If _create_slots has been overridden instead of
_create_vars.
"""
- # Note: We intentionally don't call parent __init__.
-
- # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
- if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
- OptimizerV2._create_slots.__code__):
- raise RuntimeError("Override _create_vars instead of _create_slots when "
- "descending from OptimizerV2 (class %s)" %
- self.__class__.__name__)
- if not name:
- raise ValueError("Must specify the optimizer name")
-
- self._use_locking = use_locking
- self._name = name
- # Map from graph_key to state for that graph. We use the graph_key
- # since it works in both eager and graph mode, and gives the outer
- # graph inside functions.
- tower_context = distribution_strategy_context.get_tower_context()
- if tower_context is None:
- # In a cross-tower context for a DistributionStrategy, which means
- # only one Optimizer will be created, not one per tower.
- self._per_graph_state = {}
- else:
- # We use get_tower_context().merge_call() to get a single dict
- # shared across all model replicas when running with a
- # DistributionStrategy.
- self._per_graph_state = tower_context.merge_call(lambda _: {})
-
- # Hyper parameters, and whether they should be re-evaluated every step.
- self._hyper = {}
-
- def _set_hyper(self, name, value):
- self._hyper[name] = (_is_dynamic(value), value)
-
- def minimize(self, loss, global_step=None, var_list=None,
- gate_gradients=GATE_OP, aggregation_method=None,
- colocate_gradients_with_ops=False, name=None,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Add operations to minimize `loss` by updating `var_list`.
-
- This method simply combines calls `compute_gradients()` and
- `apply_gradients()`. If you want to process the gradient before applying
- them call `compute_gradients()` and `apply_gradients()` explicitly instead
- of using this function.
-
- Args:
- loss: A `Tensor` containing the value to minimize.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- var_list: Optional list or tuple of `Variable` objects to update to
- minimize `loss`. Defaults to the list of variables collected in
- the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- name: Optional name for the returned operation.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- An Operation that updates the variables in `var_list`. If `global_step`
- was not `None`, that operation also increments `global_step`.
-
- Raises:
- ValueError: If some of the variables are not `Variable` objects.
-
- @compatibility(eager)
- When eager execution is enabled, `loss` should be a Python function that
- takes elements of `var_list` as arguments and computes the value to be
- minimized. If `var_list` is None, `loss` should take no arguments.
- Minimization (and gradient computation) is done with respect to the
- elements of `var_list` if not None, else with respect to any trainable
- variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
- @end_compatibility
- """
- grads_and_vars = self.compute_gradients(
- loss, var_list=var_list, gate_gradients=gate_gradients,
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- grad_loss=grad_loss, stop_gradients=stop_gradients,
- scale_loss_by_num_towers=scale_loss_by_num_towers)
-
- vars_with_grad = [v for g, v in grads_and_vars if g is not None]
- if not vars_with_grad:
- raise ValueError(
- "No gradients provided for any variable, check your graph for ops"
- " that do not support gradients, between variables %s and loss %s." %
- ([str(v) for _, v in grads_and_vars], loss))
-
- return self.apply_gradients(grads_and_vars, global_step=global_step,
- name=name)
-
- def compute_gradients(self, loss, var_list=None,
- gate_gradients=GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Compute gradients of `loss` for the variables in `var_list`.
-
- This is the first part of `minimize()`. It returns a list
- of (gradient, variable) pairs where "gradient" is the gradient
- for "variable". Note that "gradient" can be a `Tensor`, an
- `IndexedSlices`, or `None` if there is no gradient for the
- given variable.
-
- Args:
- loss: A Tensor containing the value to minimize or a callable taking
- no arguments which returns the value to minimize. When eager execution
- is enabled it must be a callable.
- var_list: Optional list or tuple of `tf.Variable` to update to minimize
- `loss`. Defaults to the list of variables collected in the graph
- under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- A list of (gradient, variable) pairs. Variable is always present, but
- gradient can be `None`.
-
- Raises:
- TypeError: If `var_list` contains anything else than `Variable` objects.
- ValueError: If some arguments are invalid.
- RuntimeError: If called with eager execution enabled and `loss` is
- not callable.
-
- @compatibility(eager)
- When eager execution is enabled, `gate_gradients`, `aggregation_method`,
- and `colocate_gradients_with_ops` are ignored.
- @end_compatibility
- """
- # TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if callable(loss):
- with backprop.GradientTape() as tape:
- if var_list is not None:
- tape.watch(var_list)
- loss_value = loss()
-
- # Scale loss for number of towers (callable-loss case). In this case,
- # we have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss_value *= 1. / num_towers
-
- if var_list is None:
- var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
- return list(zip(grads, var_list))
- if context.executing_eagerly():
- raise RuntimeError(
- "`loss` passed to Optimizer.compute_gradients should "
- "be a function when eager execution is enabled.")
-
- # Scale loss for number of towers (non-callable-loss case).
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss *= 1. / num_towers
-
- if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
- optimizer_v1.Optimizer.GATE_OP,
- optimizer_v1.Optimizer.GATE_GRAPH]:
- raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
- "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
- gate_gradients)
- self._assert_valid_dtypes([loss])
- if grad_loss is not None:
- self._assert_valid_dtypes([grad_loss])
- if var_list is None:
- var_list = (
- variables.trainable_variables() +
- ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
- else:
- var_list = nest.flatten(var_list)
- # pylint: disable=protected-access
- var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
- # pylint: enable=protected-access
- processors = [_get_processor(v) for v in var_list]
- if not var_list:
- raise ValueError("No variables to optimize.")
- var_refs = [p.target() for p in processors]
- grads = gradients.gradients(
- loss, var_refs, grad_ys=grad_loss,
- gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- stop_gradients=stop_gradients)
- if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
- grads = control_flow_ops.tuple(grads)
- grads_and_vars = list(zip(grads, var_list))
- self._assert_valid_dtypes(
- [v for g, v in grads_and_vars
- if g is not None and v.dtype != dtypes.resource])
- return grads_and_vars
-
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- """Apply gradients to variables.
-
- This is the second part of `minimize()`. It returns an `Operation` that
- applies gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs as returned by
- `compute_gradients()`.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- name: Optional name for the returned operation. Default to the
- name passed to the `Optimizer` constructor.
-
- Returns:
- An `Operation` that applies the specified gradients. If `global_step`
- was not None, that operation also increments `global_step`.
-
- Raises:
- TypeError: If `grads_and_vars` is malformed.
- ValueError: If none of the variables have gradients.
- """
- # This is a default implementation of apply_gradients() that can be shared
- # by most optimizers. It relies on the subclass implementing the following
- # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
-
- # Filter out variables with gradients of `None`.
- grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
- if not grads_and_vars:
- raise ValueError("No variables provided.")
- filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
- if not filtered:
- raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, v in grads_and_vars],))
- return distribution_strategy_context.get_tower_context().merge_call(
- self._distributed_apply, filtered, global_step=global_step, name=name)
-
- def _get_or_create_state(self, var_list=None):
- """Either looks up or creates `_OptimizerV2State`.
-
- If any variables are available, they should be passed via the `var_list`
- argument, and these will be used to determine the graph to create/retrieve
- state for. Otherwise the returned state is for the current default graph.
-
- Args:
- var_list: A list of variables to extract a graph from.
-
- Returns:
- An `_OptimizerV2State` object.
- """
- # Determine the graph_key from the current graph.
- eager_execution = context.executing_eagerly()
- if eager_execution or var_list is None:
- graph = ops.get_default_graph()
- else:
- graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
- assert graph is not None
- graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Get the per graph state by looking up the graph_key.
- if graph_key in self._per_graph_state:
- per_graph_state = self._per_graph_state[graph_key]
- else:
- per_graph_state = _OptimizerV2State(self._name)
- per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
- self._per_graph_state[graph_key] = per_graph_state
- return per_graph_state
-
- def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
- """`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce(
- variable_scope.VariableAggregation.SUM, grads_and_vars)
- var_list = [v for _, v in grads_and_vars]
- grads_and_vars = zip(reduced_grads, var_list)
-
- unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
- eager_execution = context.executing_eagerly()
- if eager_execution:
- # Give a clear error in this case instead of "name not supported
- # for Eager Tensors" when we compute non_slot_devices.
- for v in unwrapped_var_list:
- if isinstance(v, ops.Tensor):
- raise NotImplementedError("Trying to update a Tensor ", v)
-
- with ops.name_scope(name, self._name) as name:
- per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
- # Include the current value of any dynamic hyper parameters in `state`.
- non_slot_devices = distribution.non_slot_devices(var_list)
- state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
- self._hyper, distribution, non_slot_devices)
-
- # Create any slot and non-slot variables we need in `state`.
- with ops.init_scope():
- self._create_vars(var_list, state)
-
- with ops.name_scope(name): # Re-enter name_scope created above
- # Give the child class a chance to do something before we start
- # applying gradients.
- self._prepare(state)
-
- def update(v, g):
- """Update variable `v` using gradient `g`."""
- assert v is not None
-
- # Convert the grad to Tensor or IndexedSlices if necessary, and
- # look up a processor for each variable's type.
- try:
- g = ops.convert_to_tensor_or_indexed_slices(g)
- except TypeError:
- raise TypeError(
- "Gradient must be convertible to a Tensor"
- " or IndexedSlices, or None: %s" % g)
- if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
- raise TypeError(
- "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
- processor = _get_processor(v)
-
- # We colocate all ops created in _apply_dense or _apply_sparse
- # on the same device as the variable.
- # TODO(apassos): figure out how to get the variable name here.
- scope_name = "" if eager_execution else v.op.name
- # device_policy is set because non-mirrored tensors will be read in
- # `update_op`.
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with ops.name_scope("update_" + scope_name), \
- context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return processor.update_op(self, g, state)
-
- # Use the processors to update the variables.
- update_ops = []
- for grad, var in grads_and_vars:
- update_ops.extend(distribution.update(var, update, grad, grouped=False))
-
- # Give the child class a chance to do something after applying
- # gradients
- def finish():
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return self._finish(state)
-
- update_ops = control_flow_ops.group(update_ops)
- with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, grouped=False)
- # We said grouped=False, which means finish_updates is always a list.
- # It will be [None] when finish() returns None.
- if finish_updates == [None]:
- finish_updates = [update_ops]
-
- # Update `global_step` (if any).
- if global_step is None:
- apply_updates = distribution.group(finish_updates, name=name)
- else:
- with ops.control_dependencies(finish_updates):
-
- def update_global_step(global_step, name):
- return global_step.assign_add(1, read_value=False, name=name)
-
- apply_updates = distribution.update(
- global_step, update_global_step, name)
-
- # Add the training op to the TRAIN_OP graph collection in graph mode.
- if not eager_execution:
- if isinstance(apply_updates, ops.Tensor):
- apply_updates = apply_updates.op
- train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
- if apply_updates not in train_op:
- train_op.append(apply_updates)
-
- return apply_updates
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- state = self._get_state_for_var(var)
- return state.get_slot(var, name) if state is not None else None
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- state = self._get_per_graph_state()
- return state.get_slot_names() if state is not None else []
-
- def variables(self):
- """A list of variables which encode the current state of `Optimizer`.
-
- Includes slot variables and additional global variables created by the
- optimizer in the current default graph.
-
- Returns:
- A list of variables.
- """
- state = self._get_per_graph_state()
- return state._variables() if state is not None else [] # pylint: disable=protected-access
-
- # --------------
- # Methods to be implemented by subclasses if they want to use the
- # inherited implementation of apply_gradients() or compute_gradients().
- # --------------
- def _create_vars(self, var_list, state):
- """Create all slots needed by the variables and any non-slot variables.
-
- Args:
- var_list: A list of `Variable` objects.
- state: An object with these methods:
- `create_slot(var, val, slot_name, optional_op_name)`,
- `create_slot_with_initializer(`
- `var, initializer, shape, dtype, slot_name, optional_op_name)`,
- `zeros_slot(var, slot_name, optional_op_name)`,
- `create_non_slot_variable(initial_value, name, colocate_with)`,
- `get_hyper(name)`
- """
- # No slots needed by default
- pass
-
- def _prepare(self, state):
- """Code to execute before applying gradients.
-
- Note that most uses of _prepare() in Optimizer have been subsumed
- by explicit support for hyper parameters in OptimizerV2
-
- Args:
- state: An object with a `get_hyper(name)` method.
-
- Returns:
- Return value will be ignored.
- """
- pass
-
- def _apply_dense(self, grad, var, state):
- """Add ops to apply dense gradients to `var`.
-
- Args:
- grad: A `Tensor`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _resource_apply_dense(self, grad, handle, state):
- """Add ops to apply dense gradients to the variable `handle`.
-
- Args:
- grad: a `Tensor` representing the gradient.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to `handle`, with repeated indices.
-
- Optimizers which override this method must deal with repeated indices. See
- the docstring of `_apply_sparse_duplicate_indices` for details. By default
- the correct behavior, to sum non-unique indices and their associated
- gradients, is enforced by first pre-processing `grad` and `indices` and
- passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
- with duplicate indices may instead override this method to avoid the
- overhead of summing.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices may be repeated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- # pylint: disable=protected-access
- summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad, indices=indices)
- # pylint: enable=protected-access
- return self._resource_apply_sparse(
- summed_grad, handle, unique_indices, state)
-
- def _resource_apply_sparse(self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to the variable `handle`.
-
- Similar to `_apply_sparse`, the `indices` argument to this method has been
- de-duplicated. Optimizers which deal correctly with non-unique indices may
- instead override `_resource_apply_sparse_duplicate_indices` to avoid this
- overhead.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices are unique.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
-
- Optimizers which override this method must deal with IndexedSlices objects
- such as the following:
-
- IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
-
- The correct interpretation is:
-
- IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
-
- Many optimizers deal incorrectly with repeated indices when updating based
- on sparse gradients (e.g. summing squares rather than squaring the sum, or
- applying momentum terms multiple times). Adding first is always the correct
- behavior, so this is enforced here by reconstructing the IndexedSlices to
- have only unique indices, then calling _apply_sparse.
-
- Optimizers which deal correctly with repeated indices may instead override
- this method to avoid the overhead of summing indices.
-
- Args:
- grad: `IndexedSlices`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- # pylint: disable=protected-access
- summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad.values, indices=grad.indices)
- # pylint: enable=protected-access
- gradient_no_duplicate_indices = ops.IndexedSlices(
- indices=unique_indices,
- values=summed_values,
- dense_shape=grad.dense_shape)
- return self._apply_sparse(gradient_no_duplicate_indices, var, state)
-
- def _apply_sparse(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`.
-
- The IndexedSlices object passed to `grad` in this function is by default
- pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
- indices (see its docstring for details). Optimizers which can tolerate or
- have correct special cases for duplicate sparse indices may override
- `_apply_sparse_duplicate_indices` instead of this function, avoiding that
- overhead.
-
- Args:
- grad: `IndexedSlices`, with no repeated indices.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _finish(self, state):
- """Do what is needed to finish the update.
-
- This is called inside a scope colocated with any non-slot variables.
-
- Args:
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- The operation to apply updates, or None if no updates.
- """
- return None
-
- # --------------
- # Utility methods for subclasses.
- # --------------
- def _get_per_graph_state(self):
- # pylint: disable=protected-access
- return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
-
- def _get_state_for_var(self, var):
- # pylint: disable=protected-access
- return self._per_graph_state.get(var._graph_key, None)
-
- # --------------
- # Overridden methods from Checkpointable.
- # --------------
-
- def _track_checkpointable(self, *args, **kwargs):
- """Optimizers may not track dependencies. Raises an error."""
- raise NotImplementedError(
- "Optimizers may not have dependencies. File a feature request if this "
- "limitation bothers you.")
-
- @property
- def _checkpoint_dependencies(self):
- """From Checkpointable. Gather graph-specific non-slot variables to save."""
- current_graph_non_slot_variables = []
- state = self._get_per_graph_state()
- if state is not None:
- for name, variable_object in sorted(
- state._non_slot_dict.items(), # pylint: disable=protected-access
- # Avoid comparing variables
- key=lambda item: item[0]):
- current_graph_non_slot_variables.append(
- checkpointable.CheckpointableReference(
- name=name, ref=variable_object))
- # Note: ignores super(); Optimizers may not have any dependencies outside of
- # state objects.
- return current_graph_non_slot_variables
-
- def _lookup_dependency(self, name):
- """From Checkpointable. Find a non-slot variable in the current graph."""
- state = self._get_per_graph_state()
- if state is None:
- return None
- else:
- return state.get_non_slot(name)
-
- @property
- def _deferred_dependencies(self):
- """Lets Checkpointable know where non-slot variables are created.
-
- If necessary, creates a new state object for the current default graph.
- Checkpointable will then add entries to that state's deferred dependency
- dictionary. The state object will check that dictionary when creating
- non-slot variables, restoring their value if an entry is found.
-
- Returns:
- A dictionary which holds deferred dependencies for the current default
- graph.
- """
- state = self._get_or_create_state()
- return state._deferred_dependencies # pylint: disable=protected-access
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable):
- """Checkpointable: Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored.
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- """
- state = self._get_or_create_state(var_list=[variable])
- state._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=slot_variable_position,
- slot_name=slot_name,
- variable=variable,
- optional_op_name=self._name)
-
- # --------------
- # Unsupported parent methods
- # --------------
- def _slot_dict(self, slot_name):
- raise NotImplementedError(
- "_slot_dict() method unsupported in OptimizerV2")
-
- def _get_or_make_slot(self, var, val, slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot() method unsupported in OptimizerV2")
-
- def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot_with_initializer() method unsupported in "
- "OptimizerV2")
-
- def _create_non_slot_variable(self, initial_value, name, colocate_with):
- raise NotImplementedError(
- "_create_non_slot_variable() method unsupported in OptimizerV2")
-
- def _get_non_slot_variable(self, name, graph=None):
- raise NotImplementedError(
- "_get_non_slot_variable() method unsupported in OptimizerV2")
-
- def _non_slot_variables(self):
- raise NotImplementedError(
- "_non_slot_variables() method unsupported in OptimizerV2")
+ super(OptimizerV2, self).__init__(name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 3de53405ec..090e257ddc 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -41,19 +41,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.util import deprecation
-from tensorflow.python.training import training_ops
-
-class RMSPropOptimizer(optimizer_v2.OptimizerV2):
+class RMSPropOptimizer(rmsprop.RMSProp):
"""Optimizer that implements the RMSProp algorithm.
See the
[paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self,
learning_rate,
decay=0.9,
@@ -96,138 +98,10 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
"""
- super(RMSPropOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("decay", decay)
- self._set_hyper("momentum", momentum)
- self._set_hyper("epsilon", epsilon)
-
- self._centered = centered
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- init_rms = state.get_hyper(
- "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
- state.create_slot_with_initializer(v, init_rms, v.get_shape(),
- v.dtype.base_dtype, "rms")
- if self._centered:
- state.zeros_slot(v, "mg")
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- # epsilon is now the rms initial value and is not added to the
- # denominator anymore, hence calling the kernel op with epsilon=0.
- 0,
- grad,
- use_locking=self._use_locking).op
- else:
- return training_ops.apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.resource_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.sparse_apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
- else:
- return training_ops.sparse_apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = self.get_slot(var, "mg")
- return training_ops.resource_sparse_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_sparse_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
+ super(RMSPropOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ rho=decay,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered,
+ name=name)