aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eli Bendersky <eliben@google.com>2017-06-07 09:10:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-07 09:14:06 -0700
commitb9d5e144193f587020c1bac5d7505af88baa24d9 (patch)
treeff7df73b7323cf6a41231fddb3f718d52e74c2bc
parent3b6fe94bbb4f5d074ee52d7394ae093396ca7b23 (diff)
[XLA] Start collecting flags for debug options in a single place.
ClientLibraryTestBase will now parse command-line flags for debug options automatically, permitting subclasses to override certain options by using mutable_debug_options. main() still has to call AppendDebugOptionsFlags() explicitly before running the TF flag parser. In the mean-time, this CL leaves flag handling to the current "legacy" approach. However, this is part of a larger plan to move *all* debugging flags for XLA into the DebugOptions message and expose them as flags from a single place. The other flags (which are not controlling debugging options) will have to be propagated more explicitly. PiperOrigin-RevId: 158276294
-rw-r--r--tensorflow/compiler/aot/BUILD2
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc4
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD26
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc84
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.h38
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc62
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h48
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc27
-rw-r--r--tensorflow/compiler/xla/tests/BUILD4
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc7
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h6
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc1
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc16
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc8
16 files changed, 174 insertions, 168 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index 5e368749a0..71c6b17d51 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -128,8 +128,8 @@ cc_library(
"//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
- "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
"//tensorflow/compiler/xla/legacy_flags:llvm_util_flags",
"//tensorflow/compiler/xla/legacy_flags:service_flags",
"//tensorflow/compiler/xla/legacy_flags:util_flags",
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 63ec261e01..6fed46b432 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -28,8 +28,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
@@ -142,7 +142,7 @@ int main(int argc, char** argv) {
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list);
- xla::legacy_flags::AppendHloPassPipelineFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::legacy_flags::AppendLlvmUtilFlags(&flag_list);
xla::legacy_flags::AppendServiceFlags(&flag_list);
xla::legacy_flags::AppendUtilFlags(&flag_list);
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index 017cb5bb0e..b124e2d425 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -66,6 +66,20 @@ cc_library(
)
cc_library(
+ name = "debug_options_flags",
+ srcs = ["debug_options_flags.cc"],
+ hdrs = ["debug_options_flags.h"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "cpu_compiler_flags",
srcs = ["cpu_compiler_flags.cc"],
hdrs = ["cpu_compiler_flags.h"],
@@ -161,18 +175,6 @@ cc_library(
)
cc_library(
- name = "hlo_pass_pipeline_flags",
- srcs = ["hlo_pass_pipeline_flags.cc"],
- hdrs = ["hlo_pass_pipeline_flags.h"],
- deps = [
- ":parse_flags_from_env",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "alias_analysis_flags",
srcs = ["alias_analysis_flags.cc"],
hdrs = ["alias_analysis_flags.h"],
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
new file mode 100644
index 0000000000..0211462cb1
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace legacy_flags {
+
+struct DebugOptionsFlags {
+ string xla_generate_hlo_graph;
+
+ string xla_disable_hlo_passes;
+};
+
+namespace {
+
+DebugOptionsFlags* flag_values;
+std::vector<tensorflow::Flag>* flag_objects;
+std::once_flag flags_init;
+
+// Allocates flag_values and flag_objects; this function must not be called more
+// than once - its call done via call_once.
+void AllocateFlags() {
+ flag_values = new DebugOptionsFlags;
+ flag_values->xla_generate_hlo_graph = "";
+ flag_values->xla_disable_hlo_passes = "";
+
+ flag_objects = new std::vector<tensorflow::Flag>(
+ {tensorflow::Flag(
+ "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph,
+ "HLO modules matching this regex will be dumped to a .dot file "
+ "throughout various stages in compilation."),
+
+ tensorflow::Flag(
+ "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes,
+ "Comma-separated list of HLO passes to be disabled. These names "
+ "must "
+ "exactly match the passes' names; no whitespace around commas.")});
+ ParseFlagsFromEnv(*flag_objects);
+}
+
+} // namespace
+
+void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
+ std::call_once(flags_init, &AllocateFlags);
+ flag_list->insert(flag_list->end(), flag_objects->begin(),
+ flag_objects->end());
+}
+
+xla::DebugOptions GetDebugOptionsFromFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+
+ DebugOptions options;
+
+ options.set_xla_generate_hlo_graph(flag_values->xla_generate_hlo_graph);
+
+ std::vector<string> disabled_passes =
+ tensorflow::str_util::Split(flag_values->xla_disable_hlo_passes, ',');
+ for (const auto& passname : disabled_passes) {
+ options.add_xla_disable_hlo_passes(passname);
+ }
+
+ return options;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h
new file mode 100644
index 0000000000..d0ef8e66ab
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Appends flag definitions for debug options to flag_list.
+void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// Fetches a DebugOptions proto message from flags provided to the program.
+// Flags must be registered with the flags parser using AppendDebugOptionsFlags
+// first.
+xla::DebugOptions GetDebugOptionsFromFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc
deleted file mode 100644
index edc04d51a7..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc
+++ /dev/null
@@ -1,62 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// Legacy flags for XLA's hlo_pass_pipeline module.
-
-#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
-#include <vector>
-
-#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static HloPassPipelineFlags* flags;
-static std::vector<tensorflow::Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new HloPassPipelineFlags;
- flags->xla_disable_hlo_passes = "";
- flag_list = new std::vector<tensorflow::Flag>({
- tensorflow::Flag("xla_disable_hlo_passes", &flags->xla_disable_hlo_passes,
- "Comma-separated list of HLO passes to disable."),
- });
- ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
-// module.
-void AppendHloPassPipelineFlags(std::vector<tensorflow::Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the HloPassPipelineFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-HloPassPipelineFlags* GetHloPassPipelineFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h
deleted file mode 100644
index 520759bbf0..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
-#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
-
-// Legacy flags for XLA's hlo_pass_pipeline module.
-
-#include <vector>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Append to *flag_list flag definitions associated with XLA's hlo_pass_pipeline
-// module.
-void AppendHloPassPipelineFlags(std::vector<tensorflow::Flag>* flag_list);
-
-// The values of flags associated with XLA's hlo_pass_pipeline module.
-typedef struct {
- // Comma-separated list of HLO passes to disable.
- string xla_disable_hlo_passes;
-} HloPassPipelineFlags;
-
-// Return a pointer to the HloPassPipelineFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-HloPassPipelineFlags* GetHloPassPipelineFlags();
-
-} // namespace legacy_flags
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index aa1349a350..7cb3c95ffa 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1446,7 +1446,6 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 4e258c2a88..afc4d3733c 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <functional>
-#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
@@ -44,23 +43,13 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
VLOG(1) << "Running HLO pass pipeline " << name();
- legacy_flags::HloPassPipelineFlags* flags =
- legacy_flags::GetHloPassPipelineFlags();
- std::unique_ptr<tensorflow::gtl::FlatSet<string>> disabled_passes;
- if (!flags->xla_disable_hlo_passes.empty()) {
- std::vector<string> passes_vec =
- tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
- disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
- passes_vec.begin(), passes_vec.end());
- } else {
- auto repeated_field =
- module->config().debug_options().xla_disable_hlo_passes();
- disabled_passes = MakeUnique<tensorflow::gtl::FlatSet<string>>(
- repeated_field.begin(), repeated_field.end());
- }
- if (!disabled_passes->empty()) {
+ auto repeated_field =
+ module->config().debug_options().xla_disable_hlo_passes();
+ tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(),
+ repeated_field.end());
+ if (!disabled_passes.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
- << tensorflow::str_util::Join(*disabled_passes, ", ");
+ << tensorflow::str_util::Join(disabled_passes, ", ");
}
auto run_invariant_checkers = [this, module]() -> Status {
@@ -75,8 +64,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
bool changed = false;
string message;
for (auto& pass : passes_) {
- if (!disabled_passes->empty() &&
- disabled_passes->count(pass->name().ToString()) > 0) {
+ if (!disabled_passes.empty() &&
+ disabled_passes.count(pass->name().ToString()) > 0) {
VLOG(1) << " Skipping HLO pass " << pass->name()
<< ", disabled by --xla_disable_hlo_passes";
continue;
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index eb31e8cdbf..1971868a38 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -150,7 +150,7 @@ cc_library(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
@@ -1153,6 +1153,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
@@ -1212,7 +1213,6 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
- "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 2d052e7a4d..03552d7bbf 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -45,7 +45,10 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
} // namespace
ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform)
- : client_(GetOrCreateLocalClientOrDie(platform)) {}
+ : client_(GetOrCreateLocalClientOrDie(platform)) {
+ *(execution_options_.mutable_debug_options()) =
+ legacy_flags::GetDebugOptionsFromFlags();
+}
string ClientLibraryTestBase::TestName() const {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 9f0d6272f4..e6fc0f457a 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -57,8 +57,10 @@ class ClientLibraryTestBase : public ::testing::Test {
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
- void SetDebugOptions(const DebugOptions& debug_options) {
- *(execution_options_.mutable_debug_options()) = debug_options;
+ // Provides mutable access to the execution DebugOptions field; this lets
+ // tests tweak the options that will be used to compile/run the graph.
+ DebugOptions* mutable_debug_options() {
+ return execution_options_.mutable_debug_options();
}
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index 25b645557e..72a8d47ac9 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
index 7b2f201d1b..f6178608c8 100644
--- a/tensorflow/compiler/xla/tests/convert_test.cc
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -37,10 +38,8 @@ class ConvertTest : public ClientLibraryTestBase {
public:
explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform) {
- DebugOptions debug_options;
- debug_options.add_xla_disable_hlo_passes("algsimp");
- debug_options.add_xla_disable_hlo_passes("inline");
- SetDebugOptions(debug_options);
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp");
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline");
}
};
@@ -199,6 +198,7 @@ TEST_F(ConvertTest, ConvertReshape) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index e263400929..6c82460c7c 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -42,10 +43,8 @@ class MapTest : public ClientLibraryTestBase {
public:
explicit MapTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform) {
- DebugOptions debug_options;
- debug_options.add_xla_disable_hlo_passes("algsimp");
- debug_options.add_xla_disable_hlo_passes("inline");
- SetDebugOptions(debug_options);
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp");
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline");
}
// Creates a function that adds its scalar argument with the constant 1.0.
@@ -103,8 +102,8 @@ class MapTest : public ClientLibraryTestBase {
// Creates a function that adds its scalar argument with the constant 1.0 and
// then multiplies by the original element.
//
- // /---------------\
- // / \
+ // /------------------|
+ // / |
// x {R0F32} ----> (add) ----> (mul)
// /
// 1.0f ---------/
@@ -150,8 +149,8 @@ class MapTest : public ClientLibraryTestBase {
// Creates a function that adds three scalar arguments
//
- // x {R0F32} ----\
- // \
+ // x {R0F32} -------|
+ // |
// y {R0F32} ----> (add) ---> (add)
// /
// z {R0F32} ---------------/
@@ -624,6 +623,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index c380c046ce..a41c2797bf 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -42,10 +43,8 @@ class VecOpsSimpleTest : public ClientLibraryTestBase {
public:
explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
: ClientLibraryTestBase(platform) {
- DebugOptions debug_options;
- debug_options.add_xla_disable_hlo_passes("algsimp");
- debug_options.add_xla_disable_hlo_passes("inline");
- SetDebugOptions(debug_options);
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("algsimp");
+ mutable_debug_options()->mutable_xla_disable_hlo_passes()->Add("inline");
}
ErrorSpec error_spec_{0.0001};
@@ -443,6 +442,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {