diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 3268b38b86..196161c661 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -260,6 +260,12 @@ def _DefaultGradYs(grad_ys, "Gradient type %s generated for complex-valued " "tensor %s with type %s must be real" % (dtypes.as_dtype( grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) + elif y.dtype == dtypes.variant: + if grad_y.dtype != dtypes.variant: + raise TypeError( + "Gradient type %s generated for variant " + "tensor %s with type %s must be variant" % (dtypes.as_dtype( + grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) else: raise TypeError( "Tensor %s with type %s must be numeric " @@ -298,7 +304,7 @@ def _IsBackpropagatable(tensor): if _IsTrainable(tensor): return True dtype = dtypes.as_dtype(tensor.dtype) - return dtype.base_dtype in (dtypes.bfloat16, dtypes.resource, dtypes.variant) + return dtype.base_dtype in (dtypes.bfloat16, dtypes.variant) def _VerifyGeneratedGradients(grads, op): |