diff options
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 1d747f8400..1471b5909e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -19,11 +19,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from autograd import core as ag_core + from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import tape +from tensorflow.python.eager import tensor_node from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -574,7 +577,14 @@ class ResourceVariable(variables.Variable): def _run_op(a, *args): # pylint: disable=protected-access - return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + value = a._AsTensor() + if ag_core.isnode(value): + # This avoids autograd trying to wrap a ResourceVariable. + value = ops.convert_to_tensor(value) + args = [ops.convert_to_tensor(x) for x in args] + return getattr(tensor_node.TensorNode, operator)(value, *args) + else: + return getattr(ops.Tensor, operator)(value, *args) # Propagate __doc__ to wrapper try: |