diff options
author | Alexandre Passos <apassos@google.com> | 2017-08-30 19:51:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-30 19:55:36 -0700 |
commit | 424aa9aa9559f6fa29d8ccf3d74ff25528b39209 (patch) | |
tree | 99bcce951c10cf6081940919c0124bab80a236be /tensorflow | |
parent | 9acea81c16c222f145515354f6176852dfdb03d7 (diff) |
Eager-graph mode should work with gradient computation.
PiperOrigin-RevId: 167086826
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop.py | 6 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/framework/op_def_library.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 12 |
5 files changed, 30 insertions, 7 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index b75f79fbf4..98dce82ee3 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1766,6 +1766,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":array_ops_gen", + ":dtypes", ":framework_ops", ":resource_variable_ops_gen", ":tensor_shape", @@ -1775,7 +1777,7 @@ py_library( "//tensorflow/python/eager:context", "//tensorflow/python/eager:custom_gradient", "//tensorflow/python/eager:tape", - "//tensorflow/python/eager:tensor", + "//tensorflow/python/eager:tensor_node", ], ) diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index ca3ad1a2c3..326f56ebf9 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -169,10 +169,6 @@ def _record_gradient(op_name, inputs, attrs, results, name): execute.record_gradient = _record_gradient -def _ones(shape, dtype): - return array_ops.fill(shape, tensor.Tensor(1, dtype=dtype)) - - def _aggregate_grads(gradients): """Aggregate gradients of the same tensor.""" grad_lists = dict() @@ -225,7 +221,7 @@ def implicit_val_and_grad(f): (end_node.progenitors, repr(start_node))) output_gradients = kwds.get("output_gradients", None) if output_gradients is None: - output_gradients = _ones(end_node.shape, end_node.dtype) + output_gradients = array_ops.ones_like(end_node.value) grad = ag_core.backward_pass(output_gradients, end_node, start_node) return end_node.value, _aggregate_grads(grad.gradients) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 18b722e792..c15dde9e48 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import function as tf_function from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops class FunctionTest(test.TestCase): @@ -52,6 +53,19 @@ class FunctionTest(test.TestCase): out = sq(t) self.assertAllEqual(out.numpy(), math_ops.matmul(t, t).numpy()) + def testGraphModeWithGradients(self): + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def step(): + def inner(): + tape.watch(v.handle) + return v * v + + return backprop.implicit_grad(inner)()[0][1] + + self.assertAllEqual(step().numpy(), 2.0) + def testTensorConversionWithDefun(self): @function.defun diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py index aa37360066..76424ef579 100644 --- a/tensorflow/python/framework/op_def_library.py +++ b/tensorflow/python/framework/op_def_library.py @@ -784,6 +784,7 @@ class OpDefLibrary(object): if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph + inputs = [ag_core.getval(x) for x in inputs] op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 1d747f8400..1471b5909e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -19,11 +19,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from autograd import core as ag_core + from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import variable_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import custom_gradient from tensorflow.python.eager import tape +from tensorflow.python.eager import tensor_node from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -574,7 +577,14 @@ class ResourceVariable(variables.Variable): def _run_op(a, *args): # pylint: disable=protected-access - return getattr(ops.Tensor, operator)(a._AsTensor(), *args) + value = a._AsTensor() + if ag_core.isnode(value): + # This avoids autograd trying to wrap a ResourceVariable. + value = ops.convert_to_tensor(value) + args = [ops.convert_to_tensor(x) for x in args] + return getattr(tensor_node.TensorNode, operator)(value, *args) + else: + return getattr(ops.Tensor, operator)(value, *args) # Propagate __doc__ to wrapper try: |