aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-10-19 16:42:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-19 16:49:12 -0700
commite885d1abdce5db4a67e0b3ba85dbcc708f856645 (patch)
tree8052e4e3c06c026b2b32f272b4cd12d2c2aeab0b
parent2977dccc96c343ca85cb00b50672b36c99656532 (diff)
One less error message in gradients_function
PiperOrigin-RevId: 172818233
-rw-r--r--tensorflow/python/eager/backprop.py11
-rw-r--r--tensorflow/python/eager/backprop_test.py8
2 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index da17be05b7..9580e84847 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -396,12 +396,11 @@ def implicit_grad(f):
return grad_fn
-def _get_arg_spec(f, params):
+def _get_arg_spec(f, params, param_args):
args = tf_inspect.getargspec(f).args
if params is None:
if not args:
- raise ValueError("When params is None the differentiated function cannot"
- " only take arguments by *args and **kwds.")
+ return range(len(param_args))
return range(len(args))
elif all(isinstance(x, six.string_types) for x in params):
return [args.index(n) for n in params]
@@ -560,10 +559,9 @@ def val_and_grad_function(f, params=None):
ValueError: if the params are not all strings or all integers.
"""
- parameter_positions = _get_arg_spec(f, params)
-
def decorated(*args, **kwds):
"""Computes the value and gradient of the decorated function."""
+ parameter_positions = _get_arg_spec(f, params, args)
dy = kwds.pop("dy", None)
if dy is not None:
dy = ops.convert_to_tensor(dy)
@@ -616,10 +614,9 @@ def make_vjp(f, params=None):
"""
- parameter_positions = _get_arg_spec(f, params)
-
def decorated(*args, **kwds):
"""Computes the value and gradient of the decorated function."""
+ parameter_positions = _get_arg_spec(f, params, args)
assert not kwds, "The gradient function can't take keyword arguments."
tape.push_new_tape()
sources = []
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 95d5f0adcb..7da8eb0c9b 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -381,6 +381,14 @@ class BackpropTest(test.TestCase):
[tensor_shape.TensorShape(s).as_proto() for s in shape_list],
backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list))
+ def testArgsGradientFunction(self):
+
+ def f(*args):
+ return args[0] * args[0]
+
+ grad = backprop.gradients_function(f)
+ self.assertAllEqual(grad(1.0)[0], 2.0)
+
def testMultiValueConvertToTensor(self):
x = resource_variable_ops.ResourceVariable(
initial_value=array_ops.constant([1.0]), name='x')