aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/resource_variable_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 1d747f8400..1471b5909e 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -19,11 +19,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from autograd import core as ag_core
+
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import custom_gradient
from tensorflow.python.eager import tape
+from tensorflow.python.eager import tensor_node
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -574,7 +577,14 @@ class ResourceVariable(variables.Variable):
def _run_op(a, *args):
# pylint: disable=protected-access
- return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
+ value = a._AsTensor()
+ if ag_core.isnode(value):
+ # This avoids autograd trying to wrap a ResourceVariable.
+ value = ops.convert_to_tensor(value)
+ args = [ops.convert_to_tensor(x) for x in args]
+ return getattr(tensor_node.TensorNode, operator)(value, *args)
+ else:
+ return getattr(ops.Tensor, operator)(value, *args)
# Propagate __doc__ to wrapper
try: