aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 3268b38b86..196161c661 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys,
"Gradient type %s generated for complex-valued "
"tensor %s with type %s must be real" % (dtypes.as_dtype(
grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
raise TypeError(
"Tensor %s with type %s must be numeric "
@@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor):
if _IsTrainable(tensor):
return True
dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant)
+ return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant)
def _VerifyGeneratedGradients(grads, op):