aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eli Bendersky <eliben@google.com>2017-06-28 15:32:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-28 19:40:34 -0700
commite6a45475735ee8a31c7d6c8e28e9164cda7d1853 (patch)
tree24a61658eba569f2614c9442d2beec7e0eacb69e
parent7ab72bf2205b1775607932b6ccbcd7099368705e (diff)
[XLA] Move the flag from user_computation_flags into debug_options_flags
This requires some plumbing in user_computation to pipe the debug options through a few layers. PiperOrigin-RevId: 160459822
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD12
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc9
-rw-r--r--tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc64
-rw-r--r--tensorflow/compiler/xla/legacy_flags/user_computation_flags.h48
-rw-r--r--tensorflow/compiler/xla/service/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.cc1
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc32
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h3
-rw-r--r--tensorflow/compiler/xla/service/user_computation_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/BUILD4
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/deep_graph_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/reduce_precision_test.cc2
-rw-r--r--tensorflow/compiler/xla/xla.proto4
15 files changed, 57 insertions, 169 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
index 9444cc1061..27b1d0c3bb 100644
--- a/tensorflow/compiler/xla/legacy_flags/BUILD
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -125,18 +125,6 @@ cc_library(
],
)
-cc_library(
- name = "user_computation_flags",
- srcs = ["user_computation_flags.cc"],
- hdrs = ["user_computation_flags.h"],
- deps = [
- ":parse_flags_from_env",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:lib",
- ],
-)
-
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 43634792ce..3d5ea4d32a 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -41,6 +41,7 @@ struct DebugOptionsFlags {
bool xla_embed_ir_in_executable;
string xla_dump_ir_to;
string xla_dump_debug_json_to;
+ bool xla_eliminate_hlo_implicit_broadcast;
bool xla_cpu_multi_thread_eigen;
@@ -76,6 +77,7 @@ void AllocateFlags() {
flag_values->xla_embed_ir_in_executable = false;
flag_values->xla_dump_ir_to = "";
flag_values->xla_dump_debug_json_to = "";
+ flag_values->xla_eliminate_hlo_implicit_broadcast = false;
flag_values->xla_cpu_multi_thread_eigen = true;
flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib";
flag_values->xla_gpu_ftz = false;
@@ -137,6 +139,11 @@ void AllocateFlags() {
"Embed the compiler IR as a string in the executable."),
tensorflow::Flag("xla_dump_ir_to", &flag_values->xla_dump_ir_to,
"Dump the compiler IR into this file/path."),
+ tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
+ &flag_values->xla_eliminate_hlo_implicit_broadcast,
+ "Eliminate implicit broadcasts when lowering user "
+ "computations to HLO instructions; use explicit "
+ "broadcast instead."),
tensorflow::Flag("xla_cpu_multi_thread_eigen",
&flag_values->xla_cpu_multi_thread_eigen,
"When generating calls to Eigen in the CPU backend, "
@@ -192,6 +199,8 @@ xla::DebugOptions GetDebugOptionsFromFlags() {
options.set_xla_embed_ir_in_executable(
flag_values->xla_embed_ir_in_executable);
options.set_xla_dump_ir_to(flag_values->xla_dump_ir_to);
+ options.set_xla_eliminate_hlo_implicit_broadcast(
+ flag_values->xla_eliminate_hlo_implicit_broadcast);
options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to);
options.set_xla_cpu_multi_thread_eigen(
flag_values->xla_cpu_multi_thread_eigen);
diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc
deleted file mode 100644
index a9597d0cd8..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.cc
+++ /dev/null
@@ -1,64 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
-#include <vector>
-
-#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Pointers to the parsed value of the flags and flag descriptors, initialized
-// via flags_init.
-static UserComputationFlags* flags;
-static std::vector<tensorflow::Flag>* flag_list;
-static std::once_flag flags_init;
-
-// Allocate *flags. Called via call_once(&flags_init,...).
-static void AllocateFlags() {
- flags = new UserComputationFlags;
- flags->xla_eliminate_hlo_implicit_broadcast = false;
- flag_list = new std::vector<tensorflow::Flag>({
- tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
- &flags->xla_eliminate_hlo_implicit_broadcast,
- "Eliminate implicit broadcast on when lowering user "
- "computation to HLO instructions, use explicit "
- "broadcast instead."),
- });
- ParseFlagsFromEnv(*flag_list);
-}
-
-// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
-// module.
-void AppendUserComputationFlags(std::vector<tensorflow::Flag>* append_to) {
- std::call_once(flags_init, &AllocateFlags);
- append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
-}
-
-// Return a pointer to the UserComputationFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-UserComputationFlags* GetUserComputationFlags() {
- std::call_once(flags_init, &AllocateFlags);
- return flags;
-}
-
-} // namespace legacy_flags
-} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h b/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h
deleted file mode 100644
index f5222c927c..0000000000
--- a/tensorflow/compiler/xla/legacy_flags/user_computation_flags.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
-
-// Legacy flags for XLA's user_computation module.
-
-#include <vector>
-
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/command_line_flags.h"
-
-namespace xla {
-namespace legacy_flags {
-
-// Append to *flag_list flags definitions associated with XLA's user_computation
-// module.
-void AppendUserComputationFlags(std::vector<tensorflow::Flag>* flag_list);
-
-typedef struct {
- // Eliminate implicit broadcast on when lowering user computation to HLO
- // instructions, use explicit broadcast instead.
- bool xla_eliminate_hlo_implicit_broadcast;
-} UserComputationFlags;
-
-// Return a pointer to the UserComputationFlags struct;
-// repeated calls return the same pointer.
-// This should be called only after Flags::Parse() has returned.
-UserComputationFlags* GetUserComputationFlags();
-
-} // namespace legacy_flags
-} // namespace xla
-
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index ec12a8acf0..2a74fd1961 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -288,7 +288,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
+ "//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:lib",
],
)
@@ -306,7 +306,7 @@ cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
+ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc
index 9aa32a1fb7..70e25eebdb 100644
--- a/tensorflow/compiler/xla/service/computation_tracker.cc
+++ b/tensorflow/compiler/xla/service/computation_tracker.cc
@@ -216,6 +216,7 @@ StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloComputation> hlo_computation,
computation->BuildHloComputation(versioned_handle.version, resolver,
+ config.debug_options(),
include_unreachable_instructions));
// Add the newly created computation to VersionedHandle-to-HloComputation
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 92b8c7bb21..90a24fb44d 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -1931,26 +1930,31 @@ class ComputationLowerer {
const SessionComputation& session_computation,
VersionedComputationHandle::Version version,
UserComputation::HloComputationResolver hlo_resolver,
+ const DebugOptions& debug_options,
bool include_unreachable_instructions) {
ComputationLowerer lowerer(computation_name, session_computation, version,
- std::move(hlo_resolver));
- return lowerer.Lower(include_unreachable_instructions);
+ std::move(hlo_resolver), debug_options,
+ include_unreachable_instructions);
+ return lowerer.Lower();
}
private:
ComputationLowerer(const string& computation_name,
const SessionComputation& session_computation,
VersionedComputationHandle::Version version,
- UserComputation::HloComputationResolver hlo_resolver)
+ UserComputation::HloComputationResolver hlo_resolver,
+ const DebugOptions& debug_options,
+ bool include_unreachable_instructions)
: hlo_builder_(computation_name),
session_computation_(session_computation),
version_(version),
- hlo_resolver_(std::move(hlo_resolver)) {}
+ hlo_resolver_(std::move(hlo_resolver)),
+ debug_options_(debug_options),
+ include_unreachable_instructions_(include_unreachable_instructions) {}
// Build an HLO computation from the SessionComputation at the given
// version.
- StatusOr<std::unique_ptr<HloComputation>> Lower(
- bool include_unreachable_instructions);
+ StatusOr<std::unique_ptr<HloComputation>> Lower();
private:
// Traverses the computation 'root' using a DFS, calling 'visit' in postorder.
@@ -1980,6 +1984,8 @@ class ComputationLowerer {
const SessionComputation& session_computation_;
const VersionedComputationHandle::Version version_;
const UserComputation::HloComputationResolver hlo_resolver_;
+ const DebugOptions& debug_options_;
+ const bool include_unreachable_instructions_;
};
// Calls 'apply' on each operand of 'request'.
@@ -2273,8 +2279,7 @@ void ComputationLowerer::TraversePostorder(
}
}
-StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower(
- bool include_unreachable_instructions) {
+StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower() {
// Map from ComputationDataHandle to HLO instruction. Serves as a record of
// which operations have been visited as well as a cache for looking up
// ComputationDataHandles as HloInstructions.
@@ -2290,7 +2295,7 @@ StatusOr<std::unique_ptr<HloComputation>> ComputationLowerer::Lower(
HloInstruction* hlo_root =
instructions.at(root_request->output_handle().handle());
- if (include_unreachable_instructions) {
+ if (include_unreachable_instructions_) {
// Iterate through all computation data handles, and visit any unvisited
// operations.
for (int64 request_num = 1; request_num <= version_; ++request_num) {
@@ -2785,8 +2790,7 @@ void ComputationLowerer::Visit(
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
}
- if (legacy_flags::GetUserComputationFlags()
- ->xla_eliminate_hlo_implicit_broadcast) {
+ if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
// lhs side is being implicitly broadcast. Change to explicit.
lhs =
@@ -2845,7 +2849,7 @@ void ComputationLowerer::Visit(
StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation(
VersionedComputationHandle::Version version,
- HloComputationResolver hlo_resolver,
+ HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
bool include_unreachable_instructions) const {
tensorflow::mutex_lock lock(mutex_);
@@ -2857,7 +2861,7 @@ StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation(
std::unique_ptr<HloComputation> hlo_computation,
ComputationLowerer::Lower(
tensorflow::strings::StrCat(name(), ".v", version),
- session_computation_, version, std::move(hlo_resolver),
+ session_computation_, version, std::move(hlo_resolver), debug_options,
include_unreachable_instructions));
XLA_VLOG_LINES(2, hlo_computation->ToString());
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 9bb7bf491a..3cc3bd0918 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -264,7 +265,7 @@ class UserComputation {
std::function<HloComputation*(const VersionedComputationHandle& handle)>;
StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation(
VersionedComputationHandle::Version version,
- HloComputationResolver hlo_resolver,
+ HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
bool include_unreachable_instructions = true) const;
// Return a vector containing the embedded computations used by this
diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc
index 41bb641f43..0d50810dc4 100644
--- a/tensorflow/compiler/xla/service/user_computation_test.cc
+++ b/tensorflow/compiler/xla/service/user_computation_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/user_computation.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@@ -92,7 +92,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
// Build the HLO computation.
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver));
+ computation.BuildHloComputation(latest_version.version, hlo_resolver,
+ DebugOptions()));
// There should be one HloInstruction per UserComputation operation.
EXPECT_EQ(3, hlo_computation->instruction_count());
// The root of the instruction should be the parameter instruction (not the
@@ -117,9 +118,10 @@ TEST_F(UserComputationTest, SimpleComputation) {
// There should be two instructions, one for the constant and one for the
// parameter. The outfeed instruction should not be included.
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(
- version_at_param.version, hlo_resolver));
+ TF_ASSIGN_OR_ASSERT_OK(
+ std::unique_ptr<HloComputation> hlo_computation,
+ computation.BuildHloComputation(version_at_param.version, hlo_resolver,
+ DebugOptions()));
EXPECT_EQ(2, hlo_computation->instruction_count());
EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter());
}
@@ -130,10 +132,11 @@ TEST_F(UserComputationTest, SimpleComputation) {
computation.GetVersionedHandle();
// Build the HLO computation.
- TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(
- latest_version.version, hlo_resolver,
- /*include_unreachable_instructions=*/false));
+ TF_ASSIGN_OR_ASSERT_OK(
+ std::unique_ptr<HloComputation> hlo_computation,
+ computation.BuildHloComputation(
+ latest_version.version, hlo_resolver, DebugOptions(),
+ /*include_unreachable_instructions=*/false));
// There is only one reachable instruction, the parameter.
EXPECT_EQ(1, hlo_computation->instruction_count());
// The root of the instruction should be the parameter instruction (not the
@@ -145,8 +148,8 @@ TEST_F(UserComputationTest, SimpleComputation) {
}
TEST_F(UserComputationTest, EliminateScalarBroadcast) {
- if (!legacy_flags::GetUserComputationFlags()
- ->xla_eliminate_hlo_implicit_broadcast) {
+ if (!legacy_flags::GetDebugOptionsFromFlags()
+ .xla_eliminate_hlo_implicit_broadcast()) {
return;
}
@@ -184,7 +187,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
// Build the HLO computation.
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver));
+ computation.BuildHloComputation(latest_version.version, hlo_resolver,
+ DebugOptions()));
// The binary operation has implicit scalar broadcast, should be converted
// to an explicit broadcast intruction and a binary instruction.
EXPECT_EQ(4, hlo_computation->instruction_count());
@@ -196,8 +200,8 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) {
}
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
- if (!legacy_flags::GetUserComputationFlags()
- ->xla_eliminate_hlo_implicit_broadcast) {
+ if (!legacy_flags::GetDebugOptionsFromFlags()
+ .xla_eliminate_hlo_implicit_broadcast()) {
return;
}
@@ -240,7 +244,8 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
// Build the HLO computation.
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver));
+ computation.BuildHloComputation(latest_version.version, hlo_resolver,
+ DebugOptions()));
// The binary operation has in-dim broadcast and degenerate broadcast, should
// first do the in-dim broadcast then convert the degnerate broadcast into a
@@ -266,7 +271,7 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
- xla::legacy_flags::AppendUserComputationFlags(&flag_list);
+ xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index f5a72a9bcc..cb050f6f4d 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -478,7 +478,6 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
@@ -500,7 +499,6 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
@@ -1015,7 +1013,6 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
],
@@ -1468,7 +1465,6 @@ xla_test(
srcs = ["deep_graph_test.cc"],
deps = [
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
- "//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
],
)
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 024988743c..198a997799 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -1870,7 +1869,6 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index aab2c74634..2a57835ca9 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -699,7 +698,6 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc
index 7a5601ada3..60953a7421 100644
--- a/tensorflow/compiler/xla/tests/deep_graph_test.cc
+++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
namespace xla {
@@ -42,7 +41,6 @@ TEST_F(ClientLibraryTestBase, DeepGraph) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index 3c87fffadb..a66c9b4487 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
-#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -247,7 +246,6 @@ INSTANTIATE_TEST_CASE_P(ReducePrecisionTest, ReducePrecisionTest,
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendDebugOptionsFlags(&flag_list);
- xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 9ac89d1b1c..15fa1255ae 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -73,6 +73,10 @@ message DebugOptions {
// Dump the compiler IR into this file/path.
string xla_dump_ir_to = 34;
+ // Eliminate implicit broadcasts when lowering user computations to HLO
+ // instructions; use explicit broadcast instead.
+ bool xla_eliminate_hlo_implicit_broadcast = 35;
+
// When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
// mode.
bool xla_cpu_multi_thread_eigen = 60;