aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/data_flow_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/data_flow_grad.py')
-rw-r--r--tensorflow/python/ops/data_flow_grad.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/ops/data_flow_grad.py b/tensorflow/python/ops/data_flow_grad.py
new file mode 100644
index 0000000000..d2473490ce
--- /dev/null
+++ b/tensorflow/python/ops/data_flow_grad.py
@@ -0,0 +1,37 @@
+"""Gradients for operators defined in data_flow_ops.py."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_data_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+@ops.RegisterGradient("DynamicStitch")
+def _DynamicStitchGrads(op, grad):
+ """Gradients for DynamicStitch."""
+
+ num_values = len(op.inputs) / 2
+ indices_grad = [None] * num_values
+
+ def AsInt32(x):
+ return (x if op.inputs[0].dtype == types.int32 else
+ math_ops.cast(x, types.int32))
+ inputs = [AsInt32(op.inputs[i]) for i in range(num_values)]
+ if isinstance(grad, ops.IndexedSlices):
+ output_shape = array_ops.shape(op.outputs[0])
+ output_rows = output_shape[0]
+ grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows)
+ values_grad = [array_ops.gather(grad, inp) for inp in inputs]
+ return indices_grad + values_grad
+
+
+ops.NoGradient("Queue")
+ops.NoGradient("QueueEnqueue")
+ops.NoGradient("QueueEnqueueMany")
+ops.NoGradient("QueueDequeue")
+ops.NoGradient("QueueDequeueMany")
+ops.NoGradient("QueueClose")
+ops.NoGradient("QueueSize")