aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
diff options
context:
space:
mode:
authorGravatar Tim Shen <timshen@google.com>2018-09-24 16:09:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 16:17:17 -0700
commitd25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (patch)
treeafc40810ae059f2459bd0ea9f8ac2a0235c101e4 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc
parent29a67eaedd8d95866011bb1c87a9d1739d448686 (diff)
Remove the public uses of CreateCudnnConv* in the favor of
CloneWithNewOperands. CreateCudnnConv* is easy to use wrongly, as it doesn't propagate backend_config. PiperOrigin-RevId: 214348788
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc25
1 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 2a6415d0b6..eead408f10 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -161,12 +161,10 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
// The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract
// out the shape of conv_result.
- Shape old_conv_shape = conv->shape().tuple_shapes(0);
-
VLOG(1) << "Canonicalizing forward conv";
- auto new_conv = CreateCudnnConvForward(
- old_conv_shape, new_input, new_kernel, new_conv_window,
- conv->convolution_dimension_numbers(), conv->feature_group_count());
+ auto new_conv = conv->parent()->AddInstruction(
+ conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel}));
+ new_conv->set_window(new_conv_window);
VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n "
<< new_conv->ToString();
TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv));
@@ -242,10 +240,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// The shape of the backward_conv CustomCall is a tuple (conv_result,
// scratch_buffer). Extract out the shape of conv_result.
- Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0);
- HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter(
- backward_conv_shape, padded_input, output, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ backward_conv->shape(), {padded_input, output}));
+ new_backward_conv->set_window(new_backward_conv_window);
VLOG(1) << "Canonicalizing backward filter conv";
VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n "
@@ -308,9 +306,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution(
HloInstruction* output = backward_conv->mutable_operand(0);
HloInstruction* filter = backward_conv->mutable_operand(1);
- HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput(
- new_backward_conv_shape, output, filter, new_backward_conv_window,
- backward_conv_dnums, backward_conv->feature_group_count());
+ HloInstruction* new_backward_conv_call =
+ computation->AddInstruction(backward_conv->CloneWithNewOperands(
+ ShapeUtil::MakeTupleShape(
+ {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}),
+ {output, filter}));
+ new_backward_conv_call->set_window(new_backward_conv_window);
// The CustomCall created above returns a tuple (conv_result, scratch_memory).
// Extract out the two elements.