diff options
Diffstat (limited to 'tensorflow/python/ops/parallel_for/gradients.py')
-rw-r--r-- | tensorflow/python/ops/parallel_for/gradients.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/python/ops/parallel_for/gradients.py b/tensorflow/python/ops/parallel_for/gradients.py index 460de0a97f..1f026b3660 100644 --- a/tensorflow/python/ops/parallel_for/gradients.py +++ b/tensorflow/python/ops/parallel_for/gradients.py @@ -42,6 +42,7 @@ def jacobian(output, inputs, use_pfor=True): [y_1, ..., y_n, x_1, ..., x_m]. """ flat_inputs = nest.flatten(inputs) + output_tensor_shape = output.shape output_shape = array_ops.shape(output) output = array_ops.reshape(output, [-1]) @@ -65,6 +66,7 @@ def jacobian(output, inputs, use_pfor=True): new_shape = array_ops.concat( [output_shape, array_ops.shape(out)[1:]], axis=0) out = array_ops.reshape(out, new_shape) + out.set_shape(output_tensor_shape.concatenate(flat_inputs[i].shape)) pfor_outputs[i] = out return nest.pack_sequence_as(inputs, pfor_outputs) |