From 4831740f90eaf266a99d3ffa7d390d54325b689f Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 5 Oct 2018 17:05:17 -0700 Subject: [XLA:GPU] Remove hidden flag for disabling heuristic layout assignment. Heuristic NCHW/NHWC layout assignment works great; we've never had to flip this flag. Might as well remove it and simplify things a bit. PiperOrigin-RevId: 215989807 --- tensorflow/compiler/xla/service/gpu/BUILD | 11 -------- .../xla/service/gpu/gpu_layout_assignment.cc | 11 ++------ tensorflow/compiler/xla/service/gpu/gpu_options.cc | 28 ------------------ tensorflow/compiler/xla/service/gpu/gpu_options.h | 33 ---------------------- 4 files changed, 2 insertions(+), 81 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/gpu_options.cc delete mode 100644 tensorflow/compiler/xla/service/gpu/gpu_options.h (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 7b84f691f6..350fd32537 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -781,7 +781,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -882,16 +881,6 @@ cc_library( ], ) -cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 74352f26aa..1ffe855750 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -125,14 +124,8 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( std::tie(*input_shape->mutable_layout(), diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc deleted file mode 100644 index 35b4b4e20b..0000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" -#include "tensorflow/core/lib/gtl/map_util.h" - -namespace xla { -namespace gpu { - -bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { - return !config.debug_options().xla_backend_extra_options().count( - "xla_gpu_experimental_conv_disable_layout_heuristic"); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h deleted file mode 100644 index 498d4a9495..0000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ - -#include "tensorflow/compiler/xla/service/hlo_module_config.h" - -// Helper functions for querying options that are specific to the GPU backend. - -namespace xla { -namespace gpu { - -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); - -} // namespace gpu -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -- cgit v1.2.3