aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2018-09-12 08:41:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 08:46:35 -0700
commit9333978b4b08e4b3fdc7f63ec0873a7e00dcc4b7 (patch)
tree177268e284ed978959862ed056f33ed232900a68 /tensorflow/python/ops
parent9098f75af917df9b9d4f5ecc423037fd2fb365f9 (diff)
Support providing default gradient for variant tensors in tf.gradients call.
PiperOrigin-RevId: 212645190
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/gradients_impl.py8
-rw-r--r--tensorflow/python/ops/gradients_test.py21
2 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 3268b38b86..196161c661 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys,
"Gradient type %s generated for complex-valued "
"tensor %s with type %s must be real" % (dtypes.as_dtype(
grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
raise TypeError(
"Tensor %s with type %s must be numeric "
@@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor):
if _IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+ return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
def _VerifyGeneratedGradients(grads, op):
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3759d8a543..6243be6c9e 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -45,6 +45,7 @@ from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
@@ -1004,5 +1005,25 @@ class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
self._assert_indexed_slices_equal(total, result)
+class TensorListGradientsTest(test_util.TensorFlowTestCase):
+
+ def testDefaultGradYs(self):
+ with ops.Graph().as_default():
+ tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ a = constant(1.0)
+ tl = list_ops.tensor_list_push_back(tl, a)
+
+ grad_tl = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32,
+ element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
+ grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
+
+ grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
+ with self.cached_session() as sess:
+ self.assertEquals(sess.run(grad), 5.)
+
+
if __name__ == "__main__":
googletest.main()