diff options
author | 2018-10-04 09:21:05 -0700 | |
---|---|---|
committer | 2018-10-04 09:30:41 -0700 | |
commit | ac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (patch) | |
tree | 4fcbb4a8078b50c31862c38c7b8c48e01d0b3a28 | |
parent | a7e8ad18a61b251ef42c0260dd80a12cea8f268c (diff) |
Gracefully disallow updating resource variables with invalid shapes.
During graph construction, the shape function for AssignAddVariableOp etc.
would raise an error when the value being "assign add"ed to the variable
has an incompatible shape.
With eager execution, no such validation was being made which triggerred
an assertion failure in eigen:
https://github.com/eigenteam/eigen-git-mirror/blob/7d97e1cbbe4424fda39e31c88def7c0863897640/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h#L479
This change prevents that assertion failure.
PiperOrigin-RevId: 215749071
-rw-r--r-- | tensorflow/core/kernels/resource_variable_ops.cc | 6 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 9 |
2 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 23d76986bf..678d675c4a 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -426,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel { // ADD if value's refcount was 1. mutex_lock ml(*variable->mu()); Tensor* var_tensor = variable->tensor(); + OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()), + errors::InvalidArgument("Cannot update variable with shape ", + var_tensor->shape().DebugString(), + " using a Tensor with shape ", + value.shape().DebugString(), + ", shapes must be equal.")); OP_REQUIRES_OK(context, PrepareToUpdateVariable<Device, T>(context, var_tensor)); functor::DenseUpdate<Device, T, Op> update_functor; diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 1365d4b240..a9fd93e9f8 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -142,7 +142,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): v = resource_variable_ops.ResourceVariable(1.0) ops.reset_default_graph() v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the - # variable graph. + # variable graph. def testFetchHandle(self): with self.cached_session(): @@ -908,6 +908,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(Exception, r"shape.*2.*3"): state_ops.scatter_update(v, [0, 1], [0, 1, 2]) + @test_util.run_in_graph_and_eager_modes + def testAssignIncompatibleShape(self): + v = resource_variable_ops.ResourceVariable([0, 1, 2, 3]) + self.evaluate(v.initializer) + with self.assertRaisesRegexp(Exception, r"hapes must be equal"): + self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4]) + class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase): |