aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-10-04 09:21:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:30:41 -0700
commitac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (patch)
tree4fcbb4a8078b50c31862c38c7b8c48e01d0b3a28 /tensorflow/python/kernel_tests
parenta7e8ad18a61b251ef42c0260dd80a12cea8f268c (diff)
Gracefully disallow updating resource variables with invalid shapes.
During graph construction, the shape function for AssignAddVariableOp etc. would raise an error when the value being "assign add"ed to the variable has an incompatible shape. With eager execution, no such validation was being made which triggerred an assertion failure in eigen: https://github.com/eigenteam/eigen-git-mirror/blob/7d97e1cbbe4424fda39e31c88def7c0863897640/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h#L479 This change prevents that assertion failure. PiperOrigin-RevId: 215749071
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 1365d4b240..a9fd93e9f8 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -142,7 +142,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(1.0)
ops.reset_default_graph()
v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the
- # variable graph.
+ # variable graph.
def testFetchHandle(self):
with self.cached_session():
@@ -908,6 +908,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(Exception, r"shape.*2.*3"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+ @test_util.run_in_graph_and_eager_modes
+ def testAssignIncompatibleShape(self):
+ v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
+ self.evaluate(v.initializer)
+ with self.assertRaisesRegexp(Exception, r"hapes must be equal"):
+ self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4])
+
class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):