aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-28 11:41:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-28 11:48:38 -0700
commit95d240c5fbecec9fdbef55dc1154c4f454752633 (patch)
tree117a2b52af4ed1df28c2094f959a9098d87bd34d /tensorflow
parent2a6c8897f59e2cbf943f52b222a1968fa7e2f158 (diff)
Fix Adam in Eager mode and test adam/momentum
PiperOrigin-RevId: 166733547
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/framework/test_util.py8
-rw-r--r--tensorflow/python/ops/variable_scope.py2
-rw-r--r--tensorflow/python/training/adam.py8
-rw-r--r--tensorflow/python/training/adam_test.py79
-rw-r--r--tensorflow/python/training/momentum_test.py128
-rw-r--r--tensorflow/python/training/optimizer.py2
-rw-r--r--tensorflow/python/training/slot_creator.py4
7 files changed, 138 insertions, 93 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index b494b988b3..aceebbc9cc 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -276,7 +276,8 @@ def enable_c_api(fn):
def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
- use_gpu=False, force_gpu=False):
+ use_gpu=False, force_gpu=False,
+ reset_test=False):
"""Runs the test in both graph and eager modes.
Args:
@@ -286,6 +287,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
session.
use_gpu: If True, attempt to run as many ops as possible on GPU.
force_gpu: If True, pin all ops to `/device:GPU:0`.
+ reset_test: If True, tearDown and SetUp the test case again.
Returns:
Returns a decorator that will run the decorated test function
@@ -302,6 +304,10 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
with self.test_session(graph, config, use_gpu, force_gpu):
f(self)
+ if reset_test:
+ self.tearDown()
+ self.setUp()
+
def run_eager_mode():
if force_gpu:
gpu_name = gpu_device_name()
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 908b6b2111..b7913890e4 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -978,7 +978,7 @@ class VariableScope(object):
def set_partitioner(self, partitioner):
"""Set partitioner for this scope."""
- if context.in_eager_mode():
+ if partitioner and context.in_eager_mode():
raise NotImplementedError("Partitioned variables are not yet supported "
"in Eager mode.")
self._partitioner = partitioner
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index 796402425a..cdc532a38e 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -118,8 +119,11 @@ class AdamOptimizer(optimizer.Optimizer):
# silently ignored).
first_var = min(var_list, key=lambda x: x.name)
- if (self._beta1_power is None or
- self._beta1_power.graph is not first_var.graph):
+ create_new = self._beta1_power is None
+ if not create_new and context.in_graph_mode():
+ create_new = (self._beta1_power.graph is not first_var.graph)
+
+ if create_new:
with ops.colocate_with(first_var):
self._beta1_power = variable_scope.variable(self._beta1,
name="beta1_power",
diff --git a/tensorflow/python/training/adam_test.py b/tensorflow/python/training/adam_test.py
index 62b171e234..defcf33714 100644
--- a/tensorflow/python/training/adam_test.py
+++ b/tensorflow/python/training/adam_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@@ -149,49 +151,60 @@ class AdamOptimizerTest(test.TestCase):
repeated_index_update_var.eval())
def doTestBasic(self, use_resource=False):
- for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- # Initialize variables for numpy implementation.
- m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
- var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
- grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
- var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
- grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable(var0_np)
- var1 = resource_variable_ops.ResourceVariable(var1_np)
- else:
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
- grads0 = constant_op.constant(grads0_np)
- grads1 = constant_op.constant(grads1_np)
- opt = adam.AdamOptimizer()
- update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
+ opt = adam.AdamOptimizer()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ if context.in_graph_mode():
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
- beta1_power, beta2_power = opt._get_beta_accumulators()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
- # Run 3 steps of Adam
- for t in range(1, 4):
- self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
- self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
- update.run()
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if context.in_graph_mode():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
- var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
- var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
- # Validate updated params
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testBasic(self):
- self.doTestBasic(use_resource=False)
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 9d6221b560..ba9f763831 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
@@ -43,66 +45,82 @@ class MomentumOptimizerTest(test.TestCase):
return var, accum
def doTestBasic(self, use_resource=False):
- for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- if use_resource:
- var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
- var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
- else:
- var0 = variables.Variable([1.0, 2.0], dtype=dtype)
- var1 = variables.Variable([3.0, 4.0], dtype=dtype)
- grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
- grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
- mom_opt = momentum_lib.MomentumOptimizer(
- learning_rate=2.0, momentum=0.9)
- mom_update = mom_opt.apply_gradients(
- zip([grads0, grads1], [var0, var1]))
- variables.global_variables_initializer().run()
- # Check we have slots
- self.assertEqual(["momentum"], mom_opt.get_slot_names())
- slot0 = mom_opt.get_slot(var0, "momentum")
- self.assertEquals(slot0.get_shape(), var0.get_shape())
- self.assertFalse(slot0 in variables.trainable_variables())
- slot1 = mom_opt.get_slot(var1, "momentum")
- self.assertEquals(slot1.get_shape(), var1.get_shape())
- self.assertFalse(slot1 in variables.trainable_variables())
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = momentum_lib.MomentumOptimizer(
+ learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ if context.in_graph_mode():
+ self.evaluate(variables.global_variables_initializer())
# Fetch params to validate initial values
- self.assertAllClose([1.0, 2.0], var0.eval())
- self.assertAllClose([3.0, 4.0], var1.eval())
- # Step 1: the momentum accumulators where 0. So we should see a normal
- # update: v -= grad * learning_rate
- mom_update.run()
- # Check that the momentum accumulators have been updated.
- self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
- self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
- # Check that the parameters have been updated.
- self.assertAllCloseAccordingToType(
- np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
- # Step 2: the momentum accumulators contain the previous update.
- mom_update.run()
- # Check that the momentum accumulators have been updated.
- self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
- self.assertAllCloseAccordingToType(
- np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
- # Check that the parameters have been updated.
- self.assertAllCloseAccordingToType(
- np.array([
- 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
- 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
- ]), var0.eval())
- self.assertAllCloseAccordingToType(
- np.array([
- 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
- (0.9 * 0.01 + 0.01) * 2.0)
- ]), var1.eval())
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in variables.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ if context.in_graph_mode():
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
+ # Step 2: the momentum accumulators contain the previous update.
+ if context.in_graph_mode():
+ self.evaluate(mom_update)
+ else:
+ mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
def testBasic(self):
- self.doTestBasic(use_resource=False)
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testResourceBasic(self):
self.doTestBasic(use_resource=True)
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 250e22f91e..86ba8e2c8e 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -69,6 +69,8 @@ def _deduplicate_indexed_slices(values, indices):
def _var_key(var):
+ if context.in_eager_mode():
+ return var._shared_name # pylint: disable=protected-access
return (var.op.graph, var.op.name)
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 4371e92bd3..ea28b5ddfc 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -39,6 +39,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -139,7 +140,8 @@ def create_slot_with_initializer(primary, initializer, shape, dtype, name,
# and the same name has been previously used, the scope name will add '_N'
# as suffix for unique identifications.
validate_shape = shape.is_fully_defined()
- with variable_scope.variable_scope(None, primary.op.name + "/" + name):
+ prefix = primary.op.name if context.in_graph_mode() else primary._shared_name # pylint: disable=protected-access
+ with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
with ops.colocate_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,