diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-24 17:48:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-24 17:50:57 -0700 |
commit | f6066436884476d7bc32cf2ad6cfc8d9c52b5482 (patch) | |
tree | 1a6aece3e70ab0c0bddc758f401a12cab67e8bd1 | |
parent | 0c940ff33add2e8481cc1a5a166d8af72a5a21f9 (diff) |
Add heuristic on picking NHWC layout for (V100, fp16) convolutions.
Also move AlgorithmPicker after layout assignment, as now
cudnn_convolution_runner will return failures on invalid input layouts.
Also add a backend debug option to switch the layout heuristic. By default
it has the old behavior (all NCHW).
PiperOrigin-RevId: 197983747
-rw-r--r-- | tensorflow/compiler/xla/layout_util.cc | 10 | ||||
-rw-r--r-- | tensorflow/compiler/xla/layout_util.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 56 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc | 121 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h | 13 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_options.cc | 28 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_options.h | 33 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/stream_executor_util.cc | 151 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/stream_executor_util.h | 46 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/BUILD | 50 |
13 files changed, 459 insertions, 108 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index a76fdcda25..89cafa1a7d 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice<int64> major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index d3d6a2cc94..739bbe7367 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice<int64> major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index aafb61b583..ffb1af2d87 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -338,6 +338,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -590,14 +591,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -694,6 +699,27 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 10b4c3de89..0645fbb3ad 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -113,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -126,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -149,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 1445684e5d..5ef422c90b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -202,18 +202,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(); pipeline.AddPass<CudnnConvolutionRewriter>(); pipeline.AddPass<PadInsertion>(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass<GpuLayoutAssignment>( + hlo_module->mutable_device_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -229,35 +239,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass<TupleSimplifier>(); - pipeline.AddPass<HloDCE>(); - - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("layout_assignment"); - pipeline.AddPass<GpuLayoutAssignment>( - hlo_module->mutable_device_entry_computation_layout()); - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>( - /*is_layout_sensitive=*/true, - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e62588..178457721a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple<DataLayout, FilterLayout, DataLayout> +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector<int64> input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector<int64> filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector<int64> output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 86a3a7111f..ce24af1cf8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,8 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4c45d2e94a..e48165c142 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 0000000000..174aaf122c --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_use_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 0000000000..498d4a9495 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 0000000000..a50ddf6ac6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr<std::tuple<Layout, Layout, Layout>> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector<int64> input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector<int64> filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector<int64> output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 0000000000..8218f4fd11 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr<std::tuple<Layout, Layout, Layout>> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr<std::tuple<stream_executor::dnn::DataLayout, + stream_executor::dnn::FilterLayout, + stream_executor::dnn::DataLayout>> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index fd54ac761c..1a12fd0113 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -776,30 +776,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_use_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( |