aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-04-10 10:43:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 10:45:31 -0700
commit36a07c59954b8ace54879b8732b6a7ae2dce6450 (patch)
tree0489f1adb914dfc732271a9758461abf49760f30
parentbd718c410478d066ed1c41d5ffe31970075b808a (diff)
Simplify test_util.run_in_graph_and_eager_modes
- Get rid of unnecessary options - Update various resource variable tests so that they correctly exercise the cases where the variables are placed on GPU (these "with tf.device('/cpu:0')" blocks that were added for eager execution are no longer necessary) PiperOrigin-RevId: 192309109
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_utils_test.py10
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum_test.py24
-rw-r--r--tensorflow/python/framework/test_util.py90
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings_test.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/pooling_test.py18
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py340
-rw-r--r--tensorflow/python/training/momentum_test.py24
7 files changed, 256 insertions, 252 deletions
diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
index e6498ddb06..3ec5c3de39 100644
--- a/tensorflow/contrib/eager/python/checkpointable_utils_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py
@@ -719,8 +719,9 @@ class CheckpointingTests(test.TestCase):
checkpoint_directory = self.get_temp_dir()
root = checkpointable.Checkpointable()
- root.var = checkpointable_utils.add_variable(
- root, name="var", initializer=0.)
+ with ops.device("/cpu:0"):
+ root.var = checkpointable_utils.add_variable(
+ root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
if context.executing_eagerly():
optimizer.minimize(root.var.read_value)
@@ -750,8 +751,9 @@ class CheckpointingTests(test.TestCase):
new_root).restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
- new_root.var = checkpointable_utils.add_variable(
- new_root, name="var", shape=[])
+ with ops.device("/cpu:0"):
+ new_root.var = checkpointable_utils.add_variable(
+ new_root, name="var", shape=[])
no_slot_status.assert_consumed()
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py
index f37eb48181..26724f66c2 100644
--- a/tensorflow/contrib/optimizer_v2/momentum_test.py
+++ b/tensorflow/contrib/optimizer_v2/momentum_test.py
@@ -237,7 +237,17 @@ class MomentumOptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
# pylint: disable=cell-var-from-loop
def loss():
@@ -256,7 +266,17 @@ class MomentumOptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
- var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
def loss():
return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index bf00fa6439..eea27d76c6 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -615,45 +615,68 @@ def assert_no_garbage_created(f):
def run_in_graph_and_eager_modes(__unused__=None,
- graph=None,
config=None,
- use_gpu=False,
- force_gpu=False,
+ use_gpu=True,
reset_test=True,
assert_no_eager_garbage=False):
- """Runs the test in both graph and eager modes.
+ """Execute the decorated test with and without enabling eager execution.
+
+ This function returns a decorator intended to be applied to test methods in
+ a @{tf.test.TestCase} class. Doing so will cause the contents of the test
+ method to be executed twice - once normally, and once with eager execution
+ enabled. This allows unittests to confirm the equivalence between eager
+ and graph execution (see @{tf.enable_eager_execution}).
+
+ For example, consider the following unittest:
+
+ ```python
+ class MyTests(tf.test.TestCase):
+
+ @run_in_graph_and_eager_modes()
+ def test_foo(self):
+ x = tf.constant([1, 2])
+ y = tf.constant([3, 4])
+ z = tf.add(x, y)
+ self.assertAllEqual([4, 6], self.evaluate(z))
+
+ if __name__ == "__main__":
+ tf.test.main()
+ ```
+
+ This test validates that `tf.add()` has the same behavior when computed with
+ eager execution enabled as it does when constructing a TensorFlow graph and
+ executing the `z` tensor in a session.
+
Args:
__unused__: Prevents sliently skipping tests.
- graph: Optional graph to use during the returned session.
config: An optional config_pb2.ConfigProto to use to configure the
- session.
- use_gpu: If True, attempt to run as many ops as possible on GPU.
- force_gpu: If True, pin all ops to `/device:GPU:0`.
- reset_test: If True, tearDown and SetUp the test case again.
+ session when executing graphs.
+ use_gpu: If True, attempt to run as many operations as possible on GPU.
+ reset_test: If True, tearDown and SetUp the test case between the two
+ executions of the test (once with and once without eager execution).
assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
collector and asserts that no extra garbage has been created when running
- the test in eager mode. This will fail if there are reference cycles
- (e.g. a = []; a.append(a)). Off by default because some tests may create
- garbage for legitimate reasons (e.g. they define a class which inherits
- from `object`), and because DEBUG_SAVEALL is sticky in some Python
- interpreters (meaning that tests which rely on objects being collected
- elsewhere in the unit test file will not work). Additionally, checks that
- nothing still has a reference to Tensors that the test allocated.
+ the test with eager execution enabled. This will fail if there are
+ reference cycles (e.g. a = []; a.append(a)). Off by default because some
+ tests may create garbage for legitimate reasons (e.g. they define a class
+ which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
+ Python interpreters (meaning that tests which rely on objects being
+ collected elsewhere in the unit test file will not work). Additionally,
+ checks that nothing still has a reference to Tensors that the test
+ allocated.
Returns:
- Returns a decorator that will run the decorated test function
- using both a graph and using eager execution.
+ Returns a decorator that will run the decorated test method twice:
+ once by constructing and executing a graph in a session and once with
+ eager execution enabled.
"""
assert not __unused__, "Add () after run_in_graph_and_eager_modes."
def decorator(f):
- """Test method decorator."""
-
def decorated(self, **kwargs):
- """Decorated the test method."""
with context.graph_mode():
- with self.test_session(graph, config, use_gpu, force_gpu):
+ with self.test_session(use_gpu=use_gpu):
f(self, **kwargs)
if reset_test:
@@ -663,27 +686,20 @@ def run_in_graph_and_eager_modes(__unused__=None,
self._tempdir = None
self.setUp()
- def run_eager_mode(self, **kwargs):
- if force_gpu:
- gpu_name = gpu_device_name()
- if not gpu_name:
- gpu_name = "/device:GPU:0"
- with context.device(gpu_name):
- f(self)
- elif use_gpu:
- # TODO(xpan): Support softplacement and gpu by default when available.
- f(self, **kwargs)
- else:
- with context.device("/device:CPU:0"):
+ def run_eagerly(self, **kwargs):
+ if not use_gpu:
+ with ops.device("/cpu:0"):
f(self, **kwargs)
+ else:
+ f(self, **kwargs)
if assert_no_eager_garbage:
- run_eager_mode = assert_no_new_tensors(
- assert_no_garbage_created(run_eager_mode))
+ run_eagerly = assert_no_new_tensors(
+ assert_no_garbage_created(run_eagerly))
with context.eager_mode():
with ops.Graph().as_default():
- run_eager_mode(self, **kwargs)
+ run_eagerly(self, **kwargs)
return decorated
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
index 26fd1f1c11..9f6793eac8 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py
@@ -26,7 +26,7 @@ from tensorflow.python.platform import test
class EmbeddingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes()
+ @tf_test_util.run_in_graph_and_eager_modes(use_gpu=False)
def test_embedding(self):
testing_utils.layer_test(
keras.layers.Embedding,
diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py
index bb003c1ddd..2c08b647ea 100644
--- a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py
@@ -27,14 +27,14 @@ from tensorflow.python.platform import test
class GlobalPoolingTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_globalpooling_2d(self):
testing_utils.layer_test(
keras.layers.pooling.GlobalMaxPooling2D,
@@ -53,7 +53,7 @@ class GlobalPoolingTest(test.TestCase):
kwargs={'data_format': 'channels_last'},
input_shape=(3, 5, 6, 4))
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_globalpooling_3d(self):
testing_utils.layer_test(
keras.layers.pooling.GlobalMaxPooling3D,
@@ -75,7 +75,7 @@ class GlobalPoolingTest(test.TestCase):
class Pooling2DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_maxpooling_2d(self):
pool_size = (3, 3)
for strides in [(1, 1), (2, 2)]:
@@ -88,7 +88,7 @@ class Pooling2DTest(test.TestCase):
},
input_shape=(3, 5, 6, 4))
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_averagepooling_2d(self):
testing_utils.layer_test(
keras.layers.AveragePooling2D,
@@ -122,7 +122,7 @@ class Pooling2DTest(test.TestCase):
class Pooling3DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_maxpooling_3d(self):
pool_size = (3, 3, 3)
testing_utils.layer_test(
@@ -141,7 +141,7 @@ class Pooling3DTest(test.TestCase):
},
input_shape=(3, 4, 11, 12, 10))
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_averagepooling_3d(self):
pool_size = (3, 3, 3)
testing_utils.layer_test(
@@ -163,7 +163,7 @@ class Pooling3DTest(test.TestCase):
class Pooling1DTest(test.TestCase):
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_maxpooling_1d(self):
for padding in ['valid', 'same']:
for stride in [1, 2]:
@@ -173,7 +173,7 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
- @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @tf_test_util.run_in_graph_and_eager_modes()
def test_averagepooling_1d(self):
for padding in ['valid', 'same']:
for stride in [1, 2]:
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index edc63264a3..6d33086936 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -174,215 +174,161 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32))
self.assertEqual(read, 2)
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ @test_util.run_in_graph_and_eager_modes()
def testScatterAdd(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(resource_variable_ops.assign_variable_op(
- handle, constant_op.constant([[1]], dtype=dtypes.int32)))
- self.evaluate(resource_variable_ops.resource_scatter_add(
- handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[3]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_add(
+ handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterSub(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[1]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_sub(handle, [0],
- constant_op.constant(
- [[2]],
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[-1]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(
+ handle, [0], constant_op.constant([[2]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMul(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[1]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_mul(handle, [0],
- constant_op.constant(
- [[5]],
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[5]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(
+ handle, [0], constant_op.constant([[5]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterDiv(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_div(handle, [0],
- constant_op.constant(
- [[3]],
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[2]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMin(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_min(handle, [0],
- constant_op.constant(
- [[3]],
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[3]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMax(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_max(handle, [0],
- constant_op.constant(
- [[3]],
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[6]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(
+ handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterAddScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[1]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_add(handle, [0],
- constant_op.constant(
- 2,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[3]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_add(
+ handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterSubScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[1]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_sub(handle, [0],
- constant_op.constant(
- 2,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[-1]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_sub(
+ handle, [0], constant_op.constant(2, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[-1]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMulScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[1]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_mul(handle, [0],
- constant_op.constant(
- 5,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[5]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[1]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_mul(
+ handle, [0], constant_op.constant(5, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[5]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterDivScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_div(handle, [0],
- constant_op.constant(
- 3,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[2]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_div(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[2]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMinScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_min(handle, [0],
- constant_op.constant(
- 3,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[3]])
-
- @test_util.run_in_graph_and_eager_modes(use_gpu=True)
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_min(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[3]])
+
+ @test_util.run_in_graph_and_eager_modes()
def testScatterMaxScalar(self):
- with ops.device("cpu:0"):
- handle = resource_variable_ops.var_handle_op(
- dtype=dtypes.int32, shape=[1, 1])
- self.evaluate(
- resource_variable_ops.assign_variable_op(handle,
- constant_op.constant(
- [[6]],
- dtype=dtypes.int32)))
- self.evaluate(
- resource_variable_ops.resource_scatter_max(handle, [0],
- constant_op.constant(
- 3,
- dtype=dtypes.int32)))
- read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
- self.assertEqual(self.evaluate(read), [[6]])
+ handle = resource_variable_ops.var_handle_op(
+ dtype=dtypes.int32, shape=[1, 1])
+ self.evaluate(
+ resource_variable_ops.assign_variable_op(
+ handle, constant_op.constant([[6]], dtype=dtypes.int32)))
+ self.evaluate(
+ resource_variable_ops.resource_scatter_max(
+ handle, [0], constant_op.constant(3, dtype=dtypes.int32)))
+ read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+ self.assertEqual(self.evaluate(read), [[6]])
def testScatterUpdateString(self):
handle = resource_variable_ops.var_handle_op(
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 297a8bbde5..7bd57ad3d8 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -237,7 +237,17 @@ class MomentumOptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
# pylint: disable=cell-var-from-loop
def loss():
@@ -256,7 +266,17 @@ class MomentumOptimizerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
- var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
def loss():
return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))