aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-24 17:48:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 17:50:57 -0700
commitf6066436884476d7bc32cf2ad6cfc8d9c52b5482 (patch)
tree1a6aece3e70ab0c0bddc758f401a12cab67e8bd1
parent0c940ff33add2e8481cc1a5a166d8af72a5a21f9 (diff)
Add heuristic on picking NHWC layout for (V100, fp16) convolutions.
Also move AlgorithmPicker after layout assignment, as now cudnn_convolution_runner will return failures on invalid input layouts. Also add a backend debug option to switch the layout heuristic. By default it has the old behavior (all NCHW). PiperOrigin-RevId: 197983747
-rw-r--r--tensorflow/compiler/xla/layout_util.cc10
-rw-r--r--tensorflow/compiler/xla/layout_util.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD26
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc56
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc121
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_options.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_options.h33
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.cc151
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_executor_util.h46
-rw-r--r--tensorflow/compiler/xla/tests/BUILD50
13 files changed, 459 insertions, 108 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index a76fdcda25..89cafa1a7d 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer(
return layout;
}
+/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
+ tensorflow::gtl::ArraySlice<int64> major_to_minor) {
+ Layout layout;
+ layout.set_format(DENSE);
+ for (int i = major_to_minor.size() - 1; i >= 0; i--) {
+ layout.add_minor_to_major(major_to_minor[i]);
+ }
+ return layout;
+}
+
/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) {
Layout layout;
layout.set_format(SPARSE);
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index d3d6a2cc94..739bbe7367 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -36,6 +36,10 @@ class LayoutUtil {
// convenience function for protobuf construction.)
static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
+ // Similar to MakeLayout, but take indices in reverse order.
+ static Layout MakeLayoutFromMajorToMinor(
+ tensorflow::gtl::ArraySlice<int64> major_to_minor);
+
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index aafb61b583..ffb1af2d87 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -338,6 +338,7 @@ cc_library(
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
+ ":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
@@ -590,14 +591,18 @@ cc_library(
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
deps = [
+ ":gpu_options",
":ir_emission_utils",
+ ":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
],
)
@@ -694,6 +699,27 @@ cc_library(
],
)
+cc_library(
+ name = "gpu_options",
+ srcs = ["gpu_options.cc"],
+ hdrs = ["gpu_options.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "stream_executor_util",
+ srcs = ["stream_executor_util.cc"],
+ hdrs = ["stream_executor_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
tf_cc_test(
name = "gpu_hlo_support_checker_test",
srcs = ["gpu_hlo_support_checker_test.cc"],
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
index 10b4c3de89..0645fbb3ad 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
@@ -113,8 +115,17 @@ Status RunCudnnConvolution(
// cuDNN's convolution APIs support the BDYX layout for activations/output and
// the OIYX layout for weights.
+ DataLayout input_dl;
+ FilterLayout filter_dl;
+ DataLayout output_dl;
+
+ TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
+ XlaConvLayoutsToStreamExecutorLayouts(
+ dnums, input_shape.layout(), filter_shape.layout(),
+ output_shape.layout()));
+
BatchDescriptor input_descriptor(effective_num_dimensions);
- input_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ input_descriptor.set_layout(input_dl)
.set_feature_map_count(
input_shape.dimensions(dnums.input_feature_dimension()))
.set_count(input_shape.dimensions(dnums.input_batch_dimension()));
@@ -126,7 +137,7 @@ Status RunCudnnConvolution(
}
FilterDescriptor filter_descriptor(effective_num_dimensions);
- filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
+ filter_descriptor.set_layout(filter_dl)
.set_input_feature_map_count(
filter_shape.dimensions(dnums.kernel_input_feature_dimension()))
.set_output_feature_map_count(
@@ -149,7 +160,7 @@ Status RunCudnnConvolution(
}
BatchDescriptor output_descriptor(effective_num_dimensions);
- output_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ output_descriptor.set_layout(output_dl)
.set_feature_map_count(
output_shape.dimensions(dnums.output_feature_dimension()))
.set_count(output_shape.dimensions(dnums.output_batch_dimension()));
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 1445684e5d..5ef422c90b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -202,18 +202,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<CudnnConvolutionRewriter>();
pipeline.AddPass<PadInsertion>();
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+
+ {
+ HloPassPipeline pipeline("layout_assignment");
+ pipeline.AddPass<GpuLayoutAssignment>(
+ hlo_module->mutable_device_entry_computation_layout(), stream_exec);
+
+ // The LayoutAssignment pass may leave behind kCopy instructions which are
+ // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
+ return true;
+ });
// Choose the fastest algorithm for each conv.
//
- // In theory doing this here is way too early: It needs to happen after
- // layout assignment, because the layout of the inputs/outputs affects the
- // speed of the conv. But currently we only allow only one input/output
- // layout when calling cudnn, so there's no ambiguity.
- //
- // We pick the algorithm at this early stage so we can generate better HLO.
- // After CudnnConvolutionRewriter, our convolutions are CustomCalls which
- // return a tuple (conv_result, scratch_memory), and the each conv uses 0
- // bytes of scratch:
+ // We pick the algorithm before fusion so we can generate better HLO. After
+ // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a
+ // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of
+ // scratch:
//
// customcall = (f32[...], f32[0])
// return gte(customcall, 0)
@@ -229,35 +239,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// The new tuple and gte instructions then be simplified away, because
// nobody is expected to use the scratch value.
//
- // However, if we were to run CudnnConvolutionAlgorithmPicker after layout
- // assignment, fusion would already have run, and the gte(customcall, 0)
- // would probably already be into a fusion node. We can't simplify across
- // HloComputation boundaries, so in this case we wouldn't be able to
- // simplify away the new_tuple bits.
- //
- // We'll need to revisit this if we ever allow multiple layouts for the
- // inputs/outputs of a cudnn convolution.
+ // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion
+ // the gte(customcall, 0) would probably already be into a fusion node. We
+ // can't simplify across HloComputation boundaries, so in this case we
+ // wouldn't be able to simplify away the new_tuple bits.
pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
device_allocator);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
- pipeline.AddPass<HloDCE>();
-
- TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
- }
-
- {
- HloPassPipeline pipeline("layout_assignment");
- pipeline.AddPass<GpuLayoutAssignment>(
- hlo_module->mutable_device_entry_computation_layout());
- // The LayoutAssignment pass may leave behind kCopy instructions which are
- // duplicate or NOPs, so remove them with algebraic simplification and CSE.
- pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
- /*is_layout_sensitive=*/true,
- /*valid_bitcast_callback=*/[](const Shape&, const Shape&) {
- return true;
- });
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 89f1e62588..178457721a 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -18,31 +18,72 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace gpu {
-// cuDNN convolutions are called with specific layouts on the input, output,
-// and filter:
-//
-// input: DataLayout::kBatchDepthYX
-// output: DataLayout::kBatchDepthYX
-// filter: FilterLayout::kOutputInputYX
-//
-// The order dimensions in the constant name is major-to-minor (eg, the
-// most-major dimension of the input is batch, most-minor is X). The
-// specific dimension numbers these named dimensions correspond to is
-// determined by the ConvolutionDimensionNumbers argument. Y is spatial
-// dimension 0, and X is spatial dimension 1.
-//
-// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
-static Status AddBackendConstraintsToDnnConvCustomCall(
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::FilterLayout;
+
+static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) {
+ int major, minor;
+ CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major,
+ &minor));
+ return major >= 7;
+}
+
+// Returns (input, filter, output) layouts.
+static std::tuple<DataLayout, FilterLayout, DataLayout>
+HeuristicLayoutAssignment(const HloInstruction* instr,
+ stream_executor::StreamExecutor* stream_executor) {
+ // DataLayout and FilterLayout uses weird enum names. Translations:
+ // N <=> Batch or Output
+ // C <=> Depth or Input
+ // H <=> Y
+ // W <=> X
+ //
+ // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW.
+
+ // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
+ // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
+ // changes, as well as the hardware updates.
+ if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
+ IsVoltaOrLater(*stream_executor))) {
+ return std::make_tuple(DataLayout::kBatchDepthYX,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ }
+ VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
+ // For BackwardInput that has stride, full NHWC layouts run significantly
+ // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
+ //
+ // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
+ // NHWC).
+ if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
+ window_util::HasStride(instr->window())) {
+ return std::make_tuple(DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ }
+ return std::make_tuple(DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth);
+}
+
+// Adds layout constraints on the cudnn custom-call instruction. The layout
+// constraints are represented in terms of minor_to_major fields of both
+// operands and the output shape. Depending on the underlying algorithm, one of
+// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen.
+Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
HloInstruction* instr, LayoutConstraints* constraints) {
CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString();
Shape input_shape;
@@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall(
<< instr->custom_call_target();
}
- // Construct minor-to-major dimension orders for operands and result.
- // cuDNN's convolution APIs support the BDYX layout for activations/output
- // and the OIYX layout for weights.
- // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
- // calls after we switch to cuDNN v5.
- const ConvolutionDimensionNumbers& dimension_numbers =
- instr->convolution_dimension_numbers();
- std::vector<int64> input_layout;
- for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- input_layout.push_back(dimension_numbers.input_spatial_dimensions(i));
- }
- input_layout.push_back(dimension_numbers.input_feature_dimension());
- input_layout.push_back(dimension_numbers.input_batch_dimension());
- *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout);
-
- std::vector<int64> filter_layout;
- for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i));
- }
- filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension());
- filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension());
- *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout);
-
- std::vector<int64> output_layout;
- for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0;
- --i) {
- output_layout.push_back(dimension_numbers.output_spatial_dimensions(i));
+ {
+ DataLayout input;
+ FilterLayout filter;
+ DataLayout output;
+ if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, stream_executor_);
+ } else {
+ input = DataLayout::kBatchDepthYX;
+ filter = FilterLayout::kOutputInputYX;
+ output = DataLayout::kBatchDepthYX;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(),
+ *output_shape.mutable_layout()),
+ StreamExecutorConvLayoutsToXlaLayouts(
+ instr->convolution_dimension_numbers(), input, filter, output));
}
- output_layout.push_back(dimension_numbers.output_feature_dimension());
- output_layout.push_back(dimension_numbers.output_batch_dimension());
- *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout);
// The custom call returns a tuple of (actual_result, scratch_buffer);
// call_result_buf is the logical buffer for actual_result, the thing that
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 86a3a7111f..ce24af1cf8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
@@ -27,8 +28,10 @@ namespace gpu {
// layout constraints for operands and results of library calls.
class GpuLayoutAssignment : public LayoutAssignment {
public:
- explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
- : LayoutAssignment(entry_computation_layout) {}
+ explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout,
+ se::StreamExecutor* stream_executor)
+ : LayoutAssignment(entry_computation_layout),
+ stream_executor_(stream_executor) {}
~GpuLayoutAssignment() override {}
protected:
@@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment {
LayoutConstraints* constraints) override;
bool CustomCallRequiresMajorFirstLayout(
const HloInstruction* instruction) override;
+
+ private:
+ Status AddBackendConstraintsToDnnConvCustomCall(
+ HloInstruction* instr, LayoutConstraints* constraints);
+
+ se::StreamExecutor* stream_executor_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 4c45d2e94a..e48165c142 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) {
*computation_layout.mutable_result_layout() =
ShapeLayout(result_shape_with_layout);
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
for (const HloInstruction* operand : add->operands()) {
@@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) {
*computation_layout.mutable_result_layout() = ShapeLayout(result_shape);
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) {
{result_shape, offset_scale_shape, offset_scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first operand to batchnorm should have the same layout as the
@@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
{result_shape, scale_shape, scale_shape}));
}
- GpuLayoutAssignment layout_assignment(&computation_layout);
+ GpuLayoutAssignment layout_assignment(
+ &computation_layout, backend().default_stream_executor());
EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
// The first and fourth operands to the batchnorm call should have the
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc
new file mode 100644
index 0000000000..174aaf122c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc
@@ -0,0 +1,28 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace xla {
+namespace gpu {
+
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
+ return config.debug_options().xla_backend_extra_options().count(
+ "xla_gpu_experimental_conv_use_layout_heuristic");
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h
new file mode 100644
index 0000000000..498d4a9495
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+// Helper functions for querying options that are specific to the GPU backend.
+
+namespace xla {
+namespace gpu {
+
+// Returns true if we should use heuristics to assign convolution layouts, as
+// opposed to always assigning NCHW.
+bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
new file mode 100644
index 0000000000..a50ddf6ac6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+
+namespace xla {
+namespace gpu {
+
+using stream_executor::dnn::DataLayout;
+using stream_executor::dnn::DataLayoutString;
+using stream_executor::dnn::FilterLayout;
+using stream_executor::dnn::FilterLayoutString;
+
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+ DataLayout input, FilterLayout filter,
+ DataLayout output) {
+ std::vector<int64> input_layout;
+ switch (input) {
+ case DataLayout::kBatchDepthYX:
+ input_layout.push_back(dnums.input_batch_dimension());
+ input_layout.push_back(dnums.input_feature_dimension());
+ input_layout.insert(input_layout.end(),
+ dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end());
+ break;
+ case DataLayout::kBatchYXDepth:
+ input_layout.push_back(dnums.input_batch_dimension());
+ input_layout.insert(input_layout.end(),
+ dnums.input_spatial_dimensions().begin(),
+ dnums.input_spatial_dimensions().end());
+ input_layout.push_back(dnums.input_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid input layout: ",
+ DataLayoutString(input));
+ }
+
+ std::vector<int64> filter_layout;
+ switch (filter) {
+ case FilterLayout::kOutputInputYX:
+ filter_layout.push_back(dnums.kernel_output_feature_dimension());
+ filter_layout.push_back(dnums.kernel_input_feature_dimension());
+ filter_layout.insert(filter_layout.end(),
+ dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end());
+ break;
+ case FilterLayout::kOutputYXInput:
+ filter_layout.push_back(dnums.kernel_output_feature_dimension());
+ filter_layout.insert(filter_layout.end(),
+ dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end());
+ filter_layout.push_back(dnums.kernel_input_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid filter layout: ",
+ FilterLayoutString(filter));
+ }
+
+ std::vector<int64> output_layout;
+ switch (output) {
+ case DataLayout::kBatchDepthYX:
+ output_layout.push_back(dnums.output_batch_dimension());
+ output_layout.push_back(dnums.output_feature_dimension());
+ output_layout.insert(output_layout.end(),
+ dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end());
+ break;
+ case DataLayout::kBatchYXDepth:
+ output_layout.push_back(dnums.output_batch_dimension());
+ output_layout.insert(output_layout.end(),
+ dnums.output_spatial_dimensions().begin(),
+ dnums.output_spatial_dimensions().end());
+ output_layout.push_back(dnums.output_feature_dimension());
+ break;
+ default:
+ return tensorflow::errors::Internal("Invalid output layout: ",
+ DataLayoutString(output));
+ }
+
+ return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout),
+ LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout),
+ LayoutUtil::MakeLayoutFromMajorToMinor(output_layout));
+}
+
+StatusOr<std::tuple<DataLayout, FilterLayout, DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+ const Layout& input, const Layout& filter,
+ const Layout& output) {
+ Layout nchw_input, nchw_filter, nchw_output;
+ std::tie(nchw_input, nchw_filter, nchw_output) =
+ StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX,
+ FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX)
+ .ConsumeValueOrDie();
+
+ Layout nhwc_input, nhwc_filter, nhwc_output;
+ std::tie(nhwc_input, nhwc_filter, nhwc_output) =
+ StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth,
+ FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth)
+ .ConsumeValueOrDie();
+
+ DataLayout input_layout;
+ if (LayoutUtil::Equal(input, nchw_input)) {
+ input_layout = DataLayout::kBatchDepthYX;
+ } else if (LayoutUtil::Equal(input, nhwc_input)) {
+ input_layout = DataLayout::kBatchYXDepth;
+ } else {
+ return tensorflow::errors::Internal("Invalid input layout: ",
+ input.ShortDebugString());
+ }
+
+ FilterLayout filter_layout;
+ if (LayoutUtil::Equal(filter, nchw_filter)) {
+ filter_layout = FilterLayout::kOutputInputYX;
+ } else if (LayoutUtil::Equal(filter, nhwc_filter)) {
+ filter_layout = FilterLayout::kOutputYXInput;
+ } else {
+ return tensorflow::errors::Internal("Invalid filter layout: ",
+ filter.ShortDebugString());
+ }
+
+ DataLayout output_layout;
+ if (LayoutUtil::Equal(output, nchw_output)) {
+ output_layout = DataLayout::kBatchDepthYX;
+ } else if (LayoutUtil::Equal(output, nhwc_output)) {
+ output_layout = DataLayout::kBatchYXDepth;
+ } else {
+ return tensorflow::errors::Internal("Invalid output layout: ",
+ output.ShortDebugString());
+ }
+
+ return std::make_tuple(input_layout, filter_layout, output_layout);
+}
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
new file mode 100644
index 0000000000..8218f4fd11
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+// Helper functions for interacting with StreamExecutor.
+
+namespace xla {
+namespace gpu {
+
+// Returns (input, filter, output) XLA Layout protos given the StreamExecutor
+// layouts.
+StatusOr<std::tuple<Layout, Layout, Layout>>
+StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums,
+ stream_executor::dnn::DataLayout input,
+ stream_executor::dnn::FilterLayout filter,
+ stream_executor::dnn::DataLayout output);
+
+// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts.
+StatusOr<std::tuple<stream_executor::dnn::DataLayout,
+ stream_executor::dnn::FilterLayout,
+ stream_executor::dnn::DataLayout>>
+XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums,
+ const Layout& input, const Layout& filter,
+ const Layout& output);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index fd54ac761c..1a12fd0113 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -776,30 +776,42 @@ xla_test(
],
)
+CONVOLUTION_TEST_DEPS = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+]
+
xla_test(
name = "convolution_test",
timeout = "long",
srcs = ["convolution_test.cc"],
shard_count = 25,
- deps = [
- "//tensorflow/compiler/xla:array2d",
- "//tensorflow/compiler/xla:array4d",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:reference_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:global_data",
- "//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/client:padding",
- "//tensorflow/compiler/xla/client/xla_client:xla_builder",
- "//tensorflow/compiler/xla/tests:client_library_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- ],
+ deps = CONVOLUTION_TEST_DEPS,
+)
+
+xla_test(
+ name = "convolution_test_gpu_alternative_layout",
+ timeout = "long",
+ srcs = ["convolution_test.cc"],
+ backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_use_layout_heuristic"]},
+ backends = ["gpu"],
+ shard_count = 25,
+ deps = CONVOLUTION_TEST_DEPS,
)
xla_test(