diff options
author | 2017-06-28 15:32:11 -0700 | |
---|---|---|
committer | 2017-06-28 19:40:34 -0700 | |
commit | e6a45475735ee8a31c7d6c8e28e9164cda7d1853 (patch) | |
tree | 24a61658eba569f2614c9442d2beec7e0eacb69e | |
parent | 7ab72bf2205b1775607932b6ccbcd7099368705e (diff) |
[XLA] Move the flag from user_computation_flags into debug_options_flags
This requires some plumbing in user_computation to pipe the debug options
through a few layers.
PiperOrigin-RevId: 160459822
15 files changed, 57 insertions, 169 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index 9444cc1061..27b1d0c3bb 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -125,18 +125,6 @@ cc_library( ], ) -cc_library( - name = "user_computation_flags", - srcs = ["user_computation_flags.cc"], - hdrs = ["user_computation_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 43634792ce..3d5ea4d32a 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -41,6 +41,7 @@ struct DebugOptionsFlags { bool xla_embed_ir_in_executable; string xla_dump_ir_to; string xla_dump_debug_json_to; + bool xla_eliminate_hlo_implicit_broadcast; bool xla_cpu_multi_thread_eigen; @@ -76,6 +77,7 @@ void AllocateFlags() { flag_values->xla_embed_ir_in_executable = false; flag_values->xla_dump_ir_to = ""; flag_values->xla_dump_debug_json_to = ""; + flag_values->xla_eliminate_hlo_implicit_broadcast = false; flag_values->xla_cpu_multi_thread_eigen = true; flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib"; flag_values->xla_gpu_ftz = false; @@ -137,6 +139,11 @@ void AllocateFlags() { "Embed the compiler IR as a string in the executable."), tensorflow::Flag("xla_dump_ir_to", &flag_values->xla_dump_ir_to, "Dump the compiler IR into this file/path."), + tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", + &flag_values->xla_eliminate_hlo_implicit_broadcast, + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), tensorflow::Flag("xla_cpu_multi_thread_eigen", &flag_values->xla_cpu_multi_thread_eigen, "When generating calls to Eigen in the CPU backend, " @@ -192,6 +199,8 @@ xla::DebugOptions GetDebugOptionsFromFlags() { options.set_xla_embed_ir_in_executable( flag_values->xla_embed_ir_in_executable); options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to); + options.set_xla_eliminate_hlo_implicit_broadcast( + flag_values->xla_eliminate_hlo_implicit_broadcast); options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to); options.set_xla_cpu_multi_thread_eigen( flag_values->xla_cpu_multi_thread_eigen); diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc deleted file mode 100644 index a9597d0cd8..0000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. -#include <vector> - -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static UserComputationFlags* flags; -static std::vector<tensorflow::Flag>* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new UserComputationFlags; - flags->xla_eliminate_hlo_implicit_broadcast = false; - flag_list = new std::vector<tensorflow::Flag>({ - tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast", - &flags->xla_eliminate_hlo_implicit_broadcast, - "Eliminate implicit broadcast on when lowering user " - "computation to HLO instructions, use explicit " - "broadcast instead."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline -// module. -void AppendUserComputationFlags(std::vector<tensorflow::Flag>* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h deleted file mode 100644 index f5222c927c..0000000000 --- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ - -// Legacy flags for XLA's user_computation module. - -#include <vector> - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flags definitions associated with XLA's user_computation -// module. -void AppendUserComputationFlags(std::vector<tensorflow::Flag>* flag_list); - -typedef struct { - // Eliminate implicit broadcast on when lowering user computation to HLO - // instructions, use explicit broadcast instead. - bool xla_eliminate_hlo_implicit_broadcast; -} UserComputationFlags; - -// Return a pointer to the UserComputationFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -UserComputationFlags* GetUserComputationFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_ diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ec12a8acf0..2a74fd1961 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -288,7 +288,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", ], ) @@ -306,7 +306,7 @@ cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", ], diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc index 9aa32a1fb7..70e25eebdb 100644 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ b/tensorflow/compiler/xla/service/computation_tracker.cc @@ -216,6 +216,7 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule( TF_ASSIGN_OR_RETURN( std::unique_ptr<HloComputation> hlo_computation, computation->BuildHloComputation(versioned_handle.version, resolver, + config.debug_options(), include_unreachable_instructions)); // Add the newly created computation to VersionedHandle-to-HloComputation diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 92b8c7bb21..90a24fb44d 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -22,7 +22,6 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -1931,26 +1930,31 @@ class ComputationLowerer { const SessionComputation& session_computation, VersionedComputationHandle::Version version, UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, bool include_unreachable_instructions) { ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver)); - return lowerer.Lower(include_unreachable_instructions); + std::move(hlo_resolver), debug_options, + include_unreachable_instructions); + return lowerer.Lower(); } private: ComputationLowerer(const string& computation_name, const SessionComputation& session_computation, VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver) + UserComputation::HloComputationResolver hlo_resolver, + const DebugOptions& debug_options, + bool include_unreachable_instructions) : hlo_builder_(computation_name), session_computation_(session_computation), version_(version), - hlo_resolver_(std::move(hlo_resolver)) {} + hlo_resolver_(std::move(hlo_resolver)), + debug_options_(debug_options), + include_unreachable_instructions_(include_unreachable_instructions) {} // Build an HLO computation from the SessionComputation at the given // version. - StatusOr<std::unique_ptr<HloComputation>> Lower( - bool include_unreachable_instructions); + StatusOr<std::unique_ptr<HloComputation>> Lower(); private: // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. @@ -1980,6 +1984,8 @@ class ComputationLowerer { const SessionComputation& session_computation_; const VersionedComputationHandle::Version version_; const UserComputation::HloComputationResolver hlo_resolver_; + const DebugOptions& debug_options_; + const bool include_unreachable_instructions_; }; // Calls 'apply' on each operand of 'request'. @@ -2273,8 +2279,7 @@ void ComputationLowerer::TraversePostorder( } } -StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower( - bool include_unreachable_instructions) { +StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower() { // Map from ComputationDataHandle to HLO instruction. Serves as a record of // which operations have been visited as well as a cache for looking up // ComputationDataHandles as HloInstructions. @@ -2290,7 +2295,7 @@ StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower( HloInstruction* hlo_root = instructions.at(root_request->output_handle().handle()); - if (include_unreachable_instructions) { + if (include_unreachable_instructions_) { // Iterate through all computation data handles, and visit any unvisited // operations. for (int64 request_num = 1; request_num <= version_; ++request_num) { @@ -2785,8 +2790,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -2845,7 +2849,7 @@ void ComputationLowerer::Visit( StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions) const { tensorflow::mutex_lock lock(mutex_); @@ -2857,7 +2861,7 @@ StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation( std::unique_ptr<HloComputation> hlo_computation, ComputationLowerer::Lower( tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), + session_computation_, version, std::move(hlo_resolver), debug_options, include_unreachable_instructions)); XLA_VLOG_LINES(2, hlo_computation->ToString()); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 9bb7bf491a..3cc3bd0918 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -264,7 +265,7 @@ class UserComputation { std::function<HloComputation*(const VersionedComputationHandle& handle)>; StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation( VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, + HloComputationResolver hlo_resolver, const DebugOptions& debug_options, bool include_unreachable_instructions = true) const; // Return a vector containing the embedded computations used by this diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 41bb641f43..0d50810dc4 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -92,7 +92,8 @@ TEST_F(UserComputationTest, SimpleComputation) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // There should be one HloInstruction per UserComputation operation. EXPECT_EQ(3, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -117,9 +118,10 @@ TEST_F(UserComputationTest, SimpleComputation) { // There should be two instructions, one for the constant and one for the // parameter. The outfeed instruction should not be included. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation( - version_at_param.version, hlo_resolver)); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr<HloComputation> hlo_computation, + computation.BuildHloComputation(version_at_param.version, hlo_resolver, + DebugOptions())); EXPECT_EQ(2, hlo_computation->instruction_count()); EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); } @@ -130,10 +132,11 @@ TEST_F(UserComputationTest, SimpleComputation) { computation.GetVersionedHandle(); // Build the HLO computation. - TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, - /*include_unreachable_instructions=*/false)); + TF_ASSIGN_OR_ASSERT_OK( + std::unique_ptr<HloComputation> hlo_computation, + computation.BuildHloComputation( + latest_version.version, hlo_resolver, DebugOptions(), + /*include_unreachable_instructions=*/false)); // There is only one reachable instruction, the parameter. EXPECT_EQ(1, hlo_computation->instruction_count()); // The root of the instruction should be the parameter instruction (not the @@ -145,8 +148,8 @@ TEST_F(UserComputationTest, SimpleComputation) { } TEST_F(UserComputationTest, EliminateScalarBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -184,7 +187,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has implicit scalar broadcast, should be converted // to an explicit broadcast intruction and a binary instruction. EXPECT_EQ(4, hlo_computation->instruction_count()); @@ -196,8 +200,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { } TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - if (!legacy_flags::GetUserComputationFlags() - ->xla_eliminate_hlo_implicit_broadcast) { + if (!legacy_flags::GetDebugOptionsFromFlags() + .xla_eliminate_hlo_implicit_broadcast()) { return; } @@ -240,7 +244,8 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { // Build the HLO computation. TF_ASSIGN_OR_ASSERT_OK( std::unique_ptr<HloComputation> hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver)); + computation.BuildHloComputation(latest_version.version, hlo_resolver, + DebugOptions())); // The binary operation has in-dim broadcast and degenerate broadcast, should // first do the in-dim broadcast then convert the degnerate broadcast into a @@ -266,7 +271,7 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; - xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index f5a72a9bcc..cb050f6f4d 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -478,7 +478,6 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -500,7 +499,6 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1015,7 +1013,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", ], @@ -1468,7 +1465,6 @@ xla_test( srcs = ["deep_graph_test.cc"], deps = [ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", ], ) diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 024988743c..198a997799 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -1870,7 +1869,6 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index aab2c74634..2a57835ca9 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -699,7 +698,6 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc index 7a5601ada3..60953a7421 100644 --- a/tensorflow/compiler/xla/tests/deep_graph_test.cc +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" namespace xla { @@ -42,7 +41,6 @@ TEST_F(ClientLibraryTestBase, DeepGraph) { int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 3c87fffadb..a66c9b4487 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -247,7 +246,6 @@ INSTANTIATE_TEST_CASE_P(ReducePrecisionTest, ReducePrecisionTest, int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 9ac89d1b1c..15fa1255ae 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -73,6 +73,10 @@ message DebugOptions { // Dump the compiler IR into this file/path. string xla_dump_ir_to = 34; + // Eliminate implicit broadcasts when lowering user computations to HLO + // instructions; use explicit broadcast instead. + bool xla_eliminate_hlo_implicit_broadcast = 35; + // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen // mode. bool xla_cpu_multi_thread_eigen = 60; |