aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-06-01 15:32:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 15:34:52 -0700
commit5e0b2f2b0d0d938152334ae1ef1c9b25d229e280 (patch)
tree2774166e3d5f04577d39fc2227d6d0e364bd249a /tensorflow/compiler
parentaf1d59aff9bf3b43dfff4d99e50d22f527201e76 (diff)
[XLA] Move xla/tools/parser/* into xla/service.
Now that we're using the parser inside of xla/service, it's awkward for it to live inside of xla/tools, because everything else in there is a standalone tool. We've already had one person be confused by this. PiperOrigin-RevId: 198935921
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD95
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/g3doc/hlo_parser.md (renamed from tensorflow/compiler/xla/tools/parser/README.md)0
-rw-r--r--tensorflow/compiler/xla/service/gather_expander_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc)26
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_lexer.h)17
-rw-r--r--tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_dce_test.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser.cc)252
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser.h)24
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc (renamed from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc)90
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_token.h (renamed from tensorflow/compiler/xla/tools/parser/hlo_token.h)11
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/pattern_matcher_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/tuple_util_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/while_util_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/BUILD10
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc11
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc2
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc4
-rw-r--r--tensorflow/compiler/xla/tests/reduce_hlo_test.cc4
-rw-r--r--tensorflow/compiler/xla/tools/parser/BUILD73
49 files changed, 442 insertions, 436 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2b14b63ea8..0102e4f003 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -349,8 +349,8 @@ tf_cc_test(
":hlo",
":pattern_matcher",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -388,8 +388,8 @@ cc_library(
deps = [
":hlo",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -399,6 +399,7 @@ tf_cc_test(
srcs = ["hlo_matchers_test.cc"],
deps = [
":hlo_matchers",
+ ":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
@@ -420,6 +421,7 @@ tf_cc_test(
srcs = ["hlo_instruction_test.cc"],
deps = [
":hlo",
+ ":hlo_parser",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:shape_util",
@@ -429,7 +431,6 @@ tf_cc_test(
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -444,9 +445,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -989,9 +990,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -1027,9 +1028,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1130,9 +1131,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1165,9 +1166,9 @@ tf_cc_test(
deps = [
":hlo_matchers",
":instruction_fusion",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1339,9 +1340,9 @@ tf_cc_test(
deps = [
":gather_expander",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:test_macros_header",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1691,9 +1692,9 @@ tf_cc_test(
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -1874,9 +1875,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2211,11 +2212,11 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2237,9 +2238,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -2310,10 +2311,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -2415,12 +2416,12 @@ tf_cc_test(
":hlo",
":hlo_domain_isolator",
":hlo_domain_remover",
+ ":hlo_parser",
":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2506,10 +2507,10 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2655,10 +2656,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -2795,7 +2796,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
@@ -2831,8 +2832,8 @@ tf_cc_test(
":tuple_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2857,8 +2858,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -2884,8 +2885,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_invariant_code_motion",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2911,8 +2912,8 @@ tf_cc_test(
":hlo_matchers",
":while_loop_constant_sinking",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
@@ -2965,9 +2966,57 @@ tf_cc_test(
":hlo_matchers",
":indexed_array_analysis",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:test",
],
)
+
+cc_library(
+ name = "hlo_parser",
+ srcs = ["hlo_parser.cc"],
+ hdrs = ["hlo_parser.h"],
+ deps = [
+ ":hlo",
+ ":hlo_lexer",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_cc_test(
+ name = "hlo_parser_test",
+ size = "small",
+ srcs = ["hlo_parser_test.cc"],
+ deps = [
+ ":hlo_parser",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main", # fixdeps: keep
+ ],
+)
+
+cc_library(
+ name = "hlo_lexer",
+ srcs = ["hlo_lexer.cc"],
+ hdrs = [
+ "hlo_lexer.h",
+ "hlo_token.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index bdcea92882..7e86c33687 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -32,12 +32,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"
@@ -1793,7 +1793,7 @@ ENTRY %test_module {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
// Run CopyInsertion and check if the graph constructed above doesn't need
// any copies inserted for BufferAssignment to run.
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index a15e41fee0..f10d71fdba 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -633,10 +633,10 @@ tf_cc_test(
deps = [
":cpu_instruction_fusion",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
],
)
@@ -690,9 +690,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -942,7 +942,7 @@ tf_cc_test(
":ir_emission_utils",
":target_machine_features_fake",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
index d12fa6bb9a..8727c72b6e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -40,7 +40,7 @@ ENTRY DotOperation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* dot = module->entry_computation()->root_instruction();
@@ -71,7 +71,7 @@ ENTRY ConvOperation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* conv = module->entry_computation()->root_instruction();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index 46fe060817..97e10a89a2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include <set>
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace op = xla::testing::opcode_matchers;
@@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* computation = module->entry_computation();
TransposeFolding transpose_folding(
@@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) {
string hlo_string = tensorflow::strings::StrCat(
"HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
RunFusionAndCheckOpcodesWereFused(
module.get(),
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
index abb2471e6a..530ebce854 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -35,7 +35,7 @@ ENTRY Conv {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* entry_computation = module->entry_computation();
diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD
index 67f776e7b5..66ae5ef0f6 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD
@@ -152,9 +152,9 @@ tf_cc_test(
srcs = ["cpu_literal_caching_test.cc"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -166,9 +166,9 @@ tf_cc_test(
srcs = ["cpu_outfeed_test.cc"],
deps = [
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
index 3cb25c5c19..27044b1d62 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -60,7 +60,7 @@ CHECK-NOT: private constant [12 x float]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
@@ -105,7 +105,7 @@ CHECK-NOT: private constant [2 x float]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
index 1a948fb4fe..1ee279290b 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
namespace xla {
namespace cpu {
@@ -41,7 +41,7 @@ CHECK: private constant [12 x float]
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
CpuAotCompilationOptions options{
/*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"",
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
index b43dc0c65d..8980d43033 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -33,7 +33,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text, config));
+ ParseHloString(hlo_text, config));
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt));
}
};
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md
index f0f3dd7785..f0f3dd7785 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md
diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc
index 1c72ca0665..020ffcd106 100644
--- a/tensorflow/compiler/xla/service/gather_expander_test.cc
+++ b/tensorflow/compiler/xla/service/gather_expander_test.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gather_expander.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -36,7 +36,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
Status status = GatherExpander{}.Run(module.get()).status();
EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
@@ -63,7 +63,7 @@ ENTRY main {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text));
+ ParseHloString(hlo_text));
TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
ASSERT_TRUE(changed);
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 68297ad4ae..6bd9d4c31d 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -416,9 +416,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -460,9 +460,9 @@ tf_cc_test(
":instruction_fusion",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
index 2217776c7d..b22bb1d39b 100644
--- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace gpu {
@@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {};
// Tuple
//
TEST_F(FusionMergerTest, MergeSharedFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule MergeSharedFusionInstruction
comp.3 {
@@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 {
//
// Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio.
TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule FlopsToBytesRatioThresholdExceeded
comp.2 {
@@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 {
// is merged into Fusion0 and Fusion1) would exceed the bytes transferred
// threshold.
TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule BytesTransferredThresholdExeceeded
comp.2 {
@@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 {
// Fusion2 is reduced for this test which makes the merge operation into its
// operand below the bytes transferred threshold.
TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule BytesTransferredThresholdNotExeceeded
comp.2 {
@@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 {
// Check that we're willing to merge f1_computation into f2_computation, even
// though f2 is an input fusion node.
TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule m
f1_computation {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index ec60f3a167..426b1d235c 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace op = xla::testing::opcode_matchers;
@@ -143,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
// Tests that broadcasts fused into a fusion with a reduce root.
TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
add {
@@ -172,7 +172,7 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) {
}
TEST_F(InstructionFusionTest, BitcastIntoAdd) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
@@ -194,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) {
}
TEST_F(InstructionFusionTest, AddIntoBitcast) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY BroadcastIntoAdd {
@@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) {
}
TEST_F(InstructionFusionTest, DontFuseGTE) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY DontFuseGTE {
p0 = (f32[10], f32[10]) parameter(0)
@@ -232,7 +232,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) {
}
TEST_F(InstructionFusionTest, DotOutputFusion) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
alpha = f32[] constant(3)
@@ -261,7 +261,7 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
// duplicated and fused into both reduces.
TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
Add {
lhs = f32[] parameter(0)
@@ -292,7 +292,7 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {
// is *not* duplicated and fused into both reduces, because we say that integer
// division is not cheap.
TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
Add {
lhs = s32[] parameter(0)
@@ -317,7 +317,7 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
}
TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY NoOutputFusion {
alpha = f32[] constant(3)
@@ -371,7 +371,7 @@ static StatusOr<const HloInstruction*> FindHloInstruction(
TEST_F(InstructionFusionTest, MultiOutputFusion) {
// sub --> add --> tuple
// \---------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -403,7 +403,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) {
TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) {
// tanh --> add --> tuple
// \---------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -424,7 +424,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) {
TEST_F(InstructionFusionTest, MultiOutputFusion2) {
// sub --> add1 --\--------\
// \----------> add2 --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -457,7 +457,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion2) {
TEST_F(InstructionFusionTest, MultiOutputFusion3) {
// sub --> add1 ----\--------\
// \ --> add2 --> add3 --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -492,7 +492,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion3) {
TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) {
// sub --> mul ---\
// \--> call --> add --> tuple
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
c = f32[] constant(42)
@@ -527,7 +527,7 @@ TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) {
TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) {
// sub[2,3] --> add[4,3] --> tuple([2,3], [4,3])
// \-------------------------/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[2,3]{1,0} parameter(0)
@@ -548,7 +548,7 @@ TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) {
}
TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
add_computation {
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
index ad55728c45..7749201cbc 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase {
return InvalidArgument("Unexpected tuple index instruction : %s",
inst->name().c_str());
} else if (tag == "loop_increment") {
- // Parse the constant which represents the loop induction variable
- // increment value.
+ // ParseHloString the constant which represents the loop induction
+ // variable increment value.
TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_));
} else if (tag == "param0" &&
inst != computation_->parameter_instruction(0)) {
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index e8c5ca347b..16db374566 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -32,10 +32,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
@@ -486,7 +486,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) {
}
TEST_F(HloCseTest, CompareComputations) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule m
add_computation {
diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc
index f29aac29c0..5553ddb153 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc
@@ -17,10 +17,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_domain_remover.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -68,7 +68,7 @@ class HloDomainTest : public HloTestBase {
tensorflow::StringPiece hlo_string) {
HloModuleConfig config;
config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- return tools::Parse(hlo_string, config);
+ return ParseHloString(hlo_string, config);
}
};
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
index 4900c813fd..eba80c0f19 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
@@ -29,7 +29,7 @@ using ::testing::ContainsRegex;
class HloExecutionProfileTest : public HloTestBase {};
TEST_F(HloExecutionProfileTest, Basic) {
- auto hlo_module = tools::Parse(R"(
+ auto hlo_module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
lhs = f32[30,30]{1,0} parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index a1a8814384..313033ddad 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -24,11 +24,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -1533,7 +1533,7 @@ ENTRY entry (param: s32[]) -> s32[] {
// Check that deep clones really deep clones every instruction and
// computations, without leaving dangling pointers to the old module.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
std::unique_ptr<HloModule> clone = module->Clone();
for (HloComputation* computation : clone->computations()) {
EXPECT_EQ(computation->parent(), clone.get());
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 350db12653..f0d9fdbc8f 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
+#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include <unordered_map>
@@ -26,9 +26,8 @@ limitations under the License.
#include "tensorflow/core/platform/regexp.h"
namespace xla {
-namespace tools {
-using tensorflow::StringPiece;
+using ::tensorflow::StringPiece;
namespace {
@@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
-StringPiece HloLexer::StringPieceFromPointers(const char* begin,
- const char* end) const {
+tensorflow::StringPiece HloLexer::StringPieceFromPointers(
+ const char* begin, const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
- return StringPiece(begin, end - begin);
+ return tensorflow::StringPiece(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
@@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() {
return TokKind::kAttributeName;
}
- StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
+ tensorflow::StringPiece identifier =
+ StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
#define KEYWORD(STR) \
@@ -332,23 +332,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
line_no_cache_.last_query = ptr;
line_no_cache_.line_no_of_query = line_no;
size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
- if (line_offset == StringPiece::npos) {
+ if (line_offset == tensorflow::StringPiece::npos) {
line_offset = 0;
}
return {line_no, ptr - start - line_offset};
}
-StringPiece HloLexer::GetLine(LocTy loc) const {
+tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
if (!CanDereference(loc)) {
return "LINE OUT OF RANGE";
}
size_t line_start =
StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
- const char* start = line_start == StringPiece::npos
+ const char* start = line_start == tensorflow::StringPiece::npos
? buf_.begin()
: buf_.begin() + line_start + 1;
size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
- const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end;
+ const char* end =
+ line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end;
return StringPieceFromPointers(start, end);
}
@@ -370,7 +371,7 @@ TokKind HloLexer::LexString() {
static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
if (RE2::Consume(&consumable, *escaping_pattern)) {
current_ptr_ = consumable.begin();
- StringPiece raw =
+ tensorflow::StringPiece raw =
StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1);
string error;
if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) {
@@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) {
}
}
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index 27880b9b8a..ceb674f25e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
#include <string>
-#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
+#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -27,9 +27,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace tools {
// Lexer for the HloModule::ToString() format text.
+//
+// This class is meant to be used by hlo_parser.cc. You shouldn't need to use
+// it directly.
class HloLexer {
public:
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
@@ -57,7 +59,7 @@ class HloLexer {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
- int64 GetInt64Val() const {
+ tensorflow::int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
}
@@ -114,7 +116,7 @@ class HloLexer {
TokKind current_kind_;
string str_val_;
Shape shape_val_;
- int64 int64_val_;
+ tensorflow::int64 int64_val_;
double decimal_val_;
struct LineNoCacheTy {
@@ -125,7 +127,6 @@ class HloLexer {
mutable LineNoCacheTy line_no_cache_{nullptr, 0};
};
-} // namespace tools
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
index 8e2e2c7627..0275294a1a 100644
--- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc
@@ -18,12 +18,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -59,7 +59,7 @@ class HloLivenessAnalysisTest : public HloTestBase {
// Test that add instruction at entry root is live at all output shape indices.
TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -75,7 +75,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) {
// Test that a dead add instruction is marked as dead by analysis.
TEST_F(HloLivenessAnalysisTest, DeadAdd) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -94,7 +94,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) {
// Test that all output shape indices of entry root tuple (and defining
// instruction in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -113,7 +113,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) {
// Tests that all outputs of nested tuple and entry root (and defining
// instruction values appearing in its output) are marked live.
TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(1)
@@ -140,7 +140,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) {
// Tests that GTE at entry root of Tuple instruction only propgates liveness
// to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -162,7 +162,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) {
// Tests that GTE at entry root of nested Tuple instruction only propgates
// liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -199,7 +199,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) {
// Tests that GTE of GTE (at entry root) of nested Tuple instruction only
// propgates liveness to the live elements in tuple.
TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleModule
ENTRY SimpleComputation {
constant.1 = s32[] constant(0)
@@ -240,7 +240,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) {
// Test that live/dead while tuple elements are marked live/dead correctly.
TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -291,7 +291,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) {
// Tests that a tuple element live in while.cond computation, propagates
// liveness to while.body.root/while.result/while.operand (where it is unused).
TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -345,7 +345,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) {
// Tests that a use of while.result{0} propagates liveness to
// while.body.param{1} to while.body.root{1}, and then to while.body.param{2}.
TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[], s32[]) parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index dfefad3634..c570b420c2 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -17,8 +17,8 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
@@ -329,7 +329,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
tensorflow::StringPiece sharding) {
return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher(
- xla::tools::ParseSharding(sharding).ValueOrDie()));
+ ParseSharding(sharding).ValueOrDie()));
}
// Verifies that no HloSharding is set for an HLO instruction.
inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 1d10e3c4fe..9a3010cf1f 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace op = xla::testing::opcode_matchers;
@@ -194,7 +195,7 @@ ENTRY DotOperationFusion_TransposeFusion {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1),
diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
index 53b7d0ed39..363862e490 100644
--- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/types.h"
@@ -73,7 +73,7 @@ class HloModuleDceTest : public HloTestBase {
// Tests that a while with all outputs live is unmodified.
TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -110,7 +110,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) {
// Tests a while loop with one unused output (which is used in the while loop
// body by an instruction with side-effects: rng) is unmodified.
TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], f32[]) parameter(0)
@@ -150,7 +150,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) {
// Tests that a while loop with one dead tuple element at {1} has its while
// loop body modified to make that tuple element pass-through the while body.
TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -193,7 +193,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) {
// dead in while.body{1} and at while.result{1}) propgates liveness of this
// tuple element to while.body{1} and at while.result{1}.
TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body {
loop_var.1 = (s32[], s32[]) parameter(0)
@@ -235,7 +235,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) {
// Tests that HloModuleDCE can remove a dead tuple element at index {1} between
// two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
@@ -303,7 +303,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) {
// Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and
// while.2{1}, between two dependent while loops.
TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule SimpleLoop
SimpleLoop.body0 {
loop_var.1 = (s32[3]{0}, s32[]) parameter(0)
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 37a7fbad97..cfe5dace05 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -22,10 +22,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -310,7 +310,7 @@ ENTRY while.v11 {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
DependencyHloOrdering ordering(module.get());
ordering.ToString(); // Shouldn't crash.
}
@@ -347,7 +347,7 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
DependencyHloOrdering ordering(module.get());
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ef10ca4bff..cefc6ff915 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -24,18 +24,17 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
-namespace tools {
namespace {
-using tensorflow::StringPiece;
-using tensorflow::gtl::optional;
-using tensorflow::str_util::Join;
-using tensorflow::str_util::Split;
-using tensorflow::str_util::SplitAndParseAsInts;
-using tensorflow::strings::Printf;
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
+using ::tensorflow::StringPiece;
+using ::tensorflow::gtl::optional;
+using ::tensorflow::str_util::Join;
+using ::tensorflow::str_util::Split;
+using ::tensorflow::str_util::SplitAndParseAsInts;
+using ::tensorflow::strings::Printf;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
const double kF16max = 65504;
@@ -83,11 +82,15 @@ class HloParser {
// Sets the sub-value of literal at the given index to the given value. The
// literal's shape must have the default layout.
- bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal);
- bool SetValueInLiteral(double value, int64 linear_index, Literal* literal);
- bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal);
+ bool SetValueInLiteral(tensorflow::int64 value,
+ tensorflow::int64 linear_index, Literal* literal);
+ bool SetValueInLiteral(double value, tensorflow::int64 linear_index,
+ Literal* literal);
+ bool SetValueInLiteral(bool value, tensorflow::int64 linear_index,
+ Literal* literal);
template <typename LiteralNativeT, typename ParsedElemT>
- bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
+ bool SetValueInLiteralHelper(ParsedElemT value,
+ tensorflow::int64 linear_index,
Literal* literal);
bool ParseOperands(std::vector<HloInstruction*>* operands);
@@ -99,9 +102,9 @@ class HloParser {
// Describes the start, limit, and stride on every dimension of the operand
// being sliced.
struct SliceRanges {
- std::vector<int64> starts;
- std::vector<int64> limits;
- std::vector<int64> strides;
+ std::vector<tensorflow::int64> starts;
+ std::vector<tensorflow::int64> limits;
+ std::vector<tensorflow::int64> strides;
};
// Types of attributes.
@@ -179,13 +182,14 @@ class HloParser {
bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
// Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
- bool ParseDxD(const string& name, std::vector<int64>* result);
+ bool ParseDxD(const string& name, std::vector<tensorflow::int64>* result);
// Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
- bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
+ bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
bool ParseInt64List(const TokKind start, const TokKind end,
- const TokKind delim, std::vector<int64>* result);
+ const TokKind delim,
+ std::vector<tensorflow::int64>* result);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
@@ -197,7 +201,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
- bool ParseInt64(int64* result);
+ bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
@@ -455,7 +459,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
- int64 parameter_number;
+ tensorflow::int64 parameter_number;
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number) ||
@@ -611,7 +615,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kRecv: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
@@ -622,7 +626,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kRecvDone: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -636,7 +640,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kSend: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -647,7 +651,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kSendDone: {
- optional<int64> channel_id;
+ optional<tensorflow::int64> channel_id;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -661,7 +665,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGetTupleElement: {
- optional<int64> index;
+ optional<tensorflow::int64> index;
attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -719,7 +723,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kFft: {
optional<FftType> fft_type;
- optional<std::vector<int64>> fft_length;
+ optional<std::vector<tensorflow::int64>> fft_length;
attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
&fft_length};
@@ -732,7 +736,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kBroadcast: {
- optional<std::vector<int64>> broadcast_dimensions;
+ optional<std::vector<tensorflow::int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&broadcast_dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -744,7 +748,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kConcatenate: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
@@ -770,7 +774,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> reduce_computation;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&reduce_computation};
- optional<std::vector<int64>> dimensions_to_reduce;
+ optional<std::vector<tensorflow::int64>> dimensions_to_reduce;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions_to_reduce};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
@@ -783,7 +787,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReverse: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -827,7 +831,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kDynamicSlice: {
- optional<std::vector<int64>> dynamic_slice_sizes;
+ optional<std::vector<tensorflow::int64>> dynamic_slice_sizes;
attrs["dynamic_slice_sizes"] = {
/*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
@@ -851,7 +855,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kTranspose: {
- optional<std::vector<int64>> dimensions;
+ optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
@@ -865,7 +869,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormTraining: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/3) ||
@@ -881,7 +885,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormInference: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
@@ -898,7 +902,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kBatchNormGrad: {
optional<float> epsilon;
attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
+ optional<tensorflow::int64> feature_index;
attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
&feature_index};
if (!ParseOperands(&operands, /*expected_size=*/5) ||
@@ -969,8 +973,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReducePrecision: {
- optional<int64> exponent_bits;
- optional<int64> mantissa_bits;
+ optional<tensorflow::int64> exponent_bits;
+ optional<tensorflow::int64> mantissa_bits;
attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
&exponent_bits};
attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
@@ -1015,7 +1019,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kHostCompute: {
optional<string> channel_name;
- optional<int64> cost_estimate_ns;
+ optional<tensorflow::int64> cost_estimate_ns;
attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
&channel_name};
attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
@@ -1028,16 +1032,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kDot: {
- optional<std::vector<int64>> lhs_contracting_dims;
+ optional<std::vector<tensorflow::int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
- optional<std::vector<int64>> rhs_contracting_dims;
+ optional<std::vector<tensorflow::int64>> rhs_contracting_dims;
attrs["rhs_contracting_dims"] = {
/*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
- optional<std::vector<int64>> lhs_batch_dims;
+ optional<std::vector<tensorflow::int64>> lhs_batch_dims;
attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&lhs_batch_dims};
- optional<std::vector<int64>> rhs_batch_dims;
+ optional<std::vector<tensorflow::int64>> rhs_batch_dims;
attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
&rhs_batch_dims};
@@ -1069,20 +1073,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kGather: {
- optional<std::vector<int64>> output_window_dims;
+ optional<std::vector<tensorflow::int64>> output_window_dims;
attrs["output_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &output_window_dims};
- optional<std::vector<int64>> elided_window_dims;
+ optional<std::vector<tensorflow::int64>> elided_window_dims;
attrs["elided_window_dims"] = {
/*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims};
- optional<std::vector<int64>> gather_dims_to_operand_dims;
+ optional<std::vector<tensorflow::int64>> gather_dims_to_operand_dims;
attrs["gather_dims_to_operand_dims"] = {/*required=*/true,
AttrTy::kBracedInt64List,
&gather_dims_to_operand_dims};
- optional<int64> index_vector_dim;
+ optional<tensorflow::int64> index_vector_dim;
attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64,
&index_vector_dim};
- optional<std::vector<int64>> window_bounds;
+ optional<std::vector<tensorflow::int64>> window_bounds;
attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List,
&window_bounds};
@@ -1178,8 +1182,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
LocTy loc = lexer_.GetLoc();
bool maximal = false;
bool replicated = false;
- std::vector<int64> devices;
- std::vector<int64> tile_assignment_dimensions;
+ std::vector<tensorflow::int64> devices;
+ std::vector<tensorflow::int64> tile_assignment_dimensions;
Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
@@ -1206,7 +1210,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
}
do {
- int64 dim;
+ tensorflow::int64 dim;
if (!ParseInt64(&dim)) {
return false;
}
@@ -1218,7 +1222,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return false;
}
do {
- int64 device;
+ tensorflow::int64 device;
if (!ParseInt64(&device)) {
return false;
}
@@ -1277,10 +1281,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
}
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
*sharding->mutable_tile_shape() = tile_shape;
- for (int64 dim : tile_assignment_dimensions) {
+ for (tensorflow::int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
- for (int64 device : devices) {
+ for (tensorflow::int64 device : devices) {
sharding->add_tile_assignment_devices(device);
}
}
@@ -1315,40 +1319,50 @@ bool HloParser::ParseInstructionNames(
"expects '}' at the end of instruction name list");
}
-bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
+bool HloParser::SetValueInLiteral(tensorflow::int64 value,
+ tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case S8:
- return SetValueInLiteralHelper<int8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int8>(value, linear_index,
+ literal);
case S16:
- return SetValueInLiteralHelper<int16>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int16>(value, linear_index,
+ literal);
case S32:
- return SetValueInLiteralHelper<int32>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int32>(value, linear_index,
+ literal);
case S64:
- return SetValueInLiteralHelper<int64>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::int64>(value, linear_index,
+ literal);
case U8:
- return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint8>(value, linear_index,
+ literal);
case U16:
- return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint8>(value, linear_index,
+ literal);
case U32:
- return SetValueInLiteralHelper<uint32>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint32>(value, linear_index,
+ literal);
case U64:
- return SetValueInLiteralHelper<uint64>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::uint64>(value, linear_index,
+ literal);
default:
LOG(FATAL) << "unknown integral primitive type "
<< PrimitiveType_Name(shape.element_type());
}
}
-bool HloParser::SetValueInLiteral(double value, int64 linear_index,
+bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
case F16:
- return SetValueInLiteralHelper<half>(value, linear_index, literal);
+ return SetValueInLiteralHelper<Eigen::half>(value, linear_index, literal);
case BF16:
- return SetValueInLiteralHelper<bfloat16>(value, linear_index, literal);
+ return SetValueInLiteralHelper<tensorflow::bfloat16>(value, linear_index,
+ literal);
case F32:
return SetValueInLiteralHelper<float>(value, linear_index, literal);
case F64:
@@ -1359,7 +1373,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index,
}
}
-bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
+bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index,
Literal* literal) {
const Shape& shape = literal->shape();
switch (shape.element_type()) {
@@ -1372,7 +1386,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
}
template <typename LiteralNativeT, typename ParsedElemT>
-bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
+bool HloParser::SetValueInLiteralHelper(ParsedElemT value,
+ tensorflow::int64 linear_index,
Literal* literal) {
// Check that linear_index is in range.
if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
@@ -1484,7 +1499,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
- const int64 rank = ShapeUtil::Rank(shape);
+ const tensorflow::int64 rank = ShapeUtil::Rank(shape);
if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
return false;
}
@@ -1492,8 +1507,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// Create a literal with the given shape in default layout.
*literal = Literal::CreateFromDimensions(shape.element_type(),
AsInt64Slice(shape.dimensions()));
- int64 nest_level = 0;
- int64 linear_index = 0;
+ tensorflow::int64 nest_level = 0;
+ tensorflow::int64 linear_index = 0;
// elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
// the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
// when we are parsing the 2nd '{' (right before '1'), we are seeing a
@@ -1501,14 +1516,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
// the first '}' (right after '3'), it means the sub-array ends, and the
// sub-array is supposed to contain exactly 3 elements, so check if
// elems_seen_per_dim[1] is 3.
- std::vector<int64> elems_seen_per_dim(rank);
+ std::vector<tensorflow::int64> elems_seen_per_dim(rank);
auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
- std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
- elems_seen_per_dim.begin() + dim);
+ std::vector<tensorflow::int64> elems_seen_until_dim(
+ elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim);
return StrCat("[",
Join(elems_seen_until_dim, ",",
- [](string* out, const int64& num_elems) {
- tensorflow::strings::StrAppend(out, num_elems - 1);
+ [](string* out, const tensorflow::int64& num_elems) {
+ StrAppend(out, num_elems - 1);
}),
"]");
};
@@ -1584,7 +1599,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
LocTy loc = lexer_.GetLoc();
- int64 value;
+ tensorflow::int64 value;
if (!ParseInt64(&value)) {
return Error(loc, StrCat("expects integer for primitive type: ",
PrimitiveType_Name(shape.element_type())));
@@ -1624,29 +1639,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
switch (shape.element_type()) {
case PRED:
- return ParseSparseLiteralHelper<uint8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape);
case S8:
- return ParseSparseLiteralHelper<int8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int8>(literal, shape);
case S16:
- return ParseSparseLiteralHelper<int16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int16>(literal, shape);
case S32:
- return ParseSparseLiteralHelper<int32>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int32>(literal, shape);
case S64:
- return ParseSparseLiteralHelper<int64>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::int64>(literal, shape);
case U8:
- return ParseSparseLiteralHelper<uint8>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint8>(literal, shape);
case U16:
- return ParseSparseLiteralHelper<uint16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint16>(literal, shape);
case U32:
- return ParseSparseLiteralHelper<uint32>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint32>(literal, shape);
case U64:
- return ParseSparseLiteralHelper<uint64>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::uint64>(literal, shape);
case F16:
- return ParseSparseLiteralHelper<half>(literal, shape);
+ return ParseSparseLiteralHelper<Eigen::half>(literal, shape);
case F32:
return ParseSparseLiteralHelper<float>(literal, shape);
case BF16:
- return ParseSparseLiteralHelper<bfloat16>(literal, shape);
+ return ParseSparseLiteralHelper<tensorflow::bfloat16>(literal, shape);
case F64:
return ParseSparseLiteralHelper<double>(literal, shape);
default:
@@ -1659,9 +1674,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
template <typename LiteralNativeT>
bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
const Shape& shape) {
- std::vector<int64> index;
+ std::vector<tensorflow::int64> index;
- int64 rank = ShapeUtil::Rank(shape);
+ tensorflow::int64 rank = ShapeUtil::Rank(shape);
*literal = MakeUnique<Literal>(shape);
@@ -1679,7 +1694,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
LocTy index_loc = lexer_.GetLoc();
index.clear();
if (lexer_.GetKind() == TokKind::kInt) {
- int64 single_index = lexer_.GetInt64Val();
+ tensorflow::int64 single_index = lexer_.GetInt64Val();
lexer_.Lex();
if (rank != 1) {
return Error(
@@ -1712,7 +1727,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true);
lexer_.Lex();
} else if (primitive_util::IsIntegralType(shape.element_type())) {
- int64 value_s64;
+ tensorflow::int64 value_s64;
if (!ParseInt64(&value_s64)) {
return Error(value_loc,
StrCat("expects integer for primitive type: ",
@@ -1885,23 +1900,24 @@ bool HloParser::ParseAttributeHelper(
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
case AttrTy::kInt64: {
- int64 result;
+ tensorflow::int64 result;
if (!ParseInt64(&result)) {
return false;
}
- static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
+ static_cast<optional<tensorflow::int64>*>(attr_out_ptr)
+ ->emplace(result);
return true;
}
case AttrTy::kInt32: {
- int64 result;
+ tensorflow::int64 result;
if (!ParseInt64(&result)) {
return false;
}
- if (result != static_cast<int32>(result)) {
+ if (result != static_cast<tensorflow::int32>(result)) {
return Error(attr_loc, "value out of range for int32");
}
- static_cast<optional<int32>*>(attr_out_ptr)
- ->emplace(static_cast<int32>(result));
+ static_cast<optional<tensorflow::int32>*>(attr_out_ptr)
+ ->emplace(static_cast<tensorflow::int32>(result));
return true;
}
case AttrTy::kFloat: {
@@ -1977,12 +1993,12 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kBracedInt64List: {
- std::vector<int64> result;
+ std::vector<tensorflow::int64> result;
if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
&result)) {
return false;
}
- static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
+ static_cast<optional<std::vector<tensorflow::int64>>*>(attr_out_ptr)
->emplace(result);
return true;
}
@@ -2157,7 +2173,7 @@ bool HloParser::ParseConvolutionDimensionNumbers(
<< str;
}
- const int64 rank = lhs_rhs_out[0].length();
+ const tensorflow::int64 rank = lhs_rhs_out[0].length();
if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
return TokenError(
"convolution lhs, rhs, and output must have the same rank");
@@ -2271,7 +2287,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
return false;
}
- std::vector<std::vector<int64>> ranges;
+ std::vector<std::vector<tensorflow::int64>> ranges;
if (lexer_.GetKind() == TokKind::kRbrace) {
// empty
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
@@ -2305,7 +2321,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
// ::= int64_val (delim int64_val)*
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
- std::vector<int64>* result) {
+ std::vector<tensorflow::int64>* result) {
if (!ParseToken(start, StrCat("expects an int64 list starting with ",
TokKindToString(start)))) {
return false;
@@ -2314,7 +2330,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
// empty
} else {
do {
- int64 i;
+ tensorflow::int64 i;
if (!ParseInt64(&i)) {
return false;
}
@@ -2431,7 +2447,8 @@ bool HloParser::ParseString(string* result) {
return true;
}
-bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
+bool HloParser::ParseDxD(const string& name,
+ std::vector<tensorflow::int64>* result) {
LocTy loc = lexer_.GetLoc();
if (!result->empty()) {
return Error(loc,
@@ -2439,7 +2456,7 @@ bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
}
// 1D
if (lexer_.GetKind() == TokKind::kInt) {
- int64 number;
+ tensorflow::int64 number;
if (!ParseInt64(&number)) {
return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
}
@@ -2459,7 +2476,8 @@ bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
return TokenError("expects token type kInt or kDxD");
}
-bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
+bool HloParser::ParseWindowPad(
+ std::vector<std::vector<tensorflow::int64>>* pad) {
LocTy loc = lexer_.GetLoc();
if (!pad->empty()) {
return Error(loc, "sub-attribute 'pad=' already exists");
@@ -2470,7 +2488,7 @@ bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
string str = lexer_.GetStrVal();
std::vector<string> padding_str = Split(str, 'x');
for (int i = 0; i < padding_str.size(); i++) {
- std::vector<int64> low_high;
+ std::vector<tensorflow::int64> low_high;
if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
low_high.size() != 2) {
return Error(loc,
@@ -2494,7 +2512,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
string str = lexer_.GetStrVal();
std::vector<string> padding_str = Split(str, 'x');
for (const auto& padding_dim_str : padding_str) {
- std::vector<int64> padding_dim;
+ std::vector<tensorflow::int64> padding_dim;
if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
(padding_dim.size() != 2 && padding_dim.size() != 3)) {
return Error(loc,
@@ -2516,7 +2534,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) {
optional<string> op_type;
optional<string> op_name;
optional<string> source_file;
- optional<int32> source_line;
+ optional<tensorflow::int32> source_line;
attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
@@ -2603,7 +2621,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
-bool HloParser::ParseInt64(int64* result) {
+bool HloParser::ParseInt64(tensorflow::int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
@@ -2726,8 +2744,8 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
} // namespace
-StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
- const HloModuleConfig& config) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str, const HloModuleConfig& config) {
HloParser parser(str, config);
if (!parser.Run()) {
return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
@@ -2735,9 +2753,10 @@ StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
return parser.ConsumeHloModule();
}
-StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str) {
HloModuleConfig config;
- return Parse(str, config);
+ return ParseHloString(str, config);
}
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
@@ -2759,5 +2778,4 @@ StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
return parser.ParseConvolutionDimensionNumbersOnly();
}
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 902c45cebc..3f3a51215e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -13,28 +13,31 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
-namespace tools {
+
+// For details about the syntax accepted by this parser, see
+// g3doc/hlo_parser.md.
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with the given config.
-StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str,
- const HloModuleConfig& config);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str, const HloModuleConfig& config);
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, parses the string and creates a HloModule with default config.
-StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
+StatusOr<std::unique_ptr<HloModule>> ParseHloString(
+ tensorflow::StringPiece str);
// Parses the result of HloSharding::ToString(), e.g. "{replicated}".
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
@@ -47,7 +50,10 @@ StatusOr<Window> ParseWindow(tensorflow::StringPiece str);
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers(
tensorflow::StringPiece str);
-} // namespace tools
+// ParseHloString sharding from str. str is supposed to contain the body of the
+// sharding, i.e. just the rhs of the "sharding={...}" attribute string.
+StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str);
+
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 3c5957b96a..9a18b4f845 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include <string>
#include "tensorflow/compiler/xla/window_util.h"
@@ -23,10 +23,10 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace xla {
-namespace tools {
+
namespace {
-using tensorflow::StringPiece;
+using ::tensorflow::StringPiece;
struct TestData {
string test_name;
@@ -901,12 +901,12 @@ class HloParserTest : public ::testing::Test,
<< "'" << s << "' does not contain '" << expected << "'";
}
- // Expects "ToString(Parse(string)) == string", that is, parses the string,
- // asserts that it succeeded, stringifies the parsed module, and checks that
- // the it equals the original string.
+ // Expects "ToString(ParseHloString(string)) == string", that is, parses the
+ // string, asserts that it succeeded, stringifies the parsed module, and
+ // checks that the it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ(original, result.ValueOrDie()->ToString(
HloPrintOptions().set_print_large_constants(true)));
@@ -917,7 +917,7 @@ class HloParserShortTest : public HloParserTest {
protected:
void ExpectEqualShort() {
const string& original = GetParam().module_string;
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ(original,
result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
@@ -938,13 +938,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
TEST_F(HloParserTest, Empty) {
const string original = "";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -958,7 +958,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -970,7 +970,7 @@ ENTRY %blabla (x: g32[]) -> g32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -983,7 +983,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -994,7 +994,7 @@ ENTRY %blabla (x: f32[]) -> pred[] {
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
}
@@ -1009,7 +1009,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
@@ -1020,7 +1020,7 @@ TEST_F(HloParserTest, ConfigurationField) {
ENTRY %configuration_test() -> s32[] {
%constant = s32[] constant(42), backend_config="foo bar"
})";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_ASSERT_OK(result.status());
EXPECT_EQ("foo bar", result.ValueOrDie()
->entry_computation()
@@ -1036,7 +1036,7 @@ ENTRY %some_2 () -> f32[2] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 1, but sees larger");
@@ -1050,7 +1050,7 @@ ENTRY %some_2x3 () -> f32[2,3] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects nested array in rank 2, but sees 1");
@@ -1064,7 +1064,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"expects 3 elements in the [0]th element");
@@ -1079,7 +1079,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
EXPECT_NE(Status::OK(), result.status());
ExpectHasSubstr(result.status().error_message(),
"is out of range for literal's primitive type F16");
@@ -1093,7 +1093,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
}
)";
- auto result = Parse(original);
+ auto result = ParseHloString(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
@@ -1111,7 +1111,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, InvalidDimLabels) {
@@ -1127,17 +1127,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
)";
+ ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat(
+ prefix, ",dim_labels=00_01_10", suffix))
+ .status()
+ .error_message(),
+ "expects dim labels pattern");
+
ExpectHasSubstr(
- Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix))
+ ParseHloString(tensorflow::strings::StrCat(
+ prefix, ",dim_labels=010_1100->010", suffix))
.status()
.error_message(),
- "expects dim labels pattern");
-
- ExpectHasSubstr(Parse(tensorflow::strings::StrCat(
- prefix, ",dim_labels=010_1100->010", suffix))
- .status()
- .error_message(),
- "must have the same rank");
+ "must have the same rank");
}
TEST_F(HloParserTest, UnexpectedAttribute) {
@@ -1152,7 +1153,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"unexpected attribute \"calls\"");
}
@@ -1168,7 +1169,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"attribute channel_id is expected but not seen");
}
@@ -1184,7 +1185,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"'done' is not defined");
}
@@ -1197,7 +1198,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
@@ -1211,7 +1212,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
}
)";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"expects padding_low and padding_high separated by '_'");
}
@@ -1223,7 +1224,7 @@ ENTRY %test_comma.v4 () -> f32[] {
}
)";
- TF_EXPECT_OK(Parse(original).status());
+ TF_EXPECT_OK(ParseHloString(original).status());
}
TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
@@ -1233,7 +1234,7 @@ ENTRY %CustomCall () -> f32[1] {
%constant = f32[1]{0} constant({12345})
ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"Shape of computation CustomCall, f32[1], is not compatible "
"with that of its root instruction foo, f32[1,2,3]");
}
@@ -1252,7 +1253,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
auto program_layout = module.ValueOrDie()->host_entry_computation_layout();
ASSERT_EQ(program_layout.parameter_count(), 1);
@@ -1275,7 +1276,7 @@ c1 {
c2 {
const2 = f32[1]{0} constant({67890})
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
}
@@ -1286,7 +1287,7 @@ ENTRY consts {
first = f32[1]{0} constant({12345})
last = f32[1]{0} constant({67890})
})";
- auto module = Parse(original);
+ auto module = ParseHloString(original);
TF_ASSERT_OK(module.status());
EXPECT_EQ(
module.ValueOrDie()->entry_computation()->root_instruction()->name(),
@@ -1301,7 +1302,7 @@ ENTRY c1 {
ENTRY c2 {
const2 = f32[1]{0} constant({67890})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"expects only one ENTRY");
}
@@ -1311,7 +1312,7 @@ ENTRY consts {
ROOT const1 = f32[1]{0} constant({12345})
ROOT const2 = f32[1]{0} constant({12345})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
"one computation should have only one ROOT");
}
@@ -1323,7 +1324,7 @@ comp {
comp {
const2 = f32[1]{0} constant({67890})
})";
- ExpectHasSubstr(Parse(original).status().error_message(),
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
R"(was parsing 2:1: error: computation previously defined here
comp {
^)");
@@ -1346,7 +1347,7 @@ ENTRY entry {
ROOT call1 = s32[] call(param), to_apply=tcallb
})";
ExpectHasSubstr(
- Parse(original).status().error_message(),
+ ParseHloString(original).status().error_message(),
"was parsing 8:39: error: instruction does not exist: aparam");
}
@@ -1371,5 +1372,4 @@ TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
}
} // namespace
-} // namespace tools
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index 31e13da0c0..e1f9d8efd4 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -36,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
const DebugOptions& debug_options) {
HloModuleConfig config;
config.set_debug_options(debug_options);
- return tools::Parse(hlo_string, config);
+ return ParseHloString(hlo_string, config);
}
namespace {
@@ -80,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
filename, &hlo_string));
HloModuleConfig config;
config.set_debug_options(debug_options);
- return tools::Parse(hlo_string, config);
+ return ParseHloString(hlo_string, config);
}
HloRunner::HloRunner(se::Platform* platform) {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
index 0bc930f9ea..db7ef6f0d4 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -22,9 +22,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -158,7 +158,7 @@ ENTRY root {
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(module_str));
+ ParseHloString(module_str));
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 94d1a3226b..ee7133689b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -19,11 +19,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -311,10 +311,10 @@ TEST_F(HloShardingTest, OstreamTest) {
EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}");
}
-TEST_F(HloShardingTest, Parse) {
+TEST_F(HloShardingTest, ParseHloString) {
auto check = [](const HloSharding& sharding) {
TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding,
- tools::ParseSharding(sharding.ToString()));
+ ParseSharding(sharding.ToString()));
EXPECT_EQ(sharding, parsed_sharding);
};
check(HloSharding::Replicate());
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h
index 7928bee5c2..533429608b 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/service/hlo_token.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
-#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
#include <string>
@@ -22,9 +22,11 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace xla {
-namespace tools {
// Defines different kinds of tokens in a hlo module string.
+//
+// You shouldn't need to use this directly unless you're using HloLexer
+// directly, and you probably don't need to do that. Use hlo_parser instead.
enum class TokKind {
// Markers
kEof,
@@ -72,7 +74,6 @@ enum class TokKind {
string TokKindToString(TokKind kind);
-} // namespace tools
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
index df109df787..21db233899 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
@@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion {
};
TEST_F(InstructionFusionTest, FuseInstructions) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) {
}
TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
fused_computation {
p1 = f32[4,3] parameter(0)
@@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) {
}
TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY entry_computation {
p0 = f32[4,3]{1,0} parameter(0)
@@ -195,7 +195,7 @@ static int Count(const HloModule& module, HloOpcode op) {
}
TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -220,7 +220,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
//
// p0 -> add -------------------------> sub
// \-> abs1 -> rng -> abs2 -/
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -251,7 +251,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// p0 -> add -------------------------> sub
// \-> abs1 -> log -> abs2 -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -282,7 +282,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// \ \-> add2 -/
// \-> log -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -314,7 +314,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) {
// \------> sub1
// log -/
// \-> send
- module = tools::Parse(R"(
+ module = ParseHloString(R"(
HloModule test_module
ENTRY OutputFusion {
p0 = f32[4,3]{1,0} parameter(0)
@@ -390,7 +390,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) {
TEST_F(InstructionFusionTest,
WideningConvertsAreAlwaysDuplicableIntoConsumers) {
- auto module = tools::Parse(R"(
+ auto module = ParseHloString(R"(
HloModule test_module
ENTRY Test {
p0 = f16[100] parameter(0)
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 7508013199..bf0448a676 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -29,13 +29,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
@@ -651,7 +651,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) {
}
)";
- auto module = tools::Parse(module_str).ValueOrDie();
+ auto module = ParseHloString(module_str).ValueOrDie();
module =
backend()
@@ -691,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) {
}
)";
- auto module = tools::Parse(module_str).ValueOrDie();
+ auto module = ParseHloString(module_str).ValueOrDie();
ComputationLayout computation_layout(
module->entry_computation()->ComputeProgramShape());
Shape param_shape = ShapeUtil::MakeTupleShape(
diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
index 204e8c9920..fef3c132b0 100644
--- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc
+++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
@@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) {
ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two)
}
)";
- TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
const HloInstruction* matched_inst;
HloInstruction* matched_operand;
@@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) {
p0 = f32[] parameter(0)
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation
})";
- TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr));
+ TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr));
auto* root = hlo_module->entry_computation()->root_instruction();
EXPECT_TRUE(Match(
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index f73f1227aa..3139801ea3 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
@@ -69,7 +69,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
@@ -91,7 +91,7 @@ ENTRY entry_computation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
@@ -119,7 +119,7 @@ ENTRY entry_computation {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
@@ -147,7 +147,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
@@ -205,7 +205,7 @@ ENTRY entry_computation {
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
FoldTranspose(module.get());
const HloComputation* callee = module->GetComputationWithName("callee");
diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc
index 754fd8ef16..d33d5bb8f3 100644
--- a/tensorflow/compiler/xla/service/tuple_util_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_util_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -37,7 +37,7 @@ ENTRY entry {
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
index 0d2288d8ea..393e758038 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
@@ -55,7 +55,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -95,7 +95,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -136,7 +136,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
@@ -184,7 +184,7 @@ ENTRY entry {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
WhileLoopConstantSinking{}.Run(module.get()));
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
index e1ec12192f..8831c513ee 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc
@@ -16,9 +16,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc
index bcc545c61d..d79d329721 100644
--- a/tensorflow/compiler/xla/service/while_util_test.cc
+++ b/tensorflow/compiler/xla/service/while_util_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -50,7 +50,7 @@ ENTRY entry {
)";
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
*entry_computation = module->entry_computation();
*param0 = (*entry_computation)->parameter_instruction(0);
@@ -151,7 +151,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* while_body = module->GetComputationWithName("body");
@@ -190,7 +190,7 @@ ENTRY main {
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_string));
+ ParseHloString(hlo_string));
HloComputation* main = module->GetComputationWithName("main");
HloInstruction* while_instr = main->root_instruction();
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index a62d49e9c7..7f6bbe6f87 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -117,11 +117,11 @@ cc_library(
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
"//tensorflow/compiler/xla/service:platform_util",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
@@ -138,8 +138,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_verifier",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -697,8 +697,8 @@ xla_test(
"//tensorflow/compiler/xla:execution_options_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
],
)
@@ -1195,9 +1195,9 @@ xla_test(
],
deps = [
":client_library_test_base",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
@@ -1520,11 +1520,11 @@ xla_test(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index b159887765..c960b3c15f 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
namespace xla {
namespace {
@@ -36,7 +36,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
p = f32[3] parameter(0)
ROOT crs = f32[3] cross-replica-sum(p)
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal = Literal::CreateR1<float>({1, 2, 3});
EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()}));
}
@@ -49,7 +50,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
p1 = f32[2] parameter(1)
ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = Literal::CreateR1<float>({1, 2, 3});
auto literal1 = Literal::CreateR1<float>({10, 20});
EXPECT_EQ(
@@ -68,7 +70,8 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
p1 = f32[2] constant({10, 20})
ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
})";
- auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie();
+ auto module =
+ ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
auto literal0 = Literal::CreateR1<float>({1, 2, 3});
auto literal1 = Literal::CreateR1<float>({10, 20});
EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}),
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 4854c649c1..143ffbdeb4 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -14,12 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/execution_options_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
// NB! TODO(b/74360564): These tests do not test out of bounds behavior since
// that hasn't been specced yet.
@@ -41,7 +41,7 @@ class GatherOperationTest : public HloTestBase {
HloModuleConfig config;
config.set_debug_options(GetDebugOptionsForTest());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
- tools::Parse(hlo_text, config));
+ ParseHloString(hlo_text, config));
EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
}
};
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index 36e19e6507..08ed826c80 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -23,11 +23,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index da4cf4ae0c..c8a05c2e9e 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -67,7 +67,7 @@ HloModule& HloVerifiedTestBase::module() {
void HloVerifiedTestBase::ParseAndVerifyModule(
tensorflow::StringPiece hlo_text) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
- TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text));
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
VerifyModule();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
index c0a2c0ca4c..9052b188ed 100644
--- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <array>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
-#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h"
@@ -73,7 +73,7 @@ ENTRY reduce.1 {
}
)";
- return tools::Parse(hlo_string);
+ return ParseHloString(hlo_string);
}
// TODO(b/72454718): XLA:GPU does not support executing code compiled without
diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD
deleted file mode 100644
index 76f35afd53..0000000000
--- a/tensorflow/compiler/xla/tools/parser/BUILD
+++ /dev/null
@@ -1,73 +0,0 @@
-# Build file for the Hlo parser.
-
-licenses(["notice"]) # Apache 2.0
-
-package(
- default_visibility = [":friends"],
-)
-
-package_group(
- name = "friends",
- includes = [
- "//tensorflow/compiler/xla:friends",
- ],
-)
-
-# Filegroup used to collect source files for dependency checking.
-filegroup(
- name = "c_srcs",
- data = glob([
- "**/*.cc",
- "**/*.h",
- ]),
-)
-
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "hlo_lexer",
- srcs = ["hlo_lexer.cc"],
- hdrs = [
- "hlo_lexer.h",
- "hlo_token.h",
- ],
- deps = [
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
- "//tensorflow/core:regexp_internal",
- ],
-)
-
-cc_library(
- name = "hlo_parser",
- srcs = ["hlo_parser.cc"],
- hdrs = ["hlo_parser.h"],
- deps = [
- ":hlo_lexer",
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-tf_cc_test(
- name = "hlo_parser_test",
- size = "small",
- srcs = ["hlo_parser_test.cc"],
- deps = [
- ":hlo_parser",
- "//tensorflow/compiler/xla:window_util",
- "//tensorflow/core:lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)