diff options
author | Geoffrey Irving <geoffreyi@google.com> | 2016-04-14 11:52:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-14 13:01:53 -0700 |
commit | a58e5caa9a3e890a2137791e8e4a0870828fa882 (patch) | |
tree | 8da87df94353c87422800ab8b42839a4a9d9582c | |
parent | 0d088fbaaa724a3b5d19870c1ae334671af3269e (diff) |
Extend singleton case in _AggregatedGrads to IndexedSlices
If we are aggregating one gradient, we can just pass it through. Previously
the code did this only for tf.Tensor, not tf.IndexedSlices.
Change: 119880763
-rw-r--r-- | tensorflow/python/ops/gradients.py | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 12 |
2 files changed, 18 insertions, 6 deletions
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index fcf3af8183..f07454028c 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -638,13 +638,13 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): "or all IndexedSlices") # Aggregate multiple gradients, and convert [] to None. if out_grad: - if all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]): + if len(out_grad) < 2: + used = "nop" + out_grads[i] = out_grad[0] + elif all([isinstance(g, ops.Tensor) for g in out_grad if g is not None]): tensor_shape = _AccumulatorShape(out_grad) - if len(out_grad) < 2: - used = "nop" - out_grads[i] = out_grad[0] - elif (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N - and len(out_grad) > 2 and tensor_shape.is_fully_defined()): + if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N + and len(out_grad) > 2 and tensor_shape.is_fully_defined()): # The benefit of using AccumulateN is that its inputs can be combined # in any order and this can allow the expression to be evaluated with # a smaller memory footprint. When used with gpu_allocator_retry, diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 492d60931f..77711569b3 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -260,6 +260,18 @@ class GradientsTest(test_util.TensorFlowTestCase): grads = gradients.gradients(z, [c]) self.assertTrue(isinstance(grads[0], ops.Tensor)) + def testSingletonIndexedSlices(self): + with ops.Graph().as_default(): + x = tf.placeholder(tf.float32) + y = tf.identity(x) + dy = tf.IndexedSlices(tf.placeholder(tf.float32), + tf.placeholder(tf.int32)) + dx, = gradients.gradients(y, x, grad_ys=dy) + # The gradient of tf.identity should pass the value through unchanged. + # A previous version of the code did this only for tf.Tensor, not + # tf.IndexedSlices. + self.assertEqual(dx, dy) + class FunctionGradientsTest(test_util.TensorFlowTestCase): |