diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc | 148 |
1 files changed, 148 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc new file mode 100644 index 0000000000..2c3fc0abbc --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -0,0 +1,148 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" + +#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +StatusOr<bool> ConvCanonicalization::Run(HloModule* module) { + legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags(); + if (!flags->xla_cpu_use_eigen) { + return false; + } + + bool changed = false; + for (HloInstruction* hlo : + module->entry_computation()->MakeInstructionPostOrder()) { + if (hlo->opcode() == HloOpcode::kConvolution && + !PotentiallyImplementedAsEigenConvolution(*hlo)) { + const ConvolutionDimensionNumbers& dnums = + hlo->convolution_dimension_numbers(); + auto batch_dim = dnums.batch_dimension(); + auto feature_dim = dnums.feature_dimension(); + auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); + + int num_spatial_dims = dnums.spatial_dimensions_size(); + int num_dims = num_spatial_dims + 2; + + // A canonical convolution's dimension numbers need to satisfy the + // following conditions (see cs/PotentiallyImplementedAsEigenConvolution). + // + // - the input is in NHWC or NWHC order. + // - the kernel is in HWIO or WHIO order. + // - the spatial dimensions are in the same relative order in the input, + // kernel and output. + // + // For simplicity, as a first step, we reshape the input and filter to + // NHWC and HWIO order, respectively. This may lose precision but not + // break the soundness. + HloInstruction* input = hlo->mutable_operand(0); + + std::vector<int64> new_input_dim_order(num_dims); + std::vector<int64> new_input_dims(num_dims); + new_input_dim_order[0] = batch_dim; + new_input_dims[0] = input->shape().dimensions(batch_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); + new_input_dims[i + 1] = + input->shape().dimensions(dnums.spatial_dimensions(i)); + } + new_input_dim_order[num_dims - 1] = feature_dim; + new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim); + + Shape new_input_shape = + ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims); + HloInstruction* new_input = module->entry_computation()->AddInstruction( + HloInstruction::CreateTranspose(new_input_shape, input, + new_input_dim_order)); + + HloInstruction* kernel = hlo->mutable_operand(1); + + std::vector<int64> new_kernel_dim_order(num_dims); + std::vector<int64> new_kernel_dims(num_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i); + new_kernel_dims[i] = + kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i)); + } + new_kernel_dim_order[num_dims - 2] = kernel_input_feature_dim; + new_kernel_dims[num_dims - 2] = + kernel->shape().dimensions(kernel_input_feature_dim); + new_kernel_dim_order[num_dims - 1] = kernel_output_feature_dim; + new_kernel_dims[num_dims - 1] = + kernel->shape().dimensions(kernel_output_feature_dim); + + Shape new_kernel_shape = + ShapeUtil::MakeShape(kernel->shape().element_type(), new_kernel_dims); + HloInstruction* new_kernel = module->entry_computation()->AddInstruction( + HloInstruction::CreateTranspose(new_kernel_shape, kernel, + new_kernel_dim_order)); + + std::vector<int64> new_conv_dims(num_dims); + new_conv_dims[0] = hlo->shape().dimensions(batch_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + new_conv_dims[i + 1] = + hlo->shape().dimensions(dnums.spatial_dimensions(i)); + } + new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim); + Shape new_conv_shape = + ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); + + ConvolutionDimensionNumbers new_dnums; + new_dnums.set_batch_dimension(0); + for (int i = 0; i < num_spatial_dims; ++i) { + new_dnums.add_spatial_dimensions(i + 1); + new_dnums.add_kernel_spatial_dimensions(i); + } + new_dnums.set_feature_dimension(num_dims - 1); + new_dnums.set_kernel_input_feature_dimension(num_dims - 2); + new_dnums.set_kernel_output_feature_dimension(num_dims - 1); + + // The window of the old convolution is reused, because reshapes only + // change the dimension mapping but not the dimension sizes. For + // example, input height and width are the same as before the reshapes. + HloInstruction* new_conv = module->entry_computation()->AddInstruction( + HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, + hlo->window(), new_dnums)); + + // kConvolution inherits the dimension mapping of its input, so we need to + // reshape the output back to the shape of the original convolution. This + // is done by apply the inverse permutation of the collapsing order of the + // input reshape. + module->entry_computation()->ReplaceWithNewInstruction( + hlo, + HloInstruction::CreateTranspose( + hlo->shape(), new_conv, InversePermutation(new_input_dim_order))); + changed = true; + } + } + + return changed; +} + +} // namespace cpu +} // namespace xla |