diff options
author | Asim Shankar <ashankar@google.com> | 2018-10-04 09:21:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 09:30:41 -0700 |
commit | ac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (patch) | |
tree | 4fcbb4a8078b50c31862c38c7b8c48e01d0b3a28 /tensorflow/python/kernel_tests | |
parent | a7e8ad18a61b251ef42c0260dd80a12cea8f268c (diff) |
Gracefully disallow updating resource variables with invalid shapes.
During graph construction, the shape function for AssignAddVariableOp etc.
would raise an error when the value being "assign add"ed to the variable
has an incompatible shape.
With eager execution, no such validation was being made which triggerred
an assertion failure in eigen:
https://github.com/eigenteam/eigen-git-mirror/blob/7d97e1cbbe4424fda39e31c88def7c0863897640/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h#L479
This change prevents that assertion failure.
PiperOrigin-RevId: 215749071
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 1365d4b240..a9fd93e9f8 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -142,7 +142,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(1.0) ops.reset_default_graph() v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the - # variable graph. + # variable graph. def testFetchHandle(self): with self.cached_session(): @@ -908,6 +908,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(Exception, r"shape.*2.*3"): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) + @test_util.run_in_graph_and_eager_modes + def testAssignIncompatibleShape(self): + v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) + self.evaluate(v.initializer) + with self.assertRaisesRegexp(Exception, r"hapes must be equal"): + self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4]) + class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase): |