aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 17:44:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 17:48:32 -0700
commit391cdd80952e9cc546d82a8bf2fe7dd04f46cb2f (patch)
treea253b5bd01d2088d07e1cae505028a600d834384 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parent9ab01c6732dae1143e22713375a9cc7758216787 (diff)
Add cuDNN fused convolution forward support.
The tests are in the next patch. PiperOrigin-RevId: 214362688
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc6
1 files changed, 5 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index eead408f10..7e77dc9ac6 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -162,8 +162,12 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
VLOG(1) << "Canonicalizing forward conv";
+ std::vector<HloInstruction*> operands(conv->operands().begin(),
+ conv->operands().end());
+ operands[0] = new_input;
+ operands[1] = new_kernel;
auto new_conv = conv->parent()->AddInstruction(
- conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel}));
+ conv->CloneWithNewOperands(conv->shape(), operands));
new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();