diff options
author | 2017-10-19 16:42:45 -0700 | |
---|---|---|
committer | 2017-10-19 16:49:12 -0700 | |
commit | e885d1abdce5db4a67e0b3ba85dbcc708f856645 (patch) | |
tree | 8052e4e3c06c026b2b32f272b4cd12d2c2aeab0b | |
parent | 2977dccc96c343ca85cb00b50672b36c99656532 (diff) |
One less error message in gradients_function
PiperOrigin-RevId: 172818233
-rw-r--r-- | tensorflow/python/eager/backprop.py | 11 | ||||
-rw-r--r-- | tensorflow/python/eager/backprop_test.py | 8 |
2 files changed, 12 insertions, 7 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index da17be05b7..9580e84847 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -396,12 +396,11 @@ def implicit_grad(f): return grad_fn -def _get_arg_spec(f, params): +def _get_arg_spec(f, params, param_args): args = tf_inspect.getargspec(f).args if params is None: if not args: - raise ValueError("When params is None the differentiated function cannot" - " only take arguments by *args and **kwds.") + return range(len(param_args)) return range(len(args)) elif all(isinstance(x, six.string_types) for x in params): return [args.index(n) for n in params] @@ -560,10 +559,9 @@ def val_and_grad_function(f, params=None): ValueError: if the params are not all strings or all integers. """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) dy = kwds.pop("dy", None) if dy is not None: dy = ops.convert_to_tensor(dy) @@ -616,10 +614,9 @@ def make_vjp(f, params=None): """ - parameter_positions = _get_arg_spec(f, params) - def decorated(*args, **kwds): """Computes the value and gradient of the decorated function.""" + parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." tape.push_new_tape() sources = [] diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 95d5f0adcb..7da8eb0c9b 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -381,6 +381,14 @@ class BackpropTest(test.TestCase): [tensor_shape.TensorShape(s).as_proto() for s in shape_list], backprop.make_attr([pywrap_tensorflow.TF_ATTR_SHAPE], shape_list)) + def testArgsGradientFunction(self): + + def f(*args): + return args[0] * args[0] + + grad = backprop.gradients_function(f) + self.assertAllEqual(grad(1.0)[0], 2.0) + def testMultiValueConvertToTensor(self): x = resource_variable_ops.ResourceVariable( initial_value=array_ops.constant([1.0]), name='x') |