aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/parallel_for/gradients.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/parallel_for/gradients.py')
-rw-r--r--tensorflow/python/ops/parallel_for/gradients.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py
index 460de0a97f..1f026b3660 100644
--- a/tensorflow/python/ops/parallel_for/gradients.py
+++ b/tensorflow/python/ops/parallel_for/gradients.py
@@ -42,6 +42,7 @@ def jacobian(output, inputs, use_pfor=True):
[y_1, ..., y_n, x_1, ..., x_m].
"""
flat_inputs = nest.flatten(inputs)
+ output_tensor_shape = output.shape
output_shape = array_ops.shape(output)
output = array_ops.reshape(output, [-1])
@@ -65,6 +66,7 @@ def jacobian(output, inputs, use_pfor=True):
new_shape = array_ops.concat(
[output_shape, array_ops.shape(out)[1:]], axis=0)
out = array_ops.reshape(out, new_shape)
+ out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape))
pfor_outputs[i] = out
return nest.pack_sequence_as(inputs, pfor_outputs)