diff options
author | Alexandre Passos <apassos@google.com> | 2017-08-29 15:06:31 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 15:12:10 -0700 |
commit | f282bb142f8d170549f6708f334bed7f40dd029e (patch) | |
tree | 61b99a6905c62434e33a5e6ef96acd5d215a6367 /tensorflow/python/eager/tape.py | |
parent | 232a4f3749e2534adb7695a1d24b8b0fbd985039 (diff) |
Fix bug with second derivatives in eager mode.
Improper wrapping and unwrapping of tensors lead to tracing being dropped.
PiperOrigin-RevId: 166910119
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 26 |
1 files changed, 16 insertions, 10 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index f2915eba59..6fd19aee45 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -101,6 +101,9 @@ class NoneVSpace(ag_core.VSpace): def __init__(self, _): self.size = 0 + def zeros(self): + return 0 + ag_core.register_vspace(NoneVSpace, type(None)) @@ -197,28 +200,32 @@ class _EagerSequenceVSpace(container_types.SequenceVSpace): return True -class _EagerList(list): - """Type used to bypass SequenceVSpace.""" +class EagerList(list): + """Type used to bypass SequenceVSpace. + + SequenceVSpace has a very strict equality check which does not match + tensorflow semantics. + """ def __init__(self, value): - super(_EagerList, self).__init__(value) + super(EagerList, self).__init__(value) for v in value: assert not ag_core.isnode(v) -ag_core.register_vspace(_EagerSequenceVSpace, _EagerList) -ag_core.register_node(_EagerSequenceNode, _EagerList) +ag_core.register_vspace(_EagerSequenceVSpace, EagerList) +ag_core.register_node(_EagerSequenceNode, EagerList) @ag_core.primitive def _record_operation(output_tensors, input_tensors, side_outputs, backward_function): del input_tensors, side_outputs, backward_function - return _EagerList(output_tensors) + return EagerList(output_tensors) def record_operation(o, i, s, b): """Primitive to trigger autograd tracing on outputs from inputs.""" - inputs = container_types.make_sequence(_EagerList, *i) + inputs = container_types.make_sequence(EagerList, *i) return _record_operation(o, inputs, s, b) @@ -227,9 +234,8 @@ def _record_operation_vjp(g, ans, vs, gvs, output_tensors, input_tensors, """Gradient for _record_operation.""" del ans, vs, gvs, output_tensors, input_tensors backward_args = tuple(g) + tuple(side_outputs) - if ag_core.isnode(backward_args): - backward_args = list(backward_args) + backward_args = container_types.make_sequence(EagerList, *backward_args) tensors = nest.flatten(backward_function(*backward_args)) - return _EagerList([ag_core.getval(t) for t in tensors]) + return container_types.make_sequence(EagerList, *tensors) _record_operation.defvjp(_record_operation_vjp, argnum=1) |