aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-07 18:38:53 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-07 18:38:53 -0800
commit2b672c4a2f6aeaea8457fd4941f48f5a9e80d283 (patch)
tree8577fddd4a20eb90c1457e4fef4c319e34a80b3c
parent827163e960e8cb86d3dc3f70434c22713ac9f41c (diff)
Add gradients for DynamicPartition
Change: 111650709
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py14
-rw-r--r--tensorflow/python/ops/data_flow_grad.py17
2 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 63d8d4f5a9..b46fef6951 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -72,11 +72,21 @@ class DynamicPartitionTest(tf.test.TestCase):
partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
for extra_shape in (), (6,), (6, 7):
data = np.random.randn(*(shape + extra_shape))
- outputs = tf.dynamic_partition(data, partitions, num_partitions=n)
+ partitions_t = tf.constant(partitions, dtype=tf.int32)
+ data_t = tf.constant(data)
+ outputs = tf.dynamic_partition(
+ data_t, partitions_t, num_partitions=n)
self.assertEqual(n, len(outputs))
- for i, output in enumerate(sess.run(outputs)):
+ outputs_val = sess.run(outputs)
+ for i, output in enumerate(outputs_val):
self.assertAllEqual(output, data[partitions == i])
+ # Test gradients
+ outputs_grad = [7 * output for output in outputs_val]
+ grads = tf.gradients(outputs, [data_t, partitions_t], outputs_grad)
+ self.assertEqual(grads[1], None) # Partitions has no gradients
+ self.assertAllEqual(7 * data, sess.run(grads[0]))
+
def testErrorIndexOutOfRange(self):
with self.test_session() as sess:
data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8],
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
index c46c29ed57..3edc013adc 100644
--- a/tensorflow/python/ops/data_flow_grad.py
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -28,6 +28,23 @@ from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
+@ops.RegisterGradient("DynamicPartition")
+def _DynamicPartitionGrads(op, *grads):
+ """Gradients for DynamicPartition."""
+ data = op.inputs[0]
+ indices = op.inputs[1]
+ num_partitions = op.get_attr("num_partitions")
+
+ prefix_shape = array_ops.shape(indices)
+ original_indices = array_ops.reshape(
+ math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
+ partitioned_indices = data_flow_ops.dynamic_partition(
+ original_indices, indices, num_partitions)
+ reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
+ reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
+ return [reconstructed, None]
+
+
@ops.RegisterGradient("DynamicStitch")
def _DynamicStitchGrads(op, grad):
"""Gradients for DynamicStitch."""