diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-01-07 18:38:53 -0800 |
---|---|---|
committer | Vijay Vasudevan <vrv@google.com> | 2016-01-07 18:38:53 -0800 |
commit | 2b672c4a2f6aeaea8457fd4941f48f5a9e80d283 (patch) | |
tree | 8577fddd4a20eb90c1457e4fef4c319e34a80b3c | |
parent | 827163e960e8cb86d3dc3f70434c22713ac9f41c (diff) |
Add gradients for DynamicPartition
Change: 111650709
-rw-r--r-- | tensorflow/python/kernel_tests/dynamic_partition_op_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/ops/data_flow_grad.py | 17 |
2 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index 63d8d4f5a9..b46fef6951 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -72,11 +72,21 @@ class DynamicPartitionTest(tf.test.TestCase): partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape) for extra_shape in (), (6,), (6, 7): data = np.random.randn(*(shape + extra_shape)) - outputs = tf.dynamic_partition(data, partitions, num_partitions=n) + partitions_t = tf.constant(partitions, dtype=tf.int32) + data_t = tf.constant(data) + outputs = tf.dynamic_partition( + data_t, partitions_t, num_partitions=n) self.assertEqual(n, len(outputs)) - for i, output in enumerate(sess.run(outputs)): + outputs_val = sess.run(outputs) + for i, output in enumerate(outputs_val): self.assertAllEqual(output, data[partitions == i]) + # Test gradients + outputs_grad = [7 * output for output in outputs_val] + grads = tf.gradients(outputs, [data_t, partitions_t], outputs_grad) + self.assertEqual(grads[1], None) # Partitions has no gradients + self.assertAllEqual(7 * data, sess.run(grads[0])) + def testErrorIndexOutOfRange(self): with self.test_session() as sess: data = tf.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py index c46c29ed57..3edc013adc 100644 --- a/tensorflow/python/ops/data_flow_grad.py +++ b/tensorflow/python/ops/data_flow_grad.py @@ -28,6 +28,23 @@ from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import math_ops +@ops.RegisterGradient("DynamicPartition") +def _DynamicPartitionGrads(op, *grads): + """Gradients for DynamicPartition.""" + data = op.inputs[0] + indices = op.inputs[1] + num_partitions = op.get_attr("num_partitions") + + prefix_shape = array_ops.shape(indices) + original_indices = array_ops.reshape( + math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape) + partitioned_indices = data_flow_ops.dynamic_partition( + original_indices, indices, num_partitions) + reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads) + reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data)) + return [reconstructed, None] + + @ops.RegisterGradient("DynamicStitch") def _DynamicStitchGrads(op, grad): """Gradients for DynamicStitch.""" |