diff options
author | Tim Shen <timshen@google.com> | 2018-09-24 16:09:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 16:17:17 -0700 |
commit | d25b23d5ec6a0a7828e86fa8868f7a6574f9f827 (patch) | |
tree | afc40810ae059f2459bd0ea9f8ac2a0235c101e4 /tensorflow/compiler/xla/service/gpu/pad_insertion.cc | |
parent | 29a67eaedd8d95866011bb1c87a9d1739d448686 (diff) |
Remove the public uses of CreateCudnnConv* in the favor of
CloneWithNewOperands. CreateCudnnConv* is easy to use wrongly, as it
doesn't propagate backend_config.
PiperOrigin-RevId: 214348788
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/pad_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/pad_insertion.cc | 25 |
1 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2a6415d0b6..eead408f10 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -161,12 +161,10 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. - Shape old_conv_shape = conv->shape().tuple_shapes(0); - VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward( - old_conv_shape, new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers(), conv->feature_group_count()); + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), {new_input, new_kernel})); + new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -242,10 +240,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( - backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -308,9 +306,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( - new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. |