diff options
author | Saurabh Saxena <srbs@google.com> | 2018-09-12 08:41:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 08:46:35 -0700 |
commit | 9333978b4b08e4b3fdc7f63ec0873a7e00dcc4b7 (patch) | |
tree | 177268e284ed978959862ed056f33ed232900a68 /tensorflow/python/ops | |
parent | 9098f75af917df9b9d4f5ecc423037fd2fb365f9 (diff) |
Support providing default gradient for variant tensors in tf.gradients call.
PiperOrigin-RevId: 212645190
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 21 |
2 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 3268b38b86..196161c661 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys, "Gradient type %s generated for complex-valued " "tensor %s with type %s must be real" % (dtypes.as_dtype( grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) + elif y.dtype == dtypes.variant: + if grad_y.dtype != dtypes.variant: + raise TypeError( + "Gradient type %s generated for variant " + "tensor %s with type %s must be variant" % (dtypes.as_dtype( + grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) else: raise TypeError( "Tensor %s with type %s must be numeric " @@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor): if _IsTrainable(tensor): return True dtype = dtypes.as_dtype(tensor.dtype) - return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant) + return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant) def _VerifyGeneratedGradients(grads, op): diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 3759d8a543..6243be6c9e 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import from tensorflow.python.ops import functional_ops # pylint: disable=unused-import from tensorflow.python.ops import gradients from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_grad # pylint: disable=unused-import @@ -1004,5 +1005,25 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase): self._assert_indexed_slices_equal(total, result) +class TensorListGradientsTest(test_util.TensorFlowTestCase): + + def testDefaultGradYs(self): + with ops.Graph().as_default(): + tl = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + a = constant(1.0) + tl = list_ops.tensor_list_push_back(tl, a) + + grad_tl = list_ops.empty_tensor_list( + element_dtype=dtypes.float32, + element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) + grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) + + grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] + with self.cached_session() as sess: + self.assertEquals(sess.run(grad), 5.) + + if __name__ == "__main__": googletest.main() |