diff options
author | Tim Shen <timshen@google.com> | 2018-09-24 17:44:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 17:48:32 -0700 |
commit | 391cdd80952e9cc546d82a8bf2fe7dd04f46cb2f (patch) | |
tree | a253b5bd01d2088d07e1cae505028a600d834384 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc | |
parent | 9ab01c6732dae1143e22713375a9cc7758216787 (diff) |
Add cuDNN fused convolution forward support.
The tests are in the next patch.
PiperOrigin-RevId: 214362688
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_insertion.cc | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index eead408f10..7e77dc9ac6 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -162,8 +162,12 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. VLOG(1) << "Canonicalizing forward conv"; + std::vector<HloInstruction*> operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; auto new_conv = conv->parent()->AddInstruction( - conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel})); + conv->CloneWithNewOperands(conv->shape(), operands)); new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); |