aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
authorGravatar Ali Yahya <alive@google.com>2017-08-18 11:02:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-18 11:09:17 -0700
commit360bff8ae51f2341e3586e247195cc80a4a3933f (patch)
treeee6dd536ee3c4a212af3a524059959e6167ae068 /tensorflow/python/eager/tape.py
parent80bd004cdcac67468d69a407665b8d552efab3a8 (diff)
Makes tape.watch() work with ResourceVariables.
To this end, also adds a property, `device`, to TensorNode. PiperOrigin-RevId: 165726368
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index f2915eba59..5a0959a75e 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -25,6 +25,7 @@ from autograd import core as ag_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
@@ -143,6 +144,10 @@ def watch(tensor):
Returns:
The tensor, potentially wrapped by all tapes in the stack.
"""
+ if isinstance(tensor, resource_variable_ops.ResourceVariable):
+ tensor._handle = watch(tensor.handle) # pylint: disable=protected-access
+ return tensor
+
for t in _tape_stack.stack:
tensor = _watch_with_tape(t, tensor)
return tensor