aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-08-29 11:12:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 11:16:45 -0700
commite87b2be0e68412e6a72e5b7184968e1e0b1f9178 (patch)
treea8c3ff0d6f10014a8b8944962de443744c615139
parent5234d453855815957c533921db4f4d5d7c328b37 (diff)
Does not silently turn None into NaNs when aggregating gradients.
PiperOrigin-RevId: 166873653
-rw-r--r--tensorflow/python/eager/backprop_test.py5
-rw-r--r--tensorflow/python/eager/tensor_node.py4
2 files changed, 9 insertions, 0 deletions
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index 37007378a0..863a43f859 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.eager import tensor
+from tensorflow.python.eager import tensor_node
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
@@ -74,6 +75,10 @@ class BackpropTest(test.TestCase):
self.assertAllClose(grad.numpy(), tf_dense_grad.eval())
+ def testTensoVspaceNoneMutAdd(self):
+ t = tensor.Tensor(1.0)
+ self.assertEqual(tensor_node.TensorVSpace(t).mut_add(t, None).numpy(), 1.0)
+
def testImplicitGradWithResourceVariable(self):
x = resource_variable_ops.ResourceVariable(
initial_value=tensor.Tensor(1.0), name='x')
diff --git a/tensorflow/python/eager/tensor_node.py b/tensorflow/python/eager/tensor_node.py
index ae5afaa970..331bf7eef8 100644
--- a/tensorflow/python/eager/tensor_node.py
+++ b/tensorflow/python/eager/tensor_node.py
@@ -300,6 +300,10 @@ class TensorVSpace(ag_core.VSpace):
x = _indexed_slices_to_tensor(x)
if isinstance(y, ops.IndexedSlices):
y = _indexed_slices_to_tensor(y)
+ if x is None:
+ return y
+ if y is None:
+ return x
return math_ops.add(x, y)