aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-29 15:06:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 15:12:10 -0700
commitf282bb142f8d170549f6708f334bed7f40dd029e (patch)
tree61b99a6905c62434e33a5e6ef96acd5d215a6367 /tensorflow/python/eager/tape.py
parent232a4f3749e2534adb7695a1d24b8b0fbd985039 (diff)
Fix bug with second derivatives in eager mode.
Improper wrapping and unwrapping of tensors lead to tracing being dropped. PiperOrigin-RevId: 166910119
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py26
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index f2915eba59..6fd19aee45 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -101,6 +101,9 @@ class NoneVSpace(ag_core.VSpace):
def __init__(self, _):
self.size = 0
+ def zeros(self):
+ return 0
+
ag_core.register_vspace(NoneVSpace, type(None))
@@ -197,28 +200,32 @@ class _EagerSequenceVSpace(container_types.SequenceVSpace):
return True
-class _EagerList(list):
- """Type used to bypass SequenceVSpace."""
+class EagerList(list):
+ """Type used to bypass SequenceVSpace.
+
+ SequenceVSpace has a very strict equality check which does not match
+ tensorflow semantics.
+ """
def __init__(self, value):
- super(_EagerList, self).__init__(value)
+ super(EagerList, self).__init__(value)
for v in value:
assert not ag_core.isnode(v)
-ag_core.register_vspace(_EagerSequenceVSpace, _EagerList)
-ag_core.register_node(_EagerSequenceNode, _EagerList)
+ag_core.register_vspace(_EagerSequenceVSpace, EagerList)
+ag_core.register_node(_EagerSequenceNode, EagerList)
@ag_core.primitive
def _record_operation(output_tensors, input_tensors, side_outputs,
backward_function):
del input_tensors, side_outputs, backward_function
- return _EagerList(output_tensors)
+ return EagerList(output_tensors)
def record_operation(o, i, s, b):
"""Primitive to trigger autograd tracing on outputs from inputs."""
- inputs = container_types.make_sequence(_EagerList, *i)
+ inputs = container_types.make_sequence(EagerList, *i)
return _record_operation(o, inputs, s, b)
@@ -227,9 +234,8 @@ def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors,
"""Gradient for _record_operation."""
del ans, vs, gvs, output_tensors, input_tensors
backward_args = tuple(g) + tuple(side_outputs)
- if ag_core.isnode(backward_args):
- backward_args = list(backward_args)
+ backward_args = container_types.make_sequence(EagerList, *backward_args)
tensors = nest.flatten(backward_function(*backward_args))
- return _EagerList([ag_core.getval(t) for t in tensors])
+ return container_types.make_sequence(EagerList, *tensors)
_record_operation.defvjp(_record_operation_vjp, argnum=1)