diff options
Diffstat (limited to 'tensorflow/compiler/xla/service')
352 files changed, 7815 insertions, 4785 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a65bdebf51..4aef093b04 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -99,6 +99,7 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", @@ -175,6 +176,9 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -226,6 +230,7 @@ cc_library( hdrs = ["hlo_evaluator.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_query", ":shape_inference", "//tensorflow/compiler/xla:literal", @@ -237,6 +242,11 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -263,6 +273,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -311,6 +322,10 @@ cc_library( "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -337,7 +352,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -389,7 +404,8 @@ cc_library( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -449,6 +465,9 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -517,6 +536,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -552,6 +572,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -574,6 +595,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -615,6 +638,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], alwayslink = 1, ) @@ -647,6 +673,9 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -669,6 +698,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -719,6 +749,9 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -736,6 +769,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:ptr_util", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -766,6 +800,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -813,6 +849,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -831,6 +869,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -847,6 +887,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -864,6 +905,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -874,6 +917,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -908,6 +952,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -917,12 +963,14 @@ tf_cc_test( deps = [ ":buffer_liveness", ":hlo", + ":hlo_dataflow_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -950,6 +998,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -977,6 +1028,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -996,6 +1048,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1031,6 +1085,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1049,6 +1104,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1059,12 +1115,15 @@ cc_library( deps = [ ":hlo", ":hlo_casting_utils", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:optional", ], ) @@ -1074,6 +1133,7 @@ cc_library( hdrs = ["hlo_module_group_util.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_group_metadata", ":hlo_reachability", "//tensorflow/compiler/xla:status", @@ -1082,6 +1142,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1101,6 +1163,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1108,17 +1171,18 @@ tf_cc_test( name = "hlo_scheduling_test", srcs = ["hlo_scheduling_test.cc"], deps = [ - ":buffer_value", ":heap_simulator", ":hlo", + ":hlo_dce", ":hlo_ordering", + ":hlo_parser", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1142,6 +1206,7 @@ cc_library( ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1167,6 +1232,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1181,6 +1247,9 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1198,6 +1267,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -1216,6 +1286,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -1231,6 +1302,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1245,6 +1317,7 @@ cc_library( ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1267,6 +1340,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1276,6 +1350,7 @@ cc_library( hdrs = ["algebraic_simplifier.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_creation_utils", ":hlo_pass", ":hlo_query", @@ -1289,6 +1364,10 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -1298,6 +1377,7 @@ tf_cc_test( deps = [ ":algebraic_simplifier", ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_pass", "//tensorflow/compiler/xla:literal", @@ -1312,6 +1392,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1323,8 +1405,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1377,6 +1458,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1414,6 +1496,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1439,8 +1523,7 @@ cc_library( deps = [ ":hlo", ":hlo_evaluator", - "//tensorflow/compiler/xla:literal", - "//tensorflow/core:lib", + "@com_google_absl//absl/types:optional", ], ) @@ -1455,6 +1538,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -1468,6 +1553,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/strings", ], ) @@ -1582,6 +1668,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -1602,6 +1689,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1635,6 +1723,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", ], ) @@ -1654,6 +1743,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], alwayslink = True, # Contains per-platform computation placer registration ) @@ -1667,6 +1758,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -1744,6 +1837,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", ], ) @@ -1758,6 +1853,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1789,6 +1885,8 @@ tf_cc_binary( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1805,6 +1903,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -1820,6 +1919,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1847,6 +1947,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", ], ) @@ -1864,6 +1965,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1882,6 +1985,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1923,6 +2029,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -1959,6 +2067,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -1979,6 +2088,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2016,6 +2126,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", ], ) @@ -2028,7 +2139,6 @@ cc_library( ":hlo_dataflow_analysis", ":logical_buffer", ":logical_buffer_analysis", - "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -2036,6 +2146,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2086,6 +2200,9 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2108,6 +2225,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2175,7 +2293,10 @@ cc_library( ":hlo_pass", ":shape_inference", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -2212,13 +2333,16 @@ cc_library( ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", - ":tuple_simplifier", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2258,6 +2382,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2339,6 +2464,9 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -2376,6 +2504,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2392,6 +2521,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2402,6 +2532,7 @@ tf_cc_test( ":hlo", ":hlo_constant_folding", ":hlo_matchers", + ":hlo_parser", ":hlo_pass", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -2423,6 +2554,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2437,6 +2569,7 @@ cc_library( "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -2497,6 +2630,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/memory", ], ) @@ -2552,6 +2686,7 @@ cc_library( hdrs = ["elemental_ir_emitter.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_module_config", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -2560,11 +2695,14 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:transform_utils", ], @@ -2596,10 +2734,11 @@ cc_library( ":computation_layout", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -2612,6 +2751,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2648,8 +2788,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -2683,6 +2823,9 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], alwayslink = 1, ) @@ -2699,6 +2842,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -2780,9 +2924,9 @@ cc_library( hdrs = ["stream_pool.h"], deps = [ "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -2880,6 +3024,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", + "@com_google_absl//absl/memory", ], ) @@ -2926,7 +3071,8 @@ cc_library( ":hlo_creation_utils", ":tuple_util", "//tensorflow/compiler/xla:literal_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -2940,6 +3086,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2955,6 +3102,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2982,6 +3131,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -3015,13 +3165,13 @@ cc_library( cc_library( name = "source_map_util", - srcs = ["source_map_util.cc"], + srcs = [], hdrs = ["source_map_util.h"], deps = [ ":executable", "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3036,6 +3186,10 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:ptr_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -3067,8 +3221,11 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", ], ) @@ -3077,11 +3234,13 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", # fixdeps: keep + "@com_google_absl//absl/strings", ], ) @@ -3100,6 +3259,8 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f7812d9661..19bb4da9a6 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -22,13 +22,19 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -41,7 +47,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -266,7 +271,7 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { StatusOr<HloInstruction*> OptimizeDotOfConcat(HloInstruction* dot); StatusOr<HloInstruction*> OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); StatusOr<HloInstruction*> OptimizeDotOfGather(HloInstruction* dot); @@ -540,7 +545,7 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // If a literal is all the same element replace it with a scalar broadcast. if (ShapeUtil::ElementsIn(constant->shape()) > 1 && constant->literal().IsAllFirst()) { - std::unique_ptr<Literal> unique_scalar = MakeUnique<Literal>( + std::unique_ptr<Literal> unique_scalar = absl::make_unique<Literal>( LiteralUtil::GetFirstScalarLiteral(constant->literal())); HloInstruction* scalar = computation_->AddInstruction( HloInstruction::CreateConstant(std::move(unique_scalar))); @@ -827,18 +832,18 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcat( TF_ASSIGN_OR_RETURN( HloInstruction * optimized_lhs_concat, - OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + OptimizeDotOfConcatHelper(*dot, lhs, lhs_contracting_dim, rhs, rhs_contracting_dim, /*swapped=*/false)); if (optimized_lhs_concat) { return optimized_lhs_concat; } - return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + return OptimizeDotOfConcatHelper(*dot, rhs, rhs_contracting_dim, lhs, lhs_contracting_dim, /*swapped=*/true); } StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( - const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + const HloInstruction& dot, HloInstruction* lhs, int64 lhs_contracting_dim, HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && lhs->concatenate_dimension() == lhs_contracting_dim && @@ -937,11 +942,12 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( } auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); + new_dot->set_precision_config(dot.precision_config()); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_shape, HloOpcode::kAdd, add_result, new_dot)); + dot.shape(), HloOpcode::kAdd, add_result, new_dot)); } else { add_result = new_dot; } @@ -1040,6 +1046,7 @@ StatusOr<HloInstruction*> AlgebraicSimplifierVisitor::OptimizeDotOfGather( auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( memoized_shape, left_operand, right_operand, dnums)); + memoized_inst->set_precision_config(dot->precision_config()); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1137,6 +1144,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers)); + new_dot->set_precision_config(dot->precision_config()); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -1232,7 +1240,7 @@ namespace { // return value = {1, 3} // // Precondition: input_dim_indices is sorted. -std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( +absl::optional<std::vector<int64>> ReshapeLeavesDimensionsUnmodified( const HloInstruction* hlo, tensorflow::gtl::ArraySlice<int64> input_dim_indices) { CHECK_EQ(HloOpcode::kReshape, hlo->opcode()); @@ -1252,11 +1260,11 @@ std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified( } if (i >= unmodified_dims.size() || unmodified_dims[i].first != input_dim_index) { - return std::make_pair(false, std::vector<int64>()); + return absl::nullopt; } output_dim_indices.push_back(unmodified_dims[i].second); } - return std::make_pair(true, output_dim_indices); + return output_dim_indices; } // Returns true if the output of "instruction" is a permutation of the @@ -1385,6 +1393,15 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } + // broadcast(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + return ReplaceWithNewInstruction( + broadcast, + HloInstruction::CreateIota( + broadcast->shape(), + dims[Cast<HloIotaInstruction>(operand)->iota_dimension()])); + } + // Merge two consecutive broadcasts into a single one. if (operand->opcode() == HloOpcode::kBroadcast) { std::vector<int64> new_dimensions; @@ -1713,12 +1730,25 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( reshape, reshape->operand(0)->dimensions()); - if (opt_dims.first) { + if (opt_dims.has_value()) { return ReplaceWithNewInstruction( reshape, HloInstruction::CreateBroadcast( reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0), - opt_dims.second)); + *opt_dims)); + } + } + + // reshape(iota) -> iota. + if (operand->opcode() == HloOpcode::kIota) { + auto* iota = Cast<HloIotaInstruction>(operand); + auto opt_dims = + ReshapeLeavesDimensionsUnmodified(reshape, {iota->iota_dimension()}); + if (opt_dims.has_value()) { + CHECK_EQ(opt_dims->size(), 1); + return ReplaceWithNewInstruction( + reshape, + HloInstruction::CreateIota(reshape->shape(), opt_dims->front())); } } @@ -1752,8 +1782,8 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { } auto is_unstrided_slice = [](const HloInstruction* hlo) { - return c_all_of(hlo->slice_strides(), - [](int64 stride) { return stride == 1; }); + return absl::c_all_of(hlo->slice_strides(), + [](int64 stride) { return stride == 1; }); }; if (slice->operand(0)->opcode() == HloOpcode::kSlice && is_unstrided_slice(slice) && is_unstrided_slice(slice->operand(0))) { @@ -1930,7 +1960,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { // This should make fusion easier or use less memory bandwidth in the unfused // case. if (arg->opcode() == HloOpcode::kConcatenate && - c_linear_search(reduce->dimensions(), arg->concatenate_dimension())) { + absl::c_linear_search(reduce->dimensions(), + arg->concatenate_dimension())) { HloInstruction* old_reduce = nullptr; for (HloInstruction* operand : arg->operands()) { HloInstruction* new_reduce = computation_->AddInstruction( @@ -1983,9 +2014,9 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() - << (convert != nullptr ? tensorflow::strings::StrCat( - "\nvia convert: ", convert->ToString()) - : ""); + << (convert != nullptr + ? absl::StrCat("\nvia convert: ", convert->ToString()) + : ""); // Do not fold interior padding into ReduceWindow since the backends do not // support it. @@ -2294,6 +2325,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); + dot->set_precision_config(convolution->precision_config()); + return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index c48196e861..b864c372fa 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -47,7 +47,7 @@ class AlgebraicSimplifier : public HloPassInterface { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; - tensorflow::StringPiece name() const override { return "algsimp"; } + absl::string_view name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 5837391d75..1900a05750 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -18,11 +18,15 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" @@ -34,13 +38,12 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -using ::testing::ElementsAre; namespace xla { namespace { +using ::testing::ElementsAre; + namespace op = xla::testing::opcode_matchers; AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { @@ -51,7 +54,12 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; +class AlgebraicSimplifierTest : public HloVerifiedTestBase { + public: + AlgebraicSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -1820,6 +1828,105 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { op::Reshape(op::Broadcast(param))); } +TEST_F(AlgebraicSimplifierTest, IotaAndReshapeMerged) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction(HloInstruction::CreateIota( + ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), 2)); + Shape result_shape = ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}); + builder.AddInstruction(HloInstruction::CreateReshape(result_shape, iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), result_shape)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x1_3) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 1}), 1)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), iota)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4_6x1x1x4) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(computation->root_instruction()) + ->iota_dimension(), + 3); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_1_3x2x1_6x1x1x1) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 1}), 2)); + builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Iota()); + const int64 iota_dim = + Cast<HloIotaInstruction>(computation->root_instruction()) + ->iota_dimension(); + EXPECT_THAT(iota_dim, ::testing::AnyOf(1, 2, 3)); +} + +TEST_F(AlgebraicSimplifierTest, IotaAndReshape_4_3x2x4x2_6x8) { + HloComputation::Builder builder(TestName()); + auto iota = builder.AddInstruction( + HloInstruction::CreateIota(ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), 2)); + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6, 8}), iota)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Iota())); +} + TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -2037,7 +2144,7 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { // Builds a convolution from <options> and runs algebraic simplification on // the computation. Returns a string description of the result of // simplification. - auto build_and_simplify = [&options]() -> string { + auto build_and_simplify = [&]() -> string { HloComputation::Builder b(TestName()); Window window; @@ -2143,9 +2250,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { root->operand(0)->opcode() == HloOpcode::kDot) { auto lhs_shape = root->operand(0)->operand(0)->shape(); auto rhs_shape = root->operand(0)->operand(1)->shape(); - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ", - tensorflow::str_util::Join(rhs_shape.dimensions(), "x")); + return absl::StrCat(absl::StrJoin(lhs_shape.dimensions(), "x"), " DOT ", + absl::StrJoin(rhs_shape.dimensions(), "x")); } return "UNEXPECTED CHANGE"; }; @@ -2648,6 +2754,47 @@ TEST_F(AlgebraicSimplifierTest, MergeBroadcasts2) { EXPECT_THAT(root->dimensions(), ElementsAre(1, 3)); } +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota) { + HloComputation::Builder builder(TestName()); + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r2f32, 1)); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 2, 2}); + builder.AddInstruction(HloInstruction::CreateBroadcast(r3f32, iota, {0, 2})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2); +} + +// Test that a broadcast of an iota can be merged to one iota. +TEST_F(AlgebraicSimplifierTest, MergeBroadcastAndIota2) { + HloComputation::Builder builder(TestName()); + Shape r3f32 = ShapeUtil::MakeShape(F32, {2, 5, 3}); + HloInstruction* iota = + builder.AddInstruction(HloInstruction::CreateIota(r3f32, 1)); + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 2, 5, 3}); + builder.AddInstruction( + HloInstruction::CreateBroadcast(r4f32, iota, {1, 2, 3})); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Iota()); + EXPECT_EQ(Cast<HloIotaInstruction>(root)->iota_dimension(), 2); +} + struct PadReduceWindowEffectiveBroadcastCase { std::vector<int64> input_spatials; std::vector<int64> symmetric_pad_spatials; @@ -2660,11 +2807,10 @@ struct PadReduceWindowEffectiveBroadcastCase { bool should_become_broadcast; string ToTestCaseName() const { - return tensorflow::strings::StrCat( - tensorflow::str_util::Join(input_spatials, ","), ";", - tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", - tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, - ";", should_become_broadcast); + return absl::StrCat(absl::StrJoin(input_spatials, ","), ";", + absl::StrJoin(symmetric_pad_spatials, ","), ";", + absl::StrJoin(reduce_window_spatials, ","), ";", + prepend_a, ";", should_become_broadcast); } }; @@ -2852,7 +2998,12 @@ struct DotOfConcatTestSpec { class DotOfConcatSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfConcatTestSpec> {}; + public ::testing::WithParamInterface<DotOfConcatTestSpec> { + public: + DotOfConcatSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Test that we transform // dot(const, concat(A, B, C)) @@ -3025,7 +3176,12 @@ struct DotOfGatherTestSpec { class DotOfGatherSimplificationTest : public HloVerifiedTestBase, - public ::testing::WithParamInterface<DotOfGatherTestSpec> {}; + public ::testing::WithParamInterface<DotOfGatherTestSpec> { + public: + DotOfGatherSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // input: dot(DS(ctA), ctB)) // where DS(ctA) = DS({M x K}, {s, 0}, {1, K}) and ctB = {K x N}. diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index 51ebc4763b..1ed6142dce 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -17,15 +17,15 @@ limitations under the License. #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -69,8 +69,7 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal( return InvalidArgument( "AllocationTracker for platform %s cannot register buffer from " "platform %s", - backend_->platform()->Name().c_str(), - shaped_buffer.platform()->Name().c_str()); + backend_->platform()->Name(), shaped_buffer.platform()->Name()); } } @@ -91,8 +90,9 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal( // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer // into a regular ShapedBuffer, which is stored in // handle_to_shaped_buffers_. - handle_to_shaped_buffers_[handle].emplace_back(MakeUnique<ShapedBuffer>( - ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); + handle_to_shaped_buffers_[handle].emplace_back( + absl::make_unique<ShapedBuffer>( + ReleaseIfScopedShapedBuffer(std::move(shaped_buffer)))); } GlobalDataHandle result; @@ -124,7 +124,7 @@ Status AllocationTracker::Unregister(const GlobalDataHandle& data) { // "handle does not exist". auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } for (auto& shaped_buffer : it->second) { @@ -143,7 +143,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple( // the same for all buffers across replicas. const ShapedBuffer* shaped_buffer = replicated_buffers[0]; if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { - return InvalidArgument("global data handle %lld is not a tuple", + return InvalidArgument("global data handle %d is not a tuple", data.handle()); } // If the on-host representation is a tuple, then the on-device one should be @@ -200,14 +200,14 @@ StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal( VLOG(2) << "resolve:" << data.handle(); auto it = handle_to_shaped_buffers_.find(data.handle()); if (it == handle_to_shaped_buffers_.end()) { - return NotFound("no allocation record for global data handle: %lld", + return NotFound("no allocation record for global data handle: %d", data.handle()); } std::vector<const ShapedBuffer*> replicated_buffers; for (const auto& shaped_buffer : it->second) { if (shaped_buffer == nullptr) { - return InvalidArgument( - "global data handle %lld was previously deallocated", data.handle()); + return InvalidArgument("global data handle %d was previously deallocated", + data.handle()); } replicated_buffers.push_back(shaped_buffer.get()); } diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index d12be3e007..a6889cb171 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -127,8 +128,8 @@ Backend::Backend( } } // Create a memory allocator for the valid stream executors. - memory_allocator_ = - MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors); + memory_allocator_ = absl::make_unique<StreamExecutorMemoryAllocator>( + platform, stream_executors); CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; @@ -176,7 +177,7 @@ StatusOr<se::StreamExecutor*> Backend::stream_executor( } } return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); + device_name(device_ordinal)); } StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a, diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index 1bc3796fa4..4a6a78daf0 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -21,6 +21,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -130,7 +130,7 @@ class Backend { // Return a string identifier for the given device, eg: "GPU:3". string device_name(int device_ordinal) const { - return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal); + return absl::StrCat(platform_->Name(), ":", device_ordinal); } // Returns true if the devices with the given ordinals are equivalent from diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index 2099916509..a16b85a0a5 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -63,6 +64,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); + new_dot->set_precision_config(batch_dot->precision_config()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); @@ -76,7 +78,7 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( return true; } -tensorflow::StringPiece BatchDotSimplification::name() const { +absl::string_view BatchDotSimplification::name() const { return "batch-dot-simplification"; } @@ -84,10 +86,10 @@ StatusOr<bool> BatchDotSimplification::Run(HloModule* module) { bool changed = false; std::vector<HloInstruction*> dot_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), - [](HloInstruction* instr) { - return instr->opcode() == HloOpcode::kDot; - }); + absl::c_copy_if(computation->instructions(), std::back_inserter(dot_instrs), + [](HloInstruction* instr) { + return instr->opcode() == HloOpcode::kDot; + }); } for (HloInstruction* dot_instr : dot_instrs) { TF_ASSIGN_OR_RETURN(bool elided_batch_dim_from_one, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index c0ca8d8eba..79d37f08d3 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -28,7 +28,7 @@ namespace xla { class BatchDotSimplification : public HloPassInterface { public: StatusOr<bool> Run(HloModule* module) override; - tensorflow::StringPiece name() const override; + absl::string_view name() const override; private: StatusOr<bool> ElideDegenerateBatchDimensionFromBatchDot( diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc index 38f1a5d3a6..b342acb025 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification_test.cc @@ -24,7 +24,12 @@ namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloVerifiedTestBase {}; +class BatchDotSimplificationTest : public HloVerifiedTestBase { + public: + BatchDotSimplificationTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index c4cd60c120..01931b2d02 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -43,7 +43,7 @@ namespace xla { namespace { -using tensorflow::gtl::optional; +using absl::optional; // BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 7ae202c583..76e32174f3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -36,7 +36,7 @@ class BatchNormExpander : public HloPassInterface { rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; - tensorflow::StringPiece name() const override { return "batchnorm_expander"; } + absl::string_view name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index a725351462..aba0d9bb5b 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index c939838709..5dcd31b83d 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -37,7 +37,7 @@ class BFloat16ConversionFolding : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16ConversionFolding() override = default; - tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + absl::string_view name() const override { return "bfloat16-fold"; } // Run BF16 conversion folding on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 7cf05ca443..6363a21c3b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,8 +235,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + sum, /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b5722..32573ed355 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector<PrimitiveType> operand_types(hlo->operand_count()); @@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 2a60fe0af3..30b6346312 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -31,7 +31,7 @@ class BFloat16Normalization : public HloPassInterface { : bfloat16_support_(bfloat16_support) {} ~BFloat16Normalization() override = default; - tensorflow::StringPiece name() const override { return "bf16-normalization"; } + absl::string_view name() const override { return "bf16-normalization"; } // Run BF16 normalization on the given computation. Returns whether the // computation was changed. @@ -54,7 +54,7 @@ class BFloat16MixedPrecisionRemoval : public HloPassInterface { ~BFloat16MixedPrecisionRemoval() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "bf16-mixed-precision-removal"; } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index f9f1f64998..b08705d4c2 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -76,7 +76,8 @@ class BFloat16NormalizationTest : public HloTestBase { StatusOr<bool> result = normalization.Run(module); EXPECT_IS_OK(result.status()); - HloVerifier verifier(/*allow_mixed_precision=*/true); + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); EXPECT_IS_OK(verifier.Run(module).status()); return result.ValueOrDie(); @@ -251,8 +252,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"", - /*all_reduce_id=*/tensorflow::gtl::nullopt)); + /*replica_groups=*/{}, /*barrier=*/"", + /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 02b8cad089..1ee64971ab 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -64,9 +64,7 @@ class BFloat16Propagation : public HloPassInterface { ~BFloat16Propagation() override = default; - tensorflow::StringPiece name() const override { - return "bfloat16-propagation"; - } + absl::string_view name() const override { return "bfloat16-propagation"; } // Runs the pass on the given module. Returns whether the module was changed // (precision reductions were added). diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index cfd26fc778..b11f15ec7b 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,8 +22,10 @@ limitations under the License. #include <ostream> #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -36,20 +38,15 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { +namespace { +using absl::StrAppend; +using absl::StrAppendFormat; using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; -using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; - -namespace { template <typename T> string ColocatedBufferSetsToString(const T& container, const char* title) { @@ -107,7 +104,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s has conflicting allocation requirements (global " "and thread-local)", - computation->name().c_str()); + computation->name()); } if (is_thread_local) { @@ -130,7 +127,7 @@ Status GatherComputationsByAllocationType( return InvalidArgument( "computation %s cannot contain call/while op because it " "requires thread-local buffer allocations", - computation->name().c_str()); + computation->name()); } worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. @@ -147,9 +144,8 @@ Status GatherComputationsByAllocationType( true)); // Thread local. break; default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); + return InternalError("Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode())); } } } @@ -236,8 +232,8 @@ size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { } string BufferAllocation::Slice::ToString() const { - return tensorflow::strings::StrCat("{index:", index(), ", offset:", offset_, - ", size:", size_, "}"); + return absl::StrCat("{index:", index(), ", offset:", offset_, + ", size:", size_, "}"); } BufferAllocation::Slice BufferAllocation::GetSlice( @@ -298,7 +294,7 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); + StrAppendFormat(&output, "allocation %d: %p, size %d", index_, this, size()); if (color().value() != 0) { StrAppend(&output, ", color ", color().value()); } @@ -330,11 +326,10 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - StrAppend(&output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, absl::StrFormat( + " %s [%d,%d]: %s\n", buffer->ToString(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()))); } return output; } @@ -427,7 +422,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice( return FailedPrecondition( "BufferAllocation::Slice for instruction %s at index %s cannot " "be determined at compile-time.", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } } else { VLOG(3) << "No allocation"; @@ -436,7 +431,7 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueSlice( if (result.allocation() == nullptr) { return FailedPrecondition( "BufferAllocation::Slice not assigned for instruction %s at index %s", - instruction->name().c_str(), index.ToString().c_str()); + instruction->name(), index.ToString()); } return result; } @@ -627,7 +622,7 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.total_allocation_bytes += allocation.size(); } - // Only compute total fragmentation if all computations are sequential. + // Only compute total fragmentation if all computations have schedules. SequentialHloOrdering::HloModuleSequence module_sequence; for (const auto& computation : module_->computations()) { const std::vector<const HloInstruction*>* sequence = @@ -648,39 +643,38 @@ Status BufferAssignment::ComputeSummaryStats() { string BufferAssignment::Stats::ToString() const { string s; - Appendf(&s, "BufferAssignment stats:\n"); - Appendf(&s, " parameter allocation: %10s\n", - HumanReadableNumBytes(parameter_allocation_bytes).c_str()); - Appendf(&s, " constant allocation: %10s\n", - HumanReadableNumBytes(constant_allocation_bytes).c_str()); - Appendf(&s, " maybe_live_out allocation: %10s\n", - HumanReadableNumBytes(maybe_live_out_allocation_bytes).c_str()); - Appendf(&s, " preallocated temp allocation: %10s\n", - HumanReadableNumBytes(preallocated_temp_allocation_bytes).c_str()); + StrAppendFormat(&s, "BufferAssignment stats:\n"); + StrAppendFormat(&s, " parameter allocation: %10s\n", + HumanReadableNumBytes(parameter_allocation_bytes)); + StrAppendFormat(&s, " constant allocation: %10s\n", + HumanReadableNumBytes(constant_allocation_bytes)); + StrAppendFormat(&s, " maybe_live_out allocation: %10s\n", + HumanReadableNumBytes(maybe_live_out_allocation_bytes)); + StrAppendFormat(&s, " preallocated temp allocation: %10s\n", + HumanReadableNumBytes(preallocated_temp_allocation_bytes)); if (preallocated_temp_fragmentation_bytes >= 0) { const double percent = 100. * preallocated_temp_fragmentation_bytes / preallocated_temp_allocation_bytes; - Appendf( + StrAppendFormat( &s, " preallocated temp fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(preallocated_temp_fragmentation_bytes).c_str(), - percent); + HumanReadableNumBytes(preallocated_temp_fragmentation_bytes), percent); } - Appendf(&s, " total allocation: %10s\n", - HumanReadableNumBytes(total_allocation_bytes).c_str()); + StrAppendFormat(&s, " total allocation: %10s\n", + HumanReadableNumBytes(total_allocation_bytes)); if (total_fragmentation_bytes >= 0) { const double percent = 100. * total_fragmentation_bytes / total_allocation_bytes; - Appendf(&s, " total fragmentation: %10s (%.2f%%)\n", - HumanReadableNumBytes(total_fragmentation_bytes).c_str(), percent); + StrAppendFormat(&s, " total fragmentation: %10s (%.2f%%)\n", + HumanReadableNumBytes(total_fragmentation_bytes), percent); } return s; } string BufferAssignment::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "BufferAssignment:\n"); + absl::StrAppend(&output, "BufferAssignment:\n"); for (auto& allocation : allocations_) { - tensorflow::strings::StrAppend(&output, allocation.ToString()); + absl::StrAppend(&output, allocation.ToString()); } return output; } @@ -1100,8 +1094,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( - MakeUnique<LazyBestFitHeap>(alignment)), + HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>( + absl::make_unique<LazyBestFitHeap>(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1130,11 +1124,12 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>( - MakeUnique<LazyBestFitHeap>(alignment)), - *computation, *instruction_sequence, - assignment->points_to_analysis(), - assignment->buffer_size_, options)); + HeapSimulator::Run( + absl::make_unique<DecreasingSizeRunsHeap>( + absl::make_unique<LazyBestFitHeap>(alignment)), + *computation, *instruction_sequence, + assignment->points_to_analysis(), assignment->buffer_size_, + options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1646,7 +1641,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( XLA_VLOG_LINES(3, liveness->ToString()); XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString()); - // Can't use MakeUnique because BufferAssignment constructor is private. + // Can't use absl::make_unique because BufferAssignment constructor is + // private. std::unique_ptr<BufferAssignment> assignment( new BufferAssignment(module, std::move(liveness), std::move(buffer_size), std::move(color_alignment))); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index eccb146a0d..52abda16c4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -87,7 +87,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique<DependencyHloOrdering>(module), + module, absl::make_unique<DependencyHloOrdering>(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -98,7 +98,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr<BufferAssignment> RunBufferAssignmentNoBuffersForConstants( HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique<DependencyHloOrdering>(module), + module, absl::make_unique<DependencyHloOrdering>(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -109,7 +109,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr<BufferAssignment> RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, xla::MakeUnique<DependencyHloOrdering>(module), + module, absl::make_unique<DependencyHloOrdering>(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -127,7 +127,8 @@ class BufferAssignmentTest : public HloTestBase { instruction_sequence.end()); return BufferAssigner::Run( module, - xla::MakeUnique<SequentialHloOrdering>(module, module_sequence), + absl::make_unique<SequentialHloOrdering>(module, + module_sequence), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1769,7 +1770,8 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, xla::MakeUnique<SequentialHloOrdering>(module, sequence), + module, + absl::make_unique<SequentialHloOrdering>(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2083,7 +2085,7 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { auto assignment, BufferAssigner::Run( module.get(), - xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence), + absl::make_unique<SequentialHloOrdering>(module.get(), sequence), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, @@ -2340,7 +2342,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto assignment = BufferAssigner::Run( module.get(), - xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence), + absl::make_unique<SequentialHloOrdering>(module.get(), sequence), ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true) diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 810d597e73..9b2783a214 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,8 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -75,27 +75,25 @@ Status BufferLiveness::Analyze() { string BufferLiveness::ToString() const { std::vector<string> pieces; - pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):", - module_->name().c_str())); + pieces.push_back( + absl::StrFormat("BufferLiveness(module=%s):", module_->name())); pieces.push_back("HloOrdering:"); pieces.push_back(hlo_ordering_->ToString()); - pieces.push_back(tensorflow::strings::Printf("Aliased buffers:")); + pieces.push_back("Aliased buffers:"); for (const LogicalBuffer* buffer : aliased_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - pieces.push_back(tensorflow::strings::Printf("Live out buffers:")); + pieces.push_back("Live out buffers:"); for (const LogicalBuffer* buffer : maybe_live_out_buffers_) { - pieces.push_back( - tensorflow::strings::Printf(" %s", buffer->ToString().c_str())); + pieces.push_back(absl::StrFormat(" %s", buffer->ToString())); } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, const LogicalBuffer& b) const { - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a)); - TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(a)); + TF_DCHECK_OK(points_to_analysis_->VerifyBuffer(b)); if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) { return false; diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 4a927b5767..26e26e316d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -18,8 +18,9 @@ limitations under the License. #include <memory> #include <string> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -119,8 +120,8 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +168,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -215,8 +216,8 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -249,8 +250,8 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -293,10 +294,10 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { SequentialHloOrdering::HloModuleSequence module_sequence; std::vector<const HloInstruction*> order = {param, negate, exp, add}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -342,10 +343,10 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { std::vector<const HloInstruction*> order = {param, add, recv, recv_done, send, send_done}; module_sequence.emplace(computation, order); - auto liveness = - BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique<SequentialHloOrdering>( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. @@ -376,8 +377,8 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -412,8 +413,8 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -453,8 +454,8 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { module->AddEntryComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -518,8 +519,8 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -580,8 +581,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { module->AddEmbeddedComputation(builder.Build()); auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique<DependencyHloOrdering>(module.get())) + BufferLiveness::Run( + module.get(), absl::make_unique<DependencyHloOrdering>(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -610,11 +611,8 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { protected: // Builds and runs a computation (see test case computation graphs below). - // Runs BufferLiveness on this computation. - // Returns whether buffer interference is detected between tuple-shaped - // parameter and root instructions at tuple element 1. - bool Run(const bool update_uses_tuple_element1, - const bool fuse_gte0 = false) { + std::unique_ptr<HloModule> BuildModule(const bool update_uses_tuple_element1, + const bool fuse_gte0) { auto builder = HloComputation::Builder(TestName()); // Create param0 Tuple. Shape data_shape = ShapeUtil::MakeShape(F32, {8}); @@ -645,12 +643,12 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); // Create output tuple. - auto tuple_root = builder.AddInstruction( + builder.AddInstruction( HloInstruction::CreateTuple({gte0, dynamic_update_slice})); // Build module and get reference to entry computation. auto module = CreateNewModule(); - module->AddEntryComputation(BuildDummyComputation()); - auto* computation = module->AddEmbeddedComputation(builder.Build()); + module->AddEntryComputation(builder.Build()); + auto* computation = module->entry_computation(); // Create fusion instruction based on number of tuple element 1 users. if (update_uses_tuple_element1) { computation->CreateFusionInstruction( @@ -666,16 +664,39 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { computation->CreateFusionInstruction({gte0}, HloInstruction::FusionKind::kLoop); } + return module; + } + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique<DependencyHloOrdering>(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); } + bool RunWithHloDataflowAnalysis(const bool update_uses_tuple_element1, + const bool fuse_gte0 = false) { + auto module = BuildModule(update_uses_tuple_element1, fuse_gte0); + // Run BufferLiveness on 'module'. + auto dataflow = HloDataflowAnalysis::Run(*module).ConsumeValueOrDie(); + auto hlo_ordering = absl::make_unique<DependencyHloOrdering>(module.get()); + // Return whether or not buffers interference is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + auto tuple_param0 = FindInstruction(module.get(), "param0"); + auto tuple_root = module->entry_computation()->root_instruction(); + return hlo_ordering->MayInterfere( + dataflow->GetUniqueValueAt(tuple_param0, {1}), + dataflow->GetUniqueValueAt(tuple_root, {1}), *dataflow); + } }; // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -693,6 +714,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false)); + EXPECT_FALSE( + RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which aliases @@ -712,6 +735,8 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterference) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { EXPECT_FALSE(Run(/*update_uses_tuple_element1=*/false, /*fuse_gte0=*/true)); + EXPECT_FALSE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/false, + /*fuse_gte0=*/true)); } // Tests that live ranges of buffers Param0[1] and Tuple[1] (which alias fusion) @@ -736,6 +761,7 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, NoInterferenceWithUnrelatedFusion) { // TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); + EXPECT_TRUE(RunWithHloDataflowAnalysis(/*update_uses_tuple_element1=*/true)); } class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { @@ -780,10 +806,10 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEntryComputation(BuildDummyComputation()); module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. - auto liveness = - BufferLiveness::Run( - module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get())) - .ConsumeValueOrDie(); + auto liveness = BufferLiveness::Run( + module.get(), + absl::make_unique<DependencyHloOrdering>(module.get())) + .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); diff --git a/tensorflow/compiler/xla/service/buffer_value.cc b/tensorflow/compiler/xla/service/buffer_value.cc index 2bc556a9e2..fdf822c666 100644 --- a/tensorflow/compiler/xla/service/buffer_value.cc +++ b/tensorflow/compiler/xla/service/buffer_value.cc @@ -17,11 +17,10 @@ limitations under the License. #include <iosfwd> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 985ff30e80..23b2a32709 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,21 +17,21 @@ limitations under the License. #include <queue> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::strings::Appendf; -using ::tensorflow::strings::StrCat; +using absl::StrAppendFormat; +using absl::StrCat; string CallContextToString(CallContext context) { switch (context) { @@ -71,10 +71,10 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { } string CallSite::ToString() const { - return StrCat(instruction()->name(), " calls in context ", - CallContextToString(context()), ": ", - tensorflow::str_util::Join( - called_computations(), ", ", + return StrCat( + instruction()->name(), " calls in context ", + CallContextToString(context()), ": ", + absl::StrJoin(called_computations(), ", ", [](string* out, const HloComputation* computation) { out->append(computation->name()); })); @@ -237,8 +237,8 @@ void CallGraph::SetCallContexts() { /* static */ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) { - // Constructor for CallGraph is private so MakeUnique can't be used. - auto call_graph = WrapUnique<CallGraph>(new CallGraph(module)); + // Constructor for CallGraph is private so absl::make_unique can't be used. + auto call_graph = absl::WrapUnique<CallGraph>(new CallGraph(module)); VLOG(2) << "Building call graph for:"; XLA_VLOG_LINES(2, module->ToString()); @@ -356,20 +356,20 @@ CallGraph::NearestAncestorsInSameComputation(HloInstruction* a, string CallGraph::ToString() const { string out; - Appendf(&out, "Call graph for module %s:\n", module_->name().c_str()); + StrAppendFormat(&out, "Call graph for module %s:\n", module_->name()); for (const CallGraphNode& node : nodes()) { - Appendf(&out, "Computation %s:\n", node.computation()->name().c_str()); - Appendf(&out, " calls:\n"); + StrAppendFormat(&out, "Computation %s:\n", node.computation()->name()); + StrAppendFormat(&out, " calls:\n"); for (const HloComputation* callee : node.callees()) { - Appendf(&out, " %s\n", callee->name().c_str()); + StrAppendFormat(&out, " %s\n", callee->name()); } - Appendf(&out, " called by:\n"); + StrAppendFormat(&out, " called by:\n"); for (const HloComputation* caller : node.callers()) { - Appendf(&out, " %s\n", caller->name().c_str()); + StrAppendFormat(&out, " %s\n", caller->name()); } - Appendf(&out, " callsites:\n"); + StrAppendFormat(&out, " callsites:\n"); for (const CallSite& callsite : node.callsites()) { - Appendf(&out, " %s\n", callsite.ToString().c_str()); + StrAppendFormat(&out, " %s\n", callsite.ToString()); } } return out; diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 97d3811508..3af2ab5edf 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -15,8 +15,8 @@ limitations under the License. // Call graph for an HLO module. -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ #include <ostream> @@ -272,4 +272,4 @@ class CallGraph { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CALL_GRAPH_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_GRAPH_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 256d05a73e..1d42140444 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -96,7 +96,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { if (it == subcomputation_hlo_to_new_hlo_.end()) { return NotFound( "Could not find mapping from subcomputation HLO %s to a cloned HLO.", - subcomputation_hlo->ToString().c_str()); + subcomputation_hlo->ToString()); } return it->second; } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index a8345a394d..c5cd88b9ea 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ #include <deque> @@ -35,11 +35,11 @@ class CallInliner : public HloPassInterface { static StatusOr<InlinedInstructionMap> Inline(HloInstruction* call); ~CallInliner() override = default; - tensorflow::StringPiece name() const override { return "CallInliner"; } + absl::string_view name() const override { return "CallInliner"; } StatusOr<bool> Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE__CALL_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CALL_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index ff968bca29..5d85a3f173 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace op = xla::testing::opcode_matchers; diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index 13008efed1..3c2d1ae6d8 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -73,20 +73,20 @@ ChannelHandle ChannelTracker::AllocateHandle(ChannelHandle::ChannelType type) { Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::HOST_TO_DEVICE) { return FailedPrecondition( "host-to-device channels cannot be used with a Send operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } if (channel.has_sender) { return FailedPrecondition( "when registering send, passed a channel handle that is already used " - "by a sender: %lld", + "by a sender: %d", handle.handle()); } channel.has_sender = true; @@ -95,13 +95,13 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (opaque_to_channel_.count(handle.handle()) == 0) { - return NotFound("channel handle not found: %lld", handle.handle()); + return NotFound("channel handle not found: %d", handle.handle()); } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.type == ChannelHandle::DEVICE_TO_HOST) { return FailedPrecondition( "device-to-host channels cannot be used with a Recv operation; " - "channel handle: %lld", + "channel handle: %d", handle.handle()); } @@ -109,7 +109,7 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { if (channel.receiver_count >= 1) { return FailedPrecondition( "when registering recv, passed a channel handle that is already used " - "by a receiver: %lld", + "by a receiver: %d", handle.handle()); } channel.receiver_count += 1; diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 7426672a7a..3079695e96 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -76,9 +76,9 @@ CompileOnlyService::CompileAheadOfTime( if (!directory_path.empty()) { HloSnapshot hlo_snapshot; *hlo_snapshot.mutable_hlo()->mutable_hlo_module() = instance.computation; - string filename = tensorflow::strings::StrCat( - "computation_", instance.computation.id(), "__", - instance.computation.entry_computation_name()); + string filename = + absl::StrCat("computation_", instance.computation.id(), "__", + instance.computation.entry_computation_name()); const string& per_host_path = tensorflow::io::JoinPath( directory_path, tensorflow::port::Hostname()); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 6b3b9820f0..687ecafe0c 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -101,7 +101,7 @@ Compiler::GetPlatformCompilers() { return NotFound( "could not find registered compiler for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } // And then we invoke the factory, placing the result into the mapping. diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc index cb61f3da39..af8f7f1027 100644 --- a/tensorflow/compiler/xla/service/computation_layout.cc +++ b/tensorflow/compiler/xla/service/computation_layout.cc @@ -17,9 +17,9 @@ limitations under the License. #include <algorithm> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -52,9 +52,8 @@ string ComputationLayout::ToString() const { for (auto& param_layout : parameter_layouts_) { params.push_back(param_layout.ToString()); } - return tensorflow::strings::StrCat("(", - tensorflow::str_util::Join(params, ", "), - ") => ", result_layout_.ToString()); + return absl::StrCat("(", absl::StrJoin(params, ", "), ") => ", + result_layout_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index 187ce568cb..2210a8578a 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -19,8 +19,9 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -29,12 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; namespace xla { @@ -60,8 +60,8 @@ DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { "computation_count=%d", proto.replica_count(), proto.computation_count()); } - auto assignment = MakeUnique<DeviceAssignment>(proto.replica_count(), - proto.computation_count()); + auto assignment = absl::make_unique<DeviceAssignment>( + proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); ++computation) { const auto& computation_device = proto.computation_devices(computation); @@ -132,7 +132,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices( return NotFound( "could not find registered computation placer for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.placer == nullptr) { @@ -156,7 +156,7 @@ ComputationPlacer::GetPlatformComputationPlacers() { } // namespace xla static std::unique_ptr<xla::ComputationPlacer> CreateComputationPlacer() { - return xla::MakeUnique<xla::ComputationPlacer>(); + return absl::make_unique<xla::ComputationPlacer>(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.cc b/tensorflow/compiler/xla/service/conditional_simplifier.cc index b7be3ba605..4ea3a13f28 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -28,8 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 063261e26d..3de50cbd7f 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONDITIONAL_SIMPLIFIER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -27,9 +27,7 @@ namespace xla { // with their true or false computation as appropriate. class ConditionalSimplifier : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "simplify-conditional"; - } + absl::string_view name() const override { return "simplify-conditional"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc index c43a31b167..6c477da038 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/conditional_simplifier_test.cc @@ -39,6 +39,10 @@ namespace op = xla::testing::opcode_matchers; class ConditionalSimplifierTest : public HloVerifiedTestBase { public: + ConditionalSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation that contains a conditional with constant predicate. HloComputation* MakeConditional(HloModule* module); }; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 45252fc1ee..9c81a86bbb 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -18,9 +18,9 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -214,7 +214,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { expanded_filter = add(HloInstruction::CreateConcatenate( expanded_filter_shape, concat_operands, input_feature_dim)); } - auto zero = add(HloInstruction::CreateConstant(MakeUnique<Literal>( + auto zero = add(HloInstruction::CreateConstant(absl::make_unique<Literal>( LiteralUtil::Zero(expanded_filter_shape.element_type())))); auto zero_filter = add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {})); @@ -224,6 +224,7 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, convolution->window(), dim_numbers, /*feature_group_count=*/1); + new_convolution->set_precision_config(convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index f213cc8709..498894737f 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CONVOLUTION_FEATURE_GROUP_CONVERTER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -29,7 +29,7 @@ class ConvolutionFeatureGroupConverter : public HloPassInterface { public: ConvolutionFeatureGroupConverter() {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-feature-group-converter"; } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 3e39c1bab1..1b7a7b36ea 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -31,18 +33,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { - -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { +using absl::StrAppend; + bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && @@ -381,7 +378,7 @@ class CopyRemover { } string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + string out = absl::StrCat("CopyRemover, module ", module_->name(), "\n"); StrAppend(&out, " Buffer values, in dependency order:\n"); for (const HloBuffer& buffer : alias_analysis_.buffers()) { StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); @@ -863,16 +860,16 @@ class CopyRemover { for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { values.push_back(p->value); } - return StrCat("{", - Join(values, ", ", - [](string* s, const HloValue* value) { - StrAppend(s, value->ToShortString()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); } string ToString() const { - string out = StrCat("BufferValueTracker:\n"); + string out = absl::StrCat("BufferValueTracker:\n"); StrAppend(&out, " Def-use chains in each buffer:\n"); for (const ValueNode* head : value_lists_) { StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), @@ -880,10 +877,10 @@ class CopyRemover { const ValueNode* p = head; do { StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), + absl::StrJoin(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), "\n"); p = p->next; @@ -960,16 +957,11 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { return Status::OK(); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// +Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) { + std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); + return AddSpecialCaseCopies(*call_graph, module); +} + Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, @@ -1065,15 +1057,6 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, for (HloInstruction* user : users) { TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); } - // Special case copies are not eligible for later copy elision passes. - indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) { - if (has_copy) { - HloInstruction* copy = *copies_added.mutable_element(index); - if (copy != nullptr) { - copy->SetCopyElisionAllowed(false); - } - } - }); if (instruction == instruction->parent()->root_instruction()) { instruction->parent()->set_root_instruction(deep_copy); } @@ -1081,10 +1064,10 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, return Status::OK(); } -Status CopyInsertion::VerifyNoLiveRangeInterference(HloModule* module) { +Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, HloAliasAnalysis::Run(module, fusion_can_share_buffer_)); - DependencyHloOrdering ordering(module); TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); return Status::OK(); } @@ -1101,8 +1084,7 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering, std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - instruction->CopyElisionAllowed()) { + if (instruction->opcode() == HloOpcode::kCopy) { TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); } } @@ -1168,10 +1150,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + DependencyHloOrdering dep_ordering(module); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module)); - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, module)); + TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module)); TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); @@ -1179,7 +1161,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + TF_DCHECK_OK( + VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module)); MaybeDumpModule("after copy insertion", *module); diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 5ba64b78a3..d308f6bc84 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -45,7 +45,7 @@ namespace xla { // InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } // fusion_can_share_buffer: backend specific function that decides whether a // fusion can share buffer with its operand. @@ -77,15 +77,29 @@ class CopyInsertion : public HloPassInterface { Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module); - private: - // Verifies that no HLO values have interfering live ranged assuming the - // ordering used by copy insertion. - Status VerifyNoLiveRangeInterference(HloModule* module); + // Add copies to address special constraints on the roots of computations not + // related to live range interference: + // + // (1) Entry computation root must be unambiguous and distinct. + // + // (2) Any computation called by a kCall instruction must have an + // unambiguous root. + // + // (3) Constants and parameters cannot be live out of the entry computation + // + Status AddSpecialCaseCopies(HloModule* module); - Status AddCopiesToResolveInterference(HloModule* module); + // Verifies that no HLO values have interfering live ranges using the given + // ordering. + Status VerifyNoLiveRangeInterference(const HloOrdering& ordering, + HloModule* module); + private: + // Override which requires the caller to pass in a call graph. Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module); + Status AddCopiesToResolveInterference(HloModule* module); + // Backend specific function that decides whether a fusion can share buffer // with its operand. HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_; diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index fe1ef78533..4cd192873f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -85,6 +86,9 @@ cc_library( ":ir_emitter", ":parallel_task_assignment", ":simple_orc_jit", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ":target_machine_features", "//tensorflow/compiler/tf2xla:cpu_function_runtime", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", @@ -178,6 +182,7 @@ cc_library( ":runtime_single_threaded_conv2d", ":runtime_single_threaded_fft", ":runtime_single_threaded_matmul", + "@com_google_absl//absl/memory", "@llvm//:execution_engine", "@llvm//:core", "@llvm//:mc", # fixdeps: keep @@ -229,6 +234,8 @@ cc_library( "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:orc_jit", ], ) @@ -271,11 +278,14 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", @@ -320,6 +330,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -330,12 +341,12 @@ cc_library( hdrs = ["parallel_loop_emitter.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -362,6 +373,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -382,6 +394,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", ], ) @@ -395,6 +408,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", @@ -418,6 +432,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", "@llvm//:analysis", "@llvm//:core", "@llvm//:ipo", @@ -634,6 +649,8 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//third_party/eigen3", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -648,6 +665,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -810,6 +828,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", ], ) @@ -846,6 +866,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -893,6 +914,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", "@llvm//:core", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 128eea4828..73b03440cb 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -35,7 +36,6 @@ limitations under the License. #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -205,7 +205,7 @@ void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { llvm::Triple target_triple(target_machine_->getTargetTriple()); auto target_library_info_impl = - MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple); + absl::make_unique<llvm::TargetLibraryInfoImpl>(target_triple); target_library_info_impl->addVectorizableFunctions( VectorFunctionsForTargetLibraryInfoImpl()); passes->add( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 0985b9297f..098ce17a56 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -132,6 +132,7 @@ StatusOr<bool> ConvCanonicalization::Run(HloModule* module) { HloInstruction* new_conv = module->entry_computation()->AddInstruction( HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); + new_conv->set_precision_config(hlo->precision_config()); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index e6fd1499ed..59437e88af 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -38,7 +38,7 @@ class ConvCanonicalization : public HloPassInterface { : target_machine_features_(*target_machine_features) {} ~ConvCanonicalization() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "convolution-canonicalization"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index fde8fbd486..6420180b13 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -26,6 +26,8 @@ limitations under the License. // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc" // IWYU pragma: no_include "llvm/Config/Targets.def.inc" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" @@ -42,7 +44,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batch_dot_simplification.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" @@ -101,8 +102,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { @@ -235,15 +234,15 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map<const HloInstruction*, int64>* hlo_to_profile_idx_; const std::unordered_map<const HloInstruction*, int64>& assigned_indices_; }; -} // namespace -Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, - llvm::TargetMachine* target_machine) { - LLVMTargetMachineFeatures target_machine_features(target_machine); +} // namespace - // Optimization pipeline. - HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker<HloVerifier>(); +Status CpuCompiler::RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes through layout assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<CpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( @@ -260,11 +259,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass<BatchDotSimplification>(); pipeline.AddPass<DotDecomposer>(); pipeline.AddPass<ConvolutionFeatureGroupConverter>(); - pipeline.AddPass<ConvCanonicalization>(&target_machine_features); + pipeline.AddPass<ConvCanonicalization>(target_machine_features); { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pass.AddPass<BatchNormExpander>( /*rewrite_training_op=*/true, @@ -291,10 +291,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, } pipeline.AddPass<IndexedArrayAnalysisPrinterPass>(); pipeline.AddPass<TransposeFolding>( - [&target_machine_features]( - const HloInstruction& dot, + [&](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { - return PotentiallyImplementedAsEigenDot(dot, target_machine_features) + return PotentiallyImplementedAsEigenDot(dot, *target_machine_features) ? candidate_operands : TransposeFolding::OperandIndices{}; }, @@ -309,12 +308,28 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass<CpuLayoutAssignment>( - module->mutable_entry_computation_layout(), &target_machine_features); + module->mutable_entry_computation_layout(), target_machine_features); + return pipeline.Run(module).status(); +} + +Status CpuCompiler::RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features) { + HloPassPipeline pipeline("HLO passes after layout assignment"); + // After layout assignment, use a layout-sensitive verifier. + auto& after_layout_assn = + pipeline.AddPass<HloPassPipeline>("after layout assignment"); + after_layout_assn.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>( - "after layout assignement"); + "simplification after layout assignement"); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass<HloPassFix<AlgebraicSimplifier>>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, @@ -322,7 +337,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass<HloDCE>(); pass.AddPass<HloCSE>(/*is_layout_sensitive=*/true); } + pipeline.AddPass<HloElementTypeConverter>(BF16, F32); + // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -335,14 +352,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, // binary size (and most AOT applications are single-threaded). // TODO(b/29630486) Support multi-threaded AOT. pipeline.AddPass<ParallelTaskAssigner>( - max_parallelism, ShapeSizeBytesFunction(), &target_machine_features); + max_parallelism, ShapeSizeBytesFunction(), target_machine_features); } - // Copy insertion should be performed immediately before IR emission to avoid - // inserting unnecessary copies (later pass adds an instruction which - // materializes the value) or missing a necessary copy (later pass removes an - // instruction which materializes a value). DCE must be run immediately before - // (and sometime after) copy insertion, to avoid dead code from interfering - // with the rewrites. + // Copy insertion should be performed immediately before IR emission to + // avoid inserting unnecessary copies (later pass adds an instruction which + // materializes the value) or missing a necessary copy (later pass removes + // an instruction which materializes a value). DCE must be run immediately + // before (and sometime after) copy insertion, to avoid dead code from + // interfering with the rewrites. pipeline.AddPass<HloDCE>(); pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<CpuCopyInsertion>(); @@ -350,6 +367,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, return pipeline.Run(module).status(); } +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, + llvm::TargetMachine* target_machine) { + LLVMTargetMachineFeatures target_machine_features(target_machine); + TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(module, is_aot_compile, + &target_machine_features)); + return RunHloPassesAfterLayoutAssn(module, is_aot_compile, + &target_machine_features); +} + namespace { // Align buffers to 16-byte boundaries. @@ -453,7 +479,7 @@ Status CreateHloProfilingArtifacts( computation_to_profile_idx, std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map, std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) { - *hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(module); + *hlo_profile_index_map = absl::make_unique<HloProfileIndexMap>(module); const HloComputation& entry_computation = *module.entry_computation(); TF_ASSIGN_OR_RETURN( @@ -520,11 +546,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = xla::MakeUnique<llvm::LLVMContext>(); + auto llvm_context = absl::make_unique<llvm::LLVMContext>(); auto llvm_module = - xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context); + absl::make_unique<llvm::Module>("__compute_module", *llvm_context); - auto jit = xla::MakeUnique<SimpleOrcJIT>( + auto jit = absl::make_unique<SimpleOrcJIT>( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -566,12 +592,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr<BufferAssignment> assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module.get(), + absl::make_unique<SequentialHloOrdering>( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -679,8 +705,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, const llvm::Target* target = llvm::TargetRegistry::lookupTarget(triple.getTriple(), error); if (target == nullptr) { - return InternalError("TargetRegistry::lookupTarget failed: %s", - error.c_str()); + return InternalError("TargetRegistry::lookupTarget failed: %s", error); } llvm::Reloc::Model reloc_model = llvm::Reloc::Static; @@ -716,7 +741,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name()); llvm::StringRef features = llvm_ir::AsStringRef(options.features()); llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config()); - std::unique_ptr<llvm::TargetMachine> target_machine = WrapUnique( + std::unique_ptr<llvm::TargetMachine> target_machine = absl::WrapUnique( target->createTargetMachine(triple.getTriple(), cpu_name, features, CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None, opt_level)); @@ -757,7 +782,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, std::unique_ptr<BufferAssignment> assignment, BufferAssigner::Run( module, - xla::MakeUnique<SequentialHloOrdering>(module, module_sequence), + absl::make_unique<SequentialHloOrdering>(module, module_sequence), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -851,7 +876,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, assignment->GetUniqueTopLevelOutputSlice()); - results.emplace_back(MakeUnique<CpuAotCompilationResult>( + results.emplace_back(absl::make_unique<CpuAotCompilationResult>( std::move(object_file_data), std::move(buffer_infos), result_slice.index(), std::move(hlo_profile_printer_data))); } @@ -874,7 +899,7 @@ HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::host::kHostPlatformId, - []() { return xla::MakeUnique<xla::cpu::CpuCompiler>(); }); + []() { return absl::make_unique<xla::cpu::CpuCompiler>(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 04e1c48872..47b5edabff 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/tf2xla/cpu_function_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" @@ -157,6 +158,16 @@ class CpuCompiler : public LLVMCompiler { Status RunHloPasses(HloModule* module, bool is_aot_compile, llvm::TargetMachine* target_machine); + // Runs HLO passes up to and including layout assignment. + Status RunHloPassesThroughLayoutAssn( + HloModule* module, bool /*is_aot_compile*/, + LLVMTargetMachineFeatures* target_machine_features); + + // Runs HLO passes after layout assignment. + Status RunHloPassesAfterLayoutAssn( + HloModule* module, bool is_aot_compile, + LLVMTargetMachineFeatures* target_machine_features); + TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index 3313d1e6eb..d49f7d7cc2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,11 +32,11 @@ namespace xla { // (module-scoped). class CpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr<bool> Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index c376864c3e..08773693fb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -22,6 +22,9 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/computation_layout.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -171,20 +171,18 @@ Status CpuExecutable::ExecuteComputeFunction( void* result_buffer = buffer_pointers[result_slice.index()]; if (VLOG_IS_ON(3)) { VLOG(3) << "Executing compute function:"; - VLOG(3) << tensorflow::strings::Printf( - " func(void* result, void* params[null], void* temps[%zu], " - "uint64 profile_counters[%zu])", + VLOG(3) << absl::StrFormat( + " func(void* result, void* params[null], void* temps[%u], " + "uint64 profile_counters[%u])", buffer_pointers.size(), profile_counters_size); - VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); + VLOG(3) << absl::StrFormat(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { - tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); + absl::StrAppend(out, absl::StrFormat("%p", p)); }; VLOG(3) << " params = nullptr"; - VLOG(3) << tensorflow::strings::Printf( - " temps = [%s]", - tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); - VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters); + VLOG(3) << absl::StrFormat( + " temps = [%s]", absl::StrJoin(buffer_pointers, ", ", ptr_printer)); + VLOG(3) << absl::StrFormat(" profile_counters = %p", profile_counters); } compute_function_(result_buffer, run_options, nullptr, buffer_pointers.data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc index 7bd4741a04..7fbe0fa157 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr<bool> CpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "CPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 2924b63659..6af724b2a5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class CpuHloSupportChecker : public HloPassInterface { CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "cpu_hlo_support_checker"; - } + absl::string_view name() const override { return "cpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index b40d264c03..7f867fa149 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -78,7 +78,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (!CanBeLoopFused(*producer)) { - VLOG(2) << "Producer is not fusile."; + VLOG(2) << "Producer is not fusible."; return false; } @@ -140,7 +140,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (CanBeLoopFused(*consumer)) { - VLOG(2) << "Fusing: consumer is elementwise or fusile."; + VLOG(2) << "Fusing: consumer is elementwise or fusible."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index e6130c7d76..28aaa28cdb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <set> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" @@ -566,7 +567,7 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } -TEST_F(OpcodeFusionTest, MessOfFusileNodes) { +TEST_F(OpcodeFusionTest, MessOfFusibleNodes) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); @@ -773,8 +774,8 @@ class GatherLoopFusionTest TEST_P(GatherLoopFusionTest, GatherLoopFusion) { const GatherLoopFusionTestSpec& spec = GetParam(); - string hlo_string = tensorflow::strings::StrCat( - "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); + string hlo_string = absl::StrCat("HloModule ", spec.test_name, "\n\n", + spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, ParseHloString(hlo_string)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index aa872d5ec9..bfecbd6e01 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -34,8 +34,8 @@ namespace cpu { // instruction stream. namespace { -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; +using absl::nullopt; +using absl::optional; using ShouldMakeOperandColMajorCache = tensorflow::gtl::FlatMap<const HloInstruction*, bool>; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index 3ed7876715..b8ace57026 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -45,17 +46,16 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } -tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor( - const HloModuleConfig& config) { +absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrDotTilingFactor); int64 tiling_factor; if (it != extra_options_map.end() && - tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + absl::SimpleAtoi(it->second, &tiling_factor)) { return tiling_factor; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { @@ -64,38 +64,37 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } -static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, - tensorflow::StringPiece suffix) { +static absl::string_view RemoveSuffix(absl::string_view str, + absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); return str.substr(0, str.size() - suffix.size()); } -tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( +absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( const HloModuleConfig& config) { const auto& extra_options_map = config.debug_options().xla_backend_extra_options(); auto it = extra_options_map.find(kLlvmIrGemmTileSize); if (it == extra_options_map.end()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - std::vector<string> tile_components = - tensorflow::str_util::Split(it->second, ':'); + std::vector<string> tile_components = absl::StrSplit(it->second, ':'); CHECK_EQ(tile_components.size(), 3); int64 tile_size_m; int64 tile_size_k; int64 tile_size_n_in_vector_width; - CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); - CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + CHECK(absl::SimpleAtoi(tile_components[0], &tile_size_m)); + CHECK(absl::SimpleAtoi(tile_components[1], &tile_size_k)); - tensorflow::StringPiece tile_size_n_in_vector_width_str = + absl::string_view tile_size_n_in_vector_width_str = RemoveSuffix(tile_components[2], "*vectwidth"); - CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, - &tile_size_n_in_vector_width)); + CHECK(absl::SimpleAtoi(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); return std::tuple<int64, int64, int64>(tile_size_m, tile_size_k, tile_size_n_in_vector_width); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 429b9e16cb..47c7eb13b6 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,9 +27,8 @@ namespace options { bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); -tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor( - const HloModuleConfig& config); -tensorflow::gtl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( +absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config); +absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize( const HloModuleConfig& config); } // namespace options diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index 2ac950e6d9..1ae3aa5711 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -19,16 +19,16 @@ limitations under the License. #include <string> #include <tuple> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -46,7 +46,7 @@ std::unique_ptr<Array2D<float>> MaybeTransposeArray2D(const Array2D<T>& array, if (transpose) { std::swap(output_width, output_height); } - auto output = MakeUnique<Array2D<float>>(output_height, output_width); + auto output = absl::make_unique<Array2D<float>>(output_height, output_width); for (int y = 0; y < array.height(); y++) { for (int x = 0; x < array.width(); x++) { if (transpose) { @@ -93,7 +93,7 @@ std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a, // Since we're going to transpose c before returning it. Swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique<Array2D<float>>(n, m); + auto c_transpose = absl::make_unique<Array2D<float>>(n, m); if (single_threaded) { __xla_cpu_runtime_EigenSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), @@ -142,10 +142,10 @@ class EigenMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "EigenMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("EigenMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -178,10 +178,10 @@ class MKLMatMulTest : public CpuRuntimeTest, bool transpose_rhs = std::get<2>(info.param); bool single_threaded = std::get<3>(info.param); - return tensorflow::strings::Printf( - "MKLMatMul_%lld_%lld_%lld_%s%s%s_threaded", shape.m, shape.k, shape.n, - transpose_lhs ? "Tlhs_" : "", transpose_rhs ? "Trhs_" : "", - single_threaded ? "single" : "multi"); + return absl::StrFormat("MKLMatMul_%d_%d_%d_%s%s%s_threaded", shape.m, + shape.k, shape.n, transpose_lhs ? "Tlhs_" : "", + transpose_rhs ? "Trhs_" : "", + single_threaded ? "single" : "multi"); } }; @@ -204,7 +204,7 @@ std::unique_ptr<Array2D<float>> MKLMatrixMultiply(const Array2D<float>& a, // Since we're going to transpose c before returning it, swap the order of the // dimension sizes to ensure the returned array is properly dimensioned. - auto c_transpose = MakeUnique<Array2D<float>>(n, m); + auto c_transpose = absl::make_unique<Array2D<float>>(n, m); if (single_threaded) { __xla_cpu_runtime_MKLSingleThreadedMatMulF32( nullptr, c_transpose->data(), a_transpose->data(), b_transpose->data(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 59bc7e0e16..0df2abf001 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" @@ -103,7 +104,7 @@ Status CpuTransferManager::TransferLiteralToInfeed( if (ShapeUtil::IsNestedTuple(shape)) { return Unimplemented( "Infeed with a nested tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + ShapeUtil::HumanString(literal.shape())); } // For a tuple, we transfer each of its elements to the device and @@ -151,11 +152,11 @@ CpuTransferManager::TransferBufferToInfeedInternal(se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits<int32>::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Infeed shape must have positive size; got %lld", + return InvalidArgument("Infeed shape must have positive size; got %d", size); } @@ -243,12 +244,12 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal( for (auto b : buffer_data) { int64 size = b.second; if (size > std::numeric_limits<int32>::max()) { - return InvalidArgument("Outfeed shape is too large: needs %lld bytes", + return InvalidArgument("Outfeed shape is too large: needs %d bytes", size); } if (size <= 0) { - return InvalidArgument("Outfeed shape must have positive size; got %lld", + return InvalidArgument("Outfeed shape must have positive size; got %d", size); } @@ -256,7 +257,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal( VLOG(2) << "Enqueueing outfeed buffer (for the device to populate) of length " << size_32 << "B"; - buffers.emplace_back(MakeUnique<CpuOutfeedBuffer>(b.first, size_32)); + buffers.emplace_back(absl::make_unique<CpuOutfeedBuffer>(b.first, size_32)); } std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers; @@ -283,7 +284,7 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal( } // namespace xla static std::unique_ptr<xla::TransferManager> CreateCpuTransferManager() { - return xla::MakeUnique<xla::CpuTransferManager>(); + return absl::make_unique<xla::CpuTransferManager>(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h index 80ef953d53..7b938e9fd7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ #include <vector> @@ -76,4 +76,4 @@ class CpuTransferManager : public GenericTransferManager { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc index e4c674e227..3ae64142cd 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.cc +++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc @@ -21,13 +21,13 @@ limitations under the License. #include <type_traits> #include <vector> +#include "absl/strings/str_format.h" #include "llvm/MC/MCInst.h" #include "llvm/Support/TargetRegistry.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -151,7 +151,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile( size = 1; } - ostream << tensorflow::strings::Printf("0x%08lx", index) << " "; + ostream << absl::StrFormat("0x%08lx", index) << " "; if (decode_status == llvm::MCDisassembler::Success) { // For branches, try to determine the actual address and emit it as an @@ -163,7 +163,7 @@ StatusOr<DisassemblerResult> Disassembler::DisassembleObjectFile( uint64_t target; if (inst_analysis_->evaluateBranch( instruction, section_address + index, size, target)) { - annotation = tensorflow::strings::Printf("[0x%08lx]", target); + annotation = absl::StrFormat("[0x%08lx]", target); } } inst_printer_->printInst(&instruction, ostream, annotation.c_str(), diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index f2ac742b6e..dd060f54a2 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -146,9 +147,9 @@ class GemvConfig { bool has_addend() const { return has_addend_; } string GetCacheKey() const { - return tensorflow::strings::StrCat( - name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", - tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + return absl::StrCat(name_, "_", PrimitiveType_Name(scalar_type()), "_", + tile_rows(), "_", tile_cols(), "_", m(), "_", k(), + has_addend() ? "_with_addend" : ""); } protected: @@ -621,19 +622,19 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // This class implements a tiled matrix multiplication algorithm, intended for -// use as the innermost GEBP loop in a GEMM kernel (GEBP is described in "Goto, -// Kazushige, and Robert Van De Geijn. "High-performance implementation of the -// level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 (2008): -// 4). +// multiplying small matrices that don't need cache tiling. +// +// In the future this can be used as the innermost GEBP loop in a GEMM kernel as +// described in "Goto, Kazushige, and Robert A. Geijn. "Anatomy of +// high-performance matrix multiplication." ACM Transactions on Mathematical +// Software (TOMS) 34.3 (2008): 12.". // // This only supports canonical dot operations (i.e. where the lhs contraction // dimension is 1 and the rhs contraction dimension is 0) over row major // matrices. -class MatrixMatrixBlockPanelEmitter { +class TiledSmallGemmEmitter { public: - // Describe the dimensions of the GEBP kernel. These will usually not be the - // dimensions of the GEMM itself, the GEMM will usually be broken up into GEBP - // kernels with smaller dimensions. + // Describe the dimensions of the kernel. class Dimensions { public: explicit Dimensions(int64 m, int64 k, int64 n) : m_(m), k_(k), n_(n) {} @@ -642,9 +643,7 @@ class MatrixMatrixBlockPanelEmitter { int64 k() const { return k_; } int64 n() const { return n_; } - string ToString() const { - return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); - } + string ToString() const { return absl::StrCat(m(), "x", k(), "x", n()); } private: const int64 m_; @@ -652,9 +651,9 @@ class MatrixMatrixBlockPanelEmitter { const int64 n_; }; - // Represents the configuration of the GEBP emitter. The LLVM IR emitted by - // the emitter, modulo the LLVM values holding the input and output buffers, - // must be a function of the instance of `Config` passed to it. + // Represents the configuration of the emitter. The LLVM IR emitted by the + // emitter, modulo the LLVM values holding the input and output buffers, must + // be a function of the instance of `Config` passed to it. // // `dims` holds the matrix multiplication dimensions. // @@ -687,10 +686,10 @@ class MatrixMatrixBlockPanelEmitter { tile_size_k_(tile_size_k) {} string GetCacheKey() const { - return tensorflow::strings::StrCat( - "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), - "_", max_vectorization_width(), "_", min_vectorization_width(), "_", - tile_size_m(), "_", tile_size_k()); + return absl::StrCat("gemm_", PrimitiveType_Name(scalar_type()), "_", + dims().ToString(), "_", max_vectorization_width(), + "_", min_vectorization_width(), "_", tile_size_m(), + "_", tile_size_k()); } PrimitiveType scalar_type() const { return scalar_type_; } @@ -712,11 +711,11 @@ class MatrixMatrixBlockPanelEmitter { int64 tile_size_k_; }; - // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // Creates an instance of TiledSmallGemmEmitter that matrix-multiplies // `lhs` with `rhs` and stores the result in `result`. - explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, - llvm::IRBuilder<>* b) + explicit TiledSmallGemmEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* b) : lhs_(lhs), rhs_(rhs), result_(result), @@ -780,9 +779,9 @@ class MatrixMatrixBlockPanelEmitter { KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } +void TiledSmallGemmEmitter::Emit() { HandleResiduesOnN(); } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { +void TiledSmallGemmEmitter::HandleResiduesOnN() { // We can only iterate the `n` dimension for an extent that is divisible by // the vectorization width. So we emit an outer loop that first processes the // largest extent in `n` that is divisible by max_vectorization_width, then @@ -799,7 +798,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, b_, - "gebp"); + "gemm"); HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } @@ -813,7 +812,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } if (n_start != dims().n()) { - VectorSupportLibrary vsl(scalar_type(), 1, b_, "gebp"); + VectorSupportLibrary vsl(scalar_type(), 1, b_, "gemm"); ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = b_->CreateAdd(n_i, b_->getInt64(1)); HandleResiduesOnK(&vsl, n_i, n_i_next); @@ -821,9 +820,9 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { +void TiledSmallGemmEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { int64 k_start = 0; int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { @@ -838,7 +837,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, } } -void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( +void TiledSmallGemmEmitter::HandleResiduesOnM( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { const int64 m_end = dims().m() - dims().m() % tile_size_m(); @@ -921,7 +920,7 @@ void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( // +-------------------+-------------------+-------------------+--------- // | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... // +-------------------+-------------------+-------------------+--------- -void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( +void TiledSmallGemmEmitter::EmitTiledGemm( VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { @@ -1001,12 +1000,22 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, return dot_emitter.Emit(); } -bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( +bool DotOpEmitter::EmitSmallGemmIfProfitable( const DotOpEmitter::MatMultDims& mat_mult_dims) { - if (!EnableExperimentalLlvmIrGemm() || ShouldUseMultiThreadedEigen()) { + if (ShouldUseMultiThreadedEigen()) { return false; } + if (!EnableExperimentalLlvmIrGemm()) { + // TODO(sanjoy): We should make these numbers micro-arch specific. + bool small_gemm = mat_mult_dims.k <= 128 && + ((mat_mult_dims.m <= 32 && mat_mult_dims.n <= 128) || + (mat_mult_dims.m <= 128 && mat_mult_dims.n <= 32)); + if (!small_gemm) { + return false; + } + } + if (mat_mult_dims.lhs_non_canonical || mat_mult_dims.rhs_non_canonical) { return false; } @@ -1054,15 +1063,15 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = GetGemmTileSize(); - MatrixMatrixBlockPanelEmitter::Config config( + TiledSmallGemmEmitter::Config config( /*scalar_type=*/primitive_type, - MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + TiledSmallGemmEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, /*max_vectorization_width=*/max_target_vector_width, /*max_vector_count=*/tile_size_n_in_vector_width, /*min_vectorization_width=*/std::min<int64>(4, max_target_vector_width), /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); - VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + VLOG(2) << "Emitting GEMM kernel in LLVM IR with config " << config.GetCacheKey(); const bool enable_fast_math = @@ -1075,10 +1084,10 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( /*optimize_for_size=*/optimize_for_size, b_, config.GetCacheKey(), lhs, rhs, target, [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { - MatrixMatrixBlockPanelEmitter gebp_emitter(config, /*lhs=*/lhs, - /*rhs=*/rhs, - /*result=*/target, b_); - gebp_emitter.Emit(); + TiledSmallGemmEmitter small_gemm_emitter(config, /*lhs=*/lhs, + /*rhs=*/rhs, + /*result=*/target, b_); + small_gemm_emitter.Emit(); }); return true; @@ -1136,7 +1145,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { } if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { - return EmitExperimentalGebpDotIfEnabled(mat_mult_dims); + return EmitSmallGemmIfProfitable(mat_mult_dims); } int64 tiling_factor = GetGemvTilingFactor(); @@ -1458,7 +1467,7 @@ Status DotOpEmitter::EmitCallToRuntime() { break; default: return Unimplemented("Invalid type %s for dot operation", - PrimitiveType_Name(type).c_str()); + PrimitiveType_Name(type)); } llvm::Type* float_ptr_type = float_type->getPointerTo(); @@ -1610,7 +1619,7 @@ bool PotentiallyImplementedAsEigenDot( // For vector-matrix dot products, it is always profitable to make the Rhs // column major. -tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor( +absl::optional<int64> ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index 590032fbe9..4c2041b556 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -38,7 +38,7 @@ bool PotentiallyImplementedAsEigenDot( // Returns the index for an operand to `hlo` that should ideally be column // major. Returns nullopt if there is no such operand or if `hlo` is not a dot // or a fusion containing a dot. -tensorflow::gtl::optional<int64> ProfitableToMakeDotOperandColumnMajor( +absl::optional<int64> ProfitableToMakeDotOperandColumnMajor( const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation @@ -121,7 +121,7 @@ class DotOpEmitter { // of rank 2 as well). MatMultDims GetMatMultDims() const; - bool EmitExperimentalGebpDotIfEnabled(const MatMultDims& mat_mult_dims); + bool EmitSmallGemmIfProfitable(const MatMultDims& mat_mult_dims); // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector // registers. diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index db54454707..c8312d80bd 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -30,15 +30,16 @@ limitations under the License. namespace xla { namespace cpu { -StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { string function_name; bool cast_result_to_fp16 = false; switch (prim_type) { case F16: cast_result_to_fp16 = true; - lhs = b_->CreateFPCast(lhs, b_->getFloatTy()); - rhs = b_->CreateFPCast(rhs, b_->getFloatTy()); + lhs = FPCast(lhs, b_->getFloatTy()); + rhs = FPCast(rhs, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "atan2f"; @@ -58,21 +59,21 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitAtan2( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, {lhs, rhs}); + llvm::Value* result = Call(function, {lhs, rhs}); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } -StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { bool cast_result_to_fp16 = false; string function_name; switch (prim_type) { case F16: cast_result_to_fp16 = true; - value = b_->CreateFPCast(value, b_->getFloatTy()); + value = FPCast(value, b_->getFloatTy()); TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; @@ -91,16 +92,16 @@ StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitTanh( function->setDoesNotThrow(); function->setDoesNotAccessMemory(); // Create an instruction to call the function. - llvm::Value* result = b_->CreateCall(function, value); + llvm::Value* result = Call(function, value); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { if (hlo->opcode() == HloOpcode::kMap) { return [this, hlo, &operand_to_generator]( const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 76833e765d..e3fba9306b 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -36,13 +36,13 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 6f433b4f30..460363e18f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" @@ -67,8 +69,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -170,9 +170,9 @@ IrEmitter::~IrEmitter() {} Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { VLOG(2) << "HandleBitcast: " << bitcast->ToString(); emitted_value_[bitcast] = - b_.CreateBitCast(GetEmittedValueFor(bitcast->operand(0)), - IrShapeType(bitcast->shape())->getPointerTo(), - AsStringRef(IrName(bitcast))); + BitCast(GetEmittedValueFor(bitcast->operand(0)), + IrShapeType(bitcast->shape())->getPointerTo(), + AsStringRef(IrName(bitcast))); return Status::OK(); } @@ -230,9 +230,8 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { // Use the elemental emitter for array shapes. return DefaultAction(copy); } - return Unimplemented( - "unsupported operand type %s for copy instruction", - PrimitiveType_Name(copy->shape().element_type()).c_str()); + return Unimplemented("unsupported operand type %s for copy instruction", + PrimitiveType_Name(copy->shape().element_type())); } // Calculate the alignment of a buffer allocated for a given primitive type. @@ -389,7 +388,7 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, int64 length = ByteSizeOf(shape); if (length <= 0 || length > std::numeric_limits<int32>::max()) { return InvalidArgument( - "xfeed (infeed or outfeed) buffer length %lld is outside the valid " + "xfeed (infeed or outfeed) buffer length %d is outside the valid " "size range", length); } @@ -440,22 +439,22 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = b_.CreateCall( - acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = + Call(acquire_func, + {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. - b_.CreateMemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, - /*SrcAlign=*/1, length_32); + MemCpy(program_buffer_address, /*DstAlign=*/1, acquired_pointer, + /*SrcAlign=*/1, length_32); } else { // Outfeed -- copy from the in-program address to the acquired buffer. - b_.CreateMemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, - /*SrcAlign=*/1, length_32); + MemCpy(acquired_pointer, /*DstAlign=*/1, program_buffer_address, + /*SrcAlign=*/1, length_32); } - b_.CreateCall(release_func, {b_.getInt32(length_32), acquired_pointer, - shape_ptr, b_.getInt32(shape_length)}); + Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr, + b_.getInt32(shape_length)}); return Status::OK(); } @@ -502,7 +501,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { llvm::Value* IrEmitter::EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name) { + absl::string_view name) { return EmitThreadLocalCall(*map_instr.to_apply(), elemental_operands, name); } @@ -519,8 +518,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &b_, MinimumAlignmentForPrimitiveType(operand_element_type)); - b_.CreateStore(b_.CreateLoad(GetEmittedValueFor(reduce_window->operand(1))), - accumulator_address); + Store(Load(GetEmittedValueFor(reduce_window->operand(1))), + accumulator_address); llvm_ir::ForLoopNest loops(IrName(reduce_window, "inner"), &b_); std::vector<int64> window_size; @@ -537,22 +536,21 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm::Value* in_bounds_condition = nullptr; for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = - b_.CreateNSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); + input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to // input_index[i] < bound as an *unsigned* comparison, since a negative // value will wrap to a large positive value. - llvm::Value* index_condition = b_.CreateICmpULT( - input_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + llvm::Value* index_condition = + ICmpULT(input_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); if (in_bounds_condition == nullptr) { in_bounds_condition = index_condition; } else { - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } } CHECK(in_bounds_condition != nullptr); @@ -565,12 +563,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value = input_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce_window->to_apply(), - {b_.CreateLoad(accumulator_address), input_value}, "reducer_function"); - b_.CreateStore(result, accumulator_address); + *reduce_window->to_apply(), {Load(accumulator_address), input_value}, + "reducer_function"); + Store(result, accumulator_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_address); + return Load(accumulator_address); } Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { @@ -647,7 +645,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { select_and_scatter, /*desc=*/IrName(select_and_scatter, "init"), [this, init_value](const llvm_ir::IrArray::Index& target_index) { llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - return b_.CreateLoad(init_value_addr); + return Load(init_value_addr); })); // Create a loop to iterate over the source array to scatter to the output. @@ -667,7 +665,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { b_.getInt64Ty(), b_.getInt32(rank), "selected_index_address", &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "window"), &b_); @@ -685,15 +683,14 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index operand_index(b_.getInt64Ty(), source_index.size()); llvm::Value* in_bounds_condition = b_.getTrue(); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( - source_index[i], b_.getInt64(window.dimensions(i).stride())); - operand_index[i] = - b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( - operand_index[i], - b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + llvm::Value* strided_index = + NSWMul(source_index[i], b_.getInt64(window.dimensions(i).stride())); + operand_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), + b_.getInt64(window.dimensions(i).padding_low())); + llvm::Value* index_condition = + ICmpULT(operand_index[i], + b_.getInt64(ShapeUtil::GetDimension(operand->shape(), i))); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -703,7 +700,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -712,38 +709,37 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { [&](const llvm_ir::IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to potentially // update the selected value and index with the currently visiting operand. SetToFirstInsertPoint(if_initialized.true_block, &b_); llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &b_); - llvm::Value* operand_element = b_.CreateLoad(operand_address); + llvm::Value* operand_element = Load(operand_address); llvm::Value* result = EmitThreadLocalCall( *select_and_scatter->select(), - {b_.CreateLoad(selected_value_address), operand_element}, - "select_function"); + {Load(selected_value_address), operand_element}, "select_function"); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -754,8 +750,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { llvm_ir::IrArray::Index selected_index(source_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value = @@ -837,7 +833,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( lhs_llvm_type, "convolution_sum_address", &b_, MinimumAlignmentForPrimitiveType(lhs_element_type)); llvm::Value* constant_zero = llvm::Constant::getNullValue(lhs_llvm_type); - b_.CreateStore(constant_zero, sum_address); + Store(constant_zero, sum_address); llvm_ir::ForLoopNest loops(IrName(convolution, "inner"), &b_); std::vector<llvm::Value*> kernel_spatial(num_spatial_dims); @@ -846,7 +842,7 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( loops .AddLoop( 0, rhs->shape().dimensions(dnums.kernel_spatial_dimensions(i)), - tensorflow::strings::StrCat("k", i)) + absl::StrCat("k", i)) ->GetIndVarValue(); } llvm::Value* input_feature = @@ -864,11 +860,11 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::Value* kernel_index, const WindowDimension& window_dim) { llvm::Value* strided_index = - b_.CreateNSWMul(output_index, b_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = b_.CreateNSWMul( - kernel_index, b_.getInt64(window_dim.window_dilation())); - return b_.CreateNSWSub(b_.CreateNSWAdd(strided_index, dilated_kernel_index), - b_.getInt64(window_dim.padding_low())); + NSWMul(output_index, b_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = + NSWMul(kernel_index, b_.getInt64(window_dim.window_dilation())); + return NSWSub(NSWAdd(strided_index, dilated_kernel_index), + b_.getInt64(window_dim.padding_low())); }; std::vector<llvm::Value*> input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -885,9 +881,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( // Also need to check that the input coordinates are not in one of the // holes created by base dilation. const auto not_in_hole = [&](llvm::Value* input_index, int64 base_dilation) { - llvm::Value* remainder = - b_.CreateSRem(input_index, b_.getInt64(base_dilation)); - return b_.CreateICmpEQ(remainder, b_.getInt64(0)); + llvm::Value* remainder = SRem(input_index, b_.getInt64(base_dilation)); + return ICmpEQ(remainder, b_.getInt64(0)); }; llvm::Value* in_bounds_condition = b_.getInt1(true); @@ -895,17 +890,17 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm::ConstantInt* input_bound = b_.getInt64(window_util::DilatedBound( lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); - llvm::Value* dim_in_bound = b_.CreateICmpULT(input_spatial[i], input_bound); + llvm::Value* dim_in_bound = ICmpULT(input_spatial[i], input_bound); llvm::Value* dim_not_in_hole = not_in_hole(input_spatial[i], window.dimensions(i).base_dilation()); - llvm::Value* dim_ok = b_.CreateAnd(dim_in_bound, dim_not_in_hole); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, dim_ok); + llvm::Value* dim_ok = And(dim_in_bound, dim_not_in_hole); + in_bounds_condition = And(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual // data indices on the lhs. const auto undilate = [&](llvm::Value* input_index, int64 base_dilation) { - return b_.CreateSDiv(input_index, b_.getInt64(base_dilation)); + return SDiv(input_index, b_.getInt64(base_dilation)); }; for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = @@ -930,8 +925,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = window.dimensions(i).window_reversal() - ? b_.CreateNSWSub(b_.getInt64(window.dimensions(i).size() - 1), - kernel_spatial[i]) + ? NSWSub(b_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) : kernel_spatial[i]; } @@ -940,13 +935,13 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForConvolution( llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = - b_.CreateFMul(input_array.EmitReadArrayElement(input_index, &b_), - kernel_array.EmitReadArrayElement(kernel_index, &b_)); - llvm::Value* sum = b_.CreateFAdd(b_.CreateLoad(sum_address), product); - b_.CreateStore(sum, sum_address); + FMul(input_array.EmitReadArrayElement(input_index, &b_), + kernel_array.EmitReadArrayElement(kernel_index, &b_)); + llvm::Value* sum = FAdd(Load(sum_address), product); + Store(sum, sum_address); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(sum_address); + return Load(sum_address); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1072,34 +1067,32 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func->setCallingConv(llvm::CallingConv::C); conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); - b_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(convolution), ir_ptr_type), - b_.CreateBitCast(lhs_address, ir_ptr_type), - b_.CreateBitCast(rhs_address, ir_ptr_type), - b_.getInt64(input_batch), - b_.getInt64(input_rows), - b_.getInt64(input_cols), - b_.getInt64(input_channels), - b_.getInt64(kernel_rows), - b_.getInt64(kernel_cols), - b_.getInt64(kernel_channels), - b_.getInt64(kernel_filters), - b_.getInt64(output_rows), - b_.getInt64(output_cols), - b_.getInt64(row_stride), - b_.getInt64(col_stride), - b_.getInt64(padding_top), - b_.getInt64(padding_bottom), - b_.getInt64(padding_left), - b_.getInt64(padding_right), - b_.getInt64(lhs_row_dilation), - b_.getInt64(lhs_col_dilation), - b_.getInt64(rhs_row_dilation), - b_.getInt64(rhs_col_dilation), - }); + Call(conv_func, { + GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(convolution), ir_ptr_type), + BitCast(lhs_address, ir_ptr_type), + BitCast(rhs_address, ir_ptr_type), + b_.getInt64(input_batch), + b_.getInt64(input_rows), + b_.getInt64(input_cols), + b_.getInt64(input_channels), + b_.getInt64(kernel_rows), + b_.getInt64(kernel_cols), + b_.getInt64(kernel_channels), + b_.getInt64(kernel_filters), + b_.getInt64(output_rows), + b_.getInt64(output_cols), + b_.getInt64(row_stride), + b_.getInt64(col_stride), + b_.getInt64(padding_top), + b_.getInt64(padding_bottom), + b_.getInt64(padding_left), + b_.getInt64(padding_right), + b_.getInt64(lhs_row_dilation), + b_.getInt64(lhs_col_dilation), + b_.getInt64(rhs_row_dilation), + b_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1159,15 +1152,14 @@ Status IrEmitter::HandleFft(HloInstruction* fft) { fft_func->setDoesNotThrow(); fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); const int fft_rank = fft_length.size(); - b_.CreateCall( - fft_func, - {GetExecutableRunOptionsArgument(), - b_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), - b_.CreateBitCast(operand_address, int8_ptr_type), - b_.getInt32(fft->fft_type()), b_.getInt32(fft_rank), - b_.getInt64(input_batch), b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), - b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), - b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + Call(fft_func, + {GetExecutableRunOptionsArgument(), + BitCast(GetEmittedValueFor(fft), int8_ptr_type), + BitCast(operand_address, int8_ptr_type), b_.getInt32(fft->fft_type()), + b_.getInt32(fft_rank), b_.getInt64(input_batch), + b_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + b_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + b_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); return Status::OK(); } @@ -1206,8 +1198,8 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { operand_ptrs.push_back(EmitTempBufferPointer(out_slice, operand_shape)); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, - /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); + MemCpy(operand_ptrs.back(), /*DstAlign=*/1, in_ptr, + /*SrcAlign=*/1, ShapeUtil::ByteSizeOf(operand_shape)); } llvm_ir::EmitTuple(GetIrArrayFor(crs), operand_ptrs, &b_, module_); return Status::OK(); @@ -1466,19 +1458,19 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( accumulator_shard_type, "accumulator", &b_, 0)); } - llvm::Value* init_value_ssa = b_.CreateLoad(GetEmittedValueFor(init_value)); + llvm::Value* init_value_ssa = Load(GetEmittedValueFor(init_value)); for (llvm::Value* accumulator_shard : accumulator) { llvm::Value* initial_value; auto shard_type = accumulator_shard->getType()->getPointerElementType(); if (auto vector_type = llvm::dyn_cast<llvm::VectorType>(shard_type)) { initial_value = - b_.CreateVectorSplat(vector_type->getNumElements(), init_value_ssa); + VectorSplat(vector_type->getNumElements(), init_value_ssa); } else { initial_value = init_value_ssa; } - b_.CreateAlignedStore(initial_value, accumulator_shard, element_alignment); + AlignedStore(initial_value, accumulator_shard, element_alignment); } llvm_ir::ForLoopNest reduction_loop_nest(IrName(arg, "vectorized_inner"), @@ -1500,24 +1492,24 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( } CHECK(output_index.end() == it); - llvm::Value* input_address = b_.CreateBitCast( + llvm::Value* input_address = BitCast( arg_array.EmitArrayElementAddress(input_index, &b_), b_.getInt8PtrTy()); for (int i = 0; i < accumulator.size(); i++) { auto input_address_typed = - b_.CreateBitCast(input_address, accumulator[i]->getType()); + BitCast(input_address, accumulator[i]->getType()); auto current_accumulator_value = - b_.CreateAlignedLoad(accumulator[i], element_alignment); - auto addend = b_.CreateAlignedLoad(input_address_typed, element_alignment); + AlignedLoad(accumulator[i], element_alignment); + auto addend = AlignedLoad(input_address_typed, element_alignment); arg_array.AnnotateLoadStoreInstructionWithMetadata(addend); auto reduced_result = reduction_generator(&b_, current_accumulator_value, addend); - b_.CreateAlignedStore(reduced_result, accumulator[i], element_alignment); + AlignedStore(reduced_result, accumulator[i], element_alignment); if (i != (accumulator.size() - 1)) { - input_address = b_.CreateConstInBoundsGEP1_32(reduced_result->getType(), - input_address_typed, 1); + input_address = ConstInBoundsGEP1_32(reduced_result->getType(), + input_address_typed, 1); } } @@ -1526,8 +1518,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( ShardedVector result_ssa; result_ssa.reserve(accumulator.size()); for (auto accumulator_shard : accumulator) { - result_ssa.push_back( - b_.CreateAlignedLoad(accumulator_shard, element_alignment)); + result_ssa.push_back(AlignedLoad(accumulator_shard, element_alignment)); } return result_ssa; } @@ -1536,18 +1527,18 @@ void IrEmitter::EmitShardedVectorStore( llvm::Value* store_address, const std::vector<llvm::Value*>& value_to_store, const int alignment, const llvm_ir::IrArray& containing_array) { for (int i = 0; i < value_to_store.size(); i++) { - auto store_address_typed = b_.CreateBitCast( - store_address, - llvm::PointerType::getUnqual(value_to_store[i]->getType())); + auto store_address_typed = + BitCast(store_address, + llvm::PointerType::getUnqual(value_to_store[i]->getType())); - auto store_instruction = b_.CreateAlignedStore( - value_to_store[i], store_address_typed, alignment); + auto store_instruction = + AlignedStore(value_to_store[i], store_address_typed, alignment); containing_array.AnnotateLoadStoreInstructionWithMetadata( store_instruction); if (i != (value_to_store.size() - 1)) { - store_address = b_.CreateConstInBoundsGEP1_32( - value_to_store[i]->getType(), store_address_typed, 1); + store_address = ConstInBoundsGEP1_32(value_to_store[i]->getType(), + store_address_typed, 1); } } } @@ -1620,9 +1611,8 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); - std::unique_ptr<llvm_ir::ForLoop> loop = - loop_nest.AddLoop(start_index, end_index, - tensorflow::strings::Printf("dim.%lld", dimension)); + std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( + start_index, end_index, absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -1641,9 +1631,9 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce( int64 start_index = 0; int64 end_index = (innermost_dimension_size / vectorization_factor) * vectorization_factor; - std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( - start_index, end_index, vectorization_factor, - tensorflow::strings::Printf("dim.%lld", innermost_dimension)); + std::unique_ptr<llvm_ir::ForLoop> loop = + loop_nest.AddLoop(start_index, end_index, vectorization_factor, + absl::StrFormat("dim.%d", innermost_dimension)); array_index[innermost_dimension] = loop->GetIndVarValue(); SetToFirstInsertPoint(loop->GetBodyBasicBlock(), &b_); @@ -1713,8 +1703,8 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &b_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); - llvm::Value* load_init_value = b_.CreateLoad(init_value_addr); - b_.CreateStore(load_init_value, accumulator_addr); + llvm::Value* load_init_value = Load(init_value_addr); + Store(load_init_value, accumulator_addr); // The enclosing loops go over all the target elements. Now we have to compute // the actual target element. For this, we build a new loop nest to iterate @@ -1747,12 +1737,12 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduce( // Apply the reduction function to the loaded value. llvm::Value* input_element = arg_array.EmitReadArrayElement(input_index, &b_); llvm::Value* result = EmitThreadLocalCall( - *reduce->to_apply(), {b_.CreateLoad(accumulator_addr), input_element}, + *reduce->to_apply(), {Load(accumulator_addr), input_element}, "reduce_function"); - b_.CreateStore(result, accumulator_addr); + Store(result, accumulator_addr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); } Status IrEmitter::HandleReduce(HloInstruction* reduce) { @@ -1990,7 +1980,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { [this, pad](const llvm_ir::IrArray::Index& target_index) { const HloInstruction* padding_value = pad->operand(1); llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value); - return b_.CreateLoad(padding_value_addr); + return Load(padding_value_addr); })); // Create a loop to iterate over the operand elements and update the output @@ -2012,10 +2002,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { const PaddingConfig& padding_config = pad->padding_config(); llvm_ir::IrArray::Index output_index(operand_index.GetType()); for (size_t i = 0; i < operand_index.size(); ++i) { - llvm::Value* offset = b_.CreateMul( - operand_index[i], - b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); - llvm::Value* index = b_.CreateAdd( + llvm::Value* offset = + Mul(operand_index[i], + b_.getInt64(padding_config.dimensions(i).interior_padding() + 1)); + llvm::Value* index = Add( offset, b_.getInt64(padding_config.dimensions(i).edge_padding_low())); output_index.push_back(index); } @@ -2118,7 +2108,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { gtl::ArraySlice<HloInstruction*> operands(custom_call->operands()); - tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); + absl::string_view custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2126,10 +2116,10 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { for (size_t i = 0; i < operands.size(); ++i) { const HloInstruction* operand = operands[i]; llvm::Value* operand_as_i8ptr = - b_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type); + PointerCast(GetEmittedValueFor(operand), i8_ptr_type); llvm::Value* slot_in_operands_alloca = - b_.CreateInBoundsGEP(operands_alloca, {b_.getInt64(i)}); - b_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca); + InBoundsGEP(operands_alloca, {b_.getInt64(i)}); + Store(operand_as_i8ptr, slot_in_operands_alloca); } auto* custom_call_ir_function = llvm::cast<llvm::Function>(module_->getOrInsertFunction( @@ -2141,9 +2131,9 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); auto* output_address_arg = - b_.CreatePointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); + PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type); - b_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); + Call(custom_call_ir_function, {output_address_arg, operands_alloca}); return Status::OK(); } @@ -2170,8 +2160,8 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { return InternalError( "instruction %s %s does not share slice with " "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), + slice_b.ToString()); } return Status::OK(); }; @@ -2202,15 +2192,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), compute_function_->function()); - b_.CreateBr(header_bb); + Br(header_bb); b_.SetInsertPoint(header_bb); // Calls the condition function to determine whether to proceed with the // body. It must return a bool, so use the scalar call form. EmitGlobalCall(*xla_while->while_condition(), IrName(xla_while, "cond")); - llvm::Value* while_predicate = b_.CreateICmpNE( - b_.CreateLoad( - GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), + llvm::Value* while_predicate = ICmpNE( + Load(GetBufferForGlobalCallReturnValue(*xla_while->while_condition())), llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. @@ -2219,7 +2208,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); - b_.CreateCondBr(while_predicate, body_bb, exit_bb); + CondBr(while_predicate, body_bb, exit_bb); // Calls the body function from the body block. b_.SetInsertPoint(body_bb); @@ -2228,7 +2217,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { EmitGlobalCall(*xla_while->while_body(), IrName(xla_while, "body")); // Finishes with a branch back to the header. - b_.CreateBr(header_bb); + Br(header_bb); // Adds the exit block to the function and sets the insert point there. compute_function_->function()->getBasicBlockList().push_back(exit_bb); @@ -2275,7 +2264,6 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( output_min2maj.end()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); - llvm::Type* i8_type = b_.getInt8Ty(); TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); @@ -2298,9 +2286,9 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = - b_.CreateBitCast(target_array.EmitArrayElementAddress( - outer_dims_index, &b_, "target_region"), - i8_ptr_type); + BitCast(target_array.EmitArrayElementAddress(outer_dims_index, &b_, + "target_region"), + i8_ptr_type); int64 byte_offset_into_target_region = 0; int64 inner_dims_product = @@ -2314,13 +2302,12 @@ StatusOr<bool> IrEmitter::EmitFastConcatenate( for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); llvm_ir::IrArray source_array = GetIrArrayFor(operand); - llvm::Value* copy_source_address = b_.CreateBitCast( + llvm::Value* copy_source_address = BitCast( source_array.EmitArrayElementAddress(outer_dims_index, &b_, "src_addr"), i8_ptr_type); llvm::Value* copy_target_address = - b_.CreateGEP(i8_type, target_region_begin, - b_.getInt64(byte_offset_into_target_region)); + GEP(target_region_begin, b_.getInt64(byte_offset_into_target_region)); EmitTransferElements( copy_target_address, copy_source_address, @@ -2352,15 +2339,15 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { - auto* load_instruction = b_.CreateAlignedLoad( - b_.CreateBitCast(source, primitive_ptr_type), element_alignment); + auto* load_instruction = + AlignedLoad(BitCast(source, primitive_ptr_type), element_alignment); source_array.AnnotateLoadStoreInstructionWithMetadata(load_instruction); - auto* store_instruction = b_.CreateAlignedStore( - load_instruction, b_.CreateBitCast(target, primitive_ptr_type), - element_alignment); + auto* store_instruction = + AlignedStore(load_instruction, BitCast(target, primitive_ptr_type), + element_alignment); target_array.AnnotateLoadStoreInstructionWithMetadata(store_instruction); } else { - auto* memcpy_instruction = b_.CreateMemCpy( + auto* memcpy_instruction = MemCpy( target, /*DstAlign=*/element_alignment, source, /*SrcAlign=*/element_alignment, element_count * primitive_type_size); @@ -2422,9 +2409,9 @@ Status IrEmitter::HandleConditional(HloInstruction* conditional) { // cond_result = true_computation(true_operand) // else // cond_result = false_computation(false_operand) - llvm::LoadInst* pred_value = b_.CreateLoad( - GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = b_.CreateICmpNE( + llvm::LoadInst* pred_value = + Load(GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ICmpNE( pred_value, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); @@ -2450,11 +2437,6 @@ Status IrEmitter::HandleAfterAll(HloInstruction* gen_token) { return Status::OK(); } -Status IrEmitter::HandleIota(HloInstruction* iota) { - // TODO(b/64798317): implement iota on CPU. - return Unimplemented("Iota is not implemented on CPU."); -} - Status IrEmitter::HandleRng(HloInstruction* rng) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : rng->operands()) { @@ -2511,8 +2493,8 @@ llvm::Value* IrEmitter::GetProfileCounterCommon( int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); - return b_.CreateGEP(GetProfileCountersArgument(), - b_.getInt64(prof_counter_idx), AsStringRef(counter_name)); + return GEP(GetProfileCountersArgument(), b_.getInt64(prof_counter_idx), + AsStringRef(counter_name)); } void IrEmitter::ProfilingState::UpdateProfileCounter(llvm::IRBuilder<>* b, @@ -2666,8 +2648,7 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &b_); - llvm::LoadInst* param_address_untyped = - b_.CreateLoad(param_address_offset); + llvm::LoadInst* param_address_untyped = Load(param_address_offset); if (!ShapeUtil::IsOpaque(target_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, target_shape); @@ -2687,17 +2668,15 @@ llvm::Value* IrEmitter::EmitThreadLocalTempBufferPointer( auto buf_it = thread_local_buffers_.find(key); if (buf_it == thread_local_buffers_.end()) { llvm::Value* buffer = llvm_ir::EmitAllocaAtFunctionEntry( - IrShapeType(shape), - tensorflow::strings::StrCat("thread_local", slice.ToString()), &b_, - MinimumAlignmentForShape(target_shape)); + IrShapeType(shape), absl::StrCat("thread_local", slice.ToString()), + &b_, MinimumAlignmentForShape(target_shape)); auto it_inserted_pair = thread_local_buffers_.insert({key, buffer}); CHECK(it_inserted_pair.second); buf_it = it_inserted_pair.first; } return buf_it->second; }(); - return b_.CreateBitCast(tempbuf_address, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address, IrShapeType(target_shape)->getPointerTo()); } llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( @@ -2705,7 +2684,7 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( const BufferAllocation& allocation = *slice.allocation(); llvm::Value* tempbuf_address_ptr = llvm_ir::EmitBufferIndexingGEP( GetTempBuffersArgument(), slice.index(), &b_); - llvm::LoadInst* tempbuf_address_base = b_.CreateLoad(tempbuf_address_ptr); + llvm::LoadInst* tempbuf_address_base = Load(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { tempbuf_address_base->setMetadata( @@ -2719,10 +2698,10 @@ llvm::Value* IrEmitter::EmitGlobalTempBufferPointer( if (slice.offset() > 0) { // Adjust the address to account for the slice offset. tempbuf_address_untyped = - b_.CreateInBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); + InBoundsGEP(tempbuf_address_base, b_.getInt64(slice.offset())); } - return b_.CreateBitCast(tempbuf_address_untyped, - IrShapeType(target_shape)->getPointerTo()); + return BitCast(tempbuf_address_untyped, + IrShapeType(target_shape)->getPointerTo()); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2753,7 +2732,7 @@ Status IrEmitter::EmitTargetElementLoop( } Status IrEmitter::EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); @@ -2808,8 +2787,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, llvm::Value* destination_value = GetEmittedValueFor(&destination); int64 source_size = ByteSizeOf(source.shape()); // TODO(b/63762267): Be more aggressive about specifying alignment. - b_.CreateMemCpy(destination_value, /*DstAlign=*/1, source_value, - /*SrcAlign=*/1, source_size); + MemCpy(destination_value, /*DstAlign=*/1, source_value, + /*SrcAlign=*/1, source_size); return Status::OK(); } @@ -2827,8 +2806,8 @@ Status IrEmitter::ElementTypesSameAndSupported( if (std::find(supported_types.begin(), supported_types.end(), primitive_type) == supported_types.end()) { return Unimplemented("unsupported operand type %s in op %s", - PrimitiveType_Name(primitive_type).c_str(), - HloOpcodeString(instruction.opcode()).c_str()); + PrimitiveType_Name(primitive_type), + HloOpcodeString(instruction.opcode())); } return Status::OK(); } @@ -2848,7 +2827,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { llvm::Value* IrEmitter::EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name) { + absl::string_view name) { const Shape& return_shape = callee.root_instruction()->shape(); // Lifting this restriction to allow "small" arrays should be easy. Allowing @@ -2863,38 +2842,37 @@ llvm::Value* IrEmitter::EmitThreadLocalCall( CHECK(!parameter->getType()->isPointerTy()); llvm::Value* parameter_addr = llvm_ir::EmitAllocaAtFunctionEntry( parameter->getType(), "arg_addr", &b_); - b_.CreateStore(parameter, parameter_addr); + Store(parameter, parameter_addr); parameter_addrs.push_back(parameter_addr); } llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(return_type, module_), - tensorflow::strings::StrCat(name, "_retval_addr"), &b_, + absl::StrCat(name, "_retval_addr"), &b_, MinimumAlignmentForPrimitiveType(return_type)); - b_.CreateCall( - FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - parameter_addrs, &b_, name, - /*return_value_buffer=*/return_value_buffer, - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), - /*profile_counters_arg=*/GetProfileCountersArgument())); + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + parameter_addrs, &b_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()->getPointerTo()), + /*profile_counters_arg=*/GetProfileCountersArgument())); - return b_.CreateLoad(return_value_buffer); + return Load(return_value_buffer); } void IrEmitter::EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name) { - b_.CreateCall(FindOrDie(emitted_functions_, &callee), - GetArrayFunctionCallArguments( - /*parameter_addresses=*/{}, &b_, name, - /*return_value_buffer=*/ - llvm::Constant::getNullValue(b_.getInt8PtrTy()), - /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), - /*temp_buffers_arg=*/GetTempBuffersArgument(), - /*profile_counters_arg=*/GetProfileCountersArgument())); + absl::string_view name) { + Call(FindOrDie(emitted_functions_, &callee), + GetArrayFunctionCallArguments( + /*parameter_addresses=*/{}, &b_, name, + /*return_value_buffer=*/ + llvm::Constant::getNullValue(b_.getInt8PtrTy()), + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::GetBufferForGlobalCallReturnValue( diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index c9a1dab62d..f98891246b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/ADT/Triple.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -39,12 +40,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -55,7 +56,8 @@ namespace cpu { // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin<IrEmitter> { public: // Create a new LLVM IR emitter. // @@ -100,6 +102,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::IRBuilder<>* b() { return &b_; } + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return &b_; } + // Emit an LLVM global variable for every constant buffer allocation. Status EmitConstantGlobals(); @@ -107,7 +112,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitElementalMap( const HloMapInstruction& map_instr, tensorflow::gtl::ArraySlice<llvm::Value*> elemental_operands, - tensorflow::StringPiece name); + absl::string_view name); protected: // @@ -152,7 +157,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleConditional(HloInstruction* conditional) override; Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; - Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; Status FinishVisit(HloInstruction* root) override; @@ -239,7 +243,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { // function that a map operation applies. StatusOr<llvm::Function*> EmitFunction( HloComputation* function, // The function to emit. - tensorflow::StringPiece + absl::string_view function_name_suffix); // Used for LLVM IR register names. // Emits a call to a thread local function (e.g. to the computation nested @@ -251,14 +255,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* EmitThreadLocalCall( const HloComputation& callee, tensorflow::gtl::ArraySlice<llvm::Value*> parameters, - tensorflow::StringPiece name); + absl::string_view name); // Emits a call to a "global" function (e.g. to the computation nested within // a kWhile or a kCall). Buffer assignment unabiguously assignes buffers to // the parameters and return values for these computations so there is no need // to explicitly pass parameters or return results. - void EmitGlobalCall(const HloComputation& callee, - tensorflow::StringPiece name); + void EmitGlobalCall(const HloComputation& callee, absl::string_view name); // Returns the buffer to which a global call to `callee` would have written // its result. @@ -285,7 +288,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); Status EmitTargetElementLoop( - HloInstruction* target_op, tensorflow::StringPiece desc, + HloInstruction* target_op, absl::string_view desc, const llvm_ir::ElementGenerator& element_generator); // Emits a memcpy from the source instruction's result value to the diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc index 2db4d000f5..784045313d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -189,7 +190,7 @@ void IrFunction::Initialize(const string& function_name, llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { CHECK_GT(num_dynamic_loop_bounds_, 0); CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + string name = absl::StrCat("dynamic_loop_bound_", offset); return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), b_->getInt64(offset), AsStringRef(name))); } @@ -200,7 +201,7 @@ llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { // address buffer). std::vector<llvm::Value*> GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { llvm::Value* parameter_addresses_buffer; @@ -211,13 +212,13 @@ std::vector<llvm::Value*> GetArrayFunctionCallArguments( } else { parameter_addresses_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( b->getInt8PtrTy(), b->getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), b); + absl::StrCat(name, "_parameter_addresses"), b); for (size_t i = 0; i < parameter_addresses.size(); ++i) { llvm::Value* parameter_as_i8ptr = b->CreateBitCast(parameter_addresses[i], b->getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat( - name, "_parameter_", i, "_address_as_i8ptr"))); + AsStringRef(absl::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); llvm::Value* slot_in_param_addresses = b->CreateInBoundsGEP(parameter_addresses_buffer, {b->getInt64(i)}); b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); @@ -320,8 +321,7 @@ Status EmitCallToParallelForkJoin( /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/partitions_array, /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + AsStringRef(absl::StrCat(name, "_parallel_dimension_partitions"))); // Add argument specifying parallel dimension partitions. fork_join_arguments.push_back( diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h index a41cbb64cd..ee7595f6e9 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_function.h +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -116,7 +116,7 @@ class IrFunction { // Returns an array of compute function call argument ir values. std::vector<llvm::Value*> GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses, - llvm::IRBuilder<>* b, tensorflow::StringPiece name, + llvm::IRBuilder<>* b, absl::string_view name, llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index 8560e4296a..f8441c3e34 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace cpu { @@ -30,8 +30,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( dynamic_loop_bounds_(dynamic_loop_bounds) {} std::vector<llvm_ir::IrArray::Index> -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { CHECK_NE(index_type, nullptr); CHECK(!ShapeUtil::IsTuple(shape_)); @@ -52,15 +52,15 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); + /*suffix=*/absl::StrFormat("dim.%d", dimension), start_index, + end_index); array_index[dimension] = loop->GetIndVarValue(); } else { // Emit static loop bounds for this dimension. std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index 076c683ca5..a604e1db22 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4fa5984b04..b4c0c09ec0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" @@ -109,7 +111,7 @@ ParallelTaskAssignment::ParallelTaskAssignment( : target_machine_features_(*target_machine_features) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. - auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size); + auto cost_analysis = absl::make_unique<HloCostAnalysis>(shape_size); HloComputation* computation = module->entry_computation(); Status status = computation->root_instruction()->Accept(cost_analysis.get()); if (status.ok()) { @@ -216,8 +218,7 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( // Outline 'instruction' in 'computation' for parallel task assignment. auto* call = module->OutlineExpressionFromComputation( - {instruction}, - tensorflow::strings::StrCat("parallel_", instruction->name()), + {instruction}, absl::StrCat("parallel_", instruction->name()), computation); // Set assigned dimension partitioning to 'instruction'. diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 8becc8fa23..a99cd99c14 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -73,7 +73,7 @@ class ParallelTaskAssigner : public HloPassInterface { target_machine_features_(*target_machine_features) {} ~ParallelTaskAssigner() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cpu-parallel-task-assigner"; } diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index ee272b5f4f..a84ee78b19 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -36,7 +35,9 @@ class ParallelTaskAssignmentTest : public HloVerifiedTestBase { cpu::TargetMachineFeaturesWithFakeAlignmentLogic target_machine_features_; ParallelTaskAssignmentTest() - : target_machine_features_([](int64 shape_size) { + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + target_machine_features_([](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }) {} diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index f227e4ae13..942e2ddd39 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -16,6 +16,7 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -67,8 +67,8 @@ int main(int argc, char** argv) { /*execution_profile=*/&profile); std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie(); - LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", - profile.compute_time_ns()); + LOG(INFO) << absl::StrFormat("computation took %dns", + profile.compute_time_ns()); LOG(INFO) << actual->ToString(); return 0; diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index be772cfb7e..bf98064647 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -20,13 +20,13 @@ limitations under the License. #include <list> #include <utility> +#include "absl/memory/memory.h" #include "llvm/ExecutionEngine/ExecutionEngine.h" #include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" @@ -170,15 +170,14 @@ namespace { bool RegisterKnownJITSymbols() { CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); -#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ - do { \ - auto* function_address = \ - reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ - registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ - CHECK_EQ( \ - tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ - "__xla_cpu_runtime_" #base_name); \ +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ } while (false) REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 181cec3cdd..2384166fd2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -51,6 +51,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", "@llvm//:core", ], ) @@ -108,6 +110,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) @@ -121,6 +124,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 6fcce42eaa..fcd87b36b3 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include <cctype> #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc index d98856fdbf..22721051e5 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_fusion_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -129,8 +129,8 @@ TEST_F(CpuFusionTest, FuseElementwiseOpChain) { error_spec_); } -TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusableInstruction) { - // Test a chain of fusable ops with a non-fusable op (a reduce) thrown in the +TEST_F(CpuFusionTest, ElementwiseOpChainWithNonfusibleInstruction) { + // Test a chain of fusible ops with a non-fusible op (a reduce) thrown in the // middle. auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc index 973aac8766..a434c04a98 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include <cctype> #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,9 +32,9 @@ const char* const kTriple_android_arm = "armv7-none-android"; struct IntrinsicTestSpec { HloOpcode opcode; - tensorflow::StringPiece triple; - tensorflow::StringPiece features; - tensorflow::StringPiece check_lines; + absl::string_view triple; + absl::string_view features; + absl::string_view check_lines; }; // Tests that unary functions get lowered using intrinsic calls. @@ -65,9 +65,8 @@ class CpuUnaryIntrinsicTest features = ""; } - return tensorflow::strings::StrCat(opcode.c_str(), "_On_", triple.c_str(), - features.empty() ? "" : "_With", - features.c_str()); + return absl::StrCat(opcode, "_On_", triple, + (features.empty() ? "" : "_With"), features); } }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 01daed4bcd..bb105194f1 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -62,7 +62,8 @@ TEST_F(CpuNoAliasTest, Concat) { // Now that we have an HLO module, build an llvm_ir::AliasAnalysis for it. auto status_or_buffer_assn = BufferAssigner::Run( - hlo_module.get(), MakeUnique<DependencyHloOrdering>(hlo_module.get()), + hlo_module.get(), + absl::make_unique<DependencyHloOrdering>(hlo_module.get()), backend().compiler()->BufferSizeBytesFunction(), [](LogicalBuffer::Color) { return /*alignment=*/1; }); ASSERT_EQ(status_or_buffer_assn.status(), Status::OK()); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 3274be8d9d..962ea69c09 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "absl/algorithm/container.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -422,8 +423,8 @@ TileVariable::TileVariable(VectorSupportLibrary* vector_support, std::vector<llvm::Value*> TileVariable::Get() const { std::vector<llvm::Value*> result; - c_transform(storage_, std::back_inserter(result), - [&](VectorVariable vect_var) { return vect_var.Get(); }); + absl::c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); return result; } diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index 56b28fd22d..c326beb899 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -29,7 +29,7 @@ class Defuser : public HloPassInterface { public: Defuser() {} ~Defuser() override {} - tensorflow::StringPiece name() const override { return "defuser"; } + absl::string_view name() const override { return "defuser"; } // Run defusion on the given module. Returns whether the module was // changed. diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc index e727ba49cb..37d1895d41 100644 --- a/tensorflow/compiler/xla/service/defuser_test.cc +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class DefuserTest : public HloVerifiedTestBase { + public: + DefuserTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Returns the number of fusion instructions in the module. int FusionCount() { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index 48e4471499..ba2a674d9a 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -27,9 +27,7 @@ namespace { class ControlDepRemover : public HloPassInterface { public: ControlDepRemover() = default; - tensorflow::StringPiece name() const override { - return "control-dep-remover"; - } + absl::string_view name() const override { return "control-dep-remover"; } StatusOr<bool> Run(HloModule* module) override { bool changed = false; diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index cc1695b7f8..7be70add2f 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -33,7 +33,7 @@ namespace xla { class Despecializer : public HloPassInterface { public: Despecializer(); - tensorflow::StringPiece name() const override { return "despecializer"; } + absl::string_view name() const override { return "despecializer"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index e228bb56bc..1d0297cfbf 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -36,9 +36,8 @@ StatusOr<OwningDeviceMemory> StreamExecutorMemoryAllocator::Allocate( se::DeviceMemoryBase result = stream_executor->AllocateArray<uint8>(size); if (size > 0 && result == nullptr) { return ResourceExhausted( - "Failed to allocate request for %s (%lluB) on device ordinal %d", - tensorflow::strings::HumanReadableNumBytes(size).c_str(), size, - device_ordinal); + "Failed to allocate request for %s (%uB) on device ordinal %d", + tensorflow::strings::HumanReadableNumBytes(size), size, device_ordinal); } return OwningDeviceMemory(result, device_ordinal, this); } @@ -61,12 +60,12 @@ StatusOr<se::StreamExecutor*> StreamExecutorMemoryAllocator::GetStreamExecutor( } if (device_ordinal >= stream_executors_.size()) { return InvalidArgument( - "device ordinal value (%d) >= number of devices (%zu)", device_ordinal, + "device ordinal value (%d) >= number of devices (%u)", device_ordinal, stream_executors_.size()); } if (stream_executors_[device_ordinal] == nullptr) { return NotFound("Device %s:%d present but not supported", - platform()->Name().c_str(), device_ordinal); + platform()->Name(), device_ordinal); } return stream_executors_[device_ordinal]; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 2172ae0a29..3e7373adc5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -28,14 +28,14 @@ template <typename HloInstructionPtr> Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseUnary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template <typename HloInstructionPtr> Status DfsHloVisitorBase<HloInstructionPtr>::HandleElementwiseBinary( HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); } template <typename HloInstructionPtr> diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 86d57581f8..f6f8fc5a2a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -19,13 +19,13 @@ limitations under the License. #include <type_traits> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" @@ -107,6 +107,7 @@ class DfsHloVisitorBase { virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; + virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } @@ -208,7 +209,6 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; - virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 617a5a2eb4..4f620e4c3a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -94,8 +94,11 @@ class DfsHloVisitorWithDefaultBase Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleAllToAll(HloInstructionPtr crs) override { - return DefaultAction(crs); + Status HandleAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } + Status HandleCollectivePermute(HloInstructionPtr hlo) override { + return DefaultAction(hlo); } Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); @@ -106,9 +109,6 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } - Status HandleHostCompute(HloInstructionPtr host_compute) override { - return DefaultAction(host_compute); - } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 12faed6967..09cb10d6ee 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -136,6 +136,7 @@ Status DecomposeBatchDot(HloInstruction* dot) { dot_dnums.add_rhs_contracting_dimensions(0); auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + dot_r2->set_precision_config(dot->precision_config()); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index 1959b687f1..fc38e31700 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -29,7 +29,7 @@ class DotDecomposer : public HloPassInterface { DotDecomposer(bool decompose_batch_dot = true) : decompose_batch_dot_(decompose_batch_dot) {} ~DotDecomposer() = default; - tensorflow::StringPiece name() const override { return "dot_decomposer"; } + absl::string_view name() const override { return "dot_decomposer"; } // Run DotDecomposer pass on computations in 'module'. // Returns whether the 'module' was changed. diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 891ae42141..813e93fafa 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -21,11 +21,15 @@ limitations under the License. #include <vector> // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -38,17 +42,16 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +using absl::StrCat; using llvm_ir::AsStringRef; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrCat; namespace { @@ -203,7 +206,7 @@ llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, } // namespace StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { if (op->opcode() == HloOpcode::kCopy) { return operand_value; } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || @@ -217,7 +220,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp( } StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -229,14 +232,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateICmpNE(operand_value, llvm::ConstantInt::get( - operand_value->getType(), 0)), + ICmpNE(operand_value, + llvm::ConstantInt::get(operand_value->getType(), 0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsIntegralType(to_type)) { - return b_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(from_type)); + return IntCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_), + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -252,19 +255,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { return EmitComposeComplex( - op, b_->CreateSIToFP(operand_value, to_ir_component_type), - nullptr); + op, SIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return EmitComposeComplex( - op, b_->CreateUIToFP(operand_value, to_ir_component_type), - nullptr); + op, UIToFP(operand_value, to_ir_component_type), nullptr); } } return Unimplemented("conversion from primitive type %s to %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -275,14 +276,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -292,10 +292,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( if (is_signed) { auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpSGE(operand_value, zero); - return b_->CreateSelect(cmp, operand_value, - b_->CreateNeg(operand_value)); + auto cmp = ICmpSGE(operand_value, GetZero(type)); + return Select(cmp, operand_value, Neg(operand_value)); } else { return operand_value; } @@ -307,44 +305,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( {operand_value->getType()}, b_); } case HloOpcode::kSign: { - bool is_signed = - primitive_util::IsSignedIntegralType(op->shape().element_type()); + CHECK(primitive_util::IsSignedIntegralType(op->shape().element_type())) + << op->shape().element_type(); auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto zero = llvm::ConstantInt::get(type, 0); - auto cmp = b_->CreateICmpEQ(operand_value, zero); - if (is_signed) { - auto ashr = - b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1); - return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1)); - } else { - return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1)); - } + auto cmp = ICmpEQ(operand_value, GetZero(type)); + auto ashr = AShr(operand_value, type->getIntegerBitWidth() - 1); + return Select(cmp, GetZero(type), Or(ashr, 1)); } case HloOpcode::kNegate: - return b_->CreateNeg(operand_value); + return Neg(operand_value); case HloOpcode::kNot: { auto type = op->shape().element_type(); if (type == PRED) { // It is not sufficient to just call CreateNot() here because a PRED // is represented as an i8 and the truth value is stored only in the // bottom bit. - return b_->CreateZExt( - b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + return b_->CreateZExt(Not(Trunc(operand_value, b_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } else if (primitive_util::IsIntegralType(type)) { - return b_->CreateNot(operand_value); + return Not(operand_value); } return Unimplemented("unary op Not is not defined for type '%d'", type); } default: return Unimplemented("unary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { switch (op->opcode()) { case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -361,8 +352,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( } return EmitComposeComplex( op, - b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType( - to_component_type, module_)), + FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } if (from_type == BF16) { @@ -378,26 +369,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( } if (to_type == PRED) { return b_->CreateZExt( - b_->CreateFCmpUNE( - operand_value, - llvm::ConstantFP::get(operand_value->getType(), 0.0)), + FCmpUNE(operand_value, + llvm::ConstantFP::get(operand_value->getType(), 0.0)), llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } if (primitive_util::IsFloatingPointType(to_type)) { - return b_->CreateFPCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { - return b_->CreateFPToSI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToSI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { - return b_->CreateFPToUI( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return FPToUI(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str()); + PrimitiveType_Name(from_type), + PrimitiveType_Name(to_type)); } case HloOpcode::kBitcastConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -408,14 +398,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( } if (primitive_util::BitWidth(from_type) == primitive_util::BitWidth(to_type)) { - return b_->CreateBitCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + return BitCast(operand_value, + llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return InvalidArgument( "bitcast conversion from primitive type %s to %s with unequal " "bit-widths (%u versus %u) ", - PrimitiveType_Name(from_type).c_str(), - PrimitiveType_Name(to_type).c_str(), + PrimitiveType_Name(from_type), PrimitiveType_Name(to_type), primitive_util::BitWidth(from_type), primitive_util::BitWidth(to_type)); } @@ -453,11 +442,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(operand_value, zero); - auto olt = b_->CreateFCmpOLT(operand_value, zero); - return b_->CreateSelect( - oeq, zero, - b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), + auto oeq = FCmpOEQ(operand_value, zero); + auto olt = FCmpOLT(operand_value, zero); + return Select(oeq, zero, + Select(olt, llvm::ConstantFP::get(type, -1.0), llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { @@ -467,24 +455,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, b_); auto infinity = llvm::ConstantFP::getInfinity(type); - auto not_infinite = b_->CreateFCmpONE(abs_value, infinity); + auto not_infinite = FCmpONE(abs_value, infinity); return b_->CreateZExt(not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: - return b_->CreateFNeg(operand_value); + return FNeg(operand_value); case HloOpcode::kReal: return operand_value; case HloOpcode::kImag: return llvm::ConstantFP::get(operand_value->getType(), 0.0); default: return Unimplemented("unary floating-point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { + const HloInstruction* op, llvm::Value* operand_value) { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType component_type = primitive_util::IsComplexType(input_type) @@ -496,12 +484,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto sum_sq = FAdd(FMul(a, a), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kLog1p: { // log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1) @@ -509,14 +496,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto one = llvm::ConstantFP::get(llvm_ty, 1.0); - auto a_plus_one = b_->CreateFAdd(a, one); - auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one), - b_->CreateFMul(b, b)); + auto a_plus_one = FAdd(a, one); + auto sum_sq = FAdd(FMul(a_plus_one, a_plus_one), FMul(b, b)); TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex(op, FMul(one_half, log_sum_sq), angle); } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -530,11 +515,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return EmitComposeComplex(op, - b_->CreateFPCast(EmitExtractReal(operand_value), - to_ir_component_type), - b_->CreateFPCast(EmitExtractImag(operand_value), - to_ir_component_type)); + return EmitComposeComplex( + op, FPCast(EmitExtractReal(operand_value), to_ir_component_type), + FPCast(EmitExtractImag(operand_value), to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) @@ -544,8 +527,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); - return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b), - b_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, FMul(exp_a, cos_b), FMul(exp_a, sin_b)); } case HloOpcode::kExpm1: { // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i @@ -556,8 +538,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN( auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0); - auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one); - auto imag_result = b_->CreateFMul(exp_a, sin_b); + auto real_result = FSub(FMul(exp_a, cos_b), one); + auto imag_result = FMul(exp_a, sin_b); return EmitComposeComplex(op, real_result, imag_result); } case HloOpcode::kCos: { @@ -572,14 +554,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)), - b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b))); + return EmitComposeComplex(op, + FMul(cos_a, FAdd(half_exp_neg_b, half_exp_b)), + FMul(sin_a, FSub(half_exp_neg_b, half_exp_b))); } case HloOpcode::kSin: { // sin(z) = .5i(e^(-iz) - e^(iz)) @@ -595,14 +576,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( auto b = EmitExtractImag(operand_value); auto type = a->getType(); TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); - auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); - auto half_exp_neg_b = - b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_b = FMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = FDiv(llvm::ConstantFP::get(type, 0.5), exp_b); TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); - return EmitComposeComplex( - op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)), - b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b))); + return EmitComposeComplex(op, + FMul(sin_a, FAdd(half_exp_b, half_exp_neg_b)), + FMul(cos_a, FSub(half_exp_b, half_exp_neg_b))); } case HloOpcode::kTanh: { /* @@ -630,74 +610,63 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); - auto exp_neg_a = - b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = b_->CreateFSub( - b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = b_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = b_->CreateFMul(sin_b, sin_b); - auto real_num = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a); + auto exp_neg_a = FDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = + FSub(FMul(exp_a, exp_a), FMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = FMul(cos_b, cos_b); + auto sin_b_sq = FMul(sin_b, sin_b); + auto real_num = FAdd(FMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + FMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = FMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = FAdd(exp_a, exp_neg_a); auto exp_a_plus_exp_neg_a_sq = - b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a); + FMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = FSub(exp_a, exp_neg_a); auto exp_a_minus_exp_neg_a_sq = - b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = b_->CreateFMul( - cos_b_sin_b, - b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); - auto denom = - b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom), - b_->CreateFDiv(imag_num, denom)); + FMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = FMul( + cos_b_sin_b, FSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq)); + auto denom = FAdd(FMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + FMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, FDiv(real_num, denom), + FDiv(imag_num, denom)); } case HloOpcode::kAbs: { - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); } case HloOpcode::kSign: { // Sign(c) = c / |c| - auto sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value), - EmitExtractReal(operand_value)), - b_->CreateFMul(EmitExtractImag(operand_value), - EmitExtractImag(operand_value))); + auto sum_sq = FAdd( + FMul(EmitExtractReal(operand_value), EmitExtractReal(operand_value)), + FMul(EmitExtractImag(operand_value), EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(cplx_abs, zero); + return Select( oeq, EmitComposeComplex(op, zero, zero), - EmitComposeComplex( - op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), - b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs))); + EmitComposeComplex(op, FDiv(EmitExtractReal(operand_value), cplx_abs), + FDiv(EmitExtractImag(operand_value), cplx_abs))); } case HloOpcode::kNegate: - return EmitComposeComplex(op, - b_->CreateFNeg(EmitExtractReal(operand_value)), - b_->CreateFNeg(EmitExtractImag(operand_value))); + return EmitComposeComplex(op, FNeg(EmitExtractReal(operand_value)), + FNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: return EmitExtractReal(operand_value); case HloOpcode::kImag: return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType operand_type = op->operand(0)->shape().element_type(); if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || operand_type == PRED) { @@ -712,21 +681,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp( } StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: - return b_->CreateFAdd(lhs_value, rhs_value); + return FAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateFSub(lhs_value, rhs_value); + return FSub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateFMul(lhs_value, rhs_value); + return FMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return b_->CreateFDiv(lhs_value, rhs_value); + return FDiv(lhs_value, rhs_value); case HloOpcode::kRemainder: - return b_->CreateFRem(lhs_value, rhs_value); + return FRem(lhs_value, rhs_value); // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas // unordered comparisons return true. @@ -763,66 +731,52 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { switch (op->opcode()) { case HloOpcode::kAdd: - return EmitComposeComplex(op, - b_->CreateFAdd(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFAdd(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FAdd(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FAdd(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return EmitComposeComplex(op, - b_->CreateFSub(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFSub(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))); + return EmitComposeComplex( + op, FSub(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FSub(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: return EmitComposeComplex( op, - b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)))); + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))), + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(rhs_value), - EmitExtractImag(rhs_value))); + FAdd(FMul(EmitExtractReal(rhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(rhs_value), EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); - auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero); - auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero); - auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero); - return b_->CreateSelect( + auto oeq = FCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = FDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = FDiv(EmitExtractImag(lhs_value), zero); + return Select( oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), - EmitComposeComplex( - op, - b_->CreateFDiv( - b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq), - b_->CreateFDiv( - b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value), - EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractReal(lhs_value), - EmitExtractImag(rhs_value))), - rhs_sum_sq))); + EmitComposeComplex(op, + FDiv(FAdd(FMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq), + FDiv(FSub(FMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + FMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), + rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered // comparisons always return false when one of the operands is NaN, whereas @@ -832,21 +786,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( // unordered comparison. This makes x != y equivalent to !(x == y), and // matches C++'s semantics. case HloOpcode::kEq: - return b_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return And(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kNe: - return b_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractReal(lhs_value), - EmitExtractReal(rhs_value), b_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, - EmitExtractImag(lhs_value), - EmitExtractImag(rhs_value), b_)); + return Or(llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), b_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), b_)); case HloOpcode::kPower: { // (a+bi)^(c+di) = @@ -858,45 +810,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( auto b = EmitExtractImag(lhs_value); auto c = EmitExtractReal(rhs_value); auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b)); + auto aa_p_bb = FAdd(FMul(a, a), FMul(b, b)); auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = b_->CreateFMul(one_half, c); + auto half_c = FMul(one_half, c); TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, EmitPow(component_type, aa_p_bb, half_c)); - auto neg_d = b_->CreateFNeg(d); + auto neg_d = FNeg(d); TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); - auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs); + auto neg_d_arg_lhs = FMul(neg_d, arg_lhs); TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, EmitExp(component_type, neg_d_arg_lhs)); - auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + auto coeff = FMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); - auto half_d = b_->CreateFMul(one_half, d); - auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs), - b_->CreateFMul(half_d, ln_aa_p_bb)); + auto half_d = FMul(one_half, d); + auto q = FAdd(FMul(c, arg_lhs), FMul(half_d, ln_aa_p_bb)); TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); - return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q), - b_->CreateFMul(coeff, sin_q)); + return EmitComposeComplex(op, FMul(coeff, cos_q), FMul(coeff, sin_q)); } default: return Unimplemented("binary complex op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + llvm::Value* rhs_value) { return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, - llvm::Value* x) const { + llvm::Value* x) { if (prim_type != F32) { // TODO(b/34339814): Implement inverse erf for F64. return Unimplemented( @@ -909,9 +859,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients, llvm::Value* w) { llvm::Value* p = getFloat(coefficients.front()); - coefficients.pop_front(); + coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), getFloat(coefficient)); } return p; }; @@ -931,25 +881,24 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::log, {b_->getFloatTy()}); - llvm::Value* w = b_->CreateFNeg(b_->CreateCall( - logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x), - b_->CreateFAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg( + Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); + FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); // Handle true BB. SetToFirstInsertPoint(if_data.true_block, b_); { - llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f)); + llvm::Value* lw = FSub(w, getFloat(2.5f)); tensorflow::gtl::ArraySlice<float> lq{ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, -4.39150654e-06f, 0.00021858087f, -0.00125372503f, -0.00417768164f, 0.246640727f, 1.50140941f}; llvm::Value* p = multiply_add(lq, lw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } // Handle false BB. @@ -958,76 +907,73 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - llvm::Value* gw = - b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); + llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); tensorflow::gtl::ArraySlice<float> gq{ -0.000200214257f, 0.000100950558f, 0.00134934322f, -0.00367342844f, 0.00573950773f, -0.0076224613f, 0.00943887047f, 1.00167406f, 2.83297682f}; llvm::Value* p = multiply_add(gq, gw); - b_->CreateStore(p, p_addr); + Store(p, p_addr); } SetToFirstInsertPoint(if_data.after_block, b_); - llvm::Value* p = b_->CreateLoad(p_addr); - return b_->CreateFMul(p, x); + llvm::Value* p = Load(p_addr); + return FMul(p, x); } -StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, + llvm::Value* value) { // Compute erfcinv(value) by calculating erfinv(1.0 - value). auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); - return EmitErfInv(prim_type, b_->CreateFSub(one, value)); + return EmitErfInv(prim_type, FSub(one, value)); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, {value->getType()}, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); auto negative_half = llvm::ConstantFP::get(type, -0.5); // When x is large, the naive evaluation of ln(x + 1) is more // accurate than the Taylor series. - TF_ASSIGN_OR_RETURN(auto for_large_x, - EmitLog(prim_type, b_->CreateFAdd(x, one))); + TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one))); // The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + …. - auto for_small_x = - b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x); + auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x); const auto kAntilogarithmIsSmallThreshold = 1e-4; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( + auto x_is_small = FCmpOLT( abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, {value->getType()}, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, {value->getType()}, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, {value->getType()}, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { auto x = value; auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); @@ -1035,40 +981,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, // When the exponent is large, the naive evaluation of e^(x) - 1 is more // accurate than the Taylor series. TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value)); - auto for_large_x = b_->CreateFSub(exp_x, one); + auto for_large_x = FSub(exp_x, one); // The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + …. // We want exp(x)-1 which is x + x^2/2 + x^3/6 + …. - auto x_squared = b_->CreateFAdd(x, x); - auto x_squared_over_two = b_->CreateFMul(x_squared, half); - auto for_small_x = b_->CreateFAdd(x, x_squared_over_two); + auto x_squared = FAdd(x, x); + auto x_squared_over_two = FMul(x_squared, half); + auto for_small_x = FAdd(x, x_squared_over_two); const auto kExponentIsSmallThreshold = 1e-5; auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_); - auto x_is_small = b_->CreateFCmpOLT( - abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); - return b_->CreateSelect(x_is_small, for_small_x, for_large_x); + auto x_is_small = + FCmpOLT(abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold)); + return Select(x_is_small, for_small_x, for_large_x); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, {lhs->getType()}, b_); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return Unimplemented("atan2"); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const { + llvm::Value* value) { return Unimplemented("tanh"); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { + const HloInstruction* hlo, llvm::Value* x) { if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } @@ -1099,23 +1045,103 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1); +} + +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) { + return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0); +} + +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) { + auto* integer_type = llvm::cast<llvm::IntegerType>(type); + return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( + integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) { + auto* integer_type = llvm::cast<llvm::IntegerType>(type); + return llvm::ConstantInt::get( + integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) { + return ICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); +} + +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) { + return And(ICmpEQ(lhs, GetIntSMin(lhs->getType())), + ICmpEQ(rhs, GetMinusOne(rhs->getType()))); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer division overflow behavior: + // + // X / 0 == -1 + // INT_SMIN /s -1 = INT_SMIN + + if (!is_signed) { + llvm::Value* udiv_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = UDiv(lhs, safe_rhs); + return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* sdiv_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = SDiv(lhs, safe_rhs); + + return Select( + has_zero_divisor, GetMinusOne(lhs->getType()), + Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div)); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) { + // Integer remainder overflow behavior: + // + // X % 0 == X + // INT_SMIN %s -1 = 0 + + if (!is_signed) { + llvm::Value* urem_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = URem(lhs, safe_rhs); + return Select(urem_is_unsafe, lhs, safe_rem); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* srem_is_unsafe = Or(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = SRem(lhs, safe_rhs); + + return Select( + has_zero_divisor, lhs, + Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); +} + StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { + bool is_signed) { switch (op->opcode()) { // TODO(jingyue): add the "nsw" attribute for signed types. case HloOpcode::kAdd: - return b_->CreateAdd(lhs_value, rhs_value); + return Add(lhs_value, rhs_value); case HloOpcode::kSubtract: - return b_->CreateSub(lhs_value, rhs_value); + return Sub(lhs_value, rhs_value); case HloOpcode::kMultiply: - return b_->CreateMul(lhs_value, rhs_value); + return Mul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) - : b_->CreateUDiv(lhs_value, rhs_value); + return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: - return is_signed ? b_->CreateSRem(lhs_value, rhs_value) - : b_->CreateURem(lhs_value, rhs_value); + return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, rhs_value, b_); @@ -1143,11 +1169,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMaximum: return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: - return b_->CreateAnd(lhs_value, rhs_value); + return And(lhs_value, rhs_value); case HloOpcode::kOr: - return b_->CreateOr(lhs_value, rhs_value); + return Or(lhs_value, rhs_value); case HloOpcode::kXor: - return b_->CreateXor(lhs_value, rhs_value); + return Xor(lhs_value, rhs_value); // Shifting out bits >= the number of bits in the type being shifted // produces a poison value in LLVM which is basically "deferred undefined @@ -1156,43 +1182,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( // UB. case HloOpcode::kShiftRightArithmetic: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateAShr(lhs_value, rhs_value), + AShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/true); case HloOpcode::kShiftLeft: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateShl(lhs_value, rhs_value), + Shl(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: return SaturateShiftIfNecessary(b_, lhs_value, rhs_value, - b_->CreateLShr(lhs_value, rhs_value), + LShr(lhs_value, rhs_value), /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", - HloOpcodeString(op->opcode()).c_str()); + HloOpcodeString(op->opcode())); } } llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE - : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE + : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const { - return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE - : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + bool is_signed) { + return Select(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE + : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); } llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const { + int64 operand_no) { CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString() << " is not elementwise."; @@ -1233,7 +1259,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const { + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) { TF_ASSIGN_OR_RETURN(llvm::Value * a_or_mean, operand_to_generator.at(hlo->operand(0))(index)); TF_ASSIGN_OR_RETURN(llvm::Value * b_or_sigma, @@ -1251,17 +1277,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution( // Perform the division using the float type with the same number of bits // as the raw value to avoid overflow. if (raw_value_size_in_bits == 32) { - elem_value = b_->CreateUIToFP(elem_value, b_->getFloatTy()); - elem_value = b_->CreateFDiv( - elem_value, llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); + elem_value = UIToFP(elem_value, b_->getFloatTy()); + elem_value = FDiv(elem_value, + llvm::ConstantFP::get(b_->getFloatTy(), std::exp2(32))); } else { - elem_value = b_->CreateUIToFP(elem_value, b_->getDoubleTy()); - elem_value = b_->CreateFDiv( + elem_value = UIToFP(elem_value, b_->getDoubleTy()); + elem_value = FDiv( elem_value, llvm::ConstantFP::get(b_->getDoubleTy(), std::exp2(64))); } if (elem_ir_ty != elem_value->getType()) { - elem_value = b_->CreateFPTrunc(elem_value, elem_ir_ty); + elem_value = FPTrunc(elem_value, elem_ir_ty); } } @@ -1269,9 +1295,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution( switch (hlo->random_distribution()) { case RNG_UNIFORM: { if (elem_ir_ty->isFloatingPointTy()) { - return b_->CreateFAdd( - b_->CreateFMul(b_->CreateFSub(b_or_sigma, a_or_mean), elem_value), - a_or_mean); + return FAdd(FMul(FSub(b_or_sigma, a_or_mean), elem_value), a_or_mean); } else { // To generate a uniform random value in [a, b) from a raw random sample // in range [0, 2^N), we let range = b - a and return @@ -1284,22 +1308,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::ConvertValueForDistribution( // the same cost as if the whole warp were to re-sample. So an // efficient re-sampling implementation on GPU would need to do // nontrivial work to share entropy between threads in the warp. - auto range = b_->CreateSub(b_or_sigma, a_or_mean); - return b_->CreateAdd(a_or_mean, b_->CreateURem(elem_value, range)); + auto range = Sub(b_or_sigma, a_or_mean); + return Add(a_or_mean, URem(elem_value, range)); } } case RNG_NORMAL: { TF_ASSIGN_OR_RETURN( llvm::Value * r, - EmitErfcInv(elem_prim_ty, - b_->CreateFMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), - elem_value))); - return b_->CreateFAdd(b_->CreateFMul(r, b_or_sigma), a_or_mean); + EmitErfcInv(elem_prim_ty, FMul(llvm::ConstantFP::get(elem_ir_ty, 2.0), + elem_value))); + return FAdd(FMul(r, b_or_sigma), a_or_mean); } default: return InvalidArgument( "unhandled distribution %s", - RandomDistribution_Name(hlo->random_distribution()).c_str()); + RandomDistribution_Name(hlo->random_distribution())); } } @@ -1414,8 +1437,7 @@ std::array<llvm::Value*, 4> CalculateSampleValues( // Precondition: the RNG instruction is not fused. llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { VLOG(3) << "Using philox RNG algorithm"; CHECK(!hlo->IsFused()); // A random number generated by the per module random number generator. @@ -1438,7 +1460,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Load the global state variable for the Philox RNG algorithm. llvm::GlobalVariable* rng_state_ptr = llvm_ir::GetOrCreateVariableForPhiloxRngState(module_, b_); - llvm::Value* rng_state = b_->CreateLoad(rng_state_ptr, "rng_state_value"); + llvm::Value* rng_state = Load(rng_state_ptr, "rng_state_value"); // Build and return the elemental IR generator to generate a random value for // the element corresponding to the current thread. @@ -1464,8 +1486,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // element within the sample. llvm::Value* elems_per_sample_value = llvm::ConstantInt::get(index_ty, elems_per_sample); - llvm::Value* sample_idx = b_->CreateUDiv(elem_idx, elems_per_sample_value); - llvm::Value* elem_offset = b_->CreateURem(elem_idx, elems_per_sample_value); + llvm::Value* sample_idx = UDiv(elem_idx, elems_per_sample_value); + llvm::Value* elem_offset = URem(elem_idx, elems_per_sample_value); std::array<llvm::Value*, 4> counter_values = CalculateSampleValues( sample_idx, hlo_random_value, global_random_number, rng_state, b_); @@ -1473,18 +1495,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( // Store the four counter_values into the sample_address alloca so we can // load the elem_offset'th one below. for (int idx = 0; idx < 4; ++idx) { - b_->CreateStore(counter_values[idx], - b_->CreateInBoundsGEP(sample_address, b_->getInt32(idx))); + Store(counter_values[idx], + InBoundsGEP(sample_address, b_->getInt32(idx))); } llvm::Type* int64_ty = b_->getInt64Ty(); CHECK(elems_per_sample == 2 || elems_per_sample == 4); llvm::Type* raw_value_ty = elems_per_sample == 2 ? int64_ty : int32_ty; // Retrieve the raw value for the current element from the current sample. - llvm::Value* raw_elem_value = b_->CreateLoad( - b_->CreateInBoundsGEP( - b_->CreatePointerCast(sample_address, raw_value_ty->getPointerTo()), - elem_offset), + llvm::Value* raw_elem_value = Load( + InBoundsGEP(PointerCast(sample_address, raw_value_ty->getPointerTo()), + elem_offset), "raw_elem_value"); return ConvertValueForDistribution(hlo, operand_to_generator, index, @@ -1495,7 +1516,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakePhiloxRngElementGenerator( StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1505,14 +1526,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect( TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()), - on_true_value, on_false_value); + return Select(Trunc(pred_value, b_->getInt1Ty()), on_true_value, + on_false_value); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { TF_ASSIGN_OR_RETURN(llvm::Value * min_value, operand_to_generator.at(hlo->operand(0))( ElementwiseSourceIndex(index, *hlo, 0))); @@ -1531,14 +1552,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp( max_value, EmitIntegralMax(min_value, arg_value, is_signed), is_signed); } else { return Unimplemented("Clamp unimplemented for %s", - PrimitiveType_Name(prim_type).c_str()); + PrimitiveType_Name(prim_type)); } } StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const { + const llvm_ir::IrArray::Index& target_index) { const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; @@ -1560,9 +1581,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate( } llvm_ir::SetToFirstInsertPoint(exit_block, b_); - llvm::PHINode* output = b_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), - hlo->operands().size()); + llvm::PHINode* output = + PHI(llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = b_->GetInsertPoint(); b_->SetInsertPoint(init_block); @@ -1577,9 +1598,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate( auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), operand->shape().dimensions(concat_dim)); - b_->CreateCondBr( - b_->CreateICmpULT(source_index[concat_dim], concat_dim_size), - true_block, false_block); + CondBr(ICmpULT(source_index[concat_dim], concat_dim_size), true_block, + false_block); // Create the terminator of the true block before calling operand // generators, because they require non-degenerate basic blocks. @@ -1592,11 +1612,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate( // Subtract the size of the concat dimension of the current operand // from the source index. b_->SetInsertPoint(false_block); - source_index[concat_dim] = - b_->CreateSub(source_index[concat_dim], concat_dim_size); + source_index[concat_dim] = Sub(source_index[concat_dim], concat_dim_size); } - b_->CreateUnreachable(); + Unreachable(); b_->SetInsertPoint(exit_block, prior_insert_point); return output; } @@ -1604,7 +1623,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate( StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { // Emit IR to read dynamic start indices from hlo->operand(1). const HloInstruction* input_hlo = hlo->operand(0); const int64 rank = ShapeUtil::Rank(input_hlo->shape()); @@ -1621,7 +1640,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice( // Clamp the start index so that the sliced portion fits in the operand: // start_index = clamp(start_index, 0, operand_dim_size - output_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); int64 largest_valid_start_index = input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i); CHECK_GE(largest_valid_start_index, 0); @@ -1641,7 +1660,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice( for (int64 i = 0; i < rank; ++i) { // Emit IR which computes: // input_index = start_index + offset_index - input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]); + input_index[i] = Add(slice_start_index[i], index[i]); } return operand_to_generator.at(input_hlo)(input_index); } @@ -1649,7 +1668,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice( StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const Shape& operand_shape = hlo->operand(0)->shape(); const Shape& indices_shape = hlo->operand(1)->shape(); const Shape& output_shape = hlo->shape(); @@ -1672,7 +1691,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1); for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0; i < e; i++) { - if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { operand_index.push_back(index.GetConstantWithIndexType(0)); } else { int64 output_window_dim = dim_numbers.offset_dims(operand_index_dim++); @@ -1686,7 +1705,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( { std::vector<llvm::Value*> gather_index_index_components; for (int64 i = 0, e = output_shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.offset_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { gather_index_index.push_back(index[i]); } } @@ -1698,7 +1717,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) { llvm::Value* gather_dim_component_extended = - b_->CreateSExtOrTrunc(index_component, index_type); + SExtOrTrunc(index_component, index_type); int64 operand_dim = dim_numbers.start_index_map(dim); int64 output_dim = operand_to_output_dim[operand_dim]; // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim. @@ -1722,8 +1741,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( gather_dim_component_extended, is_signed), is_signed); - operand_index[operand_dim] = b_->CreateAdd( - operand_index[operand_dim], gather_dim_component_extended_inbound); + operand_index[operand_dim] = + Add(operand_index[operand_dim], gather_dim_component_extended_inbound); }; if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) { @@ -1747,7 +1766,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather( StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const { + const llvm_ir::IrArray::Index& index) { const HloInstruction* input_hlo = hlo->operand(0); const HloInstruction* update_hlo = hlo->operand(1); const HloInstruction* start_hlo = hlo->operand(2); @@ -1770,7 +1789,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Clamp the start index so that the update region fits in the operand. // start_index = clamp(start_index, 0, input_dim_size - update_dim_size) - start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type); + start_index_value = SExtOrTrunc(start_index_value, index_type); llvm::Value* update_dim_size = index_typed_const(update_hlo->shape().dimensions(i)); int64 largest_valid_start_index = @@ -1786,14 +1805,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( start_index_value->setName( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; - slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size); - - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - slice_intersection = b_->CreateAnd( - slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + slice_limit_index[i] = Add(slice_start_index[i], update_dim_size); + + slice_intersection = + And(slice_intersection, ICmpSGE(index[i], slice_start_index[i]), + "slice_intersection"); + slice_intersection = + And(slice_intersection, ICmpSLT(index[i], slice_limit_index[i]), + "slice_intersection"); } // Emit: @@ -1810,26 +1829,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice( // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(index.GetType(), rank); for (int64 i = 0; i < rank; ++i) { - update_index[i] = b_->CreateSub(index[i], slice_start_index[i]); + update_index[i] = Sub(index[i], slice_start_index[i]); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); - b_->CreateStore(true_value, ret_value_addr); + Store(true_value, ret_value_addr); // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); - b_->CreateStore(false_value, ret_value_addr); + Store(false_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const { + const llvm_ir::IrArray::Index& padded_index) { auto index = padded_index; llvm::Value* in_bounds = b_->getTrue(); for (size_t i = 0; i < index.size(); ++i) { @@ -1837,26 +1856,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad( return llvm::ConstantInt::get(index[i]->getType(), n); }; const auto& pad_dim = hlo->padding_config().dimensions(i); - index[i] = - b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low())); - in_bounds = b_->CreateAnd(in_bounds, - b_->CreateICmpSGE(index[i], index_typed_const(0)), - "in_bounds"); - in_bounds = b_->CreateAnd( + index[i] = Sub(index[i], index_typed_const(pad_dim.edge_padding_low())); + in_bounds = + And(in_bounds, ICmpSGE(index[i], index_typed_const(0)), "in_bounds"); + in_bounds = And( in_bounds, - b_->CreateICmpEQ( + ICmpEQ( index_typed_const(0), - b_->CreateURem(index[i], - index_typed_const(pad_dim.interior_padding() + 1))), - "in_bounds"); - index[i] = b_->CreateSDiv( - index[i], index_typed_const(pad_dim.interior_padding() + 1)); - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpSLT( - index[i], - index_typed_const(hlo->operand(0)->shape().dimensions(i))), + URem(index[i], index_typed_const(pad_dim.interior_padding() + 1))), "in_bounds"); + index[i] = + SDiv(index[i], index_typed_const(pad_dim.interior_padding() + 1)); + in_bounds = + And(in_bounds, + ICmpSLT(index[i], + index_typed_const(hlo->operand(0)->shape().dimensions(i))), + "in_bounds"); } // if (in_bounds) { @@ -1872,26 +1887,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad( SetToFirstInsertPoint(if_data.true_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, operand_to_generator.at(hlo->operand(0))(index)); - b_->CreateStore(operand_value, ret_value_addr); + Store(operand_value, ret_value_addr); SetToFirstInsertPoint(if_data.false_block, b_); TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(padding_value, ret_value_addr); + Store(padding_value, ret_value_addr); SetToFirstInsertPoint(if_data.after_block, b_); // Don't create phi(operand_value, padding_value) here, because invoking // operand_to_generator may create new basic blocks, making the parent // of operand_value or padding_value no longer a predecessor of // if_data.after_block. - return b_->CreateLoad(ret_value_addr); + return Load(ret_value_addr); } StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const { + const llvm_ir::IrArray::Index& dot_result_index) { auto lhs_generator = operand_to_generator.at(hlo->operand(0)); auto rhs_generator = operand_to_generator.at(hlo->operand(1)); @@ -1919,8 +1934,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_); - b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm), - accumulator_alloca); + Store(llvm::Constant::getNullValue(primitive_type_llvm), accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_); @@ -1942,42 +1956,37 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot( } rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue()); - llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca); + llvm::Value* current_accumulator = Load(accumulator_alloca); TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - llvm::Value* product_real = b_->CreateFSub( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); - llvm::Value* product_imag = b_->CreateFAdd( - b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), - b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); - next_accumulator = b_->CreateInsertValue( + llvm::Value* product_real = + FSub(FMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value))); + llvm::Value* product_imag = + FAdd(FMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)), + FMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value))); + next_accumulator = InsertValue( current_accumulator, - b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real), - {0}); - next_accumulator = b_->CreateInsertValue( + FAdd(EmitExtractReal(current_accumulator), product_real), {0}); + next_accumulator = InsertValue( next_accumulator, - b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag), - {1}); + FAdd(EmitExtractImag(current_accumulator), product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { - next_accumulator = b_->CreateFAdd(current_accumulator, - b_->CreateFMul(lhs_value, rhs_value)); + next_accumulator = FAdd(current_accumulator, FMul(lhs_value, rhs_value)); } else { - next_accumulator = - b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value)); + next_accumulator = Add(current_accumulator, Mul(lhs_value, rhs_value)); } - b_->CreateStore(next_accumulator, accumulator_alloca); + Store(next_accumulator, accumulator_alloca); SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_); - return b_->CreateLoad(accumulator_alloca); + return Load(accumulator_alloca); } llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) - const { + const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -2071,10 +2080,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); auto source_index = target_index; for (int64 dim : hlo->dimensions()) { - source_index[dim] = b_->CreateSub( - llvm::ConstantInt::get(target_index[dim]->getType(), - hlo->shape().dimensions(dim) - 1), - target_index[dim]); + source_index[dim] = + Sub(llvm::ConstantInt::get(target_index[dim]->getType(), + hlo->shape().dimensions(dim) - 1), + target_index[dim]); } return operand_to_generator.at(operand)(source_index); }; @@ -2088,6 +2097,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(), hlo->dimensions(), b_)); }; + case HloOpcode::kIota: + return [this, hlo]( + const IrArray::Index& target_index) -> StatusOr<llvm::Value*> { + auto* iota = Cast<HloIotaInstruction>(hlo); + PrimitiveType element_type = iota->shape().element_type(); + IrArray::Index elem_index = + ShapeUtil::Rank(iota->shape()) > 1 + ? target_index.SourceIndexOfBroadcast( + iota->shape(), + ShapeUtil::MakeShapeWithDescendingLayout( + element_type, + {iota->shape().dimensions(iota->iota_dimension())}), + {iota->iota_dimension()}, b_) + : target_index; + llvm::Value* elem_index_linear = elem_index.linear(); + if (elem_index_linear == nullptr) { + std::vector<int64> iota_bound = { + iota->shape().dimensions(iota->iota_dimension())}; + elem_index_linear = elem_index.Linearize(iota_bound, b_); + } + if (ShapeUtil::ElementIsIntegral(iota->shape())) { + return b_->CreateIntCast( + elem_index_linear, + llvm_ir::PrimitiveTypeToIrType(element_type, module_), + /*isSigned=*/false); + } else { + TF_RET_CHECK(ShapeUtil::ElementIsFloating(iota->shape())) + << element_type; + llvm::Type* float_ir_type; + if (element_type == BF16) { + float_ir_type = llvm_ir::PrimitiveTypeToIrType(F32, module_); + } else { + float_ir_type = + llvm_ir::PrimitiveTypeToIrType(element_type, module_); + } + llvm::Value* float_val = + b_->CreateUIToFP(elem_index_linear, float_ir_type); + if (element_type == BF16) { + return EmitF32ToBF16(float_val, b_); + } else { + return float_val; + } + } + }; case HloOpcode::kSlice: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr<llvm::Value*> { @@ -2153,28 +2206,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( default: return [hlo](const IrArray::Index& index) { return Unimplemented("Unhandled opcode for elemental IR emission: %s", - HloOpcodeString(hlo->opcode()).c_str()); + HloOpcodeString(hlo->opcode())); }; } } -llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { - return b_->CreateExtractValue(value, {0}); +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) { + return ExtractValue(value, {0}); } -llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { - return b_->CreateExtractValue(value, {1}); +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) { + return ExtractValue(value, {1}); } llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const { + llvm::Value* imag) { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); - auto complex = b_->CreateInsertValue( - llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + auto complex = + InsertValue(llvm::ConstantAggregateZero::get(cplx_type), real, {0}); if (imag != nullptr) { - complex = b_->CreateInsertValue(complex, imag, {1}); + complex = InsertValue(complex, imag, {1}); } return complex; } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1598a4dd85..d3e2acaabd 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -23,12 +23,13 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { -class ElementalIrEmitter { +class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> { public: using HloToElementGeneratorMap = std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>; @@ -40,100 +41,114 @@ class ElementalIrEmitter { virtual ~ElementalIrEmitter() = default; virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op, - llvm::Value* operand_value) const; + llvm::Value* operand_value); virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Returns a function to generate an element of the output of `hlo`, given a // map of functions to generate elements of its operands. virtual llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); - llvm::IRBuilder<>* b() const { return b_; } - llvm::Module* module() const { return module_; } + llvm::IRBuilder<>* b() { return b_; } + + // builder() is for IrBuilderMixin. + llvm::IRBuilder<>* builder() { return b_; } + + llvm::Module* module() { return module_; } protected: - virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); + + virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr<llvm::Value*> EmitFloatUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(const HloInstruction* op, + llvm::Value* operand_value); - virtual StatusOr<llvm::Value*> EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const; + llvm::Value* IsZero(llvm::Value* v); + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* GetZero(llvm::Type* type); + llvm::Value* GetOne(llvm::Type* type); + llvm::Value* GetIntSMin(llvm::Type* type); + llvm::Value* GetMinusOne(llvm::Type* type); + + llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); + llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed); virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); - virtual StatusOr<llvm::Value*> EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); - virtual StatusOr<llvm::Value*> EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + virtual StatusOr<llvm::Value*> EmitComplexBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - bool is_signed) const; + bool is_signed); virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, - llvm::Value* lhs, - llvm::Value* rhs) const; + llvm::Value* lhs, llvm::Value* rhs); virtual StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const; + llvm::Value* value); virtual StatusOr<llvm::Value*> EmitReducePrecision(const HloInstruction* hlo, - llvm::Value* x) const; + llvm::Value* x); - virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; - virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value); + virtual llvm::Value* EmitExtractImag(llvm::Value* value); // Composes a complex struct. imag may be nullptr for simple cast operations. llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; + llvm::Value* imag); // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its @@ -142,50 +157,50 @@ class ElementalIrEmitter { // Precondition: `hlo` is an elementwise op. llvm_ir::IrArray::Index ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, - int64 operand_no) const; + int64 operand_no); // Identifier of the thread unique among all threads on the device - virtual llvm::Value* EmitThreadId() const { return b_->getIntN(128, 0); } + virtual llvm::Value* EmitThreadId() { return b_->getIntN(128, 0); } StatusOr<llvm::Value*> EmitElementalSelect( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr<llvm::Value*> EmitElementalClamp( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr<llvm::Value*> EmitElementalConcatenate( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& target_index) const; + const llvm_ir::IrArray::Index& target_index); StatusOr<llvm::Value*> EmitElementalDynamicSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr<llvm::Value*> EmitElementalGather( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr<llvm::Value*> EmitElementalDynamicUpdateSlice( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index) const; + const llvm_ir::IrArray::Index& index); StatusOr<llvm::Value*> EmitElementalPad( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& padded_index) const; + const llvm_ir::IrArray::Index& padded_index); StatusOr<llvm::Value*> EmitElementalDot( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& dot_result_index) const; + const llvm_ir::IrArray::Index& dot_result_index); llvm::IRBuilder<>* const b_; @@ -200,13 +215,13 @@ class ElementalIrEmitter { // random number generation algorithm. llvm_ir::ElementGenerator MakePhiloxRngElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const; + const HloToElementGeneratorMap& operand_to_generator); // Converts the raw value generated by a random number generation algorithm // to the distribution requested by the RNG HloInstruction. StatusOr<llvm::Value*> ConvertValueForDistribution( const HloInstruction* hlo, const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator, - const llvm_ir::IrArray::Index& index, llvm::Value* raw_value) const; + const llvm_ir::IrArray::Index& index, llvm::Value* raw_value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index addb016b04..5ab0756219 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -24,7 +24,7 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; +using absl::nullopt; class ElementalIrEmitterExecutionTest : public HloTestBase { protected: diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index fd75847d0c..78edf918a4 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/executable.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/status.h" @@ -22,7 +24,6 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" using tensorflow::gtl::ArraySlice; @@ -76,8 +77,8 @@ StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper( std::unique_ptr<HloExecutionProfile> profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? MakeUnique<HloExecutionProfile>(&hlo_profile_printer_data(), - &hlo_profile_index_map()) + ? absl::make_unique<HloExecutionProfile>(&hlo_profile_printer_data(), + &hlo_profile_index_map()) : nullptr; StatusOr<ScopedShapedBuffer> return_value = @@ -154,9 +155,9 @@ Status Executable::DumpHloSnapshot() { const string& directory_path = module_config().debug_options().xla_dump_executions_to(); const auto& module = hlo_snapshot_->hlo().hlo_module(); - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", module.id(), - module.entry_computation_name().c_str(), ++execution_count_); + string filename = + absl::StrFormat("computation_%d__%s__execution_%d", module.id(), + module.entry_computation_name(), ++execution_count_); return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index 228c3fac95..997db7c058 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -17,7 +17,7 @@ limitations under the License. #include <utility> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -53,8 +53,8 @@ ExecutionHandle ExecutionTracker::Register(Backend* backend, tensorflow::mutex_lock lock(execution_mutex_); int64 handle = next_handle_++; auto inserted = handle_to_execution_.emplace( - handle, - MakeUnique<AsyncExecution>(backend, std::move(streams), profile, result)); + handle, absl::make_unique<AsyncExecution>(backend, std::move(streams), + profile, result)); CHECK(inserted.second); ExecutionHandle execution_handle; @@ -66,7 +66,7 @@ Status ExecutionTracker::Unregister(const ExecutionHandle& handle) { tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } handle_to_execution_.erase(handle.handle()); @@ -78,7 +78,7 @@ StatusOr<const AsyncExecution*> ExecutionTracker::Resolve( tensorflow::mutex_lock lock(execution_mutex_); auto it = handle_to_execution_.find(handle.handle()); if (it == handle_to_execution_.end()) { - return NotFound("no execution record for execution handle: %lld", + return NotFound("no execution record for execution handle: %d", handle.handle()); } return it->second.get(); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index d3efab3614..3cccec9862 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -28,7 +28,7 @@ namespace xla { // points-to analysis (see b/36865746 for details). class FlattenCallGraph : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "flatten-call-graph"; } + absl::string_view name() const override { return "flatten-call-graph"; } // Duplicates computations called from multiple call- or while-nodes to // flatten the call graph. diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc index 9370c88710..3f1a881372 100644 --- a/tensorflow/compiler/xla/service/gather_expander.cc +++ b/tensorflow/compiler/xla/service/gather_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gather_expander.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -230,7 +231,7 @@ static StatusOr<HloInstruction*> CreateGatherLoopAccumulatorInitValue( accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); for (int64 i = 0; i < slice_sizes.size(); i++) { - if (!c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { accumulator_state_shape_dims.push_back(slice_sizes[i]); } } @@ -251,7 +252,7 @@ static StatusOr<HloInstruction*> PermuteBatchAndOffsetDims( int64 batch_idx_counter = 0; int64 offset_idx_counter = output_rank - offset_dims.size(); for (int64 i = 0; i < output_rank; i++) { - bool is_offset_dim = c_binary_search(offset_dims, i); + bool is_offset_dim = absl::c_binary_search(offset_dims, i); if (is_offset_dim) { permutation.push_back(offset_idx_counter++); } else { @@ -322,7 +323,7 @@ StatusOr<HloInstruction*> GatherExpander::ExpandGather( return Unimplemented( "Gather operations with more than 2147483647 gather indices are not " "supported. This error occurred for %s.", - gather_instr->ToString().c_str()); + gather_instr->ToString()); } TF_ASSIGN_OR_RETURN( @@ -373,8 +374,8 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) { std::vector<HloInstruction*> gather_instrs; for (HloComputation* computation : module->MakeNonfusionComputations()) { - c_copy_if(computation->instructions(), std::back_inserter(gather_instrs), - is_nontrivial_gather); + absl::c_copy_if(computation->instructions(), + std::back_inserter(gather_instrs), is_nontrivial_gather); } for (HloInstruction* inst : gather_instrs) { diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index c1fc8574da..7bd9ea5984 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -25,7 +25,7 @@ namespace xla { // nevertheless have a minimum level of support. class GatherExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "gather_expander"; } + absl::string_view name() const override { return "gather_expander"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 8ef72850dc..82290bfea8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -56,6 +56,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -91,6 +93,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -107,6 +110,8 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -126,6 +131,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -171,6 +177,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:dynamic_update_slice_util", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:ir_builder_mixin", "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:kernel_tiling", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", @@ -180,6 +187,11 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@llvm//:core", "@llvm//:support", ], @@ -224,6 +236,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:math_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", ], @@ -243,6 +256,7 @@ cc_library( "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -257,6 +271,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -337,6 +352,10 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -373,6 +392,9 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", ], ) @@ -390,6 +412,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) @@ -420,7 +443,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], @@ -466,6 +489,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", ], ) @@ -483,6 +507,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -513,6 +538,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) @@ -544,6 +571,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:shape_inference", + "@com_google_absl//absl/memory", ], ) @@ -600,6 +628,7 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", "@llvm//:core", ], alwayslink = True, # Contains per-platform transfer manager registration @@ -670,6 +699,9 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@llvm//:core", ], alwayslink = True, # Contains compiler registration @@ -702,8 +734,8 @@ cc_library( ":xfeed_queue", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) @@ -718,6 +750,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -756,6 +789,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", ], ) @@ -767,12 +801,12 @@ cc_library( ":stream_assignment", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:buffer_value", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/compiler/xla/service:hlo_scheduling", + "@com_google_absl//absl/memory", ], ) @@ -789,6 +823,8 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], ) @@ -839,7 +875,9 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -868,9 +906,8 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_runner", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 537295292b..528209abc7 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -17,8 +17,8 @@ limitations under the License. #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -40,7 +40,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build( const BufferAssignment* buffer_assignment, int device_ordinal, DeviceMemoryAllocator* memory_allocator) { const int64 num_buffers = buffer_assignment->Allocations().size(); - auto buffer_allocations = WrapUnique(new BufferAllocations( + auto buffer_allocations = absl::WrapUnique(new BufferAllocations( num_buffers, device_ordinal, memory_allocator, buffer_assignment)); for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { @@ -62,7 +62,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build( if (reinterpret_cast<uintptr_t>(address.opaque()) % expected_alignment != 0) { return InternalError( - "Address of registered buffer %lld must be a multiple of %llx, but " + "Address of registered buffer %d must be a multiple of %x, but " "was %p", i, kEntryParameterAlignBytes, address.opaque()); } @@ -83,7 +83,7 @@ StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build( 0) { return InternalError( "Address returned by memory_allocator->Allocate must be a " - "multiple of %llx, but was %p", + "multiple of 0x%x, but was %p", kXlaAllocatedBufferAlignBytes, buffer.opaque()); } // We do manual memory management within BufferAllocations. Be sure not diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 6a285a6b98..13c83c9199 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include <cmath> +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -74,9 +74,8 @@ ENTRY MaxDifference { %error = f32[SIZE] divide(%sub_abs, %denominator) ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 })"; - auto size_string = std::to_string(num_elements); - return tensorflow::str_util::StringReplace( - kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); + return absl::StrReplaceAll(kF16CompHloText, + {{"SIZE", absl::StrCat(num_elements)}}); } StatusOr<F16BufferComparator> F16BufferComparator::Create( @@ -125,7 +124,7 @@ StatusOr<F16BufferComparator> F16BufferComparator::Create( StatusOr<bool> F16BufferComparator::CompareEqualImpl( se::DeviceMemory<Eigen::half> test_buffer) { if (ref_buffer_.root_buffer().size() != test_buffer.size()) { - return InternalError("Mismatched buffer size: %lld vs %lld", + return InternalError("Mismatched buffer size: %d vs %d", ref_buffer_.root_buffer().size(), test_buffer.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc index 5780e0af40..9ed523998b 100644 --- a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -59,7 +59,7 @@ Status ConditionalThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to retrieve predicate value on stream %p: %s.", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } // Execute the true or the false computation depending on the value of the diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 7833a4077e..eea31f3de1 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d76ca6698d..f7952787c1 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index e09cde9abf..6e2e330edd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -54,9 +54,7 @@ namespace gpu { // BatchNormRewriter. class CudnnBatchNormRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "cudnn_batchnorm_rewriter"; - } + absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc index 7b172812c3..bc3c6f72f6 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -17,12 +17,11 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index caeb89d78e..dbdf8e7a0e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,24 +14,25 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" namespace xla { namespace gpu { namespace { +using absl::optional; using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { public: @@ -59,8 +60,8 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -128,14 +129,14 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind, string AlgorithmToString(const AlgorithmDesc& algo) { if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + return absl::StrCat(algo.algo_id(), "+TC"); } - return tensorflow::strings::StrCat(algo.algo_id()); + return absl::StrCat(algo.algo_id()); } string NumBytesToString(int64 bytes) { - return tensorflow::strings::StrCat( - tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); + return absl::StrCat(tensorflow::strings::HumanReadableNumBytes(bytes), " (", + bytes, "B)"); } // Acquires a process-global lock on the device pointed to by the given @@ -361,7 +362,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr->ToString().c_str()); + instr->ToString()); } StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 8b7749628a..f76d273e8c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -39,7 +39,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { Compiler* compiler) : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-algorithm-picker"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 905b5ee876..0b1ee2dc33 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -234,6 +234,23 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( << "Backward input convolution should reverse all kernel dimensions."; return no_match_result; } + } else if (reverse_filter->IsConstant()) { + // If the filter is a constant, we're willing to pattern-match to a + // backwards-input conv, on the theory that + // + // a) reversing a constant is free, and + // b) even if the user specified this filter as reverse(constant), we would + // long ago have constant-folded away the reverse. + // + // If the constant has any other uses, reversing it isn't entirely free, + // since we'd now have two constants to keep in memory. But hopefully it's + // free enough. + // + // TODO(jlebar): Should we do this even if the filter is not a constant? + // Reversing a non-constant filter is probably cheaper than padding the + // input! + + // Nothing to do, just fall through. } else { // Possibly 1x1 filter. for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) { @@ -373,22 +390,25 @@ std::tuple<bool, Window, ConvolutionDimensionNumbers> MatchBackwardInput( } } - // Fuse the matched HLOs into a backward convolution instruction. - // - // If the reverse is omitted (for 1x1 filters) in the original pattern, we add - // it back in the fusion instruction so that later passes (such as - // PadInsertion) can handle such fusion instructions easily. + // OK, it's a match! Canonicalize the conv's filter so that it's a reverse. + // This simplifies things for our caller, and algebraic-simplifier will later + // remove any unnecessary reverses. if (reverse_filter->opcode() != HloOpcode::kReverse) { - reverse_filter = reverse_filter->parent()->AddInstruction( + // Create a double-reverse, which is a nop. + HloComputation* c = conv->parent(); + reverse_filter = c->AddInstruction( + HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, + AsInt64Slice(kernel_spatial_dims))); + reverse_filter = c->AddInstruction( HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter, AsInt64Slice(kernel_spatial_dims))); TF_CHECK_OK(conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter)); } + dnums.set_kernel_input_feature_dimension( conv->convolution_dimension_numbers().kernel_output_feature_dimension()); dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, new_window, dnums); } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index 0c0578d888..fbe7e98494 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -26,7 +26,7 @@ namespace gpu { // backwards-input convolutions into CustomCall HLOs that call into cuDNN. class CudnnConvolutionRewriter : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "cudnn-convolution-rewriter"; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 65588b6aaf..46c23db465 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -32,10 +32,13 @@ namespace gpu { namespace { namespace op = xla::testing::opcode_matchers; +using ::testing::_; -class CudnnConvolutionRewriterTest : public HloTestBase { +class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { public: - CudnnConvolutionRewriterTest() { + CudnnConvolutionRewriterTest() + : HloVerifiedTestBase(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false) { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -114,7 +117,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -142,7 +145,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -172,7 +175,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -202,7 +205,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -230,7 +233,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); @@ -280,7 +283,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( @@ -325,7 +328,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -357,7 +360,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -410,7 +413,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -457,7 +460,7 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); @@ -510,7 +513,7 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); ASSERT_THAT(entry_computation->root_instruction(), op::GetTupleElement( op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); @@ -562,12 +565,38 @@ TEST_F(CudnnConvolutionRewriterTest, auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(RunPass(module.get())); + EXPECT_TRUE(RunPass(module)); EXPECT_THAT( entry_computation->root_instruction(), op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +// Check that we will materialize a reversed version of a constant in order to +// pattern-match a backwards input convolution. +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveConstantFilter) { + Array4D<float> constant_arr(4, 4, 2, 2); + constant_arr.FillIota(0); + string constant_str = + LiteralUtil::CreateR4FromArray4D(constant_arr)->ToString(); + ParseAndVerifyModule(absl::StrFormat(R"( + HloModule test + + ENTRY entry_computation { + param0 = f32[128,2,16,16]{3,2,1,0} parameter(0) + constant = f32[4,4,2,2]{3,2,1,0} constant(%s) + ROOT convolution = f32[128,2,32,32]{3,2,1,0} convolution(param0, constant), + window={size=4x4 pad=2_2x2_2 lhs_dilate=2x2}, + dim_labels=bf01_01oi->bf01, feature_group_count=1 + })", + constant_str)); + EXPECT_TRUE(RunPass(&module())); + EXPECT_THAT( + module().entry_computation()->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvBackwardInputCallTarget, _, + op::Reverse(op::Constant())), + 0)); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 7b0d9e53d6..07b96fbd3f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -56,7 +57,7 @@ class ScratchBufAllocator : public se::ScratchAllocator { "Can't allocate twice from a ScratchBufAllocator."); } if (byte_size > scratch_.size()) { - return se::port::InternalError(tensorflow::strings::StrCat( + return se::port::InternalError(absl::StrCat( "Can't allocate ", byte_size, " bytes from a ScratchBufAllocator of size ", scratch_.size())); } @@ -196,8 +197,8 @@ Status RunCudnnConvolution( if (!stream->ok()) { return InternalError( - "Unable to launch convolution with type %s and algorithm (%lld, %lld)", - CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + "Unable to launch convolution with type %s and algorithm (%d, %d)", + CudnnConvKindToString(kind), algorithm.algorithm().algo_id(), algorithm.algorithm_no_scratch().algo_id()); } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 9b6de115ad..57a3a43a6f 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" // IWYU pragma: no_include "llvm/IR/Attributes.gen.inc" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Instructions.h" @@ -43,16 +45,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gpu { +using absl::StrAppend; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; -using tensorflow::strings::StrAppend; namespace { // Returns whether operand is a floating-point literal with the given value. @@ -77,7 +77,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall( const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // The libdevice math functions differentiate between "double" and "float" by // appending an 'f' to the function's name. libdevice doesn't have f16 math // functions, so we convert the operands to f32 before calling the function @@ -94,7 +94,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall( for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { converted_operands[i] = - b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); + FPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -107,13 +107,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall( break; default: return Unimplemented("Bad type for libdevice math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = b_->CreateFPCast(result, b_->getHalfTy()); + result = FPCast(result, b_->getHalfTy()); } return result; } @@ -122,7 +122,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // llvm intrinsics differentiate between half/float/double functions via // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; @@ -138,7 +138,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( break; default: return Unimplemented("Bad type for llvm intrinsic math call: %s", - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(output_type)); } return EmitMathCall(munged_callee, operands, input_types, output_type); } @@ -147,13 +147,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall( const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const { + PrimitiveType output_type) { // Binary math functions transform are of type [T] -> T. for (PrimitiveType input_type : input_types) { if (output_type != input_type) { return Unimplemented("Input type ≠output type: %s ≠%s", - PrimitiveType_Name(input_type).c_str(), - PrimitiveType_Name(output_type).c_str()); + PrimitiveType_Name(input_type), + PrimitiveType_Name(output_type)); } } @@ -163,8 +163,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall( } StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); @@ -183,8 +182,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp( } StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { + const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) { CHECK_EQ(op->opcode(), HloOpcode::kPower); PrimitiveType lhs_input_type = op->operand(0)->shape().element_type(); PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); @@ -218,7 +216,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp( // TODO(jlebar): Does this happen with fastmath disabled? If not, should // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return FDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -227,55 +225,56 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp( } StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv( - PrimitiveType prim_type, llvm::Value* value) const { + PrimitiveType prim_type, llvm::Value* value) { return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLog1p(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_log1p", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitExpm1(PrimitiveType prim_type, + llvm::Value* value) { return EmitLibdeviceMathCall("__nv_expm1", {value}, {prim_type}, prim_type); } StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const { + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2( - PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) { return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, prim_type); } -StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh( - PrimitiveType prim_type, llvm::Value* value) const { +StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(PrimitiveType prim_type, + llvm::Value* value) { // Emit a fast approximation of tanh instead of calling __nv_tanh. // __nv_tanh is particularly bad because it contains branches, thus // preventing LLVM's load-store vectorizer from working its magic across a @@ -285,9 +284,9 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh( // Upcast F16 to F32 if necessary. llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* input = FPCast(value, type); llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); + return FPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( @@ -295,7 +294,7 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, PrimitiveType output_type, - tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const { + tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) { std::vector<llvm::Type*> ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( @@ -315,29 +314,28 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( callee->addFnAttr(attribute); } - return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return Call(callee, llvm_ir::AsArrayRef(operands)); } -llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = b_->CreateIntCast( - llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, b_), - b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); +llvm::Value* GpuElementalIrEmitter::EmitThreadId() { + llvm::Value* block_id = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = + IntCast(llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return NSWAdd(NSWMul(block_id, threads_per_block), thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const { + const HloToElementGeneratorMap& operand_to_generator) { switch (hlo->opcode()) { case HloOpcode::kMap: return [=, &operand_to_generator]( @@ -383,7 +381,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - b_->CreateStore(init_value, accum_ptr); + Store(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -405,22 +403,21 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( IrArray::Index input_index(index_type, index.size()); llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = b_->CreateNSWMul( + llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = b_->CreateNSWSub( - b_->CreateNSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + input_index[i] = + NSWSub(NSWAdd(stridden_index, window_index[i]), + index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = b_->CreateAnd( - in_bounds, - b_->CreateICmpULT( - input_index[i], - index_typed_const(operand->shape().dimensions(i)))); + in_bounds = + And(in_bounds, + ICmpULT(input_index[i], + index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = @@ -432,12 +429,11 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( operand_to_generator.at(operand)(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_(*hlo->to_apply(), - {b_->CreateLoad(accum_ptr), input_value})); - b_->CreateStore(accum_value, accum_ptr); + compute_nested_(*hlo->to_apply(), {Load(accum_ptr), input_value})); + Store(accum_value, accum_ptr); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); - return b_->CreateLoad(accum_ptr); + return Load(accum_ptr); }; case HloOpcode::kReduce: // TODO(b/112040122): This should be supported. diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 84454d31bb..91942785d2 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -48,50 +48,50 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm_ir::ElementGenerator MakeElementGenerator( const HloInstruction* hlo, - const HloToElementGeneratorMap& operand_to_generator) const override; + const HloToElementGeneratorMap& operand_to_generator) override; protected: - StatusOr<llvm::Value*> EmitFloatBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; + StatusOr<llvm::Value*> EmitFloatBinaryOp(const HloInstruction* op, + llvm::Value* lhs_value, + llvm::Value* rhs_value) override; StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitLog(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitLog1p(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitSin(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitCos(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitExp(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitExpm1(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; StatusOr<llvm::Value*> EmitPow(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr<llvm::Value*> EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, - llvm::Value* rhs) const override; + llvm::Value* rhs) override; StatusOr<llvm::Value*> EmitTanh(PrimitiveType prim_type, - llvm::Value* value) const override; + llvm::Value* value) override; - llvm::Value* EmitThreadId() const override; + llvm::Value* EmitThreadId() override; private: // Emits IR for op, which must have opcode kPower. StatusOr<llvm::Value*> EmitPowerOp(const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const; + llvm::Value* rhs_value); // Emits IR to call a device function named "callee_name" on the given // operand. Returns the IR value that represents the return value. @@ -100,7 +100,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_type, PrimitiveType output_type, - tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const; + tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes); // Emits IR to call an LLVM intrinsic of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the @@ -109,7 +109,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); // Emits IR to call a libdevice function of type [T] -> T. Adjusts // callee_name according to T. Returns the IR value that represents the @@ -118,7 +118,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); // Emits IR to call a function of type [T] -> T. Does not munge callee_name. // Returns the IR value that represents the return value of the function. @@ -126,7 +126,7 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const string& callee_name, tensorflow::gtl::ArraySlice<llvm::Value*> operands, tensorflow::gtl::ArraySlice<PrimitiveType> input_types, - PrimitiveType output_type) const; + PrimitiveType output_type); const HloModuleConfig& hlo_module_config_; NestedComputer compute_nested_; diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc index 0cdddf8bcf..11549cdac5 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -17,11 +17,11 @@ limitations under the License. #include <string> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -43,8 +43,8 @@ StatusOr<se::DeviceMemory<uint8>> FftScratchAllocator::AllocateBytes( if (byte_size > GetMemoryLimitInBytes(stream)) { return se::port::Status( se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", byte_size, GetMemoryLimitInBytes(stream))); } @@ -213,7 +213,7 @@ Status FftThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, return Status::OK(); } return InternalError("Unable to launch fft for thunk %p with type %s", this, - FftTypeToString(fft_type_).c_str()); + FftTypeToString(fft_type_)); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h index 8c53be5077..4adec7ee54 100644 --- a/tensorflow/compiler/xla/service/gpu/fft_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc index 2fd2206324..88f0b4d71c 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -28,7 +28,7 @@ ForThunk::ForThunk(const int64 loop_limit, const HloInstruction* hlo) : Thunk(Kind::kWhile, hlo), loop_limit_(loop_limit), - body_thunk_sequence_(MakeUnique<SequentialThunk>( + body_thunk_sequence_(absl::make_unique<SequentialThunk>( // Pass nullptr as the HloInstruction* to the body_thunk_sequence_ // constructor because this SequentialThunk is logically "part of" // this ForThunk, and shouldn't be profiled separately from it. diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 3cd30b754c..1bd88233e1 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -18,12 +18,13 @@ limitations under the License. #include <algorithm> #include <vector> +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace gpu { @@ -64,10 +65,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Slice for a more accurate estimate of bytes read. double bytes = 0.0; for (auto& instruction : instructions) { - if (c_all_of(instruction->users(), [](const HloInstruction* instruction) { - return instruction->opcode() == HloOpcode::kSlice || - instruction->opcode() == HloOpcode::kDynamicSlice; - })) { + if (absl::c_all_of( + instruction->users(), [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSlice || + instruction->opcode() == HloOpcode::kDynamicSlice; + })) { // All users are slice: accumulate bytes of all user slice instructions. for (auto& user : instruction->users()) { bytes += ShapeUtil::ByteSizeOf(user->shape()); @@ -223,7 +225,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // Skip 'fusion' instruction if we cannot merge into all of its users. // Merging into all users enables the removal of 'fusion' from the // computation. - if (!c_all_of(fusion->users(), [](const HloInstruction* user) { + if (!absl::c_all_of(fusion->users(), [](const HloInstruction* user) { return user->opcode() == HloOpcode::kFusion && (user->fusion_kind() == HloInstruction::FusionKind::kLoop || user->fusion_kind() == HloInstruction::FusionKind::kInput); @@ -241,11 +243,11 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // If 'fusion' has just one user, then an earlier fusion pass chose not to // fuse this producer/comsumer pair (likely because of expensive instruction // re-use by the consumer), and so we honor that choice here as well. - if (c_any_of(fusion->fused_instructions(), - [](const HloInstruction* instruction) { - return instruction->opcode() != HloOpcode::kParameter && - GpuInstructionFusion::IsExpensive(*instruction); - })) { + if (absl::c_any_of(fusion->fused_instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() != HloOpcode::kParameter && + GpuInstructionFusion::IsExpensive(*instruction); + })) { VLOG(3) << "Not merging " << fusion->name() << ": Contains one or more expensive instructions."; ++num_fail_expensive_fused_instruction_; @@ -287,11 +289,10 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { << " flops_to_bytes_ratio: " << CalculateFlopsToBytesRatio(fusion) << " merged_to_current_bytes_ratio: " << merged_to_current_bytes_ratio << " into users { " - << tensorflow::str_util::Join(users, ", ", - [](string* out, HloInstruction* user) { - tensorflow::strings::StrAppend( - out, user->name()); - }) + << absl::StrJoin(users, ", ", + [](string* out, HloInstruction* user) { + absl::StrAppend(out, user->name()); + }) << " }"; // Remove 'fusion' instruction. CHECK_EQ(0, fusion->user_count()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 4c523a66de..7e3f5775b8 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -34,7 +34,7 @@ namespace gpu { // class FusionMerger : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "fusion merger"; } + absl::string_view name() const override { return "fusion merger"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 74282c568c..9c4a490366 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -17,8 +17,8 @@ limitations under the License. #include <functional> +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -186,7 +186,7 @@ StatusOr<se::blas::AlgorithmType> DoGemmAutotune( } return InternalError( - "Unable to autotune cuBLAS gemm on stream %p; none of the %zu algorithms " + "Unable to autotune cuBLAS gemm on stream %p; none of the %u algorithms " "ran successfully", stream, algorithms.size()); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 0c6f9b511f..8ffae18fe8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -27,7 +27,7 @@ namespace gpu { // inserting kCopy instructions. class GpuCopyInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } + absl::string_view name() const override { return "copy-insertion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 7060837904..71a02e70df 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,8 +19,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -112,7 +112,7 @@ Status GpuExecutable::ExecuteThunks( // // TODO(jlebar): Should we cache the results of HloInstruction::ToString(), // since we expect it to be an expensive call? - tensorflow::gtl::optional<ScopedAnnotation> op_annotation; + absl::optional<ScopedAnnotation> op_annotation; if (top_level_annotation.IsEnabled()) { op_annotation.emplace( thunk->hlo_instruction() != nullptr @@ -144,7 +144,7 @@ Status GpuExecutable::ExecuteThunks( TF_RETURN_IF_ERROR( thunk->ExecuteOnStream(buffer_allocations, stream, &profiler)); if (thunk_schedule_->Depended(thunk)) { - auto finish_event = MakeUnique<se::Event>(main_stream->parent()); + auto finish_event = absl::make_unique<se::Event>(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); @@ -160,7 +160,7 @@ Status GpuExecutable::ExecuteThunks( if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", - main_stream, block_status.error_message().c_str()); + main_stream, block_status.error_message()); } } @@ -260,10 +260,9 @@ StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream( if (buffer.is_null() && buffer.size() > 0) { return FailedPrecondition( "Cannot run XLA computation because pointer to (sub-)buffer at " - "index %s of parameter %lld was null. All pointers to " - "(sub-)buffers must not be null, unless the (sub-)buffer has zero " - "elements.", - allocation.param_shape_index().ToString().c_str(), param_no); + "index %s of parameter %d was null. All pointers to (sub-)buffers " + "must not be null, unless the (sub-)buffer has zero elements.", + allocation.param_shape_index().ToString(), param_no); } buffer_allocations_builder.RegisterBuffer(i, buffer); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index c7ce6d0acb..627a05e240 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,8 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -32,10 +34,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc index 4944c41f7d..4268fb2c7a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -34,9 +34,8 @@ StatusOr<bool> GpuHloSupportChecker::Run(HloModule* module) { return xla::Unimplemented( "GPU backend does not support HLO instruction %s with shape " "containing a sparse layout: %s", - instruction->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction->shape()) - .c_str()); + instruction->ToString(), + ShapeUtil::HumanStringWithLayout(instruction->shape())); } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index d63e213d2b..bbb3340760 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -28,9 +28,7 @@ class GpuHloSupportChecker : public HloPassInterface { GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; - tensorflow::StringPiece name() const override { - return "gpu_hlo_support_checker"; - } + absl::string_view name() const override { return "gpu_hlo_support_checker"; } // Note: always returns false (no instructions are ever modified by this // pass). diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 286547ebae..fbc8ddf599 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -119,7 +120,7 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -192,7 +193,7 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { // Enumerate all combinations of shapes. for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); @@ -265,7 +266,7 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { for (const Shape& input_shape : AllLayoutsOf(shape)) { for (const Shape& result_shape : AllLayoutsOf(shape)) { for (int constrained_param_no : {0, 4}) { - SCOPED_TRACE(tensorflow::strings::StrCat( + SCOPED_TRACE(absl::StrCat( "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index a2f53f8446..f3c2744292 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -83,7 +84,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } infeed_manager->EnqueueDestination(std::move(buffers)); @@ -96,7 +97,7 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal( se::StreamExecutor* executor, int64 size, const void* source) { if (size > std::numeric_limits<int32>::max()) { - return InvalidArgument("Infeed shape is too large: needs %lld bytes", size); + return InvalidArgument("Infeed shape is too large: needs %d bytes", size); } if (size == 0) { @@ -160,9 +161,10 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( if (ShapeUtil::IsTuple(shape)) { return; } - *buffer = MakeUnique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape)); + *buffer = absl::make_unique<gpu::OutfeedBuffer>( + GetByteSizeRequirement(shape)); (*buffer)->set_destination( - MakeUnique<MutableBorrowingLiteral>(literal, index)); + absl::make_unique<MutableBorrowingLiteral>(literal, index)); }); // Give the tree of buffers to the outfeed mananger. The device will fill it @@ -179,7 +181,7 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace xla static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() { - return xla::MakeUnique<xla::gpu::GpuTransferManager>( + return absl::make_unique<xla::gpu::GpuTransferManager>( /*id=*/stream_executor::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(xla::gpu::NVPTXCompiler::kDataLayout) .getPointerSize(0 /* default address space */)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 7929042869..fa88816bc8 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ #include <vector> @@ -61,4 +61,4 @@ class GpuTransferManager : public GenericTransferManager { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc index 1722676930..b9c21e8edb 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,7 @@ namespace gpu { namespace { void InitAndStartTimer(std::stack<std::unique_ptr<se::Timer>>* timers, se::Stream* stream) { - timers->push(MakeUnique<se::Timer>(stream->parent())); + timers->push(absl::make_unique<se::Timer>(stream->parent())); stream->InitTimer(timers->top().get()).ThenStartTimer(timers->top().get()); } @@ -115,7 +116,7 @@ HloExecutionProfiler::MakeScopedInstructionProfiler( CHECK(hlo_instructions_.insert(hlo_instruction).second) << hlo_instruction->name(); } - return MakeUnique<ScopedInstructionProfiler>(this, hlo_instruction); + return absl::make_unique<ScopedInstructionProfiler>(this, hlo_instruction); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 19de37b0fb..76055ff009 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" @@ -59,8 +59,8 @@ GpuHloOrdering::GpuHloOrdering( : PredecessorHloOrdering(module) { // The entry computation has a total order when there's only one stream. if (stream_assignment.StreamCount() == 1) { - entry_sequence_ = - MakeUnique<std::vector<const HloInstruction*>>(thunk_launch_order); + entry_sequence_ = absl::make_unique<std::vector<const HloInstruction*>>( + thunk_launch_order); } // The ordering of instructions for the entry computation is determined by the @@ -75,7 +75,7 @@ GpuHloOrdering::GpuHloOrdering( // same-stream predecessors of each instruction. // Compute the set of all instructions we will want to set reachability on. - auto predecessor_map = MakeUnique<HloReachabilityMap>( + auto predecessor_map = absl::make_unique<HloReachabilityMap>( module->entry_computation()->MakeInstructionPostOrder()); // The most recently visited instruction per stream. @@ -208,7 +208,7 @@ StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build( BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); } - schedule->hlo_ordering_ = MakeUnique<GpuHloOrdering>( + schedule->hlo_ordering_ = absl::make_unique<GpuHloOrdering>( &module, stream_assignment, schedule->thunk_launch_order_); return std::move(schedule); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index 45f0a1c645..bb147c8d98 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include <algorithm> #include <unordered_set> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -47,7 +49,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique<HloModule>("test_module", config); + return absl::make_unique<HloModule>("test_module", config); } HloVec RemoveHlo(const HloVec& input, @@ -265,7 +267,7 @@ TEST_F(HloScheduleTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 8c11cd0541..0e205b9c02 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -24,16 +25,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index c5f0cdf6cd..a4364b0deb 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" namespace xla { namespace gpu { @@ -24,7 +24,7 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { tensorflow::mutex_lock l(host_to_device_stream_mu_); if (host_to_device_executor_ == nullptr) { host_to_device_executor_ = executor; - host_to_device_stream_ = MakeUnique<se::Stream>(executor); + host_to_device_stream_ = absl::make_unique<se::Stream>(executor); host_to_device_stream_->Init(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index fee6d2af3b..8c3a026740 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -96,7 +96,7 @@ Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Infeeding to GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 0f2c83aeb2..0bcaaee2b7 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -26,7 +26,7 @@ namespace gpu { namespace { -bool IsFusile(const HloInstruction& hlo) { +bool IsFusible(const HloInstruction& hlo) { // Don't fuse get-tuple-element on GPU: We can, but it's slower than not // fusing. We never generate kernels for unfused GTEs. Instead, if an // unfused GTE is an input to a kernel (including a fusion kernel), we @@ -245,7 +245,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (!IsFusile(*producer) || !IsFusile(*consumer) || + if (!IsFusible(*producer) || !IsFusible(*consumer) || !InstructionFusion::ShouldFuse(consumer, operand_index)) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 8d0522bd8f..f53dfaee3d 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -365,7 +365,7 @@ static StatusOr<const HloInstruction*> FindHloInstruction( } return NotFound( "Computation '%s' does not contain an instruction with op code '%s'.", - computation.name().c_str(), HloOpcodeString(op).c_str()); + computation.name(), HloOpcodeString(op)); } TEST_F(InstructionFusionTest, MultiOutputFusion) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index c349063c71..f544bcc919 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -215,7 +215,7 @@ bool IsReductionToVector(const HloInstruction& reduce) { // This emits a device-side call to // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice<llvm::Value*> arguments, llvm::IRBuilder<>* builder) { std::vector<llvm::Type*> argument_types; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 5d23a3d018..a35e250101 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -126,7 +126,7 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo); bool IsReductionToVector(const HloInstruction& reduce); // Emits call to "vprintf" with given format and arguments. -llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, +llvm::Value* EmitPrintf(absl::string_view fmt, tensorflow::gtl::ArraySlice<llvm::Value*> arguments, llvm::IRBuilder<>* builder); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6675dbd3f9..bdf6aadde6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/algorithm/container.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -155,7 +156,7 @@ Status IrEmitter::EmitCallToNestedComputation( std::vector<llvm::Value*> arguments(operands.begin(), operands.end()); arguments.push_back(output); arguments.push_back(bindings_.GetTempBufferBase()); - b_.CreateCall(emitted_function, arguments); + Call(emitted_function, arguments); return Status::OK(); } @@ -177,7 +178,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( computation.root_instruction()->shape().element_type(); bool is_atomic_integral = element_type == S32 || element_type == U32 || element_type == S64 || element_type == U64; - llvm::Value* source = b_.CreateLoad(source_address, "source"); + llvm::Value* source = Load(source_address, "source"); if (root_opcode == HloOpcode::kAdd) { // NVPTX supports atomicAdd on F32 and integer types. if (element_type == F32) { @@ -189,8 +190,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( } if (is_atomic_integral) { // integral + integral - b_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } } @@ -201,8 +202,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Max : llvm::AtomicRMWInst::UMax; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -211,8 +212,8 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation( auto opcode = primitive_util::IsSignedIntegralType(element_type) ? llvm::AtomicRMWInst::Min : llvm::AtomicRMWInst::UMin; - b_.CreateAtomicRMW(opcode, output_address, source, - llvm::AtomicOrdering::SequentiallyConsistent); + AtomicRMW(opcode, output_address, source, + llvm::AtomicOrdering::SequentiallyConsistent); return true; } @@ -291,10 +292,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // cas_old_output_address and cas_new_output_address point to the scratch // memory where we store the old and new values for the repeated atomicCAS // operations. - llvm::Value* cas_old_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); - llvm::Value* cas_new_output_address = b_.CreateAlloca( - atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); + llvm::Value* cas_old_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = + Alloca(atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = b_.GetInsertBlock(); @@ -308,29 +309,26 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, CHECK_EQ((element_size % sizeof(char)), 0); llvm::Type* address_int_type = module_->getDataLayout().getIntPtrType(output_address_type); - atomic_memory_address = b_.CreatePtrToInt(output_address, address_int_type); + atomic_memory_address = PtrToInt(output_address, address_int_type); llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); - llvm::Value* offset = b_.CreateAnd(atomic_memory_address, mask); + llvm::Value* offset = And(atomic_memory_address, mask); mask = llvm::ConstantInt::get(address_int_type, -4); - atomic_memory_address = b_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = And(atomic_memory_address, mask); atomic_memory_address = - b_.CreateIntToPtr(atomic_memory_address, atomic_address_type); - binop_output_address = b_.CreateAdd( - b_.CreatePtrToInt(cas_new_output_address, address_int_type), offset); + IntToPtr(atomic_memory_address, atomic_address_type); binop_output_address = - b_.CreateIntToPtr(binop_output_address, element_address_type); + Add(PtrToInt(cas_new_output_address, address_int_type), offset); + binop_output_address = IntToPtr(binop_output_address, element_address_type); } else { - atomic_memory_address = - b_.CreateBitCast(output_address, atomic_address_type); + atomic_memory_address = BitCast(output_address, atomic_address_type); binop_output_address = - b_.CreateBitCast(cas_new_output_address, element_address_type); + BitCast(cas_new_output_address, element_address_type); } // Use the value from the memory that atomicCAS operates on to initialize // cas_old_output. - llvm::Value* cas_old_output = - b_.CreateLoad(atomic_memory_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_old_output_address); + llvm::Value* cas_old_output = Load(atomic_memory_address, "cas_old_output"); + Store(cas_old_output, cas_old_output_address); llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( b_.GetInsertPoint(), "atomic_op_loop_exit"); @@ -343,32 +341,29 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, // Emit the body of the loop that repeatedly invokes atomicCAS. // // Use cas_old_output to initialize cas_new_output. - cas_old_output = b_.CreateLoad(cas_old_output_address, "cas_old_output"); - b_.CreateStore(cas_old_output, cas_new_output_address); + cas_old_output = Load(cas_old_output_address, "cas_old_output"); + Store(cas_old_output, cas_new_output_address); // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( computation, {binop_output_address, source_address}, binop_output_address)); - llvm::Value* cas_new_output = - b_.CreateLoad(cas_new_output_address, "cas_new_output"); + llvm::Value* cas_new_output = Load(cas_new_output_address, "cas_new_output"); // Emit code to perform the atomicCAS operation // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, // cas_new_output); - llvm::Value* ret_value = b_.CreateAtomicCmpXchg( - atomic_memory_address, cas_old_output, cas_new_output, - llvm::AtomicOrdering::SequentiallyConsistent, - llvm::AtomicOrdering::SequentiallyConsistent); + llvm::Value* ret_value = + AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, + llvm::AtomicOrdering::SequentiallyConsistent); // Extract the memory value returned from atomicCAS and store it as // cas_old_output. - b_.CreateStore(b_.CreateExtractValue(ret_value, 0, "cas_old_output"), - cas_old_output_address); + Store(ExtractValue(ret_value, 0, "cas_old_output"), cas_old_output_address); // Extract the success bit returned from atomicCAS and generate a // conditional branch on the success bit. - b_.CreateCondBr(b_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, - loop_body_bb); + CondBr(ExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. @@ -383,8 +378,8 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( // TODO(b/30258929): We only accept binary computations so far. return Unimplemented( "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); + "computation %s has %d.", + computation.name(), computation.num_parameters()); } if (MaybeEmitDirectAtomicOperation(computation, output_address, @@ -471,10 +466,10 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_value, rhs_value, &b_); result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); - result = b_.CreateInsertValue(result, value.first, {0}); - result = b_.CreateInsertValue(result, value.second, {1}); + result = InsertValue(result, value.first, {0}); + result = InsertValue(result, value.second, {1}); } else { - result = b_.CreateFMul(lhs_value, rhs_value); + result = FMul(lhs_value, rhs_value); } target_array.EmitWriteArrayElement(/*index=*/element_index, result, &b_); return Status::OK(); @@ -518,7 +513,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // We don't have to iterate over the batch dimensions in both arrays, simplify // the loop nest of the rhs. for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { - DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + DCHECK(absl::c_linear_search(dnums.lhs_batch_dimensions(), i)); rhs_index[i] = lhs_index[i]; } @@ -558,21 +553,21 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt()); llvm::Value* lhs_element = lhs_array.EmitReadArrayElement(lhs_index, &b_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &b_); - llvm::Value* accum = b_.CreateLoad(accum_address); + llvm::Value* accum = Load(accum_address); llvm::Value* updated_accum; if (ShapeUtil::ElementIsComplex(lhs_shape)) { auto value = MultiplyComplex(lhs_element, rhs_element, &b_); llvm::Value* accum_real = Real(accum, &b_); - llvm::Value* real_sum = b_.CreateFAdd(accum_real, value.first); - updated_accum = b_.CreateInsertValue(accum, real_sum, {0}); + llvm::Value* real_sum = FAdd(accum_real, value.first); + updated_accum = InsertValue(accum, real_sum, {0}); llvm::Value* accum_imag = Imag(accum, &b_); - llvm::Value* imag_sum = b_.CreateFAdd(accum_imag, value.second); - updated_accum = b_.CreateInsertValue(updated_accum, imag_sum, {1}); + llvm::Value* imag_sum = FAdd(accum_imag, value.second); + updated_accum = InsertValue(updated_accum, imag_sum, {1}); } else { - llvm::Value* product = b_.CreateFMul(lhs_element, rhs_element); - updated_accum = b_.CreateFAdd(accum, product); + llvm::Value* product = FMul(lhs_element, rhs_element); + updated_accum = FAdd(accum, product); } - b_.CreateStore(updated_accum, accum_address); + Store(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target // address. The index into the target address is the concatenation of the rhs @@ -594,7 +589,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &b_); target_array.EmitWriteArrayElement( target_index, - b_.CreateLoad(accum_address), // The value written to the target array. + Load(accum_address), // The value written to the target array. &b_); // Set the IR builder insert point to the exit basic block of the outer most @@ -645,10 +640,9 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = - b_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + Alloca(llvm_ir::PrimitiveTypeToIrType( reduce->shape().element_type(), module_)); - b_.CreateStore(b_.CreateLoad(GetBasePointer(*init_value)), - accumulator_addr); + Store(Load(GetBasePointer(*init_value)), accumulator_addr); // The enclosing loops go over all the target elements. Now we have to // compute the actual target element. For this, we build a new loop nest @@ -685,7 +679,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { *function, {accumulator_addr, input_address}, accumulator_addr)); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &b_); - return b_.CreateLoad(accumulator_addr); + return Load(accumulator_addr); }); } @@ -752,11 +746,6 @@ Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { "to a cudnn CustomCall using CudnnBatchNormRewriter."); } -Status IrEmitter::HandleIota(HloInstruction*) { - // TODO(b/64798317): implement iota on GPU. - return Unimplemented("Iota is not implemented on GPU."); -} - StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) { @@ -768,11 +757,11 @@ StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement( for (llvm::Value* parameter_element : parameter_elements) { parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry( parameter_element->getType(), "parameter_buffer", &b_)); - b_.CreateStore(parameter_element, parameter_buffers.back()); + Store(parameter_element, parameter_buffers.back()); } TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers, return_buffer)); - return b_.CreateLoad(return_buffer); + return Load(return_buffer); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 561c683879..3673b9f58d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" @@ -35,12 +36,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" @@ -64,7 +65,8 @@ namespace gpu { // IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is // not a subclass of gpu::IrEmitter, and in fact is better understood as an IR // generator generator. See comments on that class. -class IrEmitter : public DfsHloVisitorWithDefault { +class IrEmitter : public DfsHloVisitorWithDefault, + public IrBuilderMixin<IrEmitter> { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; @@ -95,10 +97,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBatchNormInference(HloInstruction* batch_norm) override; Status HandleBatchNormTraining(HloInstruction* batch_norm) override; Status HandleBatchNormGrad(HloInstruction* batch_norm) override; - Status HandleIota(HloInstruction* iota) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } + llvm::IRBuilder<>* builder() { return &b_; } + protected: // Constructs an IrEmitter with the given IrEmitter context. // ir_emitter_context is owned by the caller and should outlive the IrEmitter diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1e81cbde35..c0c8ae181a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -21,6 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/types/optional.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -29,7 +34,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" @@ -77,7 +81,6 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -85,13 +88,13 @@ namespace gpu { namespace { +using absl::InlinedVector; +using absl::nullopt; +using absl::optional; +using absl::StrCat; using llvm_ir::IrArray; using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; -using tensorflow::gtl::InlinedVector; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; -using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -314,13 +317,13 @@ llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size, }; // Check the size of input tensors - if (!c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { + if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) { return i64_ty; } // Check the size of the internal result tensors if (unnested_hlo->opcode() == HloOpcode::kFusion) { - if (!c_all_of( + if (!absl::c_all_of( unnested_hlo->fused_instructions_computation()->instructions(), hlo_shape_in_range)) { return i64_ty; @@ -383,7 +386,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { int64 feature_index_value = feature_index->literal().Get<int64>({}); thunk_sequence_->emplace_back( - MakeUnique<CudnnBatchNormForwardInferenceThunk>( + absl::make_unique<CudnnBatchNormForwardInferenceThunk>( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -413,7 +416,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); thunk_sequence_->emplace_back( - MakeUnique<CudnnBatchNormForwardTrainingThunk>( + absl::make_unique<CudnnBatchNormForwardTrainingThunk>( /*operand=*/GetAllocationSlice(*custom_call->operand(0)), /*scale=*/GetAllocationSlice(*custom_call->operand(1)), /*offset=*/GetAllocationSlice(*custom_call->operand(2)), @@ -443,19 +446,20 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); auto output_grad_offset = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); - thunk_sequence_->emplace_back(MakeUnique<CudnnBatchNormBackwardThunk>( - /*operand=*/GetAllocationSlice(*custom_call->operand(0)), - /*scale=*/GetAllocationSlice(*custom_call->operand(1)), - /*mean=*/GetAllocationSlice(*custom_call->operand(2)), - /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), - /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), - /*epsilon=*/epsilon_value, - /*feature_index=*/feature_index_value, - /*output_grad_data=*/output_grad_data, - /*output_grad_scale=*/output_grad_scale, - /*output_grad_offset=*/output_grad_offset, - /*output_tuple=*/GetAllocationSlice(*custom_call), - /*hlo=*/custom_call)); + thunk_sequence_->emplace_back( + absl::make_unique<CudnnBatchNormBackwardThunk>( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); return Status::OK(); } @@ -475,7 +479,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { const auto& target = custom_call->custom_call_target(); std::unique_ptr<ConvolutionThunk> thunk; if (target == kCudnnConvForwardCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kForward, /*input_buffer=*/lhs_slice, /*filter_buffer=*/rhs_slice, @@ -489,7 +493,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardInput, /*input_buffer=*/conv_result_slice, /*filter_buffer=*/rhs_slice, @@ -503,7 +507,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { backend_config.algorithm(), backend_config.tensor_ops_enabled(), custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { - thunk = MakeUnique<ConvolutionThunk>( + thunk = absl::make_unique<ConvolutionThunk>( CudnnConvKind::kBackwardFilter, /*input_buffer=*/lhs_slice, /*filter_buffer=*/conv_result_slice, @@ -576,7 +580,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunks.push_back( BuildKernelThunk(fusion, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), fusion)); + absl::make_unique<SequentialThunk>(std::move(thunks), fusion)); std::vector<IrArray> parameter_arrays; for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand, *fusion)); @@ -725,7 +729,7 @@ Status IrEmitterUnnested::EmitExtraOutputsForReduce( "extra_output_element_address"); TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, extra_output_gens[i].first(index)); - b_.CreateStore(extra_output_ir_value, extra_output_address); + Store(extra_output_ir_value, extra_output_address); } return Status::OK(); } @@ -798,8 +802,7 @@ Status IrEmitterUnnested::EmitReductionToScalar( // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), // // // // and threads_per_block is a multiple of warpSize. - // reduce_kernel<<<num_blocks, threads_per_block>>>(); - // + // reduce_kernel // auto loop_body_emitter = [=](const IrArray::Index& tile_index) -> Status { const int num_reduces = reducers.size(); llvm::Type* element_ir_type = @@ -807,17 +810,17 @@ Status IrEmitterUnnested::EmitReductionToScalar( std::vector<llvm::Value*> partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); // Emit an inner for-loop that reduces the elements in the tile. auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { @@ -829,15 +832,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize)), - tile_element_loop->GetIndVarValue()); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileSize)), + tile_element_loop->GetIndVarValue()); // Unless we know the tile is entirely in bounds, we have to emit a // x-in-bounds check before reading from the input. if (!tile_in_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(num_elems)), "x_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. @@ -846,11 +848,11 @@ Status IrEmitterUnnested::EmitReductionToScalar( IrArray::Index input_index( /*linear=*/x, input_shape, &b_); - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -861,14 +863,14 @@ Status IrEmitterUnnested::EmitReductionToScalar( // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's // immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileSize), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileSize))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileSize), + NSWMul(x_in_tiles, index_typed_constant(kTileSize))); // The tile is entirely in bound if all_threads_in_bounds or // x_end <= num_elems. llvm::Value* tile_in_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(num_elems)), - b_.getInt1(all_threads_in_bounds)); + Or(ICmpULE(x_end, index_typed_constant(num_elems)), + b_.getInt1(all_threads_in_bounds)); llvm_ir::LlvmIfData if_tile_in_bounds_data = llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, &b_); @@ -889,20 +891,18 @@ Status IrEmitterUnnested::EmitReductionToScalar( for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -917,10 +917,9 @@ Status IrEmitterUnnested::EmitReductionToScalar( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm::Value* lane_id = - b_.CreateURem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); + URem(x_in_tiles, index_typed_constant(kWarpSize), "lane_id"); llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { @@ -1040,12 +1039,12 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + - llvm::Twine(i * kTileWidth + x_offset)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + + llvm::Twine(i * kTileWidth + x_offset)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1056,8 +1055,8 @@ Status IrEmitterUnnested::EmitColumnReduction( llvm::Value* y_in_tiles = tile_index[0]; llvm::Value* x_in_tiles = tile_index[1]; - y_in_tiles = b_.CreateZExtOrTrunc(y_in_tiles, index_ty); - x_in_tiles = b_.CreateZExtOrTrunc(x_in_tiles, index_ty); + y_in_tiles = ZExtOrTrunc(y_in_tiles, index_ty); + x_in_tiles = ZExtOrTrunc(x_in_tiles, index_ty); auto emit_tile_element_loop = [=](bool tile_in_y_bounds, bool tile_in_x_bounds) -> Status { @@ -1069,34 +1068,32 @@ Status IrEmitterUnnested::EmitColumnReduction( // Emit the body of the partial reduction loop. llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), &b_); - llvm::Value* y = b_.CreateNSWAdd( - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight)), - tile_element_loop->GetIndVarValue()); + llvm::Value* y = + NSWAdd(NSWMul(y_in_tiles, index_typed_constant(kTileHeight)), + tile_element_loop->GetIndVarValue()); // Unless we know that y is in bounds, we have to emit a check before // reading from the input. if (!tile_in_y_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(y, index_typed_constant(height)), "y_in_bounds", - &b_); + ICmpULT(y, index_typed_constant(height)), "y_in_bounds", &b_); // Emit code that reads the input element and accumulates it to // the partial reduction result. llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); // Unless we know that x is in bounds, we have to emit a check before // reading from the input. if (!tile_in_x_bounds) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), "x_in_bounds", - &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_data.true_block, &b_); } - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); // {y,x} is an index to input_matrix_shape [height,width]. We need to // convert that to an index to input_shape (the shape of the operand of // "reduce"). This conversion is composed of a transposition from @@ -1123,7 +1120,7 @@ Status IrEmitterUnnested::EmitColumnReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i * kTileWidth + x_offset], @@ -1138,20 +1135,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // y_end = kTileHeight + y_in_tiles * kTileHeight, i.e., the y location // that's immediately beyond the tile. - llvm::Value* y_end = b_.CreateNSWAdd( - index_typed_constant(kTileHeight), - b_.CreateNSWMul(y_in_tiles, index_typed_constant(kTileHeight))); + llvm::Value* y_end = + NSWAdd(index_typed_constant(kTileHeight), + NSWMul(y_in_tiles, index_typed_constant(kTileHeight))); // x_end = kTileWidth + x_in_tiles * kTileWidth, i.e., the x location // that's immediately beyond the tile. - llvm::Value* x_end = b_.CreateNSWAdd( - index_typed_constant(kTileWidth), - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth))); + llvm::Value* x_end = + NSWAdd(index_typed_constant(kTileWidth), + NSWMul(x_in_tiles, index_typed_constant(kTileWidth))); llvm::Value* tile_in_y_bounds = - b_.CreateOr(b_.CreateICmpULE(y_end, index_typed_constant(height)), - b_.getInt1(height % kTileHeight == 0)); + Or(ICmpULE(y_end, index_typed_constant(height)), + b_.getInt1(height % kTileHeight == 0)); llvm::Value* tile_in_x_bounds = - b_.CreateOr(b_.CreateICmpULE(x_end, index_typed_constant(width)), - b_.getInt1(width % kTileWidth == 0)); + Or(ICmpULE(x_end, index_typed_constant(width)), + b_.getInt1(width % kTileWidth == 0)); // The tile is in y bounds if "height" is a multiple of kTileHeight or // y_end <= height. llvm_ir::LlvmIfData if_tile_in_y_bounds_data = @@ -1185,9 +1182,9 @@ Status IrEmitterUnnested::EmitColumnReduction( reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; for (int i = 0; i != num_reduces; ++i) { for (int x_offset = 0; x_offset < kTileWidth; ++x_offset) { - llvm::Value* x = b_.CreateNSWAdd( - b_.CreateNSWMul(x_in_tiles, index_typed_constant(kTileWidth)), - index_typed_constant(x_offset)); + llvm::Value* x = + NSWAdd(NSWMul(x_in_tiles, index_typed_constant(kTileWidth)), + index_typed_constant(x_offset)); llvm::Value* output_address = GetIrArray(*output, *output, reduce_output_shapes[i]) .EmitArrayElementAddress( @@ -1376,11 +1373,11 @@ Status IrEmitterUnnested::EmitRowReduction( std::vector<llvm::Value*> partial_reduction_result_addresses; for (int i = 0; i != num_reduces; ++i) { llvm::Value* partial_reduction_result_address = - b_.CreateAlloca(element_ir_type, /*ArraySize=*/nullptr, - "partial_reduction_result." + llvm::Twine(i)); + Alloca(element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, init_value_gens[i](IrArray::Index(index_ty))); - b_.CreateStore(init_ir_value, partial_reduction_result_address); + Store(init_ir_value, partial_reduction_result_address); partial_reduction_result_addresses.push_back( partial_reduction_result_address); } @@ -1389,22 +1386,20 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; - x_tile = b_.CreateZExtOrTrunc(x_tile, index_ty); + x_tile = ZExtOrTrunc(x_tile, index_ty); llvm::Value* warp_id = - b_.CreateUDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); + UDiv(x_tile, index_typed_constant(kWarpSize), "warp_id"); llvm::Value* lane_id = - b_.CreateURem(x_tile, index_typed_constant(kWarpSize), "lane_id"); + URem(x_tile, index_typed_constant(kWarpSize), "lane_id"); // The x-location of the last element in this z-x-tile. // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * x_tile_size); - llvm::Value* last_x = b_.CreateNSWAdd( + llvm::Value* last_x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - index_typed_constant(x_tile_size - 1), - b_.CreateNSWMul(warp_id, index_typed_constant(x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(index_typed_constant(x_tile_size - 1), + NSWMul(warp_id, index_typed_constant(x_tile_size))))); KernelSupportLibrary ksl( &b_, @@ -1416,9 +1411,8 @@ Status IrEmitterUnnested::EmitRowReduction( auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, int64 x_tile_loop_bound) -> Status { auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { - llvm::Value* z = b_.CreateNSWAdd( - z_indvar, - b_.CreateNSWMul(index_typed_constant(z_tile_size), z_tile)); + llvm::Value* z = + NSWAdd(z_indvar, NSWMul(index_typed_constant(z_tile_size), z_tile)); TF_RETURN_IF_ERROR(ksl.For( "x_tile", /*start=*/index_typed_constant(0), @@ -1426,22 +1420,20 @@ Status IrEmitterUnnested::EmitRowReduction( /*step=*/1, [&](llvm::Value* x_indvar) -> Status { // x = lane_id + // warpSize * (element_id_in_x_tile + warp_id * x_tile_size); - llvm::Value* x = b_.CreateNSWAdd( + llvm::Value* x = NSWAdd( lane_id, - b_.CreateNSWMul( - index_typed_constant(kWarpSize), - b_.CreateNSWAdd( - x_indvar, b_.CreateNSWMul( - warp_id, llvm::ConstantInt::get( - index_ty, x_tile_size))))); + NSWMul(index_typed_constant(kWarpSize), + NSWAdd(x_indvar, + NSWMul(warp_id, llvm::ConstantInt::get( + index_ty, x_tile_size))))); // Unless we know the x-tile is entirely in bounds, we have to // emit a x-in-bounds check before reading from the input. if (!x_tile_in_bounds) { llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpULT(x, index_typed_constant(width)), - "x_in_bounds", &b_); + ICmpULT(x, index_typed_constant(width)), "x_in_bounds", + &b_); // Points b_ to the then-block. llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, &b_); @@ -1449,7 +1441,7 @@ Status IrEmitterUnnested::EmitRowReduction( // Emit code that reads the input element and accumulates it // to the partial reduction result. - llvm::Value* input_address = b_.CreateAlloca(element_ir_type); + llvm::Value* input_address = Alloca(element_ir_type); { // {z,y,x} is an index to input_3d_tensor_shape // [depth,height,width]. We need to convert that to an index @@ -1480,7 +1472,7 @@ Status IrEmitterUnnested::EmitRowReduction( for (int i = 0; i != num_reduces; ++i) { TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, input_gens[i](input_index)); - b_.CreateStore(input_ir_value, input_address); + Store(input_ir_value, input_address); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], input_address}, @@ -1500,8 +1492,8 @@ Status IrEmitterUnnested::EmitRowReduction( }; llvm::Value* tile_in_bounds = - b_.CreateOr(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), - b_.CreateICmpULT(last_x, index_typed_constant(width))); + Or(b_.getInt1(width % (x_tile_size * kWarpSize) == 0), + ICmpULT(last_x, index_typed_constant(width))); TF_RETURN_IF_ERROR( ksl.If(tile_in_bounds, @@ -1529,20 +1521,18 @@ Status IrEmitterUnnested::EmitRowReduction( for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* result_from_other_lane = - b_.CreateAlloca(element_ir_type, nullptr, "result_from_other_lane"); + Alloca(element_ir_type, nullptr, "result_from_other_lane"); for (int i = 0; i != num_reduces; ++i) { - llvm::Value* partial_reduction_result = b_.CreateLoad( - b_.CreateBitCast(partial_reduction_result_addresses[i], - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); + llvm::Value* partial_reduction_result = + Load(BitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); CHECK_EQ(launch_dimensions.threads_per_block() % kWarpSize, 0) << "Requires block size a multiple of the warp size, otherwise we " "will read undefined elements."; - b_.CreateStore( - EmitFullWarpShuffleDown(partial_reduction_result, - b_.getInt32(shuffle_distance), &b_), - b_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); + Store(EmitFullWarpShuffleDown(partial_reduction_result, + b_.getInt32(shuffle_distance), &b_), + BitCast(result_from_other_lane, shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducers[i], {partial_reduction_result_addresses[i], result_from_other_lane}, @@ -1557,8 +1547,7 @@ Status IrEmitterUnnested::EmitRowReduction( // lane 0 (which holds the partially accumulated result for its warp) to the // output element. llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( - b_.CreateICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", - &b_); + ICmpEQ(lane_id, index_typed_constant(0)), "lane_id_is_zero", &b_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_); for (int i = 0; i != num_reduces; ++i) { llvm::Value* output_address = @@ -1718,7 +1707,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { thunks.push_back( BuildKernelThunk(reduce, /*implements_whole_instruction=*/false)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), reduce)); + absl::make_unique<SequentialThunk>(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), {[&](const IrArray::Index& index) { @@ -1738,7 +1727,7 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { bool all_tuple_elements_have_buffer = - c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { + absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment() .GetUniqueTopLevelSlice(tuple_element) .ok(); @@ -1760,7 +1749,7 @@ Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } - thunk_sequence_->emplace_back(MakeUnique<TupleThunk>( + thunk_sequence_->emplace_back(absl::make_unique<TupleThunk>( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } @@ -1792,8 +1781,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( thunks.push_back(std::move(initializer_thunk)); thunks.push_back(BuildKernelThunk(select_and_scatter, /*implements_whole_instruction=*/false)); - thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter)); + thunk_sequence_->emplace_back(absl::make_unique<SequentialThunk>( + std::move(thunks), select_and_scatter)); // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { @@ -1842,7 +1831,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( &b_); llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry( b_.getInt1Ty(), "initialized_flag_address", &b_); - b_.CreateStore(b_.getInt1(false), initialized_flag_address); + Store(b_.getInt1(false), initialized_flag_address); // Create the inner loop to iterate over the window. llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_, @@ -1863,15 +1852,15 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index operand_index(index_type, source_index.size()); llvm::Value* in_bounds_condition = b_.getInt1(true); for (int64 i = 0; i < rank; ++i) { - llvm::Value* strided_index = b_.CreateNSWMul( + llvm::Value* strided_index = NSWMul( source_index[i], index_typed_constant(window.dimensions(i).stride())); - operand_index[i] = b_.CreateNSWSub( - b_.CreateNSWAdd(strided_index, window_index[i]), - index_typed_constant(window.dimensions(i).padding_low())); - llvm::Value* index_condition = b_.CreateICmpULT( + operand_index[i] = + NSWSub(NSWAdd(strided_index, window_index[i]), + index_typed_constant(window.dimensions(i).padding_low())); + llvm::Value* index_condition = ICmpULT( operand_index[i], index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i))); - in_bounds_condition = b_.CreateAnd(in_bounds_condition, index_condition); + in_bounds_condition = And(in_bounds_condition, index_condition); } CHECK(in_bounds_condition != nullptr); @@ -1881,7 +1870,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_); llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_); llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse( - b_.CreateLoad(initialized_flag_address), "initialized", &b_); + Load(initialized_flag_address), "initialized", &b_); // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. @@ -1889,16 +1878,16 @@ Status IrEmitterUnnested::HandleSelectAndScatter( const auto save_operand_index = [&](const IrArray::Index& operand_index) { for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - b_.CreateStore(operand_index[i], selected_index_address_slot); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + Store(operand_index[i], selected_index_address_slot); } }; IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &b_); - b_.CreateStore(operand_data, selected_value_address); + Store(operand_data, selected_value_address); save_operand_index(operand_index); - b_.CreateStore(b_.getInt1(true), initialized_flag_address); + Store(b_.getInt1(true), initialized_flag_address); // If the initialized_flag is true, call the `select` function to // potentially update the selected value and index with the currently @@ -1914,11 +1903,11 @@ Status IrEmitterUnnested::HandleSelectAndScatter( TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), {selected_value_address, operand_address}, select_return_buffer)); - llvm::Value* result = b_.CreateLoad(select_return_buffer); + llvm::Value* result = Load(select_return_buffer); // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. - llvm::Value* cond = b_.CreateICmpNE( + llvm::Value* cond = ICmpNE( result, llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( PRED, ir_emitter_context_->llvm_module()), @@ -1927,7 +1916,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_); llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_); - b_.CreateStore(b_.CreateLoad(operand_address), selected_value_address); + Store(Load(operand_address), selected_value_address); save_operand_index(operand_index); // After iterating over the window elements, scatter the source element to @@ -1939,8 +1928,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( IrArray::Index selected_index(operand_index.GetType()); for (int64 i = 0; i < rank; ++i) { llvm::Value* selected_index_address_slot = - b_.CreateInBoundsGEP(selected_index_address, {b_.getInt32(i)}); - selected_index.push_back(b_.CreateLoad(selected_index_address_slot)); + InBoundsGEP(selected_index_address, {b_.getInt32(i)}); + selected_index.push_back(Load(selected_index_address_slot)); } llvm::Value* source_value_address = GetIrArray(*source, *select_and_scatter) @@ -2018,7 +2007,7 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { thunks.push_back(std::move(rng_thunk)); thunks.push_back(std::move(increment_seed_thunk)); thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), rng)); + absl::make_unique<SequentialThunk>(std::move(thunks), rng)); return Status::OK(); } @@ -2043,7 +2032,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { auto values_destination = GetAllocationSlice(*sort, values_shape_index); if (keys_destination != GetAllocationSlice(*keys)) { - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*keys), /*destination_buffer=*/keys_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(keys->shape()), nullptr)); @@ -2051,7 +2040,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { if (values != nullptr && values_destination != GetAllocationSlice(*values)) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*values), /*destination_buffer=*/values_destination, /*mem_size=*/ShapeUtil::ByteSizeOf(values->shape()), nullptr)); @@ -2095,15 +2084,15 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace( dimension_to_sort, GetIrArray(*sort, *sort, keys_shape_index), - values != nullptr ? tensorflow::gtl::make_optional<IrArray>( + values != nullptr ? absl::make_optional<IrArray>( GetIrArray(*sort, *sort, values_shape_index)) - : tensorflow::gtl::nullopt, + : absl::nullopt, IrName(sort), xor_mask, &b_, &launch_dimensions)); } } thunk_sequence_->emplace_back( - MakeUnique<SequentialThunk>(std::move(thunks), sort)); + absl::make_unique<SequentialThunk>(std::move(thunks), sort)); return Status::OK(); } @@ -2130,7 +2119,7 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { if (crs->operand_count() == 1) { CHECK(ShapeUtil::IsArray(crs->operand(0)->shape())) << "Operands to cross-replica-sum must be arrays: " << crs->ToString(); - thunk_sequence_->push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunk_sequence_->push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*crs->operand(0)), /*destination_buffer=*/GetAllocationSlice(*crs), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs)); @@ -2145,17 +2134,17 @@ Status IrEmitterUnnested::HandleCrossReplicaSum(HloInstruction* crs) { tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment() .GetUniqueSlice(crs, {i}) .ValueOrDie()); - thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>( + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*crs->operand(i)), /*destination_buffer=*/tuple_element_buffers.back(), /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr)); } // Output a tuple of the buffers above. - thunks.push_back(MakeUnique<TupleThunk>(tuple_element_buffers, - GetAllocationSlice(*crs), nullptr)); + thunks.push_back(absl::make_unique<TupleThunk>( + tuple_element_buffers, GetAllocationSlice(*crs), nullptr)); thunk_sequence_->push_back( - MakeUnique<SequentialThunk>(std::move(thunks), crs)); + absl::make_unique<SequentialThunk>(std::move(thunks), crs)); return Status::OK(); } @@ -2305,7 +2294,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( for (const auto& kv : hlo_slices) { buffers_needed.insert(kv.second.first.allocation()); } - tensorflow::gtl::optional<const BufferAllocation*> temp_buffer; + absl::optional<const BufferAllocation*> temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { if (alloc.IsPreallocatedTempBuffer()) { if (!temp_buffer.has_value()) { @@ -2322,10 +2311,10 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( // We'll pass a pointer to each of the elements of `buffers` to our kernel, in // this order. std::vector<const BufferAllocation*> non_constant_buffers; - c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), - [](const BufferAllocation* allocation) { - return !allocation->is_constant(); - }); + absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers), + [](const BufferAllocation* allocation) { + return !allocation->is_constant(); + }); std::sort(non_constant_buffers.begin(), non_constant_buffers.end(), [](const BufferAllocation* a, const BufferAllocation* b) { @@ -2364,8 +2353,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( *slice.allocation()))); CHECK_NE(loc, nullptr); } else { - loc = b_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -2373,8 +2362,8 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( llvm::Type* int8_double_pointer = llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0); for (int64 idx : gte_index) { - loc = b_.CreateBitCast(loc, int8_double_pointer); - loc = b_.CreateLoad(b_.CreateInBoundsGEP(loc, {b_.getInt64(idx)})); + loc = BitCast(loc, int8_double_pointer); + loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } bindings_.BindHloToIrValue(*instr, loc, index); @@ -2389,7 +2378,7 @@ std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return MakeUnique<KernelThunk>( + return absl::make_unique<KernelThunk>( non_constant_buffers, llvm_ir::AsString(kernel->getName()), implements_whole_instruction ? inst : nullptr, unroll_factor); } @@ -2398,7 +2387,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique<HostToDeviceCopyThunk>( + return absl::make_unique<HostToDeviceCopyThunk>( /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2410,7 +2399,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildHostToDeviceCopyThunk( std::unique_ptr<Thunk> IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique<DeviceToDeviceCopyThunk>( + return absl::make_unique<DeviceToDeviceCopyThunk>( /*source_address=*/GetAllocationSlice(*operand), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -2430,7 +2419,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); }); - return MakeUnique<InfeedThunk>(slices, inst); + return absl::make_unique<InfeedThunk>(slices, inst); } std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( @@ -2447,7 +2436,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( *slice = status_or_slice.ConsumeValueOrDie(); } }); - return MakeUnique<OutfeedThunk>(std::move(slices), inst); + return absl::make_unique<OutfeedThunk>(std::move(slices), inst); } namespace { @@ -2470,7 +2459,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( if (inst->opcode() == HloOpcode::kDot) { const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); - return MakeUnique<GemmThunk>( + return absl::make_unique<GemmThunk>( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2512,7 +2501,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( const HloInstruction* rhs = inst->operand(rhs_parameter->parameter_number()); - return MakeUnique<GemmThunk>( + return absl::make_unique<GemmThunk>( GetAllocationSlice(*lhs), // The buffer assigned to LHS. GetAllocationSlice(*rhs), // The buffer assigned to RHS. GetAllocationSlice(*inst), // The output buffer. @@ -2529,11 +2518,12 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk( std::unique_ptr<Thunk> IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); - return MakeUnique<FftThunk>(inst->fft_type(), inst->fft_length(), - /*input_buffer=*/GetAllocationSlice(*operand), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/operand->shape(), - /*output_shape=*/inst->shape(), inst); + return absl::make_unique<FftThunk>( + inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( @@ -2582,9 +2572,9 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( // MemzeroThunk. ArraySlice<uint8> literal_bytes( reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes); - if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return { - MakeUnique<MemzeroThunk>(GetAllocationSlice(*hlo, index), nullptr)}; + if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { + return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index), + nullptr)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2601,7 +2591,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16)); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique<Memset32BitValueThunk>( + return {absl::make_unique<Memset32BitValueThunk>( pattern32, GetAllocationSlice(*hlo, index), nullptr)}; } @@ -2612,7 +2602,7 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique<Memset32BitValueThunk>( + return {absl::make_unique<Memset32BitValueThunk>( word, GetAllocationSlice(*hlo, index), nullptr)}; } } @@ -2670,8 +2660,7 @@ Status CheckHloBuffersShareAllocation( if (slice_a != slice_b) { return InternalError( "instruction %s %s does not share allocation with instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); + a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString()); } return Status::OK(); } @@ -2764,7 +2753,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique<WhileThunk>( + return absl::make_unique<WhileThunk>( GetAllocationSlice(*condition->root_instruction()), // cond result ir_emitter_condition.ConsumeThunkSequence(), ir_emitter_body.ConsumeThunkSequence(), hlo); @@ -2782,8 +2771,8 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk( ir_emitter_context_); TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return MakeUnique<ForThunk>(loop_limit, - ir_emitter_body.ConsumeThunkSequence(), hlo); + return absl::make_unique<ForThunk>( + loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo); } std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( @@ -2803,7 +2792,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk( ir_emitter_context_); TF_CHECK_OK(false_computation->Accept(&ir_emitter_false)); - return MakeUnique<ConditionalThunk>( + return absl::make_unique<ConditionalThunk>( GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo->operand(1)), GetAllocationSlice(*hlo->operand(2)), @@ -3105,7 +3094,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( CeilOfRatio<int64>(output_dims_in_tiles[i], kTileSize); } const int64 num_tiles = - c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>()); + absl::c_accumulate(output_dims_in_tiles, 1, std::multiplies<int64>()); LaunchDimensions launch_dimensions(num_tiles, kThreadsPerTile); llvm::Type* index_ty = @@ -3151,9 +3140,8 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( const IrArray::Index output_tile_origin = [&] { IrArray::Index index = output_tile_index; for (int i = 1; i < 3; ++i) { - index[i] = - b_.CreateMul(output_tile_index[i], index_typed_constant(kTileSize), - "tile_origin." + std::to_string(i)); + index[i] = Mul(output_tile_index[i], index_typed_constant(kTileSize), + "tile_origin." + std::to_string(i)); } return index; }(); @@ -3166,12 +3154,12 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( std::vector<llvm::Value*> output_tile_bounds(3); for (int i = 1; i < 3; ++i) { // Only last row or column may not have full size. - output_tile_bounds[i] = b_.CreateSelect( - b_.CreateICmpEQ(output_tile_index[i], - index_typed_constant(output_dims_in_tiles[i] - 1)), - index_typed_constant(reduced_output_dims[i] - - (output_dims_in_tiles[i] - 1) * kTileSize), - index_typed_constant(kTileSize), "kTileSize"); + output_tile_bounds[i] = + Select(ICmpEQ(output_tile_index[i], + index_typed_constant(output_dims_in_tiles[i] - 1)), + index_typed_constant(reduced_output_dims[i] - + (output_dims_in_tiles[i] - 1) * kTileSize), + index_typed_constant(kTileSize), "kTileSize"); } KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll); @@ -3189,7 +3177,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( // Adds `addend` to the given `dim` of `index`. auto offset_dim = [&](IrArray::Index index, llvm::Value* addend, int64 dim) { - index[dim] = b_.CreateAdd(index[dim], addend); + index[dim] = Add(index[dim], addend); return index; }; const IrArray::Index input_index = @@ -3205,10 +3193,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( llvm::Value* shmem_buffer = param_shmem_buffers[id]; // TODO(jlebar): Add AA metadata to this store. Tile buffers are // global variables, so LLVM can't infer much about it. - b_.CreateStore( - input_in_logical_shape.EmitReadArrayElement(index, &b_, - "input_element"), - b_.CreateGEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); + Store(input_in_logical_shape.EmitReadArrayElement(index, &b_, + "input_element"), + GEP(shmem_buffer, {index_typed_constant(0), y_loc, x})); } }); @@ -3229,9 +3216,9 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_index, "output", output_tile_bounds[2], output_tile_bounds[1], [&](const IrArray::Index& index, llvm::Value* y_loc) { // TODO(jlebar): Add AA metadata to this load. - llvm::Instruction* load_from_shmem_buffer = b_.CreateLoad( - b_.CreateGEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), - "output_element"); + llvm::Instruction* load_from_shmem_buffer = + Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x, y_loc}), + "output_element"); output_in_reduced_shape_arrays[0].EmitWriteArrayElement( index, load_from_shmem_buffer, &b_); }); @@ -3259,7 +3246,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile( output_in_reduced_shape_arrays.size()); for (int64 i = 0; i < output_in_reduced_shape_arrays.size(); ++i) { output_in_reduced_shape_arrays[i].EmitWriteArrayElement( - index, b_.CreateExtractValue(output_value, i), &b_); + index, ExtractValue(output_value, i), &b_); } } else { output_in_reduced_shape_arrays[0].EmitWriteArrayElement( @@ -3341,7 +3328,7 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) { // if there's a Right Choice. // // This is only sound if tiled transposes are the only place where we use - // shared memory in fusions. If in the future other fusile ops use shared + // shared memory in fusions. If in the future other fusible ops use shared // memory, we'll have to adjust this heuristic. constexpr int kMinBlocksPerCore = 3; constexpr int64 kShmemPerCore = 48 * 1024; diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index e76823ad10..3259eaa2a2 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -41,8 +41,8 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, tensorflow::mutex_lock lock(mutex_); if (!loader_spec_) { loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); - tensorflow::StringPiece ptx = executable.ptx(); - // Convert tensorflow::StringPiece to se::port::StringPiece because + absl::string_view ptx = executable.ptx(); + // Convert absl::string_view to se::port::StringPiece because // StreamExecutor uses the latter. loader_spec_->AddCudaPtxInMemory( se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); @@ -63,7 +63,7 @@ Status KernelThunk::Initialize(const GpuExecutable& executable, if (kernel_cache_.end() == it) { it = kernel_cache_.emplace(executor, se::KernelBase(executor)).first; if (!executor->GetKernel(*loader_spec_, &it->second)) { - return InternalError("Unable to load kernel %s", kernel_name_.c_str()); + return InternalError("Unable to load kernel %s", kernel_name_); } } @@ -95,7 +95,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; - auto kernel_args = MakeUnique<se::KernelArgsArray<kKernelArgsLimit>>(); + auto kernel_args = absl::make_unique<se::KernelArgsArray<kKernelArgsLimit>>(); for (const BufferAllocation* arg : args_) { const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); kernel_args->add_device_memory_argument(buf); @@ -107,7 +107,7 @@ Status KernelThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, *kernel_args)) { - return InternalError("Unable to launch kernel %s", kernel_name_.c_str()); + return InternalError("Unable to launch kernel %s", kernel_name_); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index eb93efc560..698d2d51cc 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -34,6 +34,9 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@llvm//:amdgpu_code_gen", "@llvm//:analysis", "@llvm//:bit_reader", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc index 12a8a59488..85bc58cb44 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -86,10 +86,11 @@ void IrDumpingPassManager::run(llvm::Module &module) { const llvm::PassInfo *PI = llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID()); const string basename = ReplaceFilenameExtension( - tensorflow::io::Basename(input_filename_), - tensorflow::strings::Printf( + absl::string_view(tensorflow::io::Basename(input_filename_)), + absl::StrFormat( "pass-%02d.before.%s.ll", i, - (PI == nullptr ? "unknown" : PI->getPassArgument().data()))); + absl::string_view(PI == nullptr ? "unknown" + : PI->getPassArgument().data()))); llvm::legacy::PassManager::add( new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename))); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index ff4ae1f9ef..8751e3a9c2 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -20,13 +20,15 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" @@ -54,10 +56,7 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Scalar.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" @@ -107,8 +106,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, << ", " << compute_capability.second << ") ." << "Defaulting to libdevice for compute_" << libdevice_version; } - return tensorflow::strings::StrCat("libdevice.compute_", libdevice_version, - ".10.bc"); + return absl::StrCat("libdevice.compute_", libdevice_version, ".10.bc"); } // Gets the GPU name as it's known to LLVM for a given compute capability. If @@ -138,15 +136,16 @@ static string GetSmName(std::pair<int, int> compute_capability) { << "Defaulting to telling LLVM that we're compiling for sm_" << sm_version; } - return tensorflow::strings::StrCat("sm_", sm_version); + return absl::StrCat("sm_", sm_version); } // Convenience function for producing a name of a temporary compilation product // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, - tensorflow::StringPiece extension) { - return ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); + absl::string_view extension) { + return ReplaceFilenameExtension(absl::string_view(tensorflow::io::Basename( + llvm_ir::AsString(input_filename))), + extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -167,7 +166,7 @@ void InitializePasses(llvm::PassRegistry* pass_registry) { // Returns the TargetMachine, given a triple. std::unique_ptr<llvm::TargetMachine> GetTargetMachine( - llvm::Triple triple, tensorflow::StringPiece cpu_name, + llvm::Triple triple, absl::string_view cpu_name, const HloModuleConfig& hlo_module_config) { std::string error; const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error); @@ -205,7 +204,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine( default: codegen_opt_level = CodeGenOpt::None; } - return WrapUnique(target->createTargetMachine( + return absl::WrapUnique(target->createTargetMachine( triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx60", target_options, Optional<Reloc::Model>(RelocModel), Optional<CodeModel::Model>(CMModel), codegen_opt_level)); @@ -243,9 +242,9 @@ void AddOptimizationPasses(unsigned opt_level, unsigned size_level, } // Emits the given module to a bit code file. -void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { +void EmitBitcodeToFile(const Module& module, absl::string_view filename) { std::error_code error_code; - llvm::ToolOutputFile outfile(filename.ToString().c_str(), error_code, + llvm::ToolOutputFile outfile(string(filename).c_str(), error_code, llvm::sys::fs::F_None); if (error_code) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); @@ -266,8 +265,9 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); IrDumpingPassManager codegen_passes( - ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-nvptx.dummy"), + ReplaceFilenameExtension( + absl::string_view(tensorflow::io::Basename(module_id)), + "-nvptx.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -332,8 +332,8 @@ Status LinkLibdeviceIfNecessary(llvm::Module* module, return !GV.hasName() || (GVS.count(GV.getName()) == 0); }); })) { - return tensorflow::errors::Internal(tensorflow::strings::StrCat( - "Error linking libdevice from ", libdevice_path)); + return tensorflow::errors::Internal( + absl::StrCat("Error linking libdevice from ", libdevice_path)); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h index 54e0e140de..9654175bfa 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.h @@ -20,11 +20,11 @@ limitations under the License. #include <string> #include <utility> +#include "absl/strings/string_view.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { namespace gpu { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc index 9ef9bc3a50..3b2c3591d9 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -17,13 +17,13 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace { @@ -52,14 +52,13 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename, return module; } -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension) { +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension) { auto pos = filename.rfind('.'); - tensorflow::StringPiece stem = - pos == tensorflow::StringPiece::npos - ? filename - : tensorflow::StringPiece(filename.data(), pos); - return tensorflow::strings::StrCat(stem, ".", new_extension); + absl::string_view stem = pos == absl::string_view::npos + ? filename + : absl::string_view(filename.data(), pos); + return absl::StrCat(stem, ".", new_extension); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h index a6daeca95a..60f4926849 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h @@ -18,8 +18,8 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace llvm { class LLVMContext; @@ -41,8 +41,8 @@ std::unique_ptr<llvm::Module> LoadIRModule(const string& filename, // // For example: // ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" -string ReplaceFilenameExtension(tensorflow::StringPiece filename, - tensorflow::StringPiece new_extension); +string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c62bae0628..7a43f0be54 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -23,6 +23,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -48,7 +49,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, // If possible, we want to pick a reduce operand of the fusion root, // because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { - if (inst->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*inst)) { return inst; } } @@ -63,7 +64,7 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, auto get_element_shape = [&](const HloInstruction* element_instr) { // Special handling of kReduce instructions -- the fusion // applies to the first operand. - if (element_instr->opcode() == HloOpcode::kReduce) { + if (IsReductionToVector(*element_instr)) { return element_instr->operand(0)->shape(); } return element_instr->shape(); @@ -131,7 +132,7 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) { max_rank_layout = ¶m->shape().layout(); } } - return c_all_of(params, [&](HloInstruction* param) { + return absl::c_all_of(params, [&](HloInstruction* param) { return (ShapeUtil::Rank(param->shape()) < max_rank) || (LayoutUtil::Equal(param->shape().layout(), *max_rank_layout)); }); @@ -140,10 +141,15 @@ bool ReduceFriendlyInputLayouts(HloInstruction* instr) { } // namespace bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { - // We can fuse reduces and loop fusions. - return IsInputFusibleReduction(instr) || - (instr->opcode() == HloOpcode::kFusion && - instr->fusion_kind() == HloInstruction::FusionKind::kLoop); + // We can fuse reduces and loop fusions. Elementwise instructions can be fused + // with any other instruction. + // TODO(b/112957171): This should use the same isFusible logic as + // instruction_fusion. + return instr->IsFusible() && + (IsInputFusibleReduction(instr) || + (instr->opcode() == HloOpcode::kFusion && + instr->fusion_kind() == HloInstruction::FusionKind::kLoop) || + instr->IsElementwise()); } int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, @@ -177,11 +183,12 @@ bool GpuMultiOutputFusion::LegalToFuse(HloInstruction* instr1, // merge into bigger loop fusions and input (reduce) fusions become fusions // with multiple reduce outputs. We could fuse reduce and loop fusions // together too (the result being an input fusion) if we find cases where this - // improves things. + // improves things. Also disable fusing standalone input-fusible reduces into + // loop fusions. CHECK(instr1->opcode() == HloOpcode::kFusion); if ((instr2->opcode() == HloOpcode::kFusion && instr1->fusion_kind() != instr2->fusion_kind()) || - (instr2->opcode() != HloOpcode::kFusion && + (IsReductionToVector(*instr2) && instr1->fusion_kind() == HloInstruction::FusionKind::kLoop)) { return false; } @@ -197,7 +204,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { tensorflow::gtl::FlatSet<HloInstruction*> to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, - // then filter out instructions that will be no longer fusable because of + // then filter out instructions that will be no longer fusible because of // reachability change. This avoids recalculating reachability on a large set // of instructions. std::vector<std::pair<HloInstruction*, HloInstruction*>> @@ -213,7 +220,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { continue; } if (!IsInputFusibleReduction(consumer)) { - VLOG(3) << consumer->name() << " is not an input-fusable reduction."; + VLOG(3) << consumer->name() << " is not an input-fusible reduction."; continue; } VLOG(3) << consumer->name() @@ -222,8 +229,8 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { auto consumer_operands = consumer->operands(); for (size_t i = 0; i < consumer_operands.size(); ++i) { HloInstruction* producer = consumer_operands[i]; - if (!producer->IsFusable()) { - VLOG(3) << producer->name() << " is not fusable."; + if (!producer->IsFusible()) { + VLOG(3) << producer->name() << " is not fusible."; continue; } const bool is_loop_fusion = @@ -248,7 +255,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } // Do not fuse a producer if the other operands of the fusion are // reachable from the producer, this would create a cycle. - if (c_any_of(consumer_operands, [&](HloInstruction* operand) { + if (absl::c_any_of(consumer_operands, [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { @@ -263,12 +270,12 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { } } - // Filter out pairs that will be no longer fusable because of reachability + // Filter out pairs that will be no longer fusible because of reachability // change. for (auto& fusion_pair : potential_fusion_list) { HloInstruction* producer = fusion_pair.first; HloInstruction* consumer = fusion_pair.second; - if (!c_any_of(consumer->operands(), [&](HloInstruction* operand) { + if (!absl::c_any_of(consumer->operands(), [&](HloInstruction* operand) { return producer != operand && reachability()->IsReachable(producer, operand); })) { diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h index 67ca5d49ee..f0b4d67ab8 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -22,7 +22,7 @@ namespace xla { namespace gpu { // Multi-output fusion of sibling and producer-consumer instructions for the -// Jellyfish backend. +// GPU backend. class GpuMultiOutputFusion : public MultiOutputFusion { public: GpuMultiOutputFusion(); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index 14f157a5e5..c822c94f1b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -15,19 +15,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace gpu { +namespace op = xla::testing::opcode_matchers; + using MultiOutputFusionTest = HloTestBase; const char kModulePrefix[] = R"( @@ -47,7 +47,7 @@ const char kModulePrefix[] = R"( TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { // Fusion with reduce instruction root and a sibling reduce instruction // sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -74,7 +74,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[6400]{0} parameter(1) mul = f32[6400]{0} multiply(p1.1, p1.1) @@ -101,7 +101,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceInputShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -130,7 +130,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceFusions) { // Two sibling fusions with reduce instruction roots sharing the same input // param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) @@ -165,7 +165,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { // Multi-output fusion with two reduce instructions root and a sibling reduce // instruction sharing the same input param. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { const.1 = f32[] constant(1) p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) @@ -198,7 +198,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { // Verify that if we already have a multi-output fusion that we prefer to pick // a reduce op from its operands for checking shape compatibility. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p1.1 = f32[10,10]{1,0} parameter(1) mul = f32[10,10]{1,0} multiply(p1.1, p1.1) @@ -228,7 +228,7 @@ TEST_F(MultiOutputFusionTest, } TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -256,8 +256,136 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { op::Tuple(op::Multiply(), op::Divide())); } -TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { + // Fusing a reduce into a loop fusion would require changing the fusion kind. + // That's not supported yet. auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(0) + reduce = f32[] reduce(p0, const.2), dimensions={0}, to_apply=scalar_add_computation + ROOT root = (f32[6400]{0}, f32[]) tuple(fusion.1, reduce) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[6400]{0} parameter(0) + ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) + } + + ENTRY entry { + p0 = f32[6400]{0} parameter(0) + fusion.1 = f32[6400]{0} fusion(p0), kind=kLoop, calls=fused_computation_1 + const.2 = f32[] constant(1) + div = f32[6400]{0} divide(p0, const.2) + ROOT root = (f32[6400]{0}, f32[6400]{0}) tuple(fusion.1, div) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Divide())); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT add = f32[8,1,5,16,1,1]{5,4,3,2,1,0} add(p0.2, const.2) + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Multiply(), op::Exp(), op::Add())); +} + +TEST_F(MultiOutputFusionTest, + MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) + exp = f32[8,1,5,16,1,1]{5,4,3,2,1,0} exponential(p0.1) + ROOT tuple = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) tuple(mul, exp) + } + + fused_computation_2 { + p0.2 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + const.2 = f32[] constant(0) + ROOT reduce = f32[8,1,5,1,1]{4,3,2,1,0} reduce(p0.2, const.2), dimensions={3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) + fusion.1 = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}) fusion(p0), kind=kLoop, calls=fused_computation_1 + fusion.2 = f32[8,1,5,1,1]{4,3,2,1,0} fusion(p0), kind=kLoop, calls=fused_computation_2 + gte0 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=0 + gte1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} get-tuple-element(fusion.1), index=1 + ROOT root = (f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,16,1,1]{5,4,3,2,1,0}, f32[8,1,5,1,1]{4,3,2,1,0}) tuple(gte0, gte1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( ENTRY reduce { p0 = f32[2,2,2]{2,1,0} parameter(0) c0 = f32[] constant(0) @@ -277,7 +405,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionElementwiseAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_add { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -304,7 +432,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduce) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f32[2,2,2]{2,1,0} parameter(1) c0 = f32[] constant(0) @@ -345,7 +473,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionLoopFusionAndReduceFusion) { } TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_element_wise { p0.1 = f32[2,2,2]{2,1,0} parameter(0) p1.1 = f32[2,2,2]{2,1,0} parameter(1) @@ -372,7 +500,7 @@ TEST_F(MultiOutputFusionTest, ProducerConsumerFusionDoNotFuseLoopReduceFusion) { TEST_F(MultiOutputFusionTest, ProducerConsumerFusionFp16LoopFusionAndReduceFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_select { p1.1 = f16[2,2,2]{2,1,0} parameter(1) c0 = f16[] constant(0) @@ -413,7 +541,7 @@ TEST_F(MultiOutputFusionTest, TEST_F(MultiOutputFusionTest, ProducerConsumerFusionReduceUnfriendlyLoopFusion) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( mixed_input_layouts_computation { p0.1 = f16[128,1024,32,32]{1,3,2,0} parameter(0) p1.1 = f16[128,1024,32,32]{3,2,1,0} parameter(1) diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 6c1eab4f8c..8e4a8e5f54 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -21,13 +21,15 @@ limitations under the License. #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -85,7 +87,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -140,7 +141,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, Compiler* compiler) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); pipeline.AddPass<GpuHloSupportChecker>(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), @@ -156,7 +158,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); - pass.AddInvariantChecker<HloVerifier>(); + pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls // where possible. Not every batchnorm op can be implemented as a call to @@ -203,10 +206,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // Convert convolutions into CustomCalls to cudnn, then canonicalize them // (PadInsertion). HloPassPipeline pipeline("conv_canonicalization"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); // TODO(b/31709653): Directly use the grouped convolution support of Cudnn. pipeline.AddPass<ConvolutionFeatureGroupConverter>(); pipeline.AddPass<CudnnConvolutionRewriter>(); + // CudnnConvolutionRewriter may add instructions of the form + // reverse(constant), which it expects will be simplified by constant + // folding. + pipeline.AddPass<HloConstantFolding>(); pipeline.AddPass<PadInsertion>(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<PadForTensorCores>(); @@ -218,9 +226,22 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, } { - HloPassPipeline pipeline("layout_assignment"); + // Run layout assignment in a separate pipeline from + // "post-layout-assignment" because we want everything after layout + // assignment to have a layout-sensitive invariant-checker, but + // HloPassPipeline also runs its invariant checker before any passes are + // run, meaning, the pipeline that contains layout assignment cannot contain + // a layout-sensitive verifier! + HloPassPipeline pipeline("layout assignment"); pipeline.AddPass<GpuLayoutAssignment>( hlo_module->mutable_entry_computation_layout(), stream_exec); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("post-layout_assignment"); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -266,17 +287,20 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix<HloPassPipeline> fusion("fusion"); - fusion.AddInvariantChecker<HloVerifier>(); + fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); fusion.AddPass<FusionMerger>(); fusion.AddPass<GpuMultiOutputFusion>(); fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); + fusion.AddPass<HloDCE>(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker<HloVerifier>(); + reduce_pipeline.AddInvariantChecker<HloVerifier>( + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -302,7 +326,8 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker<HloVerifier>(); + pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -352,9 +377,9 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { string vmaj_str, vmin_str, vdot_str; if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, &vmin_str, &vdot_str) || - !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || - !tensorflow::strings::safe_strto64(vmin_str, &vmin) || - !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + !absl::SimpleAtoi(vmaj_str, &vmaj) || + !absl::SimpleAtoi(vmin_str, &vmin) || + !absl::SimpleAtoi(vdot_str, &vdot)) { LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path << " --version:\n" << out; @@ -466,7 +491,7 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, tensorflow::SubProcess ptxas_info_dumper; std::vector<string> ptxas_args = { ptxas_path, ptx_path, "-o", cubin_path, - tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; + absl::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -674,7 +699,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); + ir_dump_directory, absl::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); @@ -690,7 +715,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend( const std::vector<uint8> cubin = CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); - auto thunk_schedule = MakeUnique<ThunkSchedule>( + auto thunk_schedule = absl::make_unique<ThunkSchedule>( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; @@ -704,7 +729,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend( cost_analysis.set_bytes_per_second( stream_exec->GetDeviceDescription().memory_bandwidth()); TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); - profile_index_map = MakeUnique<HloProfileIndexMap>(*module); + profile_index_map = absl::make_unique<HloProfileIndexMap>(*module); profile_printer = CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } @@ -813,7 +838,7 @@ se::Platform::Id NVPTXCompiler::PlatformId() const { static bool InitModule() { xla::Compiler::RegisterCompilerFactory( stream_executor::cuda::kCudaPlatformId, - []() { return xla::MakeUnique<xla::gpu::NVPTXCompiler>(); }); + []() { return absl::make_unique<xla::gpu::NVPTXCompiler>(); }); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index d4d2909f1b..08ef6ef56c 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,13 +20,13 @@ limitations under the License. #include <string> #include <vector> +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_compiler.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc index 4aaf0c9e14..2fa170964e 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/platform/logging.h" diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index b99d998c4d..e0f3e84a4c 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -96,7 +96,7 @@ Status OutfeedThunk::ExecuteOnStream( Status block_status = stream->BlockHostUntilDone(); if (!block_status.ok()) { return InternalError("Failed to complete data transfer on stream %p: %s", - stream, block_status.error_message().c_str()); + stream, block_status.error_message()); } VLOG(2) << "Outfeeding from GPU complete"; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 192359f026..11dc56a64f 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -32,9 +32,7 @@ namespace gpu { // TODO(jlebar): Also pad dots. class PadForTensorCores : public HloPassInterface { public: - tensorflow::StringPiece name() const override { - return "pad for tensor cores"; - } + absl::string_view name() const override { return "pad for tensor cores"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc index 99e7580b82..104af48c82 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores_test.cc @@ -29,7 +29,12 @@ namespace { namespace op = xla::testing::opcode_matchers; using ::testing::_; -using PadForTensorCoresTest = HloVerifiedTestBase; +class PadForTensorCoresTest : public HloVerifiedTestBase { + public: + PadForTensorCoresTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(PadForTensorCoresTest, PadF16ForwardConvInputChannels) { ParseAndVerifyModule(R"( diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index b22040eee1..98cc21ccac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -69,7 +70,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique<Literal>(LiteralUtil::Zero(element_type)))); + absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); input = MakePadHlo(input, padding, padding_config).ValueOrDie(); } @@ -126,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique<Literal>(LiteralUtil::Zero(element_type)))); + absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); return MakePadHlo(kernel, padding, padding_config).ValueOrDie(); } } // namespace @@ -236,7 +237,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique<Literal>( + HloInstruction::CreateConstant(absl::make_unique<Literal>( LiteralUtil::Zero(input->shape().element_type())))); HloInstruction* padded_input = MakePadHlo(input, padding, input_padding_config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index 67e51509e4..a622e894ed 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -26,7 +26,7 @@ namespace gpu { // padding, so that they can be lowered to cuDNN convolution. class PadInsertion : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "pad insertion"; } + absl::string_view name() const override { return "pad insertion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 3838fee674..ca57cacb98 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -57,8 +57,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( unroll_factor_(unroll_factor) {} std::vector<llvm_ir::IrArray::Index> -ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { +ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, + llvm::Type* index_type) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index b82a23419d..cc7da2e73b 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -58,7 +58,7 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type) override; private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d3fd0544fb..cf9f102d31 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -18,15 +18,15 @@ limitations under the License. #include <ostream> #include <string> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -34,9 +34,8 @@ namespace gpu { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims) { - out << tensorflow::strings::Printf("[block: %lld, thread: %lld]", - launch_dims.block_count(), - launch_dims.threads_per_block()); + out << absl::StrFormat("[block: %d, thread: %d]", launch_dims.block_count(), + launch_dims.threads_per_block()); return out; } @@ -91,9 +90,9 @@ LaunchDimensions CalculateLaunchDimensions( } int64 block_count = CeilOfRatio(num_elements, threads_per_block); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " - "block) = ceil(%lld/%lld) = %lld", + "block) = ceil(%d/%d) = %d", num_elements, threads_per_block, block_count); return LaunchDimensions(block_count, threads_per_block); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc index 0806dd5161..5b6cf2c04d 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -119,7 +119,7 @@ int ComputeStreamToAssign( } // namespace std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) { - auto stream_assignment = MakeUnique<StreamAssignment>(); + auto stream_assignment = absl::make_unique<StreamAssignment>(); const HloComputation& computation = *module.entry_computation(); std::unique_ptr<HloReachabilityMap> reachability = computation.ComputeReachability(); diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 6f4bb0580e..091aca23e5 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace gpu { @@ -33,7 +34,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique<HloModule>("test_module", config); + return absl::make_unique<HloModule>("test_module", config); } // Pre-canned shapes. @@ -97,7 +98,7 @@ TEST_F(StreamAssignmentTest, LatticeMatMul) { params.reserve(6); for (int i = 0; i < 6; ++i) { params.push_back(builder.AddInstruction(HloInstruction::CreateParameter( - i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i)))); + i, f32_2x2_, /*name=*/absl::StrFormat("param%d", i)))); } HloInstruction* d00 = builder.AddInstruction( HloInstruction::CreateCanonicalDot(f32_2x2_, params[2], params[3])); diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc index 05b305ea4c..08ff52211a 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace gpu { @@ -53,8 +55,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, input_layout.push_back(dnums.input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid input layout: ", - DataLayoutString(input)); + return InternalError("Invalid input layout %s for conv with dnums %s", + DataLayoutString(input), + ConvolutionDimensionNumbersToString(dnums)); } std::vector<int64> filter_layout; @@ -74,8 +77,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, filter_layout.push_back(dnums.kernel_input_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid filter layout: ", - FilterLayoutString(filter)); + return InternalError("Invalid filter layout %s for conv with dnums %s", + FilterLayoutString(filter), + ConvolutionDimensionNumbersToString(dnums)); } std::vector<int64> output_layout; @@ -95,8 +99,9 @@ StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, output_layout.push_back(dnums.output_feature_dimension()); break; default: - return tensorflow::errors::Internal("Invalid output layout: ", - DataLayoutString(output)); + return InternalError("Invalid output layout %s for conv with dnums %s", + DataLayoutString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), @@ -128,8 +133,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(input, nhwc_input)) { input_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid input layout: ", - input.ShortDebugString()); + return InternalError("Invalid input layout %s for conv with dnums %s", + LayoutUtil::HumanString(input), + ConvolutionDimensionNumbersToString(dnums)); } FilterLayout filter_layout; @@ -138,8 +144,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(filter, nhwc_filter)) { filter_layout = FilterLayout::kOutputYXInput; } else { - return tensorflow::errors::Internal("Invalid filter layout: ", - filter.ShortDebugString()); + return InternalError("Invalid filter layout %s for conv with dnums %s", + LayoutUtil::HumanString(filter), + ConvolutionDimensionNumbersToString(dnums)); } DataLayout output_layout; @@ -148,8 +155,9 @@ XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, } else if (LayoutUtil::Equal(output, nhwc_output)) { output_layout = DataLayout::kBatchYXDepth; } else { - return tensorflow::errors::Internal("Invalid output layout: ", - output.ShortDebugString()); + return InternalError("Invalid output layout %s for conv with dnums %s", + LayoutUtil::HumanString(output), + ConvolutionDimensionNumbersToString(dnums)); } return std::make_tuple(input_layout, filter_layout, output_layout); diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 4fad3f46cf..db4a33dc56 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -35,13 +35,13 @@ cc_library( "requires-gpu-sm35", ], deps = [ - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service/gpu:gpu_executable", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], ) @@ -60,6 +60,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -94,6 +95,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -150,6 +152,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) @@ -168,6 +171,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc index 4b8415fe91..79e77d4c4d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/tests/filecheck.h" #include "tensorflow/core/platform/logging.h" @@ -32,15 +32,14 @@ std::unique_ptr<HloModule> GpuCodegenTest::CreateNewModuleWithFTZ(bool ftz) { debug_options.add_xla_disable_hlo_passes("constant_folding"); config.set_debug_options(debug_options); - return MakeUnique<HloModule>(TestName(), config); + return absl::make_unique<HloModule>(TestName(), config); } void GpuCodegenTest::CompileAndVerifyPtx(std::unique_ptr<HloModule> hlo_module, const string& pattern) { std::unique_ptr<Executable> executable = std::move(CompileToExecutable(std::move(hlo_module)).ValueOrDie()); - string ptx_str = - std::string(static_cast<GpuExecutable*>(executable.get())->ptx()); + string ptx_str(static_cast<GpuExecutable*>(executable.get())->ptx()); StatusOr<bool> filecheck_result = RunFileCheck(ptx_str, pattern); ASSERT_TRUE(filecheck_result.ok()); EXPECT_TRUE(filecheck_result.ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc index ce69e058e6..4550f36fdf 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc index e5958165ef..a06576df7b 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_index_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index cca35316f0..15d1e269cc 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -27,13 +27,22 @@ namespace { class GpuKernelTilingTest : public GpuCodegenTest { protected: - GpuKernelTilingTest() { + GpuKernelTilingTest() {} + + // Most tests in this file want to skip layout assignment, but a few need it + // enabled. + HloModuleConfig ConfigWithLayoutAssignment() { + return GetModuleConfigForTest(); + } + + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - config_.set_debug_options(debug_options); // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout_assignment"); + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; } - HloModuleConfig config_; }; TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { @@ -46,7 +55,13 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithProperDimensionsTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // + // We must enable layout assignment in order for this test to work correctly. + // AlgebraicSimplifier removes copy1; it's added back by layout assignment, + // which respects the module's entry computation layout. But if we don't run + // layout assignment...well, nobody else adds the copy back. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -68,8 +83,11 @@ TEST_F(GpuKernelTilingTest, UnnestedTransposeWithSmallDimensionsNotTiled) { ROOT copy1 = f16[2,3,64]{1,0,2} copy(para0) })"; - // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + // Check that a call to llvm.nvvm.barrier0 is not generated. As in + // UnnestedTransposeWithProperDimensionsTiled, we must run layout assignment + // here. + auto hlo_module = + ParseHloString(kHloString, ConfigWithLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @copy @@ -95,7 +113,8 @@ TEST_F(GpuKernelTilingTest, SimpleFusionWithTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -128,7 +147,8 @@ TEST_F(GpuKernelTilingTest, MultipleOutputFusionWithOnePossibleTransposeTiled) { })"; // Check that a call to llvm.nvvm.barrier0 is generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion @@ -162,7 +182,8 @@ TEST_F(GpuKernelTilingTest, })"; // Check that a call to llvm.nvvm.barrier0 is not generated. - auto hlo_module = ParseHloString(kHloString, config_).ValueOrDie(); + auto hlo_module = + ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie(); CompileAndVerifyIr(std::move(hlo_module), R"( ; CHECK-LABEL: define void @fusion diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc index 6c9ae7bada..6a9ecd9dae 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ldg_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc index c42e5704a4..15198865bd 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_noalias_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 9622936306..0f2d5568ca 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -138,6 +138,9 @@ TEST_F(GpuUnrollingTest, UnrollMultiOutputFusion) { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_max_kernel_unroll_factor(2); + // Disable layout assignment for this test. Layout assignment does not expect + // fusions to be present, and so it does the wrong thing. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); config.set_debug_options(debug_options); const char *const kMultiOutputFusionModule = R"( diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index bdb062837c..141f321938 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -144,16 +144,15 @@ const std::list<const Thunk*>& ThunkSchedule::DependsOn( string ThunkSchedule::ToString() const { string result = "Total order:\n"; for (Thunk* thunk : thunk_total_order_) { - tensorflow::strings::StrAppend(&result, "\t", - thunk->hlo_instruction()->ToString(), "\n"); + absl::StrAppend(&result, "\t", thunk->hlo_instruction()->ToString(), "\n"); } - tensorflow::strings::StrAppend(&result, "Dependencies:\n"); + absl::StrAppend(&result, "Dependencies:\n"); for (const auto& entry : depends_on_) { const Thunk* dependent = entry.first; for (const Thunk* dependency : entry.second) { - tensorflow::strings::StrAppend( - &result, "\t", dependent->hlo_instruction()->name(), " depends on ", - dependency->hlo_instruction()->name(), "\n"); + absl::StrAppend(&result, "\t", dependent->hlo_instruction()->name(), + " depends on ", dependency->hlo_instruction()->name(), + "\n"); } } return result; diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc index 8579b1545f..989b542ff4 100644 --- a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" @@ -25,7 +26,7 @@ Status TupleThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { auto size = tuple_element_buffers_.size(); - auto tuple_element_buffer_addresses = MakeUnique<void*[]>(size); + auto tuple_element_buffer_addresses = absl::make_unique<void*[]>(size); for (int i = 0; i != size; ++i) { tuple_element_buffer_addresses[i] = buffer_allocations.GetDeviceAddress(tuple_element_buffers_[i]).opaque(); diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index d81d87e7dc..c4754fe378 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" @@ -34,9 +34,9 @@ WhileThunk::WhileThunk( // and body_thunk_sequence_ constructors because these SequentialThunks // are logically "part of" this WhileThunk, and shouldn't be profiled // separately from it. - condition_thunk_sequence_(MakeUnique<SequentialThunk>( + condition_thunk_sequence_(absl::make_unique<SequentialThunk>( std::move(*condition_thunk_sequence), nullptr)), - body_thunk_sequence_(MakeUnique<SequentialThunk>( + body_thunk_sequence_(absl::make_unique<SequentialThunk>( std::move(*body_thunk_sequence), nullptr)) {} Status WhileThunk::Initialize(const GpuExecutable& executable, @@ -70,7 +70,7 @@ Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, if (!block_status.ok()) { return InternalError( "Failed to complete all kernels launched on stream %p: %s", stream, - block_status.error_message().c_str()); + block_status.error_message()); } if (!condition_result) { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index c5f3906356..40183de96e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -118,7 +118,8 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier; + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false); TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index aa89567ee8..a2be89511b 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -22,9 +22,10 @@ limitations under the License. #include <memory> #include <string> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -43,8 +43,7 @@ namespace { // Adds a computation to the given HLO module which adds a scalar constant to // its parameter and returns the result. HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { - auto builder = - HloComputation::Builder(tensorflow::strings::StrCat("add_", addend)); + auto builder = HloComputation::Builder(absl::StrCat("add_", addend)); auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( @@ -84,7 +83,7 @@ HloComputation* CallForwardingComputation(HloComputation* computation, // the module. std::unique_ptr<HloModule> MakeBigGraph() { HloModuleConfig config; - auto module = MakeUnique<HloModule>("BigGraph", config); + auto module = absl::make_unique<HloModule>("BigGraph", config); auto builder = HloComputation::Builder("TestBigGraphvizGraph"); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 4005fc0d11..38c3982ebf 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" @@ -45,7 +46,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module, + HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), *module, module_sequence, *points_to_analysis, size_function)); return result.heap_size; } @@ -60,9 +61,10 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, - sequence, points_to_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Run(absl::make_unique<NoFragmentationStatsHeap>(), + computation, sequence, points_to_analysis, + size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -142,7 +144,7 @@ Status HeapSimulator::RunComputation( } } else { // A GetTupleElement doesn't need to keep all of its operand's buffers - // alive. It only needs the buffers that relate to the element its + // alive. It only needs the buffers that relate to the element it's // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. for (const BufferValue* buffer : points_to.element({})) { @@ -275,13 +277,13 @@ Status HeapSimulator::RunComputation( *memory_by_computation_); } - // If the whole module is sequential, we can save memory by running the - // heap-simulation for sub-computations inline. E.g. the buffers for the - // condition and body of a kWhile instruction are only live for the duration - // of the instruction itself. + // If all computations in the module have been scheduled, we can save memory + // by running the heap-simulation for sub-computations inline. E.g. the + // buffers for the condition and body of a kWhile instruction are only live + // for the duration of the instruction itself. // // The order that the sub-computations are simulated does not affect - // correctness; since the whole module is sequential, we know that the + // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || @@ -344,7 +346,7 @@ HeapSimulator::HeapSimulator( const SequentialHloOrdering::HloModuleSequence* module_sequence, const tensorflow::gtl::FlatMap<const HloComputation*, int64>* memory_by_computation) - : no_fragmentation_stats_(MakeUnique<NoFragmentationStatsHeap>()), + : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), @@ -378,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - algorithm_->Alloc(buffer, size); - no_fragmentation_stats_->Alloc(buffer, size); - + const HloInstruction* instruction_to_calc_aliasing = + memory_by_computation_ == nullptr ? nullptr : instruction; + algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); + no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -518,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + if (instruction == nullptr || + (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kCall && + instruction->opcode() != HloOpcode::kConditional)) { + Alloc(buffer, size); + } +} + void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap<const HloComputation*, int64>& diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 811a6042df..af05bedee7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -36,6 +36,7 @@ namespace xla { // Forward declare classes defined below. class HeapAlgorithm; +class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular // memory heap with Alloc and Free calls. It only works for completely @@ -161,7 +162,10 @@ class HeapSimulator { const HloInstruction* instruction, const BufferValue* shared_with_canonical); - const std::unique_ptr<HeapAlgorithm> no_fragmentation_stats_; + // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, + // in which case we are calculating the same allocs/frees twice in the + // simulation. + const std::unique_ptr<NoFragmentationStatsHeap> no_fragmentation_stats_; const std::unique_ptr<HeapAlgorithm> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; @@ -216,6 +220,21 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + // NoFragmentationStatsHeap overrides this method. + virtual void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + Alloc(buffer, size); + } + + // Takes memory usage of subcomputations into account when calculating the + // memory usage of a computation. Currently, we don't handle buffer aliasing + // between computations entirely correctly. We are careful to not double count + // for the output buffers of whiles/conds/calls. But we don't take into + // account other aliases, such as for the while init. A more thorough solution + // would require something like BufferAssignment::BuildColocatedBufferSets. + // TODO(b/65835246): + // Since TuplePointsToAnalysis is being replaced with a module-aware alias + // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap<const HloComputation*, int64>& @@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) override; + void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap<const HloComputation*, int64>& diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b41dc66fe9..5f85f14565 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -137,7 +138,7 @@ class HeapSimulatorTracker { const string& name, std::unique_ptr<HloComputation> computation, const std::vector<const HloInstruction*>& instruction_sequence) { HloModuleConfig config; - module_ = MakeUnique<HloModule>(name, config); + module_ = absl::make_unique<HloModule>(name, config); module_->AddEntryComputation(std::move(computation)); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); @@ -146,8 +147,8 @@ class HeapSimulatorTracker { // the secondary sorting criteria of DecreasingSizeRunsHeap to sort calls by // buffer id, for determinism in the tests. auto zero_size = [](const BufferValue& buffer) { return 0; }; - auto algorithm = MakeUnique<DecreasingSizeRunsHeap>( - MakeUnique<HeapCallRecorder>(&actual_calls_)); + auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( + absl::make_unique<HeapCallRecorder>(&actual_calls_)); result_ = HeapSimulator::Run( std::move(algorithm), *module_->entry_computation(), instruction_sequence, *points_to_analysis_, zero_size) @@ -156,7 +157,7 @@ class HeapSimulatorTracker { explicit HeapSimulatorTracker(const string& name) { HloModuleConfig config; - module_ = MakeUnique<HloModule>(name, config); + module_ = absl::make_unique<HloModule>(name, config); } // Similar to the single entry computation constructor above, but runs the @@ -182,8 +183,8 @@ class HeapSimulatorTracker { auto size_fn = [&reverse_position](const BufferValue& buffer) { return reverse_position[buffer.instruction()]; }; - auto algorithm = MakeUnique<DecreasingSizeRunsHeap>( - MakeUnique<HeapCallRecorder>(&actual_calls_)); + auto algorithm = absl::make_unique<DecreasingSizeRunsHeap>( + absl::make_unique<HeapCallRecorder>(&actual_calls_)); result_ = HeapSimulator::Run(std::move(algorithm), *module_, module_sequence, *points_to_analysis_, size_fn) .ConsumeValueOrDie(); @@ -675,7 +676,8 @@ class HeapAlgorithmTestBase : public ::testing::Test { const BufferValue::Id id = buffers_.size(); auto const0 = builder_.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); - buffers_.emplace_back(MakeUnique<HloValue>(id, const0, ShapeIndex{})); + buffers_.emplace_back( + absl::make_unique<HloValue>(id, const0, ShapeIndex{})); return buffers_.back().get(); } @@ -724,7 +726,8 @@ class DecreasingSizeRunsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(DecreasingSizeRunsHeapTest, Empty) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique<HeapCallRecorder>(&call_sequence)); heap.Finish(); EXPECT_EQ(call_sequence, CallSequence({ {kFinish, nullptr}, @@ -733,7 +736,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Empty) { TEST_F(DecreasingSizeRunsHeapTest, Simple) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique<HeapCallRecorder>(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -760,7 +764,8 @@ TEST_F(DecreasingSizeRunsHeapTest, Simple) { TEST_F(DecreasingSizeRunsHeapTest, Mixed) { CallSequence call_sequence; - DecreasingSizeRunsHeap heap(MakeUnique<HeapCallRecorder>(&call_sequence)); + DecreasingSizeRunsHeap heap( + absl::make_unique<HeapCallRecorder>(&call_sequence)); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Free(buffer_b_, 20); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index fa218657fe..58b7af93eb 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 51 +// Next ID: 53 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -46,6 +46,8 @@ message HloInstructionProto { reserved "control_predecessor_names"; reserved 6; reserved "called_computation_names"; + reserved 44; + reserved "replica_group_ids"; string name = 1; string opcode = 2; @@ -158,9 +160,6 @@ message HloInstructionProto { string backend_config = 43; // Cross replica op fields. - // TODO(b/112107579): remove replica_group_ids field and always use - // replica_groups. - repeated int64 replica_group_ids = 44; repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; @@ -171,6 +170,12 @@ message HloInstructionProto { bool is_host_transfer = 47; xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; + + // Precision configuration for the instruction. Has backend-specific meaning. + xla.PrecisionConfigProto precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index e8a4b034b4..0986da65cb 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,15 +30,11 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; // Data structure used to construct the alias analysis. Thrown away after alias // analysis is complete. This data structure keeps track of which sets of @@ -414,7 +412,7 @@ Status HloAliasAnalysis::Verify() const { } string HloAliasAnalysis::ToString() const { - string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); + string out = absl::StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); for (const HloComputation* computation : module_->computations()) { for (const HloInstruction* instruction : computation->instructions()) { @@ -457,7 +455,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run( VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); - auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); + auto alias_analysis = absl::WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false, @@ -537,10 +535,10 @@ bool HloAliasAnalysis::HasLiveRangeInterference( if (ordering.MayInterfere(*values[i - 1], *values[i], dataflow_analysis())) { VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " - << Join(values, ", ", - [](string* out, const HloValue* value) { - StrAppend(out, value->ToShortString()); - }) + << absl::StrJoin(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) << "\nValue " << values[i - 1]->ToShortString() << " may interfere with value " << values[i]->ToShortString(); diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index e16413f361..6c11a073b7 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -27,15 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; - bool HloBuffer::operator==(const HloBuffer& other) const { bool equal = id() == other.id(); if (equal) { @@ -59,10 +56,11 @@ std::vector<HloPosition> HloBuffer::ComputePositions() const { } string HloBuffer::ToString() const { - return StrCat("HloBuffer ", id_, ", values: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return absl::StrCat( + "HloBuffer ", id_, ", values: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer) { diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 441288da1a..c2d0673f49 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -23,9 +23,13 @@ limitations under the License. #include <set> #include <sstream> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -36,13 +40,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::StrCat; +using absl::StrCat; std::unique_ptr<HloComputation> HloComputation::Builder::Build( HloInstruction* root_instruction) { @@ -56,8 +58,8 @@ std::unique_ptr<HloComputation> HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, - root, fusion_instruction_)); + return absl::WrapUnique(new HloComputation( + name_, parameter_count, &instructions_, root, fusion_instruction_)); } HloComputation::HloComputation( @@ -135,7 +137,7 @@ string RenameFusionParameter(const string& original_name, int64 new_param_no) { } string after_param = original_name.substr(index + param_underscore.size()); int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + if (absl::SimpleAtoi(after_param, &numeric_suffix)) { return StrCat(original_name.substr(0, index + param_underscore.size()), new_param_no); } @@ -317,11 +319,12 @@ void ComputeComputationPostOrder( } } -enum State { kVisiting, kVisited }; +} // namespace -void ComputeInstructionPostOrder( +void HloComputation::ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) { + tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const { std::vector<HloInstruction*> dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -354,16 +357,71 @@ void ComputeInstructionPostOrder( for (HloInstruction* op : current->control_predecessors()) { dfs_stack.emplace_back(op); } + + // Add inputs for send->recv_done dependencies and cross-replica-sum + // dependencies. + switch (current->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(current->channel_id()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = current->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + for (HloInstruction* op : it->second) { + dfs_stack.emplace_back(op); + } + } + } + break; + } + default: + break; + } } } -} // namespace +HloComputation::ChannelDependencyMap +HloComputation::ComputeChannelDependencies() const { + ChannelDependencyMap channel_dependency_map; + for (const auto& instruction : instructions_) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + channel_dependency_map[instruction->channel_id()].push_back( + instruction.get()); + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = instruction->all_reduce_id(); + if (all_reduce_id) { + auto& dependencies = channel_dependency_map[all_reduce_id.value()]; + absl::c_copy(instruction->operands(), + std::back_inserter(dependencies)); + absl::c_copy(instruction->control_predecessors(), + std::back_inserter(dependencies)); + } + break; + } + default: + break; + } + } + return channel_dependency_map; +} std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> post_order; post_order.reserve(instruction_count()); std::vector<HloInstruction*> trace_instructions; - tensorflow::gtl::FlatMap<HloInstruction*, State> visited; + tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -371,7 +429,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - ComputeInstructionPostOrder(&post_order, instruction.get(), &visited); + ComputeInstructionPostOrder(channel_dependency_map, &post_order, + instruction.get(), &visited); } } post_order.insert(post_order.end(), trace_instructions.begin(), @@ -493,9 +552,9 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); - return WrapUnique(new HloComputation(proto.name(), parameter_count, - &instructions, root, - /*fusion_instruction=*/nullptr)); + return absl::WrapUnique(new HloComputation(proto.name(), parameter_count, + &instructions, root, + /*fusion_instruction=*/nullptr)); } void HloComputation::FuseInstructionsInto( @@ -566,16 +625,15 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shapes: %s vs. %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanString(indices_to_copy->shape())); } ShapeIndex index; @@ -605,7 +663,7 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier( if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", - instruction->name().c_str(), name().c_str()); + instruction->name(), name()); } ShapeIndex index; return DeepCopyHelper(instruction, &index, copy_leaf); @@ -624,6 +682,9 @@ ProgramShape HloComputation::ComputeProgramShape() const { } bool HloComputation::operator==(const HloComputation& other) const { + if (this == &other) { + return true; + } std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited; std::function<bool(const HloInstruction*, const HloInstruction*)> eq = [&visited, &eq](const HloInstruction* a, const HloInstruction* b) { @@ -674,13 +735,37 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability() const { const auto& all = MakeInstructionPostOrder(); - auto result = MakeUnique<HloReachabilityMap>(all); + auto result = absl::make_unique<HloReachabilityMap>(all); + auto channel_dependency_map = ComputeChannelDependencies(); std::vector<HloInstruction*> inputs; for (const HloInstruction* hlo : all) { inputs.assign(hlo->operands().begin(), hlo->operands().end()); inputs.insert(inputs.end(), hlo->control_predecessors().begin(), hlo->control_predecessors().end()); + + switch (hlo->opcode()) { + case HloOpcode::kRecvDone: { + auto it = channel_dependency_map.find(hlo->channel_id()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + break; + } + case HloOpcode::kCrossReplicaSum: { + auto all_reduce_id = hlo->all_reduce_id(); + if (all_reduce_id) { + auto it = channel_dependency_map.find(all_reduce_id.value()); + if (it != channel_dependency_map.end()) { + absl::c_copy(it->second, std::back_inserter(inputs)); + } + } + break; + } + default: + break; + } + result->FastSetReachabilityToUnion(inputs, hlo); } return result; @@ -723,11 +808,10 @@ std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const { } } VLOG(3) << "Unreachable roots:" - << tensorflow::str_util::Join( - unreachable_roots, "\n\t", - [](string* out, const HloInstruction* hlo) { - tensorflow::strings::StrAppend(out, hlo->ToString()); - }); + << absl::StrJoin(unreachable_roots, "\n\t", + [](string* out, const HloInstruction* hlo) { + absl::StrAppend(out, hlo->ToString()); + }); return unreachable_roots; } @@ -829,7 +913,7 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( HloCloneContext* context, const string& suffix) { std::unique_ptr<HloCloneContext> context_ptr; if (context == nullptr) { - context_ptr = MakeUnique<HloCloneContext>(parent(), suffix); + context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix); context = context_ptr.get(); } @@ -898,12 +982,11 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } -HloInstruction* HloComputation::GetInstructionWithName( - tensorflow::StringPiece name) { +HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) { auto instructions_in_computation = instructions(); - auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) { - return instr->name() == name; - }); + auto it = absl::c_find_if( + instructions_in_computation, + [&](HloInstruction* instr) { return instr->name() == name; }); return it == instructions_in_computation.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 49ed65910f..59016624f7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -367,7 +367,7 @@ class HloComputation { // Returns the instruction in this computation that has name `name`. Returns // null if there is no such computation. - HloInstruction* GetInstructionWithName(tensorflow::StringPiece name); + HloInstruction* GetInstructionWithName(absl::string_view name); int64 unique_id() const { return unique_id_; } @@ -399,6 +399,20 @@ class HloComputation { // Internal helper to collect unreachable roots. std::vector<HloInstruction*> CollectUnreachableRoots() const; + // Returns a map from channel-id to directed dependencies of the channel + // instructions. For send&recv pairs it means the send instruction and for + // cross-replica-sum the union of the dependencies for all participating + // instructions. + using ChannelDependencyMap = + tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>; + ChannelDependencyMap ComputeChannelDependencies() const; + + enum VisitState { kVisiting, kVisited }; + void ComputeInstructionPostOrder( + const HloComputation::ChannelDependencyMap& channel_dependency_map, + std::vector<HloInstruction*>* post_order, HloInstruction* root, + tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const; + string name_; int64 unique_id_; HloInstruction* root_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index e4c5470331..f7ed1b0316 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) { EXPECT_EQ(computation->ToString(options), expected_computation2); } -} // namespace +TEST_F(HloComputationTest, ChannelReachability) { + const Shape shape = ShapeUtil::MakeShape(F32, {5, 7}); + HloComputation::Builder builder("ChannelReachability"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto token0 = builder.AddInstruction(HloInstruction::CreateToken()); + auto send = + builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + auto token1 = builder.AddInstruction(HloInstruction::CreateToken()); + auto recv = + builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(recv_done)); + auto reachability = computation->ComputeReachability(); + EXPECT_TRUE(reachability->IsReachable(param, recv_done)); + EXPECT_FALSE(reachability->IsReachable(send, recv)); + EXPECT_FALSE(reachability->IsReachable(send_done, recv)); +} + +} // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 7229031c0c..2ed645c3ae 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -38,7 +39,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { // Limit the constant folding to 0 iterations to skip folding loops. This // retains the behavior from before while loop support in HloEvaluator and may // be revised. - auto evaluator = MakeUnique<HloEvaluator>(/*max_loop_iterations=*/0); + auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0); XLA_VLOG_LINES(2, "HloConstantFolding::Run(), before:\n" + module->ToString()); @@ -51,9 +52,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { computation->root_instruction() != instruction) { continue; } - // Skip Constant, Parameter, Reduce, and AfterAll operation. - // TODO(b/35975797): Enable Reduce operation once arbitrary computation - // are supported by the evaluator. + // Skip Constant, Parameter, and AfterAll operation. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. // TODO(b/110532604): Enable AfterAll once AfterAll requires at least one // operand in which case constant folding will be impossible and this @@ -61,7 +60,6 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || instruction->opcode() == HloOpcode::kTuple || - instruction->opcode() == HloOpcode::kReduce || instruction->opcode() == HloOpcode::kAfterAll) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 331480bd02..4557983a9c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -25,7 +25,7 @@ namespace xla { // computation on constants. class HloConstantFolding : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "constant_folding"; } + absl::string_view name() const override { return "constant_folding"; } // Run constant folding operations on the given module. Returns whether the // module was changed (constant expressions folded). diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 64a42c1efc..7cd1481a8a 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -202,5 +203,45 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { EXPECT_TRUE(matched); } +const char* const kConstantFoldReduce = R"( + HloModule ConstantFoldReduce + + add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = s32[] add(a, b) + } + + ENTRY r { + x = s32[3] constant({1, 2, 3}) + init = s32[] constant(0) + ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add + })"; + +TEST_F(HloConstantFoldingTest, ConstantFoldReduce) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(kConstantFoldReduce)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_TRUE(result); + + EXPECT_EQ(6, module->entry_computation() + ->root_instruction() + ->literal() + .GetFirstElement<int32>()); +} + +TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(kConstantFoldReduce)); + HloInstruction* add = module->computations().begin()->root_instruction(); + LayoutUtil::ClearLayout(add->mutable_shape()); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1bbb0ff08e..0e12a1ee03 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -258,10 +258,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { - return Status::OK(); -} - Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -544,15 +540,10 @@ Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { } Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { - // TODO(b/110096724): Compute correct cost here. - double flops = 0.0; - ShapeUtil::ForEachSubshape(hlo->shape(), - [&](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsArray(subshape)) { - flops += ShapeUtil::ElementsIn(subshape); - } - }); - current_properties_[kFlopsKey] = flops; + return Status::OK(); +} + +Status HloCostAnalysis::HandleCollectivePermute(const HloInstruction* /*hlo*/) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 193a04bea0..c6a2007904 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -72,9 +72,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; + Status HandleCollectivePermute(const HloInstruction* hlo) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; - Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 858992a326..131846794d 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -14,15 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" namespace xla { +using absl::StrCat; using tensorflow::gtl::ArraySlice; -using tensorflow::strings::StrCat; StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { @@ -149,13 +151,13 @@ StatusOr<HloInstruction*> MakeConcatHlo(ArraySlice<HloInstruction*> operands, CHECK_GT(operands.size(), 0); HloComputation* computation = operands[0]->parent(); - CHECK(c_all_of(operands, [&](HloInstruction* instr) { + CHECK(absl::c_all_of(operands, [&](HloInstruction* instr) { return instr->parent() == computation; })); std::vector<const Shape*> operand_shapes; - c_transform(operands, std::back_inserter(operand_shapes), - [](HloInstruction* instr) { return &instr->shape(); }); + absl::c_transform(operands, std::back_inserter(operand_shapes), + [](HloInstruction* instr) { return &instr->shape(); }); TF_ASSIGN_OR_RETURN(Shape concat_shape, ShapeInference::InferConcatOpShape( operand_shapes, dimension)); @@ -228,7 +230,7 @@ StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand, const Shape& operand_shape = operand->shape(); new_shape_dims.reserve(n + operand_shape.dimensions_size()); new_shape_dims.insert(new_shape_dims.begin(), n, 1); - c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); + absl::c_copy(operand_shape.dimensions(), std::back_inserter(new_shape_dims)); return MakeReshapeHlo(new_shape_dims, operand); } @@ -240,7 +242,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims( std::vector<int64> expanded_shape_dim_bounds; expanded_shape_dim_bounds.reserve(expanded_dims.size() + operand->shape().dimensions_size() - 1); - c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); + absl::c_copy(expanded_dims, std::back_inserter(expanded_shape_dim_bounds)); std::copy(operand->shape().dimensions().begin() + 1, operand->shape().dimensions().end(), std::back_inserter(expanded_shape_dim_bounds)); @@ -251,7 +253,7 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims( StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, ArraySlice<int64> dims_to_elide) { - CHECK(c_is_sorted(dims_to_elide)); + CHECK(absl::c_is_sorted(dims_to_elide)); const Shape& input_shape = operand->shape(); // First accumulate in reverse @@ -268,7 +270,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, } } - c_reverse(new_shape_dim_bounds); + absl::c_reverse(new_shape_dim_bounds); Shape output_shape = ShapeUtil::MakeShape(input_shape.element_type(), new_shape_dim_bounds); return MakeReshapeHlo(output_shape, operand); @@ -276,7 +278,7 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, StatusOr<HloInstruction*> InsertDegenerateDims( HloInstruction* operand, ArraySlice<int64> dims_to_insert) { - CHECK(c_is_sorted(dims_to_insert)); + CHECK(absl::c_is_sorted(dims_to_insert)); const Shape& operand_shape = operand->shape(); int64 output_shape_rank = @@ -318,7 +320,7 @@ StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand, *padding_config.add_dimensions() = padding_config_dim; HloInstruction* zero = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique<Literal>( + HloInstruction::CreateConstant(absl::make_unique<Literal>( LiteralUtil::Zero(operand->shape().element_type())))); return MakePadHlo(operand, zero, padding_config); } @@ -328,15 +330,15 @@ StatusOr<HloInstruction*> BroadcastZeros( ArraySlice<int64> broadcast_dimensions) { HloInstruction* zero = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique<Literal>(LiteralUtil::Zero(element_type)))); + absl::make_unique<Literal>(LiteralUtil::Zero(element_type)))); return MakeBroadcastHlo(zero, /*broadcast_dimensions=*/{}, /*result_shape_bounds=*/broadcast_dimensions); } StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature( ArraySlice<const Shape*> domain, const Shape& range, - tensorflow::StringPiece name) { - HloComputation::Builder b{std::string(name)}; + absl::string_view name) { + HloComputation::Builder b{string(name)}; int64 param_idx = 0; for (const Shape* param_shape : domain) { b.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 5ff8946fb0..1bc6d09b45 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -177,7 +177,7 @@ StatusOr<HloInstruction*> BroadcastZeros( // a value of type `range`. StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature( tensorflow::gtl::ArraySlice<const Shape*> domain, const Shape& range, - tensorflow::StringPiece name); + absl::string_view name); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc index 60d3e71757..a8de285d16 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -28,7 +28,7 @@ using tensorflow::gtl::ArraySlice; class HloCreationUtilsTest : public HloTestBase { protected: - static std::unique_ptr<HloModule> CreateModuleWithProgramShape( + std::unique_ptr<HloModule> CreateModuleWithProgramShape( PrimitiveType primitive_type, ArraySlice<int64> input_shape_dims, ArraySlice<int64> output_shape_dims, HloInstruction** param, HloComputation** entry_computation) { diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 06484f4012..cb367adf5e 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -103,6 +104,9 @@ int64 CseHash(const HloInstruction* instruction) { for (auto operand : instruction->operands()) { hash = tensorflow::Hash64Combine(hash, operand->unique_id()); } + if (instruction->opcode() == HloOpcode::kConstant) { + hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash()); + } return hash; } diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index 5e2b348bdd..a28c03599a 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -34,7 +34,7 @@ class HloCSE : public HloPassInterface { : is_layout_sensitive_(is_layout_sensitive), only_fusion_computations_(only_fusion_computations) {} ~HloCSE() override = default; - tensorflow::StringPiece name() const override { return "cse"; } + absl::string_view name() const override { return "cse"; } // Run CSE on the given module. Returns whether the module was changed (common // subexpressions were found and eliminated). diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 90fbaa37c5..406d712ec6 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index bbfb0c253f..3376d170e6 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,8 +19,10 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -29,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -78,8 +78,8 @@ bool MultiDynamicSliceUseShareSameIndices( } // namespace -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; HloDataflowAnalysis::HloDataflowAnalysis( const HloModule& module, bool ssa_form, bool bitcast_defines_value, @@ -93,7 +93,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { tensorflow::gtl::FlatSet<const HloInstruction*> visited; - tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack; + absl::InlinedVector<const HloInstruction*, 4> stack; stack.push_back(inst); while (!stack.empty()) { const HloInstruction* current = stack.back(); @@ -837,7 +837,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Unimplemented( "Computation %s is called in both a parallel (eg, kMap) and " "sequential (eg, kCall) context", - computation->name().c_str()); + computation->name()); } if (call_graph_node.caller_callsites().empty() || call_graph_node.context() == CallContext::kParallel) { @@ -886,7 +886,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto dataflow_analysis = WrapUnique(new HloDataflowAnalysis( + auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis( module, ssa_form, bitcast_defines_value, fusion_can_share_buffer)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); @@ -976,28 +976,22 @@ Status HloDataflowAnalysis::Verify() const { bool HloDataflowAnalysis::DoesNotUseOperandBuffer( const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - if (user->opcode() == HloOpcode::kFusion && - user->fusion_kind() == HloInstruction::FusionKind::kLoop) { - // Find fusion parameter associated with 'operand'. - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - // Iterate through all users of all uses of the fusion parameter value. - // Return false if any uses are detected, returns true otherwise. - const HloValue& value = GetValueDefinedAt(fusion_param, index); - return value.uses().empty(); - } else { - // Return false if no value at 'operand' and 'index' is used at 'user'. - for (const HloValue* value : GetValueSet(operand, index).values()) { - for (const HloUse& use : value->uses()) { - if (use.instruction == user) { - return false; + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + HloInstruction* fusion_param = + user->fused_parameter(use.operand_number); + const HloValue& value = + GetValueDefinedAt(fusion_param, use.operand_index); + return value.uses().empty(); } + return false; } } } - return true; } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index f4abc7a7c7..a1678d4943 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -138,7 +138,8 @@ class HloDataflowAnalysis { // Returns true if 'user' cannot possibly use the buffer at 'index' in // 'operand'. Returns false otherwise. // - // REQUIRES: 'operand' is an operand of 'user'. + // 'operand' does not have to be an operand of 'user'. This can be the case + // with indirect uses. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user) const; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4755c4a0cf..d1a96c10f8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1963,6 +1963,54 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); } +// Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the +// parameter tuple. +TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto t0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0)); + auto t1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1)); + // Swap the tuple elements. + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0})); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {dynamic_update_slice, starts, update, gte1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction never uses tuple element 0, but does use element 1. + EXPECT_TRUE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {0}, fusion)); + EXPECT_FALSE(dataflow_analysis_->DoesNotUseOperandBuffer(tuple, {1}, fusion)); + // The same holds for the parameter tuple, except that the tuple elements are + // swapped in 'tuple'. + EXPECT_TRUE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {1}, fusion)); + EXPECT_FALSE( + dataflow_analysis_->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); +} + class CanShareOperandBufferWithUserTest : public HloDataflowAnalysisTestBase {}; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 4e244494d6..1fe69b1395 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -36,7 +36,7 @@ namespace xla { class HloDCE : public HloPassInterface { public: ~HloDCE() override {} - tensorflow::StringPiece name() const override { return "dce"; } + absl::string_view name() const override { return "dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 26e3736e01..3b5cde2996 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include <memory> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 78955db0da..72185698c9 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,31 +31,10 @@ class HloDomainIsolator::RunContext { StatusOr<bool> Run(); private: - // Inserts a kDomain instruction between parent and operand, in case - // the attribute (ie, sharding) values change between instruction and operand. - // Returns the newly inserted kDomain instruction, or nullptr if no kDomain - // instruction was necessary. - StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction, - HloInstruction* parent, - HloInstruction* operand); - HloModule* module_; HloDomainIsolator* isolator_; }; -StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* parent, - HloInstruction* operand) { - HloInstruction* domain = nullptr; - std::unique_ptr<HloInstruction> domain_instruction = - isolator_->creator_(instruction, operand); - if (domain_instruction != nullptr) { - domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); - } - return domain; -} - StatusOr<bool> HloDomainIsolator::RunContext::Run() { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); @@ -71,16 +50,16 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() { // When applying multiple domains, we could end up stacking more than // one in one edge, so here we want to build the effective // (kDomain-less) instruction->operand edge. - HloInstruction* parent = instruction; - while (operand->opcode() == HloOpcode::kDomain) { - parent = operand; - operand = operand->mutable_operand(0); + HloInstruction* root = operand; + while (root->opcode() == HloOpcode::kDomain) { + root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. - TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, parent, operand)); + HloInstruction* domain = + isolator_->creator_(instruction, root, operand); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); ++added_domains; } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index eded3e78ee..d36631fc2f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -34,14 +34,16 @@ class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the - // second HloInstruction argument). + // third HloInstruction argument) if the interesting attribute of the + // instruction differes from the attribute of the root (the second + // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. - using DomainCreator = std::function<std::unique_ptr<HloInstruction>( - HloInstruction*, HloInstruction*)>; + using DomainCreator = std::function<HloInstruction*( + HloInstruction*, HloInstruction*, HloInstruction*)>; explicit HloDomainIsolator(DomainCreator creator); - tensorflow::StringPiece name() const override { return "domain_isolator"; } + absl::string_view name() const override { return "domain_isolator"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 9e096320db..8b2846e0c2 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,7 @@ limitations under the License. #include <algorithm> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" @@ -25,14 +26,14 @@ namespace xla { /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create( HloComputation* computation, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); TF_RETURN_IF_ERROR(domain_map->Populate(computation)); return std::move(domain_map); } /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create( HloModule* module, string domain_kind) { - auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind))); for (HloComputation* computation : module->computations()) { TF_RETURN_IF_ERROR(domain_map->Populate(computation)); } @@ -56,14 +57,14 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { // both sides. for (HloInstruction* operand : instruction->unique_operands()) { if (IsDomainInstruction(operand)) { - auto domain = MakeUnique<DomainMetadata::Domain>(); + auto domain = absl::make_unique<DomainMetadata::Domain>(); domain->enter_domains.insert(operand); domain->exit_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } } if (instruction == instruction->parent()->root_instruction()) { - auto domain = MakeUnique<DomainMetadata::Domain>(); + auto domain = absl::make_unique<DomainMetadata::Domain>(); domain->enter_domains.insert(instruction); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } @@ -71,6 +72,11 @@ Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { } Status HloDomainMap::Populate(HloComputation* computation) { + InstructionOrderMap instructions_post_order; + int64 count = 0; + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + instructions_post_order.insert(std::make_pair(instruction, count++)); + } for (HloInstruction* instruction : computation->instructions()) { if (IsDomainInstruction(instruction)) { // If this is a kDomain of the kind we are currently processing, check @@ -84,7 +90,7 @@ Status HloDomainMap::Populate(HloComputation* computation) { continue; } TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain, - CreateDomain(instruction)); + CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } return Status::OK(); @@ -142,10 +148,12 @@ Status HloDomainMap::ExpandDomain(HloInstruction* instruction, } StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain( - HloInstruction* instruction) const { - auto domain = MakeUnique<DomainMetadata::Domain>(); + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const { + auto domain = absl::make_unique<DomainMetadata::Domain>(); TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); - domain->instructions = MakeNonDomainInstructions(domain->reach_set); + domain->instructions = + MakeNonDomainInstructions(domain->reach_set, instructions_order); return std::move(domain); } @@ -167,7 +175,8 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector<HloInstruction*> HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set) { + const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const InstructionOrderMap& instructions_order) { std::vector<HloInstruction*> instructions; instructions.reserve(instruction_set.size()); for (HloInstruction* instruction : instruction_set) { @@ -175,9 +184,10 @@ HloDomainMap::MakeNonDomainInstructions( instructions.push_back(instruction); } } + // sort instructions according to instructions_order std::sort(instructions.begin(), instructions.end(), - [](HloInstruction* a, HloInstruction* b) { - return a->unique_id() < b->unique_id(); + [&instructions_order](HloInstruction* a, HloInstruction* b) { + return instructions_order.at(a) < instructions_order.at(b); }); return instructions; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 1ca7159725..633109249a 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -70,6 +70,11 @@ class HloDomainMap { int64 GetDomainId(HloInstruction* instruction) const; private: + // Map used for representing instruction ordering, i.e. + // order_map[a] < order_map[b] means a must be ordered before b. + using InstructionOrderMap = + tensorflow::gtl::FlatMap<const HloInstruction*, int64>; + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} // Check if the kDomain instruction is facing (via its operand link) another @@ -95,12 +100,14 @@ class HloDomainMap { // Creates a domain data structure using the ExpandDomain() API. StatusOr<std::unique_ptr<DomainMetadata::Domain>> CreateDomain( - HloInstruction* instruction) const; + HloInstruction* instruction, + const InstructionOrderMap& instructions_order) const; // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector<HloInstruction*> MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set); + const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const InstructionOrderMap& instructions_order); string domain_kind_; std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_; diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index f855f2a1fc..6c142ee474 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,10 +20,10 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -44,7 +44,10 @@ class DomainMetadata { // two domains of different kind intersect each other. tensorflow::gtl::FlatSet<HloInstruction*> reach_set; - // The same instructions in reach_set, but purged from kDomain instructions. + // The same instructions in reach_set, but purged from kDomain instructions + // and ordered according to their computation graph post-order, i.e. + // if instructions[pos_a] depends on instructions[pos_b], then pos_a > + // pos_b. std::vector<HloInstruction*> instructions; // If we consider a graph edge as an arrow oriented from the operand to the @@ -63,7 +66,7 @@ class DomainMetadata { // Returns the metadata type. A unique identifier which describes the real // metadata type. - virtual tensorflow::StringPiece Kind() const = 0; + virtual absl::string_view Kind() const = 0; // Compares the metadata object with another one and returns true if the // two matches. diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index c859e05f02..97bc8ef604 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -35,13 +35,13 @@ class HloDomainRemover : public HloPassInterface { // instructions in it with the same attributes (ie, sharding), a normalizer // function is tasked at applying attribute normalization on the instructions // within such domain. - HloDomainRemover(tensorflow::StringPiece kind, + HloDomainRemover(absl::string_view kind, std::function<Status(const DomainMetadata::Domain&, const DomainMetadata* metadata)> normalizer) - : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + : kind_(kind), normalizer_(std::move(normalizer)) {} - tensorflow::StringPiece name() const override { return "domain_remover"; } + absl::string_view name() const override { return "domain_remover"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 70271be304..c8e0a9e289 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -28,6 +29,11 @@ namespace xla { namespace { class HloDomainTest : public HloVerifiedTestBase { + public: + HloDomainTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: bool FindUserViaDomainPath(HloInstruction* instruction, HloInstruction* operand) const { @@ -45,9 +51,8 @@ class HloDomainTest : public HloVerifiedTestBase { // Checks whether there is a kDomain instruction in the edge between the // instruction and the operand. - bool HasDomainEdge(HloModule* module, - tensorflow::StringPiece instruction_name, - tensorflow::StringPiece operand_name) { + bool HasDomainEdge(HloModule* module, absl::string_view instruction_name, + absl::string_view operand_name) { HloInstruction* instruction = FindInstruction(module, instruction_name); HloInstruction* operand = FindInstruction(module, operand_name); CHECK_NE(instruction, nullptr); @@ -65,7 +70,7 @@ class HloDomainTest : public HloVerifiedTestBase { return false; } - StatusOr<HloModule*> ParseModule(tensorflow::StringPiece hlo_string) { + StatusOr<HloModule*> ParseModule(absl::string_view hlo_string) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); ParseAndVerifyModule(hlo_string, config); @@ -80,10 +85,10 @@ class OpNameMetadata : public DomainMetadata { explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} std::unique_ptr<DomainMetadata> Clone() const override { - return MakeUnique<OpNameMetadata>(opname_); + return absl::make_unique<OpNameMetadata>(opname_); } - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override { const OpNameMetadata* other_ptr = @@ -97,25 +102,26 @@ class OpNameMetadata : public DomainMetadata { string ToString() const override { return opname_; } - static tensorflow::StringPiece KindName() { return "opname"; } + static absl::string_view KindName() { return "opname"; } private: string opname_; }; // Creator function for OpNameMetadata domains. -std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction, - HloInstruction* operand) { - if (instruction->metadata().op_name() == operand->metadata().op_name()) { +HloInstruction* OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } std::unique_ptr<DomainMetadata> operand_side_metadata = - MakeUnique<OpNameMetadata>(operand->metadata().op_name()); + absl::make_unique<OpNameMetadata>(root->metadata().op_name()); std::unique_ptr<DomainMetadata> user_side_metadata = - MakeUnique<OpNameMetadata>(instruction->metadata().op_name()); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); + absl::make_unique<OpNameMetadata>(instruction->metadata().op_name()); + return operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, std::move(operand_side_metadata), + std::move(user_side_metadata))); } Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain, @@ -142,7 +148,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -184,7 +190,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(!isolator_changed); } @@ -211,7 +217,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -248,7 +254,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_FALSE(isolator_changed); } @@ -302,7 +308,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator sharding_isolator(CreateShardingDomain); + HloDomainIsolator sharding_isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, sharding_isolator.Run(module)); EXPECT_TRUE(sharding_isolator_changed); @@ -344,7 +350,8 @@ ENTRY entry { token = token[] after-all() infeed = ((f32[4], f32[4]), token[]) infeed(token), sharding={{maximal device=1}, {maximal device=0}, {maximal device=0}} - infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0 + infeed.data = (f32[4], f32[4]) get-tuple-element(infeed), index=0, + sharding={{maximal device=1}, {maximal device=0}} gte0 = f32[4] get-tuple-element(infeed.data), index=0 gte1 = f32[4] get-tuple-element(infeed.data), index=1 copy0 = f32[4] copy(gte0) @@ -356,7 +363,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); LOG(INFO) << "Original module:\n" << module->ToString(); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -378,11 +385,8 @@ ENTRY entry { // \ / // TUPLE // | - HloInstruction* infeed = FindInstruction(module, "infeed"); - ASSERT_NE(infeed, nullptr); - HloInstruction* infeed_data = - infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* infeed_data = FindInstruction(module, "infeed.data"); + ASSERT_NE(infeed_data, nullptr); auto infeed_data_users = infeed_data->users(); HloInstruction* new_gte0 = infeed_data->parent()->AddInstruction( @@ -445,7 +449,7 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -474,8 +478,8 @@ ENTRY entry { TEST_F(HloDomainTest, DumpParseNullSharding) { auto builder = HloComputation::Builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {}); - auto sharding_md_0 = MakeUnique<ShardingMetadata>(nullptr); - auto sharding_md_1 = MakeUnique<ShardingMetadata>(nullptr); + auto sharding_md_0 = absl::make_unique<ShardingMetadata>(nullptr); + auto sharding_md_1 = absl::make_unique<ShardingMetadata>(nullptr); HloInstruction* param = builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); HloInstruction* domain = builder.AddInstruction(HloInstruction::CreateDomain( @@ -490,6 +494,7 @@ TEST_F(HloDomainTest, DumpParseNullSharding) { ASSERT_TRUE(ParseModule(hlo_string).status().ok()); } +// Tuple inputs are domain instructions. TEST_F(HloDomainTest, DomainTuple) { const char* const hlo_string = R"( HloModule Module @@ -497,14 +502,15 @@ HloModule Module ENTRY entry { p0 = f32[4] parameter(0), sharding={maximal device=0} cst = u32[] constant(0), sharding={maximal device=1} - tpl = (u32[], f32[4]) tuple(cst, p0), sharding={{maximal device=1}, {maximal device=0}} + tpl = (u32[], f32[4]) tuple(cst, p0), + sharding={{maximal device=1}, {maximal device=0}} ROOT gte = f32[4] get-tuple-element(tpl), index=1, sharding={maximal device=0} } )"; TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); - HloDomainIsolator isolator(CreateShardingDomain); + HloDomainIsolator isolator(ShardingDomainCreator{}); TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); EXPECT_TRUE(isolator_changed); @@ -523,5 +529,168 @@ ENTRY entry { tpl->sharding()); } +TEST_F(HloDomainTest, MultiDomainMultiUser) { + const char* const hlo_string = R"( + HloModule Module + +ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { + %p0 = (f32[4], f32[4]) parameter(0) + %a = f32[4]{0} get-tuple-element(%p0), index=0 + %domain = f32[4] domain(%a), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %b = f32[4] get-tuple-element(%p0), index=1 + %domain.1 = f32[4] domain(%b), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1} + %domain.2 = f32[4] domain(%c), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %d = f32[4] subtract(%domain, %c), + sharding={maximal device=1}, metadata={op_name="D"} + %domain.3 = f32[4] domain(%d), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %e = f32[4] multiply(%c, %d), + sharding={maximal device=1}, metadata={op_name="D"} + %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1} + %domain.4 = f32[4]{0} domain(%f), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + +// Emulate instructions inserted at top and bottom within nested tuple domain. +TEST_F(HloDomainTest, DomainTupleTopBottomInsert) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = f32[4] parameter(0), sharding={maximal device=1} + p1 = (f32[5], f32[6]) parameter(1), + sharding={{maximal device=1}, {maximal device=0}} + tuple.0 = (f32[4], (f32[5], f32[6])) tuple(p0, p1), + sharding={{maximal device=1}, {maximal device=1}, {maximal device=0}} + ROOT res = (f32[5], f32[6]) get-tuple-element(tuple.0), index=1, + sharding={{maximal device=1}, {maximal device=0}} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + + HloDomainIsolator isolator(ShardingDomainCreator{}); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module)); + EXPECT_TRUE(isolator_changed); + + // Clear sharding of tuple.0 instruction, in order to test domain sharding + // application. + auto tuple0 = FindInstruction(module, "tuple.0"); + tuple0->clear_sharding(); + + // Insert the following instructons above and below tuple.0, to emulate other + // passes effects: + // COPY.0 + // \ / + // TUPLE.0 + // / \ + // COPY.1 \ + // / \ + // GTE.0 GTE.1 + // | | + // | COPY.2 + // \ / + // \ / + // TUPLE.1 + // | + auto tuple0_users = tuple0->users(); + auto computation = tuple0->parent(); + HloInstruction* copy0 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->operand(1)->shape(), HloOpcode::kCopy, + tuple0->mutable_operand(1))); + TF_EXPECT_OK(tuple0->ReplaceOperandWith(1, copy0)); + + HloInstruction* copy1 = computation->AddInstruction( + HloInstruction::CreateUnary(tuple0->shape(), HloOpcode::kCopy, tuple0)); + HloInstruction* gte0 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(copy1->shape(), 0), copy1, 0)); + HloInstruction* gte1 = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(tuple0->shape(), 1), tuple0, 1)); + HloInstruction* copy2 = computation->AddInstruction( + HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1)); + HloInstruction* tuple1 = + computation->AddInstruction(HloInstruction::CreateTuple({gte0, copy2})); + + for (HloInstruction* user : tuple0_users) { + TF_EXPECT_OK(tuple0->ReplaceUseWith(user, tuple1)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module)); + EXPECT_TRUE(remover_changed); + + EXPECT_TRUE(tuple0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(tuple0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + tuple0->sharding()); + + EXPECT_TRUE(copy0->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy0->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy0->sharding()); + + // copy1 has partial information only from gte.0, so in the end it gets no + // sharding at all. During propagation it does propagate the information from + // gte.0 though, enabling Tuple.0 to be fully sharded. + EXPECT_FALSE(copy1->has_sharding()); + + EXPECT_TRUE(gte0->has_sharding()); + EXPECT_EQ(HloSharding::AssignDevice(1), gte0->sharding()); + + EXPECT_TRUE(gte1->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(gte1->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + gte1->sharding()); + + EXPECT_TRUE(copy2->has_sharding()); + EXPECT_EQ(HloSharding::Tuple(copy2->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)}), + copy2->sharding()); + + EXPECT_TRUE(tuple1->has_sharding()); + EXPECT_EQ(tuple0->sharding(), tuple1->sharding()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc index 751fc677e2..dc514ae3e5 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.cc @@ -52,7 +52,7 @@ Status HloDomainVerifier::RunContext::PopulateDomainKinds() { TF_RET_CHECK(instruction->user_side_metadata().Kind() == instruction->operand_side_metadata().Kind()) << instruction->ToString(); - kinds.insert(instruction->user_side_metadata().Kind().ToString()); + kinds.insert(string(instruction->user_side_metadata().Kind())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 8e53cf97f8..81d6d69a8c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -33,7 +33,7 @@ class HloDomainVerifier : public HloPassInterface { public: HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {} - tensorflow::StringPiece name() const override { return "domain_verifier"; } + absl::string_view name() const override { return "domain_verifier"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 2b109225d0..44ded2c2fa 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -32,9 +32,7 @@ class HloElementTypeConverter : public HloPassInterface { HloElementTypeConverter(PrimitiveType eliminate_type, PrimitiveType replace_with_type); - tensorflow::StringPiece name() const override { - return "element_type_converter"; - } + absl::string_view name() const override { return "element_type_converter"; } // Returns the pass on the module and returns whether the module was modified. StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 36d6a2eed6..71f91fde93 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -23,13 +23,15 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/index_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -95,7 +96,7 @@ StatusOr<std::unique_ptr<Literal>> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = MakeUnique<Literal>(shape); + auto result = absl::make_unique<Literal>(shape); TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) { return compare_op(lhs_literal.Get<OperandT>(multi_index), rhs_literal.Get<OperandT>(multi_index)); @@ -125,7 +126,7 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>( << HloOpcodeString(opcode); } - auto result = MakeUnique<Literal>(shape); + auto result = absl::make_unique<Literal>(shape); TF_RETURN_IF_ERROR(result->Populate<bool>([&](ArraySlice<int64> multi_index) { return compare_op(lhs_literal.Get<complex64>(multi_index), rhs_literal.Get<complex64>(multi_index)); @@ -138,44 +139,57 @@ StatusOr<std::unique_ptr<Literal>> Compare<complex64>( HloEvaluator::HloEvaluator(int64 max_loop_iterations) : max_loop_iterations_(max_loop_iterations) { - typed_visitors_[PRED] = MakeUnique<HloEvaluatorTypedVisitor<bool>>(this); - typed_visitors_[U8] = MakeUnique<HloEvaluatorTypedVisitor<uint8>>(this); - typed_visitors_[U16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "U16."); - }); - typed_visitors_[U32] = MakeUnique<HloEvaluatorTypedVisitor<uint32>>(this); - typed_visitors_[U64] = MakeUnique<HloEvaluatorTypedVisitor<uint64>>(this); - typed_visitors_[S8] = MakeUnique<HloEvaluatorTypedVisitor<int8>>(this); - typed_visitors_[S16] = MakeUnique<FunctionVisitor>([](HloInstruction*) { - return Unimplemented( - "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " - "S16."); - }); - typed_visitors_[S32] = MakeUnique<HloEvaluatorTypedVisitor<int32>>(this); - typed_visitors_[S64] = MakeUnique<HloEvaluatorTypedVisitor<int64>>(this); + typed_visitors_[PRED] = + absl::make_unique<HloEvaluatorTypedVisitor<bool>>(this); + typed_visitors_[U8] = + absl::make_unique<HloEvaluatorTypedVisitor<uint8>>(this); + typed_visitors_[U16] = + absl::make_unique<FunctionVisitor>([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "U16."); + }); + typed_visitors_[U32] = + absl::make_unique<HloEvaluatorTypedVisitor<uint32>>(this); + typed_visitors_[U64] = + absl::make_unique<HloEvaluatorTypedVisitor<uint64>>(this); + typed_visitors_[S8] = absl::make_unique<HloEvaluatorTypedVisitor<int8>>(this); + typed_visitors_[S16] = + absl::make_unique<FunctionVisitor>([](HloInstruction*) { + return Unimplemented( + "HloEvaluator::HloEvaluatorTypedVisitor: unhandled primitive type: " + "S16."); + }); + typed_visitors_[S32] = + absl::make_unique<HloEvaluatorTypedVisitor<int32>>(this); + typed_visitors_[S64] = + absl::make_unique<HloEvaluatorTypedVisitor<int64>>(this); typed_visitors_[F16] = - MakeUnique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this); - typed_visitors_[F32] = MakeUnique<HloEvaluatorTypedVisitor<float>>(this); - typed_visitors_[F64] = MakeUnique<HloEvaluatorTypedVisitor<double>>(this); - typed_visitors_[C64] = MakeUnique<HloEvaluatorTypedVisitor<complex64>>(this); + absl::make_unique<HloEvaluatorTypedVisitor<Eigen::half, float>>(this); + typed_visitors_[F32] = + absl::make_unique<HloEvaluatorTypedVisitor<float>>(this); + typed_visitors_[F64] = + absl::make_unique<HloEvaluatorTypedVisitor<double>>(this); + typed_visitors_[C64] = + absl::make_unique<HloEvaluatorTypedVisitor<complex64>>(this); // Most of the evaluator computations we use don't support BF16 (e.g., // std::ceil, std::tanh). To make evaluator work with BF16, we set all // elementwise computations to be done in F32 and do BF16<->F32 conversion // around the input and the output of the computations. typed_visitors_[BF16] = - MakeUnique<HloEvaluatorTypedVisitor<bfloat16, float>>(this); - - typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); - }); - typed_visitors_[OPAQUE] = MakeUnique<FunctionVisitor>([](HloInstruction*) { - return Unimplemented( - "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); - }); + absl::make_unique<HloEvaluatorTypedVisitor<bfloat16, float>>(this); + + typed_visitors_[TUPLE] = + absl::make_unique<FunctionVisitor>([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: TUPLE."); + }); + typed_visitors_[OPAQUE] = + absl::make_unique<FunctionVisitor>([](HloInstruction*) { + return Unimplemented( + "HloEvaluatorTypedVisitor: unhandled primitive type: OPAQUE."); + }); } template <typename LiteralPtr> @@ -216,7 +230,6 @@ template <typename LiteralPtr> StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( HloInstruction* instruction, ArraySlice<LiteralPtr> arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); evaluated_.clear(); arg_literals_.clear(); @@ -253,7 +266,6 @@ StatusOr<std::unique_ptr<Literal>> HloEvaluator::Evaluate( return tensorflow::errors::FailedPrecondition( "Not all operands are constants."); } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); evaluated_.clear(); @@ -423,7 +435,7 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", - PrimitiveType_Name(operand->shape().element_type()).c_str()); + PrimitiveType_Name(operand->shape().element_type())); } switch (operand->shape().element_type()) { @@ -464,9 +476,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare) { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s", - ShapeUtil::HumanString(compare->shape()).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(compare->shape()), + ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } TF_RET_CHECK(lhs->shape().element_type() == rhs->shape().element_type()); @@ -564,7 +576,8 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputBatchIndices( std::vector<int64> index_count; index_count.reserve(output_rank); for (int64 i = 0; i < output_rank; i++) { - bool is_output_batch_dim = !c_binary_search(dim_numbers.offset_dims(), i); + bool is_output_batch_dim = + !absl::c_binary_search(dim_numbers.offset_dims(), i); index_count.push_back(is_output_batch_dim ? output_shape.dimensions(i) : 1); } @@ -581,10 +594,11 @@ ShapeUtil::IndexIterationSpace IterationSpaceForOutputOffsetIndices( std::vector<int64> index_count(output_rank, 1); int64 slice_sizes_idx = 0; for (int64 i = 0; i < output_rank; i++) { - bool is_output_window_dim = c_binary_search(dim_numbers.offset_dims(), i); + bool is_output_window_dim = + absl::c_binary_search(dim_numbers.offset_dims(), i); if (is_output_window_dim) { - while (c_binary_search(dim_numbers.collapsed_slice_dims(), - slice_sizes_idx)) { + while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), + slice_sizes_idx)) { slice_sizes_idx++; } index_count[i] = slice_sizes[slice_sizes_idx++]; @@ -610,13 +624,13 @@ class OutputBatchIndexToInputIndex { : dim_numbers_(*dim_numbers), start_indices_(*start_indices) { for (int64 i = 0; i < output_shape.dimensions_size(); i++) { output_dim_is_batch_dims_.push_back( - !c_binary_search(dim_numbers_.offset_dims(), i)); + !absl::c_binary_search(dim_numbers_.offset_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { int64 index_of_input_dim_in_index_vector = std::distance(dim_numbers_.start_index_map().begin(), - c_find(dim_numbers_.start_index_map(), i)); + absl::c_find(dim_numbers_.start_index_map(), i)); if (index_of_input_dim_in_index_vector == dim_numbers_.start_index_map_size()) { input_dim_value_to_index_vector_.push_back(-1); @@ -736,7 +750,7 @@ class OutputOffsetIndexToInputIndex { std::vector<int64> window_index_to_output_index; int64 output_index_count = 0; for (int64 i = 0; i < output_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.offset_dims(), i)) { + if (absl::c_binary_search(dim_numbers.offset_dims(), i)) { window_index_to_output_index.push_back(output_index_count++); } else { output_index_count++; @@ -745,7 +759,7 @@ class OutputOffsetIndexToInputIndex { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { input_dim_value_to_output_index_.push_back(-1); } else { input_dim_value_to_output_index_.push_back( @@ -953,7 +967,7 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = MakeUnique<Literal>( + evaluated_[get_tuple_element] = absl::make_unique<Literal>( ShapeUtil::GetTupleElementShape(operand->shape(), index)); return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, /*dest_shape_index=*/{}, @@ -1091,8 +1105,8 @@ Status HloEvaluator::HandleWhile(HloInstruction* while_hlo) { HloEvaluator loop_body_evaluator(max_loop_iterations_); while (keep_going) { if (max_loop_iterations_ >= 0 && iteration_count++ > max_loop_iterations_) { - return InvalidArgument("Loop %s exceeded loop iteration limit (%lld).", - while_hlo->name().c_str(), max_loop_iterations_); + return InvalidArgument("Loop %s exceeded loop iteration limit (%d).", + while_hlo->name(), max_loop_iterations_); } TF_ASSIGN_OR_RETURN(auto cond_val, cond_evaluator.Evaluate<Literal*>( *cond_comp, {lcv.get()})); @@ -1155,10 +1169,11 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal( result_keys.push_back(key_value.first); result_values.push_back(key_value.second); } - auto result_keys_literal = MakeUnique<Literal>(keys_literal.shape()); + auto result_keys_literal = absl::make_unique<Literal>(keys_literal.shape()); result_keys_literal->PopulateR1( tensorflow::gtl::ArraySlice<KeyType>(result_keys)); - auto result_values_literal = MakeUnique<Literal>(values_literal.shape()); + auto result_values_literal = + absl::make_unique<Literal>(values_literal.shape()); result_values_literal->PopulateR1( tensorflow::gtl::ArraySlice<ValueType>(result_values)); return std::make_pair(std::move(result_keys_literal), @@ -1173,8 +1188,9 @@ StatusOr<std::unique_ptr<Literal>> EvaluateSortInternal( } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto keys_result_literal = MakeUnique<Literal>(keys_literal.shape()); - auto values_result_literal = MakeUnique<Literal>(values_literal.shape()); + auto keys_result_literal = absl::make_unique<Literal>(keys_literal.shape()); + auto values_result_literal = + absl::make_unique<Literal>(values_literal.shape()); int64 r1_length = keys_literal.shape().dimensions(1); for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto keys_r1_slice, @@ -1246,7 +1262,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); if (sort_dim != rank - 1) { return Unimplemented( - "Trying to support along dimension %lld, which is not the last " + "Trying to support along dimension %d, which is not the last " "dimension", sort_dim); } @@ -1267,7 +1283,7 @@ Status HloEvaluator::HandleSort(HloInstruction* sort) { Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); - return Status::OK(); + return ShapeUtil::ValidateShape(hlo->shape()); } Status HloEvaluator::Postprocess(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index a4c37ef328..0ea7089552 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,7 +18,7 @@ limitations under the License. #include <memory> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -222,11 +222,11 @@ class HloEvaluator : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(operand->shape()).c_str()); + ShapeUtil::HumanString(shape), + ShapeUtil::HumanString(operand->shape())); } - auto result = MakeUnique<Literal>(shape); + auto result = absl::make_unique<Literal>(shape); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> multi_index) { return unary_op(operand_literal.Get<NativeT>(multi_index)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 1394be68e4..c3af15c6a8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -51,8 +52,11 @@ static std::array<bool, 2> use_bf16_params{true, false}; class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, public HloVerifiedTestBase { protected: - HloEvaluatorTest() : use_bfloat16_(GetParam()) { - evaluator_ = MakeUnique<HloEvaluator>(); + HloEvaluatorTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false), + use_bfloat16_(GetParam()) { + evaluator_ = absl::make_unique<HloEvaluator>(); } std::unique_ptr<Literal> Evaluate( @@ -523,7 +527,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { std::unique_ptr<Literal> result = Evaluate(); - auto expected_array = MakeUnique<Array4D<float>>(8, 5, 1, 1); + auto expected_array = absl::make_unique<Array4D<float>>(8, 5, 1, 1); expected_array->Fill(kPadValue); (*expected_array)(1, 0, 0, 0) = 1.0f; (*expected_array)(1, 2, 0, 0) = 2.0f; @@ -547,7 +551,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique<Array2D<float>>(4, 3); + auto input_array = absl::make_unique<Array2D<float>>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array); HloInstruction* input_instruction = @@ -568,7 +572,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { std::unique_ptr<Literal> result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } - auto expected_array = MakeUnique<Array2D<float>>(1, 5); + auto expected_array = absl::make_unique<Array2D<float>>(1, 5); (*expected_array)(0, 0) = 7.0f; (*expected_array)(0, 1) = 2.718f; (*expected_array)(0, 2) = 2.718f; @@ -588,7 +592,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto input_array = MakeUnique<Array2D<float>>(4, 3); + auto input_array = absl::make_unique<Array2D<float>>(4, 3); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array); HloInstruction* input_instruction = @@ -612,7 +616,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { std::unique_ptr<Literal> result = Evaluate(); - auto expected_array = MakeUnique<Array2D<float>>(0, 9); + auto expected_array = absl::make_unique<Array2D<float>>(0, 9); auto expected = LiteralUtil::CreateR2FromArray2D<float>(*expected_array); EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); @@ -628,7 +632,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { // { 3 }, // { 4 }, // } - auto lhs_array = MakeUnique<Array2D<float>>(4, 1); + auto lhs_array = absl::make_unique<Array2D<float>>(4, 1); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array); HloInstruction* lhs_instruction = @@ -679,7 +683,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique<Array2D<float>>(3, 2); + auto rhs_array = absl::make_unique<Array2D<float>>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array); HloInstruction* rhs_instruction = @@ -710,7 +714,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 9, 10, 11 }, // { 13, 14, 15 }, // } - auto lhs_array = MakeUnique<Array2D<float>>(4, 3); + auto lhs_array = absl::make_unique<Array2D<float>>(4, 3); lhs_array->FillUnique(1.0f); auto lhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*lhs_array); HloInstruction* lhs_instruction = @@ -722,7 +726,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { // { 3, 4 }, // { 5, 6 }, // } - auto rhs_array = MakeUnique<Array2D<float>>(3, 2); + auto rhs_array = absl::make_unique<Array2D<float>>(3, 2); rhs_array->FillUnique(1.0f); auto rhs_literal = LiteralUtil::CreateR2FromArray2D<float>(*rhs_array); HloInstruction* rhs_instruction = @@ -1215,7 +1219,12 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } -class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase { + public: + HloEvaluatorPreciseReduceTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). @@ -1297,7 +1306,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique<Array2D<float>>(2, 3); + auto arg_array = absl::make_unique<Array2D<float>>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array); @@ -1339,7 +1348,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique<Array2D<float>>(2, 3); + auto arg_array = absl::make_unique<Array2D<float>>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array); @@ -1390,7 +1399,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto arg_array = MakeUnique<Array2D<float>>(2, 3); + auto arg_array = absl::make_unique<Array2D<float>>(2, 3); arg_array->FillUnique(1.0f); auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array); @@ -1511,7 +1520,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) { // { 9, 10, 11, 12, 13 }, // { 17, 18, 19, 20, 21 }, // } - auto operand_array = MakeUnique<Array2D<float>>(3, 5); + auto operand_array = absl::make_unique<Array2D<float>>(3, 5); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand_array); @@ -1544,7 +1553,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique<Array2D<float>>(2, 4); + auto operand_array = absl::make_unique<Array2D<float>>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand_array); @@ -1580,7 +1589,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { // { 1, 2, 3, 4 }, // { 5, 6, 7, 8 }, // } - auto operand_array = MakeUnique<Array2D<float>>(2, 4); + auto operand_array = absl::make_unique<Array2D<float>>(2, 4); operand_array->FillUnique(1.0f); auto operand_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand_array); @@ -1614,7 +1623,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique<Array2D<double>>(2, 3); + auto operand_array = absl::make_unique<Array2D<double>>(2, 3); operand_array->FillUnique(1.0); auto operand_literal = LiteralUtil::CreateR2FromArray2D<double>(*operand_array); @@ -1651,7 +1660,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique<Array2D<double>>(2, 3); + auto operand_array = absl::make_unique<Array2D<double>>(2, 3); operand_array->FillUnique(1.0); auto operand_literal2 = LiteralUtil::CreateR2FromArray2D<double>(*operand_array); @@ -1687,7 +1696,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { // { 1, 2, 3 }, // { 5, 6, 7 }, // } - auto operand_array = MakeUnique<Array2D<double>>(2, 3); + auto operand_array = absl::make_unique<Array2D<double>>(2, 3); operand_array->FillUnique(1.0); HloInstruction* operand2 = b.AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 7fdf4521de..f682e69ee9 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,11 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/core/lib/core/casts.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -105,7 +110,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> double GetAsDouble(const Literal& literal, tensorflow::gtl::ArraySlice<int64> input_index) { - CHECK(false); + LOG(FATAL) << "Trying to get complex literal as double: " + << literal.ToString(); } public: @@ -139,7 +145,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } // TODO(b/35950897): many of the stl functions used in the handlers are not @@ -547,7 +553,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleDivide(HloInstruction* divide) override { + template < + typename NativeT, + typename std::enable_if<std::is_floating_point<NativeT>::value || + is_complex_t<NativeT>::value>::type* = nullptr> + Status HandleDivide(HloInstruction* divide) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { @@ -557,6 +567,46 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } template <typename NativeT, + typename std::enable_if<std::is_signed<NativeT>::value && + std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[divide], + ElementWiseBinaryOp( + divide, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) -> ElementwiseT { + if (rhs_elem == 0) { + return static_cast<ElementwiseT>(-1); + } + if (rhs_elem == -1 && + lhs_elem == std::numeric_limits<ElementwiseT>::min()) { + return lhs_elem; + } + return lhs_elem / rhs_elem; + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = + nullptr> + Status HandleDivide(HloInstruction* divide) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return rhs_elem == 0 + ? std::numeric_limits<ElementwiseT>::max() + : (lhs_elem / rhs_elem); + })); + return Status::OK(); + } + + Status HandleDivide(HloInstruction* divide) { + return HandleDivide<ElementwiseT>(divide); + } + + template <typename NativeT, typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr> Status HandleMaximum(HloInstruction* maximum) { @@ -642,9 +692,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr> + template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, @@ -654,6 +703,40 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template <typename NativeT, + typename std::enable_if<std::is_unsigned<NativeT>::value>::type* = + nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return rhs_el == 0 ? lhs_el : (lhs_el % rhs_el); + })); + return Status::OK(); + } + + template <typename NativeT, + typename std::enable_if<std::is_signed<NativeT>::value && + std::is_integral<NativeT>::value>::type* = + nullptr> + Status HandleRemainder(HloInstruction* remainder) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[remainder], + ElementWiseBinaryOp( + remainder, + [](ElementwiseT lhs_el, ElementwiseT rhs_el) -> ElementwiseT { + if (rhs_el == 0) { + return lhs_el; + } + if (rhs_el == -1 && + lhs_el == std::numeric_limits<ElementwiseT>::min()) { + return 0; + } + return lhs_el % rhs_el; + })); + return Status::OK(); + } + template < typename NativeT, typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr> @@ -895,7 +978,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); - auto result = MakeUnique<Literal>(result_shape); + auto result = absl::make_unique<Literal>(result_shape); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> out_index) { @@ -1052,7 +1135,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return static_cast<ReturnT>(result_val); }; - auto result = MakeUnique<Literal>(result_shape); + auto result = absl::make_unique<Literal>(result_shape); TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func)); parent_->evaluated_[conv] = std::move(result); @@ -1100,7 +1183,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // result_index_locations[i] contains one or two pointers to the locations // in lhs_index or rhs_index where the i'th result index should go. - tensorflow::gtl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> + absl::InlinedVector<std::pair<int64*, int64*>, kInlineRank> result_index_locations; result_index_locations.reserve(lhs_rank + rhs_rank - 2); @@ -1126,7 +1209,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } } - auto result = MakeUnique<Literal>(dot->shape()); + auto result = absl::make_unique<Literal>(dot->shape()); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> result_index) { ElementwiseT result_val = static_cast<ElementwiseT>(0); @@ -1175,7 +1258,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Create new HLO of padded shape with padding value. ReturnT scalar = parent_->GetEvaluatedLiteralFor(pad->operand(1)).Get<ReturnT>({}); - auto result = MakeUnique<Literal>(pad->shape()); + auto result = absl::make_unique<Literal>(pad->shape()); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&scalar](tensorflow::gtl::ArraySlice<int64> multi_index) { return scalar; @@ -1340,7 +1423,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto operands = map->operands(); HloComputation* computation = map->to_apply(); - auto result = MakeUnique<Literal>(map->shape()); + auto result = absl::make_unique<Literal>(map->shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( @@ -1454,7 +1537,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { [](const ReturnT& a, const ReturnT& b) { return SafeLess<ReturnT>(a, b); }); - auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); result_literal->PopulateR1( tensorflow::gtl::ArraySlice<ReturnT>(result_data)); VLOG(3) << "HandleSort result_literal: " << result_literal->ToString(); @@ -1466,7 +1549,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } else { // For R2 sort, the desired semantics are to sort each matrix row // independently. - auto result_literal = MakeUnique<Literal>(keys_literal.shape()); + auto result_literal = absl::make_unique<Literal>(keys_literal.shape()); int64 r1_length = keys->shape().dimensions(1); for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { TF_ASSIGN_OR_RETURN(auto r1_slice, @@ -1540,11 +1623,15 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique<Literal>(reduce->shape()); + auto result = absl::make_unique<Literal>(reduce->shape()); + Status eval_status; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> multi_index) { ReturnT result_val = init_scalar; + if (!eval_status.ok()) { + return result_val; + } std::vector<int64> base(arg_dimensions.size()); for (int64 i = 0; i < multi_index.size(); ++i) { @@ -1565,7 +1652,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { arg_dim_steps, func); return static_cast<ReturnT>(computed_result); } - auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) { + auto func = [&](tensorflow::gtl::ArraySlice<int64> input_index) + -> StatusOr<bool> { auto curr_val = arg_literal.Get<ReturnT>(input_index); // Evaluate computation with specified literal operands. @@ -1573,12 +1661,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto result_val_literal = LiteralUtil::CreateR0<ReturnT>(result_val); - std::unique_ptr<Literal> computed_result = - embedded_evaluator - .Evaluate<const Literal*>( - *function, - {result_val_literal.get(), curr_val_literal.get()}) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> computed_result, + embedded_evaluator.Evaluate<const Literal*>( + *function, {result_val_literal.get(), + curr_val_literal.get()})); // Clear visit states so that we can use the evaluator again on // the same computation. embedded_evaluator.ResetVisitStates(); @@ -1588,13 +1674,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { }; // Computes one element of the result, reducing all dimensions that // contribute to that element. - ShapeUtil::ForEachIndex(arg_literal.shape(), base, arg_dim_counts, - arg_dim_steps, func); + eval_status = ShapeUtil::ForEachIndexWithStatus( + arg_literal.shape(), base, arg_dim_counts, arg_dim_steps, func); return result_val; })); parent_->evaluated_[reduce] = std::move(result); - return Status::OK(); + return eval_status; } bool IsScalarAdd(HloComputation* computation) { @@ -1621,7 +1707,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); auto init_scalar = init_literal.Get<ReturnT>({}); - auto result = MakeUnique<Literal>(select_and_scatter->shape()); + auto result = absl::make_unique<Literal>(select_and_scatter->shape()); // Initialize result array with the init value. TF_RETURN_IF_ERROR(result->Populate<ReturnT>( @@ -1665,8 +1751,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // 2. Using the selected index, scatter value from `source` to result. We // do this by iterating through the window, and compare each index with // the selected index. - tensorflow::gtl::optional<ReturnT> selected_val; - tensorflow::gtl::optional<std::vector<int64>> selected_index; + absl::optional<ReturnT> selected_val; + absl::optional<std::vector<int64>> selected_index; IterateThroughWindow( window_shape, window, operand_literal.shape(), source_index, @@ -1757,7 +1843,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - auto result = MakeUnique<Literal>(reduce_window->shape()); + auto result = absl::make_unique<Literal>(reduce_window->shape()); // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> output_index) { @@ -1824,7 +1910,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_scatter_dim = - !c_binary_search(dim_numbers.update_window_dims(), i); + !absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_scatter_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1843,7 +1929,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> index_count(updates_rank, 1); for (int64 i = 0; i < updates_rank; i++) { bool is_update_window_dim = - c_binary_search(dim_numbers.update_window_dims(), i); + absl::c_binary_search(dim_numbers.update_window_dims(), i); if (is_update_window_dim) { index_count[i] = updates_shape.dimensions(i); } @@ -1870,7 +1956,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { : dim_numbers_(*dim_numbers), scatter_indices_(*scatter_indices) { for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { update_dim_is_scatter_dims_.push_back( - !c_binary_search(dim_numbers_.update_window_dims(), i)); + !absl::c_binary_search(dim_numbers_.update_window_dims(), i)); } for (int64 i = 0; i < input_shape.dimensions_size(); i++) { @@ -2000,7 +2086,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> window_index_to_update_index; int64 update_index_count = 0; for (int64 i = 0; i < updates_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.update_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { window_index_to_update_index.push_back(update_index_count++); } else { update_index_count++; @@ -2009,7 +2095,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 window_dim_count = 0; for (int64 i = 0; i < input_shape.dimensions_size(); i++) { - if (c_binary_search(dim_numbers.inserted_window_dims(), i)) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { input_dim_value_to_update_index_.push_back(-1); } else { input_dim_value_to_update_index_.push_back( @@ -2409,11 +2495,21 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::is_same<NativeT, float>::value || std::is_same<NativeT, int32>::value || std::is_same<NativeT, uint32>::value>::type* = nullptr> - Status HandleIota(HloInstruction* iota) { - auto result = MakeUnique<Literal>(iota->shape()); - auto data = result->data<ReturnT>(); + Status HandleIota(HloInstruction* instruction) { + auto* iota = Cast<HloIotaInstruction>(instruction); + std::vector<NativeT> data(iota->shape().dimensions(iota->iota_dimension())); std::iota(data.begin(), data.end(), 0); - parent_->evaluated_[iota] = std::move(result); + auto result = LiteralUtil::CreateR1<NativeT>(data); + + if (ShapeUtil::Rank(iota->shape()) > 1) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[iota], + result->Broadcast(iota->shape(), {iota->iota_dimension()})); + } else { + TF_RET_CHECK(ShapeUtil::Rank(iota->shape()) == 1); + parent_->evaluated_[iota] = std::move(result); + } + return Status::OK(); } template <typename NativeT, @@ -2492,7 +2588,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { } std::vector<int64> operand_indices(start.size()); - auto result = MakeUnique<Literal>(result_shape); + auto result = absl::make_unique<Literal>(result_shape); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> multi_index) { for (int64 i = 0; i < operand_indices.size(); ++i) { @@ -2570,15 +2666,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = MakeUnique<Literal>(shape); + auto result = absl::make_unique<Literal>(shape); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> multi_index) { @@ -2606,17 +2701,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Unimplemented( "Implicit broadcasting is currently unsupported in HLO evaluator " "Shape Mismatch: %s vs %s vs %s vs %s: ", - ShapeUtil::HumanString(shape).c_str(), - ShapeUtil::HumanString(lhs->shape()).c_str(), - ShapeUtil::HumanString(rhs->shape()).c_str(), - ShapeUtil::HumanString(ehs->shape()).c_str()); + ShapeUtil::HumanString(shape), ShapeUtil::HumanString(lhs->shape()), + ShapeUtil::HumanString(rhs->shape()), + ShapeUtil::HumanString(ehs->shape())); } const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = MakeUnique<Literal>(shape); + auto result = absl::make_unique<Literal>(shape); TF_RETURN_IF_ERROR(result->Populate<ReturnT>( [&](tensorflow::gtl::ArraySlice<int64> multi_index) { diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index c3ccbf0f0c..de3d7a1677 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -19,6 +19,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" @@ -49,7 +51,7 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData( size_t profile_counters_size = hlo_profile_index_map.total_count(); std::unique_ptr<HloProfilePrinterData> profile_printer_data = - MakeUnique<HloProfilePrinterData>(); + absl::make_unique<HloProfilePrinterData>(); profile_printer_data->set_profile_counters_size(profile_counters_size); profile_printer_data->mutable_computation_infos()->Reserve( hlo_profile_index_map.computation_count()); @@ -67,11 +69,11 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData( // The profile indices were computed deterministically in // HloProfileIndexMap::HloProfileIndexMap. - c_sort(computation_and_profile_idx_list, - [](const std::pair<const HloComputation*, int64>& left, - const std::pair<const HloComputation*, int64>& right) { - return left.second < right.second; - }); + absl::c_sort(computation_and_profile_idx_list, + [](const std::pair<const HloComputation*, int64>& left, + const std::pair<const HloComputation*, int64>& right) { + return left.second < right.second; + }); for (const auto& pair : computation_and_profile_idx_list) { CHECK_LT(pair.second, profile_counters_size); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index eba80c0f19..460ae2b5ec 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -14,15 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::AllOf; using ::testing::ContainsRegex; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1efa6eb5bd..3041d94fa9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -26,6 +26,12 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -37,50 +43,25 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" -using ::tensorflow::Env; -using ::tensorflow::WriteStringToFile; -using ::tensorflow::gtl::nullopt; -using ::tensorflow::gtl::optional; -using ::tensorflow::io::JoinPath; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::StringReplace; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { namespace hlo_graph_dumper { namespace { -// Helpers for Printf and Appendf. -template <typename T> -struct PrintfConvert { - const T& operator()(const T& t) const { return t; } -}; -template <> -struct PrintfConvert<string> { - const char* operator()(const string& s) const { return s.c_str(); } -}; - -// Like tensorflow::strings::Printf/Appendf, but you don't need to call c_str() -// on strings. -template <typename... Ts> -string Printf(const char* fmt, const Ts&... ts) { - return tensorflow::strings::Printf(fmt, PrintfConvert<Ts>()(ts)...); -} -template <typename... Ts> -void Appendf(string* s, const char* fmt, const Ts&... ts) { - tensorflow::strings::Appendf(s, fmt, PrintfConvert<Ts>()(ts)...); -} +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; +using tensorflow::Env; +using tensorflow::WriteStringToFile; +using tensorflow::io::JoinPath; // Used to indicate how we should treat a given HLOInstruction in the graph. // should we treat it like normal, hide it, and so on? @@ -209,17 +190,15 @@ NodeColors NodeColorsForScheme(ColorScheme color) { string NodeColorAttributes(ColorScheme color) { NodeColors node_colors = NodeColorsForScheme(color); - return Printf( - R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", - node_colors.style, node_colors.font_color, node_colors.stroke_color, - node_colors.fill_color); + return StrFormat(R"(style="%s", fontcolor="%s", color="%s", fillcolor="%s")", + node_colors.style, node_colors.font_color, + node_colors.stroke_color, node_colors.fill_color); } // Replaces <> with <>, so that this string is safe(er) for use in a // graphviz HTML-like string. -string HtmlLikeStringSanitize(tensorflow::StringPiece s) { - return StringReplace(StringReplace(s, "<", "<", /*replace_all=*/true), ">", - ">", /*replace_all=*/true); +string HtmlLikeStringSanitize(absl::string_view s) { + return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } // Tries to generates a human-readable one-word description of the given @@ -322,11 +301,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) { // Encapsulates logic for dumping an HLO module to DOT (i.e. graphviz syntax). class HloDotDumper { public: - HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, + HloDotDumper(const HloComputation* computation, absl::string_view label, const DebugOptions& debug_options, bool show_backend_config, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), - label_(std::string(label)), + label_(label), debug_options_(debug_options), show_backend_config_(show_backend_config), profile_(profile), @@ -448,7 +427,7 @@ string HloDotDumper::Dump() { } string HloDotDumper::Header() { - const char* fmt = R"(digraph G { + constexpr char fmt[] = R"(digraph G { rankdir = TB; compound = true; label = <<b>%s</b>>; @@ -457,7 +436,7 @@ labelloc = t; tooltip = " "; // DOT graphs accept a stylesheet as a URI. So naturally, an inline // stylesheet is a data URI! -stylesheet=" +stylesheet=< data:text/css, @import url(https://fonts.googleapis.com/css?family=Roboto:400,700); svg text { @@ -466,7 +445,7 @@ stylesheet=" } %s -" +> )"; @@ -481,8 +460,8 @@ stylesheet=" } if (profile_ != nullptr) { auto cycles = profile_->total_cycles_executed(*computation_); - Appendf(&graph_label, "<br/>total cycles = %lld (%s)", cycles, - tensorflow::strings::HumanReadableNum(cycles)); + absl::StrAppendFormat(&graph_label, "<br/>total cycles = %d (%s)", cycles, + tensorflow::strings::HumanReadableNum(cycles)); } // Create CSS rules that say, when you hover over the given node or cluster, @@ -509,14 +488,14 @@ stylesheet=" // One could imagine other ways of writing this CSS rule that involve // less duplication, but this way seems to be relatively performant. edge_css_rules.push_back( - Printf(" #%s%d:hover ~ #edge%lld text { fill: %s; }\n" - " #%s%d:hover ~ #edge%lld path { " - "stroke: %s; stroke-width: .2em; }\n" - " #%s%d:hover ~ #edge%lld polygon { " - "fill: %s; stroke: %s; stroke-width: .2em; }\n", - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, // - elem_type, elem_id, edge_id, color, color)); + StrFormat(" #%s%d:hover ~ #edge%d text { fill: %s; }\n" + " #%s%d:hover ~ #edge%d path { " + "stroke: %s; stroke-width: .2em; }\n" + " #%s%d:hover ~ #edge%d polygon { " + "fill: %s; stroke: %s; stroke-width: .2em; }\n", + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, // + elem_type, elem_id, edge_id, color, color)); }; // The "to_node" value may be a NULL, indicating that this points to the @@ -559,10 +538,10 @@ stylesheet=" } } - return Printf(fmt, graph_label, Join(edge_css_rules, "\n")); + return StrFormat(fmt, graph_label, StrJoin(edge_css_rules, "\n")); } -string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } +string HloDotDumper::Footer() { return StrCat(StrJoin(edges_, "\n"), "\n}"); } bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { CHECK_EQ(instr->opcode(), HloOpcode::kFusion); @@ -600,9 +579,9 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = + constexpr char edge_fmt[] = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( + edges_.push_back(StrFormat( edge_fmt, InstructionId(from), InstructionId(parent_instr), SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } @@ -619,9 +598,10 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, string subcomp_label, style; if (parent_instr->opcode() == HloOpcode::kFusion) { - subcomp_label = Printf("Fused expression for <b>%s</b><br/>%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(parent_instr->ToCategory())); + subcomp_label = + StrFormat("Fused expression for <b>%s</b><br/>%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(parent_instr->ToCategory())); string extra_info = GetInstructionNodeExtraInfo(parent_instr); if (!extra_info.empty()) { StrAppend(&subcomp_label, "<br/>", extra_info); @@ -647,18 +627,18 @@ string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, strokecolor = highlight ? "#b71c1c" : "#c2c2c2"; } style = - Printf(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", - fillcolor, strokecolor); + StrFormat(R"(style="rounded,filled,bold"; fillcolor="%s"; color="%s;")", + fillcolor, strokecolor); } else { - subcomp_label = Printf("Subcomputation for <b>%s</b><br/>%s", - HtmlLikeStringSanitize(parent_instr->name()), - HtmlLikeStringSanitize(subcomp->name())); + subcomp_label = StrFormat("Subcomputation for <b>%s</b><br/>%s", + HtmlLikeStringSanitize(parent_instr->name()), + HtmlLikeStringSanitize(subcomp->name())); style = "style=rounded; color=black;"; } string comp_body = DumpComputation(subcomp); - const char* computation_fmt = R"(subgraph %s { + constexpr char computation_fmt[] = R"(subgraph %s { %s label = <%s>; labelloc = t; @@ -667,7 +647,7 @@ tooltip = " "; } // %s )"; - return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + return StrFormat(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -718,11 +698,11 @@ string HloDotDumper::DumpRootTag() { VLOG(2) << "Adding edge from " << from->name() << " to root tag as " << next_edge_id_; edge_ids_.insert({{from, to}, next_edge_id_++}); - edges_.push_back(Printf(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); + edges_.push_back(StrFormat(R"(%s -> %s [tooltip=" "];)", from_id, to_id)); - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" - "\n", - to_id, node_body, node_shape, NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + "\n", + to_id, node_body, node_shape, NodeColorAttributes(color)); } static const HloConstantInstruction* TryGetFusionParameterConstant( @@ -817,10 +797,10 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" - "\n", - InstructionId(instr), node_body, node_shape, node_metadata, - NodeColorAttributes(color)); + return StrFormat(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" + "\n", + InstructionId(instr), node_body, node_shape, node_metadata, + NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( @@ -833,7 +813,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which // is just noise. if (ShapeUtil::IsZeroElementArray(shape)) { - return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape())); + return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } // Print the literal value of constants with <= K elements. @@ -848,19 +828,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return Printf("%s (%s)", constant->literal().ToString(), - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s (%s)", constant->literal().ToString(), + ShapeUtil::HumanString(constant->shape())); } // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::str_util::StartsWith(constant->name(), "constant")) { + if (absl::StartsWith(constant->name(), "constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); } - return Printf("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; std::vector<string> lines; @@ -881,7 +861,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( TryGetFusionParameterConstant(operand)) { operand_str = stringify_constant(constant); } else { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + operand_str = StrFormat("Parameter %d", operand->parameter_number()); } } else { operand_str = operand->name(); @@ -890,13 +870,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand_str) { if (instr->operand_count() > 1) { - lines.push_back(Printf("<b>operand %lld</b> = %s", i, *operand_str)); + lines.push_back(StrFormat("<b>operand %d</b> = %s", i, *operand_str)); } else { - lines.push_back(Printf("<b>operand</b> = %s", *operand_str)); + lines.push_back(StrFormat("<b>operand</b> = %s", *operand_str)); } } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { @@ -1049,6 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGray; case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kRecv: @@ -1059,7 +1040,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: @@ -1080,14 +1060,13 @@ string HloDotDumper::GetInstructionNodeShape(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // If we have a parameter, put the param number in the name. if (instr->opcode() == HloOpcode::kParameter) { - return Printf("<b>Parameter %lld</b>", instr->parameter_number()); + return StrFormat("<b>Parameter %d</b>", instr->parameter_number()); } // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. - if (tensorflow::str_util::StartsWith(instr->name(), - HloOpcodeString(instr->opcode()))) { - return Printf("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); + if (absl::StartsWith(instr->name(), HloOpcodeString(instr->opcode()))) { + return StrFormat("<b>%s</b>", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = StrCat(HloOpcodeString(instr->opcode()), @@ -1095,8 +1074,8 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { ? "" : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode), - HtmlLikeStringSanitize(instr->name())); + return StrFormat("<b>%s</b><br/>%s", HtmlLikeStringSanitize(extended_opcode), + HtmlLikeStringSanitize(instr->name())); } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { @@ -1105,16 +1084,16 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); } if (!instr->metadata().op_type().empty()) { - lines.push_back(Printf( + lines.push_back(StrFormat( "op_type: %s", HtmlLikeStringSanitize(instr->metadata().op_type()))); } if (!instr->metadata().source_file().empty() && instr->metadata().source_line() != 0) { - lines.push_back(Printf("op_type: %s", instr->metadata().source_file(), - instr->metadata().source_line())); + lines.push_back(StrFormat("op_type: %s:%d", instr->metadata().source_file(), + instr->metadata().source_line())); } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } string HloDotDumper::GetInstructionNodeBackendConfig( @@ -1161,13 +1140,12 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { constexpr int kMaxShapeLen = 64; if (instr_shape.length() > kMaxShapeLen) { instr_shape = StrCat( - tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), - "..."); + absl::string_view(instr_shape).substr(0, kMaxShapeLen - 3), "..."); } lines.push_back(instr_shape); } if (debug_options_.xla_hlo_graph_addresses()) { - lines.push_back(Printf("[%p]", instr)); + lines.push_back(StrFormat("[%p]", instr)); } if (profile_ != nullptr) { double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); @@ -1175,11 +1153,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { lines.push_back( - Printf("%% of cycles executed=%.2f", - 100 * hlo_cycles_executed / total_cycles_executed)); + StrFormat("%% of cycles executed=%.2f", + 100 * hlo_cycles_executed / total_cycles_executed)); } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } // Gets the total number of array elements in the given shape. For tuples, this @@ -1211,7 +1189,8 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { string edge_label; if (instr->operand_count() > 1 && !control_edge) { - edge_label = Printf(R"( headlabel="%lld", labeldistance=2)", operand_num); + edge_label = + StrFormat(R"( headlabel="%d", labeldistance=2)", operand_num); } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } @@ -1221,10 +1200,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // means. bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; - edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), from->name(), - to->name(), edge_label)); + constexpr char kEdgeFmt[] = + R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; + edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), + (is_big_array ? "normal" : "empty"), + from->name(), to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1265,14 +1245,14 @@ string HloDotDumper::GetInstructionTrivialComputationStr( continue; } if (instr->called_computations().size() == 1) { - lines.push_back(Printf("Subcomputation: <b>%s</b>", - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation: <b>%s</b>", + HtmlLikeStringSanitize(*computation_type))); } else { - lines.push_back(Printf("Subcomputation %lld: <b>%s</b>", i, - HtmlLikeStringSanitize(*computation_type))); + lines.push_back(StrFormat("Subcomputation %d: <b>%s</b>", i, + HtmlLikeStringSanitize(*computation_type))); } } - return Join(lines, "<br/>"); + return StrJoin(lines, "<br/>"); } const HloInstruction* HloDotDumper::GetNodeForEdge( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 1d7a062c55..064c53252c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,12 +24,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace { -using ::tensorflow::strings::StrCat; +using absl::StrCat; using ::testing::HasSubstr; string TestName() { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 57e75cf931..ed4e159910 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -21,10 +21,17 @@ limitations under the License. #include <unordered_set> #include <utility> +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -39,17 +46,15 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; /* static */ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( @@ -224,7 +229,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( Literal::CreateFromProto(proto.literal())); instruction = CreateConstant(std::move(literal)); } else { - instruction = MakeUnique<HloConstantInstruction>(proto.shape()); + instruction = absl::make_unique<HloConstantInstruction>(proto.shape()); } break; } @@ -294,15 +299,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "CrossReplicaSum should have 1 called computation but sees " << proto.called_computation_ids_size(); - tensorflow::gtl::optional<int64> all_reduce_id; + absl::optional<int64> all_reduce_id; if (proto.all_reduce_id() > 0) { all_reduce_id = proto.all_reduce_id(); } instruction = CreateCrossReplicaSum( proto.shape(), all_operands(), computations(0), - /*replica_group_ids=*/ - std::vector<int64>(proto.replica_group_ids().begin(), - proto.replica_group_ids().end()), + /*replica_groups=*/ + std::vector<ReplicaGroup>(proto.replica_groups().begin(), + proto.replica_groups().end()), /*barrier=*/proto.cross_replica_sum_barrier(), /*all_reduce_id=*/all_reduce_id); break; @@ -312,8 +317,18 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.shape(), all_operands(), /*replica_groups=*/ std::vector<ReplicaGroup>(proto.replica_groups().begin(), - proto.replica_groups().end()), - /*barrier=*/proto.cross_replica_sum_barrier()); + proto.replica_groups().end())); + break; + } + case HloOpcode::kCollectivePermute: { + std::vector<std::pair<int64, int64>> source_target_pairs( + proto.source_target_pairs_size()); + for (int i = 0; i < source_target_pairs.size(); i++) { + source_target_pairs[i].first = proto.source_target_pairs(i).source(); + source_target_pairs[i].second = proto.source_target_pairs(i).target(); + } + instruction = CreateCollectivePermute(proto.shape(), operands(0), + source_target_pairs); break; } case HloOpcode::kConvolution: @@ -361,11 +376,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.convolution_dimension_numbers()); } break; - case HloOpcode::kHostCompute: - instruction = - CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(), - proto.cost_estimate_ns()); - break; case HloOpcode::kPad: TF_RET_CHECK(proto.operand_ids_size() == 2) << "Pad instruction should have 2 operands but sees " @@ -379,7 +389,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << "DynamicSlice instruction should have 2 operands but sees " << proto.operand_ids_size(); std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size()); - c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); + absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin()); instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1), slice_sizes); break; @@ -391,7 +401,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.has_gather_dimension_numbers()) << "Gather instruction should have GatherDimensionNumbers set."; std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers = - MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); + absl::make_unique<GatherDimensionNumbers>( + proto.gather_dimension_numbers()); std::vector<int64> gather_slice_sizes; for (int64 bound : proto.gather_slice_sizes()) { gather_slice_sizes.push_back(bound); @@ -409,15 +420,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.called_computation_ids_size() == 1) << "Scatter instruction should have 1 called computation but sees " << proto.called_computation_ids_size(); - auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>( - proto.scatter_dimension_numbers()); + auto scatter_dimension_numbers = + absl::make_unique<ScatterDimensionNumbers>( + proto.scatter_dimension_numbers()); instruction = CreateScatter(proto.shape(), operands(0), operands(1), operands(2), computations(0), *scatter_dimension_numbers); break; } + case HloOpcode::kIota: + TF_RET_CHECK(proto.dimensions_size() <= 1) + << "Iota instruction should have at most 1 dimension but sees " + << proto.dimensions_size(); + instruction = CreateIota(proto.shape(), proto.dimensions(0)); + break; default: { - instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) << "No instruction with id " << operand_id; @@ -445,10 +463,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); + instruction->precision_config_ = proto.precision_config(); if (proto.has_dot_dimension_numbers()) { instruction->dot_dimension_numbers_ = - MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers()); + absl::make_unique<DotDimensionNumbers>(proto.dot_dimension_numbers()); } if (proto.has_sharding()) { @@ -462,34 +481,36 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - return MakeUnique<HloParameterInstruction>(parameter_number, shape, name); + return absl::make_unique<HloParameterInstruction>(parameter_number, shape, + name); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - return MakeUnique<HloTraceInstruction>(tag, operand); + return absl::make_unique<HloTraceInstruction>(tag, operand); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant( std::unique_ptr<Literal> literal) { - return MakeUnique<HloConstantInstruction>(std::move(literal)); + return absl::make_unique<HloConstantInstruction>(std::move(literal)); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota( - const Shape& shape) { - return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); + const Shape& shape, int64 iota_dimension) { + return absl::make_unique<HloIotaInstruction>(shape, iota_dimension); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - return MakeUnique<HloGetTupleElementInstruction>(shape, operand, index); + return absl::make_unique<HloGetTupleElementInstruction>(shape, operand, + index); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice<HloInstruction*> parameters) { - return MakeUnique<HloRngInstruction>(shape, distribution, parameters); + return absl::make_unique<HloRngInstruction>(shape, distribution, parameters); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary( @@ -499,7 +520,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // It is impossible to copy an opaque shape, we don't know how big it is. CHECK(!ShapeUtil::IsOpaque(shape)); } - auto instruction = WrapUnique(new HloInstruction(opcode, shape)); + auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -604,31 +625,33 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* map_computation) { - return MakeUnique<HloMapInstruction>(shape, operands, map_computation); + return absl::make_unique<HloMapInstruction>(shape, operands, map_computation); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count) { - return MakeUnique<HloConvolutionInstruction>( + return absl::make_unique<HloConvolutionInstruction>( shape, lhs, rhs, window, dimension_numbers, feature_group_count); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice<int64> fft_length) { - return MakeUnique<HloFftInstruction>(shape, operand, fft_type, fft_length); + return absl::make_unique<HloFftInstruction>(shape, operand, fft_type, + fft_length); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = - MakeUnique<DotDimensionNumbers>(dimension_numbers); + absl::make_unique<DotDimensionNumbers>(dimension_numbers); return instruction; } @@ -637,10 +660,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); - instruction->dot_dimension_numbers_ = MakeUnique<DotDimensionNumbers>(); + instruction->dot_dimension_numbers_ = + absl::make_unique<DotDimensionNumbers>(); instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); return instruction; @@ -651,7 +676,7 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - return MakeUnique<HloReducePrecisionInstruction>( + return absl::make_unique<HloReducePrecisionInstruction>( shape, operand, exponent_bits, mantissa_bits); } @@ -659,40 +684,47 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional<int64>& all_reduce_id) { - return MakeUnique<HloAllReduceInstruction>( - shape, operands, reduce_computation, replica_group_ids, barrier, + const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, + const absl::optional<int64>& all_reduce_id) { + return absl::make_unique<HloAllReduceInstruction>( + shape, operands, reduce_computation, replica_groups, barrier, all_reduce_id); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier) { - return MakeUnique<HloAllToAllInstruction>(shape, operands, replica_groups, - barrier); + const std::vector<ReplicaGroup>& replica_groups) { + return absl::make_unique<HloAllToAllInstruction>(shape, operands, + replica_groups); +} + +/* static */ std::unique_ptr<HloInstruction> +HloInstruction::CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) { + return absl::make_unique<HloCollectivePermuteInstruction>( + shape, operand, source_target_pairs); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( const Shape& infeed_shape, HloInstruction* token_operand, const string& config) { - return MakeUnique<HloInfeedInstruction>(infeed_shape, token_operand, config); + return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand, + config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) { - return MakeUnique<HloOutfeedInstruction>(outfeed_shape, operand, - token_operand, outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config) { + return absl::make_unique<HloOutfeedInstruction>( + outfeed_shape, operand, token_operand, outfeed_config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( HloInstruction* operand, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique<HloSendInstruction>(operand, token, channel_id, - is_host_transfer); + return absl::make_unique<HloSendInstruction>(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone( @@ -700,14 +732,15 @@ HloInstruction::CreateCrossReplicaSum( auto send_operand = DynCast<HloSendInstruction>(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer); + return absl::make_unique<HloSendDoneInstruction>(send_operand, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv( const Shape& shape, HloInstruction* token, int64 channel_id, bool is_host_transfer) { - return MakeUnique<HloRecvInstruction>(shape, token, channel_id, - is_host_transfer); + return absl::make_unique<HloRecvInstruction>(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone( @@ -715,19 +748,20 @@ HloInstruction::CreateCrossReplicaSum( auto recv_operand = DynCast<HloRecvInstruction>(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer); + return absl::make_unique<HloRecvDoneInstruction>(recv_operand, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) { - return MakeUnique<HloReverseInstruction>(shape, operand, dimensions); + return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice<HloInstruction*> operands) { CHECK(!operands.empty()); - auto instruction = WrapUnique( + auto instruction = absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { instruction->AppendOperand(operand); @@ -736,14 +770,15 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() { - return WrapUnique( + return absl::WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape)); instruction->AppendOperand(init); // Body comes before condition computation in the vector. instruction->called_computations_.push_back(body); @@ -756,7 +791,7 @@ HloInstruction::CreateCrossReplicaSum( HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); instruction->AppendOperand(pred); instruction->AppendOperand(true_computation_arg); instruction->AppendOperand(false_computation_arg); @@ -773,15 +808,15 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice<int64> start_indices, tensorflow::gtl::ArraySlice<int64> limit_indices, tensorflow::gtl::ArraySlice<int64> strides) { - return MakeUnique<HloSliceInstruction>(shape, operand, start_indices, - limit_indices, strides); + return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice( const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return MakeUnique<HloDynamicSliceInstruction>(shape, operand, start_indices, - slice_sizes); + return absl::make_unique<HloDynamicSliceInstruction>( + shape, operand, start_indices, slice_sizes); } /* static */ std::unique_ptr<HloInstruction> @@ -789,8 +824,8 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction* operand, HloInstruction* update, HloInstruction* start_indices) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); + auto instruction = absl::WrapUnique( + new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape)); instruction->AppendOperand(operand); instruction->AppendOperand(update); instruction->AppendOperand(start_indices); @@ -800,12 +835,14 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, int64 dimension) { - return MakeUnique<HloConcatenateInstruction>(shape, operands, dimension); + return absl::make_unique<HloConcatenateInstruction>(shape, operands, + dimension); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert( const Shape& shape, HloInstruction* operand) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -814,7 +851,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, HloInstruction::CreateBitcastConvert(const Shape& shape, HloInstruction* operand) { auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); instruction->AppendOperand(operand); return instruction; } @@ -823,7 +860,7 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* operand, HloInstruction* init_value, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloReduceInstruction( + auto instruction = absl::WrapUnique(new HloReduceInstruction( shape, {operand, init_value}, dimensions_to_reduce, reduce_computation)); return std::move(instruction); } @@ -837,15 +874,15 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, all_args.reserve(operands.size() * 2); all_args.insert(all_args.end(), operands.begin(), operands.end()); all_args.insert(all_args.end(), init_values.begin(), init_values.end()); - return MakeUnique<HloReduceInstruction>(shape, all_args, dimensions_to_reduce, - reduce_computation); + return absl::make_unique<HloReduceInstruction>( + shape, all_args, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow( const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation) { - return MakeUnique<HloReduceWindowInstruction>(shape, operand, init_value, - window, reduce_computation); + return absl::make_unique<HloReduceWindowInstruction>( + shape, operand, init_value, window, reduce_computation); } /* static */ std::unique_ptr<HloInstruction> @@ -854,7 +891,7 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - return MakeUnique<HloBatchNormTrainingInstruction>( + return absl::make_unique<HloBatchNormTrainingInstruction>( shape, operand, scale, offset, epsilon, feature_index); } @@ -863,7 +900,7 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - return MakeUnique<HloBatchNormInferenceInstruction>( + return absl::make_unique<HloBatchNormInferenceInstruction>( shape, operand, scale, offset, mean, variance, epsilon, feature_index); } @@ -873,9 +910,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - return MakeUnique<HloBatchNormGradInstruction>(shape, operand, scale, mean, - variance, grad_output, epsilon, - feature_index); + return absl::make_unique<HloBatchNormGradInstruction>( + shape, operand, scale, mean, variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr<HloInstruction> @@ -883,15 +920,15 @@ HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, const Window& window, HloInstruction* source, HloInstruction* init_value, HloComputation* scatter) { - return MakeUnique<HloSelectAndScatterInstruction>( + return absl::make_unique<HloSelectAndScatterInstruction>( shape, operand, select, window, source, init_value, scatter); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { - return MakeUnique<HloBroadcastInstruction>(shape, operand, - broadcast_dimensions); + return absl::make_unique<HloBroadcastInstruction>(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr<HloInstruction> @@ -949,8 +986,8 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { - return MakeUnique<HloPadInstruction>(shape, operand, padding_value, - padding_config); + return absl::make_unique<HloPadInstruction>(shape, operand, padding_value, + padding_config); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape( @@ -959,7 +996,8 @@ HloInstruction::CreateBroadcastSequence( ShapeUtil::ElementsIn(operand->shape())) << "shape: " << ShapeUtil::HumanString(shape) << " operand: " << ShapeUtil::HumanString(operand->shape()); - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape)); instruction->AppendOperand(operand); return instruction; } @@ -967,26 +1005,27 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) { - return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions); + return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, HloInstruction* values) { - return MakeUnique<HloSortInstruction>(shape, dimension, keys, values); + return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root); + return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, + fused_root); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* fusion_computation) { - return MakeUnique<HloFusionInstruction>(shape, fusion_kind, operands, - fusion_computation); + return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands, + fusion_computation); } void HloInstruction::set_single_sharding(const HloSharding& sharding) { @@ -1006,6 +1045,7 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); + derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1018,7 +1058,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: - case HloOpcode::kHostCompute: return true; case HloOpcode::kCrossReplicaSum: return all_reduce_id().has_value(); @@ -1044,7 +1083,7 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* computation) { std::unique_ptr<HloInstruction> instruction = - WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); + absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape)); for (auto operand : operands) { instruction->AppendOperand(operand); } @@ -1054,16 +1093,9 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target) { - return MakeUnique<HloCustomCallInstruction>(shape, operands, - custom_call_target); -} - -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { - return MakeUnique<HloHostComputeInstruction>(shape, operands, channel_name, - cost_estimate_ns); + absl::string_view custom_call_target) { + return absl::make_unique<HloCustomCallInstruction>(shape, operands, + custom_call_target); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple( @@ -1080,8 +1112,8 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* start_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> slice_sizes) { - return MakeUnique<HloGatherInstruction>(shape, operand, start_indices, - gather_dim_numbers, slice_sizes); + return absl::make_unique<HloGatherInstruction>( + shape, operand, start_indices, gather_dim_numbers, slice_sizes); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter( @@ -1089,16 +1121,17 @@ bool HloInstruction::HasSideEffect() const { HloInstruction* scatter_indices, HloInstruction* updates, HloComputation* update_computation, const ScatterDimensionNumbers& scatter_dim_numbers) { - return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices, - updates, update_computation, - scatter_dim_numbers); + return absl::make_unique<HloScatterInstruction>( + shape, operand, scatter_indices, updates, update_computation, + scatter_dim_numbers); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr<DomainMetadata> operand_side_metadata, std::unique_ptr<DomainMetadata> user_side_metadata) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + auto instruction = + absl::WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); instruction->operand_side_metadata_ = std::move(operand_side_metadata); instruction->user_side_metadata_ = std::move(user_side_metadata); instruction->AppendOperand(operand); @@ -1146,13 +1179,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kReducePrecision: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kSort: @@ -1274,6 +1307,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( } break; } + // SetupDerivedInstruction will setup the precision_config_ field. SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); clone->set_raw_backend_config_string(backend_config_); @@ -1339,7 +1373,7 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone( // If names ends with .suffix[0-9]+ then replace with a suffix with the // numeric value incremented. int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { clone->name_ = StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1); } else { @@ -1614,11 +1648,11 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kOutfeed: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kConvolution: case HloOpcode::kCustomCall: case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: - case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: @@ -1812,7 +1846,7 @@ void HloInstruction::set_false_computation(HloComputation* false_computation) { string HloInstruction::SignatureString() const { string operands = - Join(operands_, ", ", [](string* out, HloInstruction* operand) { + StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) { StrAppend(out, ShapeUtil::HumanString(operand->shape())); }); return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); @@ -1832,7 +1866,7 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { } bool HloInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const { + const absl::optional<int64>& operand_idx) const { switch (opcode_) { // Unary elementwise operations. case HloOpcode::kAbs: @@ -1959,7 +1993,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) { // If operand is already been deleted, put `null` to the string output. if (operand == nullptr) { StrAppend(out, "null "); @@ -1979,7 +2013,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( } else if (!options.compact_operands()) { str.push_back(PrintName(operand->name(), options)); } - StrAppend(out, Join(str, " ")); + StrAppend(out, StrJoin(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -1996,6 +2030,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( extra.push_back(DotDimensionNumbersToString()); } + string precision_config_string = PrecisionConfigToString(); + if (!precision_config_string.empty()) { + extra.push_back(precision_config_string); + } + if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { if (opcode() == HloOpcode::kWhile) { @@ -2021,11 +2060,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, - PrintName(computation->name(), options)); - }))); + "calls=", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, PrintName(computation->name(), options)); + }))); } } else if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kFullBodies) { @@ -2058,12 +2097,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString( break; default: if (!called_computations().empty()) { - extra.push_back( - StrCat("calls=\n", - Join(called_computations(), ", ", - [&](string* out, const HloComputation* computation) { - StrAppend(out, computation->ToString(new_options)); - }))); + extra.push_back(StrCat( + "calls=\n", + StrJoin(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, computation->ToString(new_options)); + }))); } break; } @@ -2074,11 +2113,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( } if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", - Join(control_predecessors_, ", ", - [&](string* out, HloInstruction* pre) { - StrAppend(out, - PrintName(pre->name(), options)); - }), + StrJoin(control_predecessors_, ", ", + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); + }), "}")); } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { @@ -2092,10 +2131,10 @@ std::vector<string> HloInstruction::ExtraAttributesToString( string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", - Join(operands_, ", ", - [](string* out, HloInstruction* operand) { - StrAppend(out, "%", operand->name()); - }), + StrJoin(operands_, ", ", + [](string* out, HloInstruction* operand) { + StrAppend(out, "%", operand->name()); + }), ")"); } @@ -2117,6 +2156,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); + *proto.mutable_precision_config() = precision_config_; if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); @@ -2155,7 +2195,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } -bool HloInstruction::IsFusable() const { +bool HloInstruction::IsFusible() const { // Instructions which are traced should not be fused. if (tracing()) { return false; @@ -2261,6 +2301,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleCrossReplicaSum(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kCollectivePermute: + return visitor->HandleCollectivePermute(this); case HloOpcode::kTuple: return visitor->HandleTuple(this); case HloOpcode::kMap: @@ -2329,8 +2371,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); - case HloOpcode::kHostCompute: - return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2369,15 +2409,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return InternalError( "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - " "please file a bug for XLA.", - HloOpcodeString(opcode_).c_str()); + HloOpcodeString(opcode_)); } // Explicit instantiations. template Status HloInstruction::Visit(DfsHloVisitor* visitor); template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); -using DFSStack = - tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>; +using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. @@ -2453,7 +2492,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } @@ -2462,7 +2501,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - current_node->ToString().c_str()); + current_node->ToString()); } } } @@ -2622,7 +2661,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - return IsElementwiseImpl(tensorflow::gtl::nullopt); + return IsElementwiseImpl(absl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -2778,7 +2817,7 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } - return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); + return InvalidArgument("Unknown fusion kind: %s", kind_name); } string PaddingConfigToString(const PaddingConfig& padding) { @@ -2787,7 +2826,7 @@ string PaddingConfigToString(const PaddingConfig& padding) { [](const PaddingConfig::PaddingConfigDimension& dim) { return dim.interior_padding() != 0; }); - return Join( + return StrJoin( padding.dimensions(), "x", [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { StrAppend( @@ -2811,11 +2850,15 @@ string OpMetadataToString(const OpMetadata& metadata) { if (metadata.source_line() != 0) { result.push_back(StrCat("source_line=", metadata.source_line())); } - return Join(result, " "); + return StrJoin(result, " "); } string RandomDistributionToString(const RandomDistribution& distribution) { - return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); + return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); +} + +string PrecisionToString(const PrecisionConfigProto::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2843,8 +2886,8 @@ string ConvolutionDimensionNumbersToString( output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", - Join(output_dims, "")); + return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->", + StrJoin(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -2855,19 +2898,21 @@ string HloInstruction::DotDimensionNumbersToString() const { const DotDimensionNumbers& dnums = *dot_dimension_numbers_; if (!dnums.lhs_batch_dimensions().empty()) { result.push_back(StrCat("lhs_batch_dims={", - Join(dnums.lhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.lhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("lhs_contracting_dims={", - Join(dnums.lhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.lhs_contracting_dimensions(), ","), + "}")); if (!dnums.rhs_batch_dimensions().empty()) { result.push_back(StrCat("rhs_batch_dims={", - Join(dnums.rhs_batch_dimensions(), ","), "}")); + StrJoin(dnums.rhs_batch_dimensions(), ","), "}")); } result.push_back(StrCat("rhs_contracting_dims={", - Join(dnums.rhs_contracting_dimensions(), ","), "}")); + StrJoin(dnums.rhs_contracting_dimensions(), ","), + "}")); - return Join(result, ", "); + return StrJoin(result, ", "); } StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { @@ -2881,7 +2926,44 @@ StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { } return map; }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); + auto found = map->find(absl::AsciiStrToLower(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +string HloInstruction::PrecisionConfigToString() const { + if (precision_config_.operand_precision().empty()) { + return ""; + } + return StrCat( + "operand_precision={", + StrJoin(precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfigProto::Precision_IsValid(precision)) + << precision; + StrAppend(out, PrecisionToString( + static_cast<PrecisionConfigProto::Precision>( + precision))); + }), + "}"); +} + +StatusOr<PrecisionConfigProto::Precision> StringToPrecision( + const string& name) { + static std::unordered_map<string, PrecisionConfigProto::Precision>* map = [] { + static auto* map = + new std::unordered_map<string, PrecisionConfigProto::Precision>; + for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { + if (PrecisionConfigProto::Precision_IsValid(i)) { + auto value = static_cast<PrecisionConfigProto::Precision>(i); + (*map)[PrecisionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(absl::AsciiStrToLower(name)); if (found == map->end()) { return InvalidArgument("Unknown distribution"); } @@ -3131,31 +3213,25 @@ const string& HloInstruction::outfeed_config() const { return Cast<HloOutfeedInstruction>(this)->outfeed_config(); } -const std::vector<int64>& HloInstruction::replica_group_ids() const { - return Cast<HloAllReduceInstruction>(this)->replica_group_ids(); +const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { + return Cast<HloCollectiveInstruction>(this)->replica_groups(); } -const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { - return Cast<HloAllToAllInstruction>(this)->replica_groups(); +const std::vector<std::pair<int64, int64>>& +HloInstruction::source_target_pairs() const { + return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs(); } string HloInstruction::cross_replica_sum_barrier() const { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); - } - return Cast<HloAllToAllInstruction>(this)->cross_replica_sum_barrier(); + return Cast<HloAllReduceInstruction>(this)->cross_replica_sum_barrier(); } void HloInstruction::set_cross_replica_sum_barrier(const string& barrier) { - if (opcode() == HloOpcode::kCrossReplicaSum) { - return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( - barrier); - } - return Cast<HloAllToAllInstruction>(this)->set_cross_replica_sum_barrier( + return Cast<HloAllReduceInstruction>(this)->set_cross_replica_sum_barrier( barrier); } -tensorflow::gtl::optional<int64> HloInstruction::all_reduce_id() const { +absl::optional<int64> HloInstruction::all_reduce_id() const { return Cast<HloAllReduceInstruction>(this)->all_reduce_id(); } @@ -3205,10 +3281,6 @@ const string& HloInstruction::custom_call_target() const { return Cast<HloCustomCallInstruction>(this)->custom_call_target(); } -const string& HloInstruction::channel_name() const { - return Cast<HloHostComputeInstruction>(this)->channel_name(); -} - const PaddingConfig& HloInstruction::padding_config() const { return Cast<HloPadInstruction>(this)->padding_config(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 8d8f149ee3..4a424cebc0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,10 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/container/inlined_vector.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" @@ -45,10 +49,8 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -101,6 +103,7 @@ class HloPrintOptions { return HloPrintOptions() .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies) .set_print_metadata(false) + .set_print_backend_config(false) .set_compact_operands(true) .set_print_operand_shape(true) .set_print_program_shape(false) @@ -182,7 +185,7 @@ class HloPrintOptions { return print_subcomputation_mode_; } bool print_metadata() const { return print_metadata_; } - bool print_backend_config() const { return print_metadata_; } + bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } bool print_program_shape() const { return print_program_shape_; } @@ -220,7 +223,7 @@ class CanonicalNameMap { return iter->second; } - string new_name = tensorflow::strings::StrCat("tmp_", index++); + string new_name = absl::StrCat("tmp_", index++); canonical_name_map[old_name] = new_name; return new_name; } @@ -347,7 +350,8 @@ class HloInstruction { std::unique_ptr<Literal> literal); // Creates an Iota instruction. - static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape); + static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape, + int64 iota_dimension); // Creates a get tuple element instruction. static std::unique_ptr<HloInstruction> CreateGetTupleElement( @@ -433,9 +437,10 @@ class HloInstruction { // // `reduction_computation`: the reduction function. // - // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). + // Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // `all_reduce_id`: for Allreduce nodes from different modules, if they have @@ -446,9 +451,8 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional<int64>& all_reduce_id); + const std::vector<ReplicaGroup>& replica_groups, + absl::string_view barrier, const absl::optional<int64>& all_reduce_id); // This op handles the communication of an Alltoall operation. On each core, // the operands are N ops in the same shape, where N is the number of cores @@ -463,12 +467,18 @@ class HloInstruction { // within replica 1, 2, 3, and in the gather phase, the received blocks will // be concatenated in the order of 1, 2, 3; another Alltoall will be applied // within replica 4, 5, 0, and the concatenation order is 4, 5, 0. - // - // TODO(b/110096724): This is NOT YET ready to use. static std::unique_ptr<HloInstruction> CreateAllToAll( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier); + const std::vector<ReplicaGroup>& replica_groups); + + // Creates a communitation instructions that permutes data cross replicas. + // Data is sent/received according to the (source_replica_id, + // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a + // target_replica_id in any pair, the output on that replica is a tensor + // conssits of 0(s) in `shape`. + static std::unique_ptr<HloInstruction> CreateCollectivePermute( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -493,7 +503,7 @@ class HloInstruction { // which is a TOKEN. static std::unique_ptr<HloInstruction> CreateOutfeed( const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config); + HloInstruction* token_operand, absl::string_view outfeed_config); // Creates an asynchronous send instruction with the given channel id, which // initiates sending the operand data to a unique receive instruction in @@ -706,13 +716,7 @@ class HloInstruction { // to the given operands. "shape" is the resultant shape. static std::unique_ptr<HloInstruction> CreateCustomCall( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target); - - // Creates a HostCompute instruction, which records host-side control and - // data dependencies for use in instruction scheduling. - static std::unique_ptr<HloInstruction> CreateHostCompute( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + absl::string_view custom_call_target); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -766,7 +770,7 @@ class HloInstruction { int64 operand_count() const { return operands_.size(); } // Returns the vector of operands of this instruction. - using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>; + using InstructionVector = absl::InlinedVector<HloInstruction*, 2>; const InstructionVector& operands() const { return operands_; } // Returns the vector of unique operands, in the same order they are found @@ -863,6 +867,11 @@ class HloInstruction { return false; } + if (!ContainersEqual(precision_config_.operand_precision(), + other.precision_config_.operand_precision())) { + return false; + } + return IdenticalSlowPath(other, eq_computations); } @@ -1030,7 +1039,7 @@ class HloInstruction { // Returns true if this instruction can be legally fused into a fusion // instruction. - bool IsFusable() const; + bool IsFusible() const; // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. @@ -1038,21 +1047,26 @@ class HloInstruction { CHECK(has_sharding()); return *sharding_; } + std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; } + // Returns the sharding applied to this operator, or default_ if none exists. const HloSharding& sharding_or_default(const HloSharding& default_) const { return sharding_ ? *sharding_ : default_; } // Returns the sharding unique device, if any. - tensorflow::gtl::optional<int64> sharding_unique_device() const { + absl::optional<int64> sharding_unique_device() const { if (sharding_ == nullptr) { - return tensorflow::gtl::optional<int64>(); + return absl::optional<int64>(); } return sharding_->UniqueDevice(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { - sharding_ = MakeUnique<HloSharding>(sharding); + sharding_ = std::make_shared<const HloSharding>(sharding); + } + void set_sharding(std::shared_ptr<const HloSharding> sharding) { + sharding_ = std::move(sharding); } void set_single_sharding(const HloSharding& sharding); // Sets a sharding that assigns the current instruction to device. @@ -1088,19 +1102,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // TODO(b/80249101): Remove these methods once HLO scheduling and copy - // insertion are integrated, and we don't need to run a separate pass - // of copy elision anymore. - bool CopyElisionAllowed() const { - CHECK_EQ(HloOpcode::kCopy, opcode_); - return copy_elision_allowed_; - } - - void SetCopyElisionAllowed(bool value) { - CHECK_EQ(HloOpcode::kCopy, opcode_); - copy_elision_allowed_ = value; - } - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1110,6 +1111,9 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + // Returns the dump string of the precision configuration. + string PrecisionConfigToString() const; + // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1253,6 +1257,20 @@ class HloInstruction { static StatusOr<string> BackendConfigToRawString( const tensorflow::protobuf::Message& proto); + // Returns the information used to tell the implementation information about + // what sort of precision is requested. The meaning of the field is backend + // specific. At the moment, it is only supported for kConvolution and kDot. + // Transformations on one kDot or kConvolution to another will preserve this + // information. Transformations to other HLOs will not preserve this + // information but it is presumed that the alternate lowering is strictly + // superior. + const PrecisionConfigProto& precision_config() const { + return precision_config_; + } + void set_precision_config(const PrecisionConfigProto& precision_config) { + precision_config_ = precision_config; + } + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1421,18 +1439,18 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllReduceInstruction::replica_group_ids. - const std::vector<int64>& replica_group_ids() const; - - // Delegates to HloAllToAllInstruction::replica_groups. + // Delegates to HloCollectiveInstruction::replica_groups. const std::vector<ReplicaGroup>& replica_groups() const; + // Delegates to HloCollectivePermuteInstruction::source_target_pairs. + const std::vector<std::pair<int64, int64>>& source_target_pairs() const; + // Delegates to HloAllReduceInstruction::cross_replica_sum_barrier. string cross_replica_sum_barrier() const; void set_cross_replica_sum_barrier(const string& barrier); // Delegates to HloAllReduceInstruction::all_reduce_id. - tensorflow::gtl::optional<int64> all_reduce_id() const; + absl::optional<int64> all_reduce_id() const; // Returns data on the window in a windowed operation such as // convolution. @@ -1475,9 +1493,6 @@ class HloInstruction { // Delegates to HloCustomCallInstruction::custom_call_target. const string& custom_call_target() const; - // Delegates to HloHostComputeInstruction::channel_name. - const string& channel_name() const; - // Delegates to HloPadInstruction::padding_config. const PaddingConfig& padding_config() const; @@ -1565,7 +1580,7 @@ class HloInstruction { // NOTE: For all instructions other than kFusion, being elementwise on one of // the operands is equivalent to being elementwise on all the operands. virtual bool IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const; + const absl::optional<int64>& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1642,7 +1657,10 @@ class HloInstruction { bool copy_elision_allowed_ = true; // The sharding, if one exists. - std::unique_ptr<HloSharding> sharding_; + // Uses std::shared_ptr to allow reuse of the same sharding object between + // HloInstructions and other components as HloSharding can be very large for + // many element tuples. + std::shared_ptr<const HloSharding> sharding_; // Fields used by the kDomain instruction. std::unique_ptr<DomainMetadata> operand_side_metadata_; @@ -1661,6 +1679,10 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // Information used to communicate to the implementation about the algorithm + // used to produce results. See the documentation on precision_config(). + PrecisionConfigProto precision_config_; + // String identifier for instruction. string name_; @@ -1683,10 +1705,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string PrecisionToString(const PrecisionConfigProto::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); +StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 504b13043f..8b0b90dfb3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -53,7 +53,7 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { public: Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("not implemented %s", - HloOpcodeString(hlo_instruction->opcode()).c_str()); + HloOpcodeString(hlo_instruction->opcode())); } Status HandleParameter(HloInstruction* parameter) override { diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 4fdf4360e6..ffc74cfedd 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -17,6 +17,12 @@ limitations under the License. #include <deque> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -27,10 +33,10 @@ limitations under the License. namespace xla { namespace { -using ::tensorflow::str_util::CEscape; -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::CEscape; +using absl::StrAppend; +using absl::StrCat; +using absl::StrJoin; bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -89,7 +95,7 @@ HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique<HloBatchNormTrainingInstruction>( + return absl::make_unique<HloBatchNormTrainingInstruction>( shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), feature_index()); } @@ -111,7 +117,7 @@ HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique<HloBatchNormInferenceInstruction>( + return absl::make_unique<HloBatchNormInferenceInstruction>( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -133,7 +139,7 @@ HloBatchNormGradInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 5); - return MakeUnique<HloBatchNormGradInstruction>( + return absl::make_unique<HloBatchNormGradInstruction>( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); } @@ -158,7 +164,7 @@ HloInstructionProto HloFftInstruction::ToProto() const { std::vector<string> HloFftInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {StrCat("fft_type=", FftType_Name(fft_type())), - StrCat("fft_length={", Join(fft_length(), ","), "}")}; + StrCat("fft_length={", StrJoin(fft_length(), ","), "}")}; } bool HloFftInstruction::IdenticalSlowPath( @@ -175,8 +181,8 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloFftInstruction>(shape, new_operands[0], fft_type_, - fft_length_); + return absl::make_unique<HloFftInstruction>(shape, new_operands[0], fft_type_, + fft_length_); } HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, @@ -230,8 +236,8 @@ std::unique_ptr<HloInstruction> HloSendInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloSendInstruction>(new_operands[0], new_operands[1], - channel_id(), is_host_transfer()); + return absl::make_unique<HloSendInstruction>( + new_operands[0], new_operands[1], channel_id(), is_host_transfer()); } HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand, @@ -248,7 +254,7 @@ HloSendDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloSendDoneInstruction>( + return absl::make_unique<HloSendDoneInstruction>( Cast<HloSendInstruction>(new_operands[0]), is_host_transfer()); } @@ -269,7 +275,7 @@ std::unique_ptr<HloInstruction> HloRecvInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloRecvInstruction>( + return absl::make_unique<HloRecvInstruction>( ShapeUtil::GetTupleElementShape(shape, 0), new_operands[0], channel_id(), is_host_transfer()); } @@ -291,31 +297,67 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloRecvDoneInstruction>( + return absl::make_unique<HloRecvDoneInstruction>( Cast<HloRecvInstruction>(new_operands[0]), is_host_transfer()); } +HloCollectiveInstruction::HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups) + : HloInstruction(opcode, shape), replica_groups_(replica_groups) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloCollectiveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; + return proto; +} + +std::vector<string> HloCollectiveInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector<string> result; + std::vector<string> replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", StrJoin(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", StrJoin(replica_group_str, ","), "}")); + return result; +} + +bool HloCollectiveInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectiveInstruction&>(other); + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }); +} + HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional<int64>& all_reduce_id) - : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()), + const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier, + const absl::optional<int64>& all_reduce_id) + : HloCollectiveInstruction(HloOpcode::kCrossReplicaSum, shape, operands, + replica_groups), + cross_replica_sum_barrier_(barrier), all_reduce_id_(all_reduce_id) { - for (auto operand : operands) { - AppendOperand(operand); - } AppendComputation(reduce_computation); } HloInstructionProto HloAllReduceInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - for (int64 i : replica_group_ids_) { - proto.add_replica_group_ids(i); - } + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); // Proto3 is so sad. if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); @@ -325,9 +367,9 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { } std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& /*options*/) const { - std::vector<string> result = { - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + const HloPrintOptions& options) const { + std::vector<string> result = + HloCollectiveInstruction::ExtraAttributesToStringImpl(options); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -342,7 +384,7 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); - return replica_group_ids() == casted_other.replica_group_ids() && + return HloCollectiveInstruction::IdenticalSlowPath(other, eq_computations) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -354,70 +396,76 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* /*context*/) const { - return MakeUnique<HloAllReduceInstruction>( - shape, new_operands, to_apply(), replica_group_ids(), + return absl::make_unique<HloAllReduceInstruction>( + shape, new_operands, to_apply(), replica_groups(), cross_replica_sum_barrier(), all_reduce_id()); } HloAllToAllInstruction::HloAllToAllInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier) - : HloInstruction(HloOpcode::kAllToAll, shape), - replica_groups_(replica_groups), - cross_replica_sum_barrier_(barrier.begin(), barrier.end()) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -bool HloAllToAllInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const { - const auto& casted_other = static_cast<const HloAllToAllInstruction&>(other); - return ContainersEqual(replica_groups(), casted_other.replica_groups(), - [](const ReplicaGroup& a, const ReplicaGroup& b) { - return ContainersEqual(a.replica_ids(), - b.replica_ids()); - }) && - cross_replica_sum_barrier() == - casted_other.cross_replica_sum_barrier(); -} + const std::vector<ReplicaGroup>& replica_groups) + : HloCollectiveInstruction(HloOpcode::kAllToAll, shape, operands, + replica_groups) {} std::unique_ptr<HloInstruction> HloAllToAllInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* /*context*/) const { - return MakeUnique<HloAllToAllInstruction>( - shape, new_operands, replica_groups(), cross_replica_sum_barrier()); + return absl::make_unique<HloAllToAllInstruction>(shape, new_operands, + replica_groups()); } -std::vector<string> HloAllToAllInstruction::ExtraAttributesToStringImpl( - const HloPrintOptions& options) const { - std::vector<string> result; - std::vector<string> replica_group_str; - for (const ReplicaGroup& group : replica_groups()) { - replica_group_str.push_back( - StrCat("{", Join(group.replica_ids(), ","), "}")); - } - result.push_back( - StrCat("replica_groups={", Join(replica_group_str, ","), "}")); +HloCollectivePermuteInstruction::HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs) + : HloInstruction(HloOpcode::kCollectivePermute, shape), + source_target_pairs_(source_target_pairs) { + AppendOperand(operand); +} - if (!cross_replica_sum_barrier().empty()) { - result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); +HloInstructionProto HloCollectivePermuteInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (const auto& pair : source_target_pairs()) { + auto* proto_pair = proto.add_source_target_pairs(); + proto_pair->set_source(pair.first); + proto_pair->set_target(pair.second); } + return proto; +} +std::vector<string> +HloCollectivePermuteInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& /*options*/) const { + std::vector<string> result; + std::vector<string> strs; + for (const auto& pair : source_target_pairs()) { + strs.push_back(StrCat("{", pair.first, ",", pair.second, "}")); + } + result.push_back(StrCat("source_target_pairs={", StrJoin(strs, ","), "}")); return result; } -HloInstructionProto HloAllToAllInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_replica_groups() = {replica_groups_.begin(), - replica_groups_.end()}; - proto.set_cross_replica_sum_barrier(cross_replica_sum_barrier_); - return proto; +bool HloCollectivePermuteInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + /*eq_computations*/) const { + const auto& casted_other = + static_cast<const HloCollectivePermuteInstruction&>(other); + return ContainersEqual( + source_target_pairs(), casted_other.source_target_pairs(), + [](const std::pair<int64, int64>& a, const std::pair<int64, int64>& b) { + return a == b; + }); +} + +std::unique_ptr<HloInstruction> +HloCollectivePermuteInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* /*context*/) const { + return absl::make_unique<HloCollectivePermuteInstruction>( + shape, new_operands[0], source_target_pairs()); } HloReverseInstruction::HloReverseInstruction( @@ -438,7 +486,7 @@ HloInstructionProto HloReverseInstruction::ToProto() const { std::vector<string> HloReverseInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReverseInstruction::IdenticalSlowPath( @@ -454,8 +502,8 @@ std::unique_ptr<HloInstruction> HloReverseInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloReverseInstruction>(shape, new_operands[0], - dimensions()); + return absl::make_unique<HloReverseInstruction>(shape, new_operands[0], + dimensions()); } HloConcatenateInstruction::HloConcatenateInstruction( @@ -477,7 +525,7 @@ HloInstructionProto HloConcatenateInstruction::ToProto() const { std::vector<string> HloConcatenateInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloConcatenateInstruction::IdenticalSlowPath( @@ -494,8 +542,8 @@ HloConcatenateInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - return MakeUnique<HloConcatenateInstruction>(shape, new_operands, - dimensions(0)); + return absl::make_unique<HloConcatenateInstruction>(shape, new_operands, + dimensions(0)); } HloReduceInstruction::HloReduceInstruction( @@ -520,7 +568,7 @@ HloInstructionProto HloReduceInstruction::ToProto() const { std::vector<string> HloReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloReduceInstruction::IdenticalSlowPath( @@ -539,8 +587,8 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloReduceInstruction>(shape, new_operands, dimensions(), - to_apply()); + return absl::make_unique<HloReduceInstruction>(shape, new_operands, + dimensions(), to_apply()); } HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, @@ -563,7 +611,7 @@ HloInstructionProto HloSortInstruction::ToProto() const { std::vector<string> HloSortInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloSortInstruction::IdenticalSlowPath( @@ -580,7 +628,8 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; - return MakeUnique<HloSortInstruction>(shape, dimensions(0), keys, values); + return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys, + values); } HloTransposeInstruction::HloTransposeInstruction( @@ -595,7 +644,7 @@ HloTransposeInstruction::HloTransposeInstruction( Permute(dimensions, shape.dimensions()).begin())) << "shape: " << ShapeUtil::HumanString(shape) << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; + << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -616,7 +665,7 @@ HloInstructionProto HloTransposeInstruction::ToProto() const { std::vector<string> HloTransposeInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloTransposeInstruction::IdenticalSlowPath( @@ -633,8 +682,8 @@ HloTransposeInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloTransposeInstruction>(shape, new_operands[0], - dimensions()); + return absl::make_unique<HloTransposeInstruction>(shape, new_operands[0], + dimensions()); } HloBroadcastInstruction::HloBroadcastInstruction( @@ -655,7 +704,7 @@ HloInstructionProto HloBroadcastInstruction::ToProto() const { std::vector<string> HloBroadcastInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloBroadcastInstruction::IdenticalSlowPath( @@ -672,8 +721,8 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloBroadcastInstruction>(shape, new_operands[0], - dimensions()); + return absl::make_unique<HloBroadcastInstruction>(shape, new_operands[0], + dimensions()); } HloMapInstruction::HloMapInstruction( @@ -699,7 +748,7 @@ HloInstructionProto HloMapInstruction::ToProto() const { } bool HloMapInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const { + const absl::optional<int64>& operand_idx) const { if (!dimensions().empty()) { // Check that the map is executed in elementwise compatible dimensions. if (dimensions().size() != shape().dimensions_size()) { @@ -716,7 +765,7 @@ bool HloMapInstruction::IsElementwiseImpl( std::vector<string> HloMapInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; + return {StrCat("dimensions={", StrJoin(dimensions(), ","), "}")}; } bool HloMapInstruction::IdenticalSlowPath( @@ -730,7 +779,7 @@ std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - return MakeUnique<HloMapInstruction>(shape, new_operands, to_apply()); + return absl::make_unique<HloMapInstruction>(shape, new_operands, to_apply()); } HloSliceInstruction::HloSliceInstruction( @@ -774,7 +823,7 @@ std::vector<string> HloSliceInstruction::ExtraAttributesToStringImpl( bounds.push_back( StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); } - return {StrCat("slice={", Join(bounds, ", "), "}")}; + return {StrCat("slice={", StrJoin(bounds, ", "), "}")}; } bool HloSliceInstruction::IdenticalSlowPath( @@ -792,8 +841,8 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_, - slice_limits_, slice_strides_); + return absl::make_unique<HloSliceInstruction>( + shape, new_operands[0], slice_starts_, slice_limits_, slice_strides_); } HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal) @@ -812,7 +861,7 @@ HloInstructionProto HloConstantInstruction::ToProto() const { } bool HloConstantInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const { + const absl::optional<int64>& operand_idx) const { return true; } @@ -845,7 +894,7 @@ HloConstantInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique()); + return absl::make_unique<HloConstantInstruction>(literal_->CloneToUnique()); } string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( @@ -860,7 +909,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector<string> v = tensorflow::str_util::Split(tmp, ' '); + std::vector<string> v = absl::StrSplit(tmp, ' '); bool first = true; // Concatenate elements in "v" with spaces separating them, but ignoring // empty entries. @@ -952,7 +1001,7 @@ HloInstructionProto HloFusionInstruction::ToProto() const { } bool HloFusionInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const { + const absl::optional<int64>& operand_idx) const { if (!operand_idx.has_value()) { for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { @@ -1155,7 +1204,7 @@ HloInstruction* HloFusionInstruction::FuseInstructionInternal( HloInstruction* HloFusionInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + CHECK(instruction_to_fuse->IsFusible()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations().empty()) { @@ -1339,8 +1388,8 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation = module->AddEmbeddedComputation( fused_instructions_computation()->Clone("clone", context)); } - return MakeUnique<HloFusionInstruction>(shape, fusion_kind(), new_operands, - new_fused_computation); + return absl::make_unique<HloFusionInstruction>( + shape, fusion_kind(), new_operands, new_fused_computation); } Status HloFusionInstruction::DeduplicateFusionOperands() { @@ -1384,7 +1433,7 @@ std::vector<string> HloRngInstruction::ExtraAttributesToStringImpl( } bool HloRngInstruction::IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const { + const absl::optional<int64>& operand_idx) const { return true; } @@ -1399,7 +1448,8 @@ std::unique_ptr<HloInstruction> HloRngInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - return MakeUnique<HloRngInstruction>(shape, distribution_, new_operands); + return absl::make_unique<HloRngInstruction>(shape, distribution_, + new_operands); } HloParameterInstruction::HloParameterInstruction(int64 parameter_number, @@ -1435,7 +1485,8 @@ HloParameterInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - return MakeUnique<HloParameterInstruction>(parameter_number_, shape, name()); + return absl::make_unique<HloParameterInstruction>(parameter_number_, shape, + name()); } HloGetTupleElementInstruction::HloGetTupleElementInstruction( @@ -1471,8 +1522,8 @@ HloGetTupleElementInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloGetTupleElementInstruction>(shape, new_operands[0], - tuple_index()); + return absl::make_unique<HloGetTupleElementInstruction>( + shape, new_operands[0], tuple_index()); } HloReducePrecisionInstruction::HloReducePrecisionInstruction( @@ -1514,7 +1565,7 @@ HloReducePrecisionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloReducePrecisionInstruction>( + return absl::make_unique<HloReducePrecisionInstruction>( shape, new_operands[0], exponent_bits(), mantissa_bits()); } @@ -1555,16 +1606,17 @@ std::unique_ptr<HloInstruction> HloInfeedInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 1); - return MakeUnique<HloInfeedInstruction>(infeed_shape(), new_operands[0], - infeed_config()); + return absl::make_unique<HloInfeedInstruction>( + infeed_shape(), new_operands[0], infeed_config()); } -HloOutfeedInstruction::HloOutfeedInstruction( - const Shape& outfeed_shape, HloInstruction* operand, - HloInstruction* token_operand, tensorflow::StringPiece outfeed_config) +HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, + HloInstruction* operand, + HloInstruction* token_operand, + absl::string_view outfeed_config) : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), - outfeed_config_(outfeed_config.begin(), outfeed_config.end()) { + outfeed_config_(outfeed_config) { CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) << "Outfeed shape " << outfeed_shape << " must be compatible with operand shape " << operand->shape(); @@ -1600,8 +1652,8 @@ std::unique_ptr<HloInstruction> HloOutfeedInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloOutfeedInstruction>(outfeed_shape(), new_operands[0], - new_operands[1], outfeed_config()); + return absl::make_unique<HloOutfeedInstruction>( + outfeed_shape(), new_operands[0], new_operands[1], outfeed_config()); } HloConvolutionInstruction::HloConvolutionInstruction( @@ -1671,7 +1723,7 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloConvolutionInstruction>( + return absl::make_unique<HloConvolutionInstruction>( shape, new_operands[0], new_operands[1], window(), convolution_dimension_numbers_, feature_group_count_); } @@ -1716,7 +1768,7 @@ HloReduceWindowInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloReduceWindowInstruction>( + return absl::make_unique<HloReduceWindowInstruction>( shape, new_operands[0], new_operands[1], window(), to_apply()); } @@ -1765,14 +1817,14 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique<HloSelectAndScatterInstruction>( + return absl::make_unique<HloSelectAndScatterInstruction>( shape, new_operands[0], select(), window(), new_operands[1], new_operands[2], scatter()); } HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target) + absl::string_view custom_call_target) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()) { @@ -1840,8 +1892,8 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { - auto cloned = MakeUnique<HloCustomCallInstruction>(shape, new_operands, - custom_call_target()); + auto cloned = absl::make_unique<HloCustomCallInstruction>( + shape, new_operands, custom_call_target()); if (window_ != nullptr) { cloned->set_window(*window_); } @@ -1851,41 +1903,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( return std::move(cloned); } -HloHostComputeInstruction::HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) - : HloInstruction(HloOpcode::kHostCompute, shape), - channel_name_(channel_name.begin(), channel_name.end()), - cost_estimate_ns_(cost_estimate_ns) { - for (auto operand : operands) { - AppendOperand(operand); - } -} - -HloInstructionProto HloHostComputeInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - proto.set_channel_name(channel_name_); - proto.set_cost_estimate_ns(cost_estimate_ns_); - return proto; -} - -bool HloHostComputeInstruction::IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const { - // Not yet supported. - return false; -} - -std::unique_ptr<HloInstruction> -HloHostComputeInstruction::CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, - HloCloneContext* context) const { - return MakeUnique<HloHostComputeInstruction>( - shape, new_operands, channel_name_, cost_estimate_ns_); -} - HloPadInstruction::HloPadInstruction(const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, @@ -1920,8 +1937,8 @@ std::unique_ptr<HloInstruction> HloPadInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloPadInstruction>(shape, new_operands[0], new_operands[1], - padding_config_); + return absl::make_unique<HloPadInstruction>(shape, new_operands[0], + new_operands[1], padding_config_); } HloDynamicSliceInstruction::HloDynamicSliceInstruction( @@ -1943,8 +1960,8 @@ HloInstructionProto HloDynamicSliceInstruction::ToProto() const { std::vector<string> HloDynamicSliceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { - return { - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")}; + return {StrCat("dynamic_slice_sizes={", StrJoin(dynamic_slice_sizes(), ","), + "}")}; } bool HloDynamicSliceInstruction::IdenticalSlowPath( @@ -1960,7 +1977,7 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloDynamicSliceInstruction>( + return absl::make_unique<HloDynamicSliceInstruction>( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } @@ -1972,25 +1989,25 @@ HloGatherInstruction::HloGatherInstruction( AppendOperand(operand); AppendOperand(start_indices); gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); + absl::make_unique<GatherDimensionNumbers>(gather_dim_numbers); + absl::c_copy(slice_sizes, std::back_inserter(gather_slice_sizes_)); } string HloGatherInstruction::GatherDimensionNumbersToString() const { CHECK(gather_dimension_numbers_ != nullptr); string offset_dims = StrCat("offset_dims={", - Join(gather_dimension_numbers_->offset_dims(), ","), "}"); - string collapsed_slice_dims = - StrCat("collapsed_slice_dims={", - Join(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); + StrJoin(gather_dimension_numbers_->offset_dims(), ","), "}"); + string collapsed_slice_dims = StrCat( + "collapsed_slice_dims={", + StrJoin(gather_dimension_numbers_->collapsed_slice_dims(), ","), "}"); string start_index_map = StrCat("start_index_map={", - Join(gather_dimension_numbers_->start_index_map(), ","), "}"); + StrJoin(gather_dimension_numbers_->start_index_map(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - return Join<std::initializer_list<string>>( + return StrJoin<std::initializer_list<string>>( {offset_dims, collapsed_slice_dims, start_index_map, index_vector_dim}, ", "); } @@ -2027,7 +2044,7 @@ HloInstructionProto HloGatherInstruction::ToProto() const { std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return {GatherDimensionNumbersToString(), - StrCat("slice_sizes={", Join(gather_slice_sizes(), ","), "}")}; + StrCat("slice_sizes={", StrJoin(gather_slice_sizes(), ","), "}")}; } bool HloGatherInstruction::IdenticalSlowPath( @@ -2046,7 +2063,7 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); - return MakeUnique<HloGatherInstruction>( + return absl::make_unique<HloGatherInstruction>( shape, new_operands[0], new_operands[1], gather_dimension_numbers(), gather_slice_sizes()); } @@ -2062,24 +2079,24 @@ HloScatterInstruction::HloScatterInstruction( AppendOperand(updates); AppendComputation(update_computation); scatter_dimension_numbers_ = - MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers); + absl::make_unique<ScatterDimensionNumbers>(scatter_dim_numbers); } string HloScatterInstruction::ScatterDimensionNumbersToString() const { - string update_window_dims = - StrCat("update_window_dims={", - Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string update_window_dims = StrCat( + "update_window_dims={", + StrJoin(scatter_dimension_numbers().update_window_dims(), ","), "}"); string inserted_window_dims = StrCat( "inserted_window_dims={", - Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + StrJoin(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); string scatter_dims_to_operand_dims = StrCat( "scatter_dims_to_operand_dims={", - Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + StrJoin(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), "}"); string index_vector_dim = StrCat( "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); - return Join<std::initializer_list<string>>( + return StrJoin<std::initializer_list<string>>( {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, index_vector_dim}, ", "); @@ -2133,9 +2150,39 @@ std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 3); - return MakeUnique<HloScatterInstruction>( + return absl::make_unique<HloScatterInstruction>( shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), scatter_dimension_numbers()); } +HloIotaInstruction::HloIotaInstruction(const Shape& shape, int64 iota_dimension) + : HloInstruction(HloOpcode::kIota, shape), + iota_dimension_(iota_dimension) {} + +HloInstructionProto HloIotaInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.add_dimensions(iota_dimension()); + return proto; +} + +std::vector<string> HloIotaInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("iota_dimension=", iota_dimension())}; +} + +bool HloIotaInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloIotaInstruction&>(other); + return iota_dimension() == casted_other.iota_dimension(); +} + +std::unique_ptr<HloInstruction> HloIotaInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + return absl::make_unique<HloIotaInstruction>(shape, iota_dimension()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 803dbeabeb..ee6e337b6a 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -217,19 +218,37 @@ class HloRecvDoneInstruction : public HloSendRecvInstruction { HloCloneContext* context) const override; }; -class HloAllReduceInstruction : public HloInstruction { +class HloCollectiveInstruction : public HloInstruction { + public: + const std::vector<ReplicaGroup>& replica_groups() const { + return replica_groups_; + } + + protected: + explicit HloCollectiveInstruction( + HloOpcode opcode, const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups); + + HloInstructionProto ToProto() const override; + + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + + std::vector<ReplicaGroup> replica_groups_; +}; + +class HloAllReduceInstruction : public HloCollectiveInstruction { public: explicit HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, - tensorflow::StringPiece barrier, - const tensorflow::gtl::optional<int64>& all_reduce_id); - - // Returns the group ids of each replica for CrossReplicaSum op. - const std::vector<int64>& replica_group_ids() const { - return replica_group_ids_; - } + const std::vector<ReplicaGroup>& replica_groups, + absl::string_view barrier, const absl::optional<int64>& all_reduce_id); // Returns the barrier config used for the CrossReplicaSum implementation of // each backend. @@ -240,9 +259,7 @@ class HloAllReduceInstruction : public HloInstruction { cross_replica_sum_barrier_ = barrier; } - tensorflow::gtl::optional<int64> all_reduce_id() const { - return all_reduce_id_; - } + absl::optional<int64> all_reduce_id() const { return all_reduce_id_; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; @@ -261,37 +278,40 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const override; - // The group id of each replica for CrossReplicaSum. - std::vector<int64> replica_group_ids_; - // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; // For Allreduce nodes from different modules, if they have the same // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be // applied cross modules. - tensorflow::gtl::optional<int64> all_reduce_id_; + absl::optional<int64> all_reduce_id_; }; -class HloAllToAllInstruction : public HloInstruction { +class HloAllToAllInstruction : public HloCollectiveInstruction { public: explicit HloAllToAllInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operand, - const std::vector<ReplicaGroup>& replica_groups, - tensorflow::StringPiece barrier); + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + const std::vector<ReplicaGroup>& replica_groups); - const std::vector<ReplicaGroup>& replica_groups() const { - return replica_groups_; - } + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; +}; - // TODO(b/110096724): rename this. - void set_cross_replica_sum_barrier(string barrier) { - cross_replica_sum_barrier_ = barrier; - } - string cross_replica_sum_barrier() const { - return cross_replica_sum_barrier_; +class HloCollectivePermuteInstruction : public HloInstruction { + public: + explicit HloCollectivePermuteInstruction( + const Shape& shape, HloInstruction* operand, + const std::vector<std::pair<int64, int64>>& source_target_pairs); + + const std::vector<std::pair<int64, int64>>& source_target_pairs() const { + return source_target_pairs_; } + // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; private: @@ -308,10 +328,7 @@ class HloAllToAllInstruction : public HloInstruction { tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const override; - std::vector<ReplicaGroup> replica_groups_; - - // The string representation of the barrier config. - string cross_replica_sum_barrier_; + const std::vector<std::pair<int64, int64>> source_target_pairs_; }; class HloReverseInstruction : public HloInstruction { @@ -507,7 +524,7 @@ class HloMapInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const override; + const absl::optional<int64>& operand_idx) const override; std::vector<string> ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -600,7 +617,7 @@ class HloConstantInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const override; + const absl::optional<int64>& operand_idx) const override; bool IdenticalSlowPath( const HloInstruction& other, const std::function<bool(const HloComputation*, const HloComputation*)>& @@ -751,7 +768,7 @@ class HloFusionInstruction : public HloInstruction { bool add_output = false); bool IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const override; + const absl::optional<int64>& operand_idx) const override; std::vector<string> ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -780,7 +797,7 @@ class HloRngInstruction : public HloInstruction { private: bool IsElementwiseImpl( - const tensorflow::gtl::optional<int64>& operand_idx) const override; + const absl::optional<int64>& operand_idx) const override; std::vector<string> ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; bool IdenticalSlowPath( @@ -920,7 +937,7 @@ class HloOutfeedInstruction : public HloInstruction { explicit HloOutfeedInstruction(const Shape& outfeed_shape, HloInstruction* operand, HloInstruction* token_operand, - tensorflow::StringPiece outfeed_config); + absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); @@ -1073,14 +1090,14 @@ class HloCustomCallInstruction : public HloInstruction { public: explicit HloCustomCallInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece custom_call_target); + absl::string_view custom_call_target); const Window& window() const override { CHECK(window_ != nullptr); return *window_; } void set_window(const Window& window) override { - window_ = MakeUnique<Window>(window); + window_ = absl::make_unique<Window>(window); } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -1091,7 +1108,7 @@ class HloCustomCallInstruction : public HloInstruction { void set_convolution_dimension_numbers( const ConvolutionDimensionNumbers& dnums) { convolution_dimension_numbers_ = - MakeUnique<ConvolutionDimensionNumbers>(dnums); + absl::make_unique<ConvolutionDimensionNumbers>(dnums); } const string& custom_call_target() const { return custom_call_target_; } // Returns a serialized representation of this instruction. @@ -1117,33 +1134,6 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_; }; -class HloHostComputeInstruction : public HloInstruction { - public: - explicit HloHostComputeInstruction( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, - tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); - // Returns the channel name associated with the instruction. The name is - // used to identify host Send/Recv operations. - const string& channel_name() const { return channel_name_; } - // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const override; - - private: - bool IdenticalSlowPath( - const HloInstruction& other, - const std::function<bool(const HloComputation*, const HloComputation*)>& - eq_computations) const override; - // Implementation for non-common logic of CloneWithNewOperands. - std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, - HloCloneContext* context) const override; - // Name to use for host send/recv channels. - string channel_name_; - // Estimate of the duration of a host computation in nanoseconds. - int64 cost_estimate_ns_ = 0; -}; - class HloPadInstruction : public HloInstruction { public: explicit HloPadInstruction(const Shape& shape, HloInstruction* operand, @@ -1289,6 +1279,30 @@ class HloScatterInstruction : public HloInstruction { std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; }; +class HloIotaInstruction : public HloInstruction { + public: + explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + int64 iota_dimension() const { return iota_dimension_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + const int64 iota_dimension_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 8e0d38b6a6..8350285e67 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -17,20 +17,20 @@ limitations under the License. #include <unordered_map> +#include "absl/strings/escaping.h" +#include "absl/strings/numbers.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { - -using ::tensorflow::StringPiece; - namespace { +using absl::string_view; + constexpr int kEOF = -1; constexpr int kError = -2; @@ -66,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -tensorflow::StringPiece HloLexer::StringPieceFromPointers( - const char* begin, const char* end) const { +absl::string_view HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return tensorflow::StringPiece(begin, end - begin); + return absl::string_view(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -235,7 +235,7 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - tensorflow::StringPiece identifier = + absl::string_view identifier = StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. @@ -269,7 +269,7 @@ TokKind HloLexer::LexIdentifier() { } } - str_val_ = std::string(identifier); + str_val_ = string(identifier); return TokKind::kIdent; } @@ -306,8 +306,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), - &decimal_val_); + CHECK(absl::SimpleAtod(string(token_start_, current_ptr_), &decimal_val_)); return TokKind::kDecimal; } @@ -339,7 +338,7 @@ TokKind HloLexer::LexNumberOrPattern() { if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); auto slice = StringPieceFromPointers(token_start_, current_ptr_); - if (tensorflow::strings::safe_strto64(slice, &int64_val_)) { + if (absl::SimpleAtoi(slice, &int64_val_)) { return TokKind::kInt; } LOG(ERROR) << "Failed to parse int literal: " << slice; @@ -365,6 +364,7 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const { line_no = line_no_cache_.line_no_of_query; } for (; ptr != location; ptr++) { + CHECK_LT(ptr, buf_.end()); if (*ptr == '\n') { line_no++; } @@ -374,24 +374,24 @@ std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == tensorflow::StringPiece::npos) { + if (line_offset == absl::string_view::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { +absl::string_view HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == tensorflow::StringPiece::npos + const char* start = line_start == absl::string_view::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); const char* end = - line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; + line_end == absl::string_view::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -403,10 +403,14 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - tensorflow::StringPiece raw = + absl::string_view raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; - if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + // TODO(b/113077997): Change to absl::CUnescape once it works properly with + // copy-on-write std::string implementations. + if (!tensorflow::str_util::CUnescape( // non-absl ok + tensorflow::StringPiece(raw.data(), raw.size()), // non-absl ok + &str_val_, &error)) { LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; return TokKind::kError; } diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index 003ac34ace..3e2f8bcd52 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -18,10 +18,10 @@ limitations under the License. #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/types.h" @@ -34,7 +34,7 @@ namespace xla { // it directly. class HloLexer { public: - explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + explicit HloLexer(absl::string_view buf) : buf_(buf) { current_ptr_ = buf_.begin(); } @@ -77,7 +77,7 @@ class HloLexer { std::pair<unsigned, unsigned> GetLineAndColumn(LocTy location) const; // Returns the whole line given the location. - tensorflow::StringPiece GetLine(LocTy loc) const; + absl::string_view GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -89,8 +89,8 @@ class HloLexer { // Creates StringPiece with the given begin and end. Exits if the begin > end, // or it's out of the range of the current buffer. - tensorflow::StringPiece StringPieceFromPointers(const char* begin, - const char* end) const; + absl::string_view StringPieceFromPointers(const char* begin, + const char* end) const; tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( const char* begin, const char* end) const; @@ -107,11 +107,11 @@ class HloLexer { TokKind LexNumberOrPattern(); TokKind LexString(); - const tensorflow::StringPiece buf_; + const absl::string_view buf_; const char* current_ptr_; // Information about the current token. - const char* token_start_; + const char* token_start_ = nullptr; TokKind current_kind_; string str_val_; Shape shape_val_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 43c41ece6e..3a1dd471c6 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -17,8 +17,9 @@ limitations under the License. #include <deque> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -29,17 +30,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { using Worklist = std::deque<const HloInstruction*>; using Workset = std::unordered_set<const HloInstruction*>; -namespace { - void AddToWorklist(const HloInstruction* instruction, Worklist* worklist, Workset* workset) { if (workset->count(instruction) == 0) { @@ -296,7 +294,7 @@ StatusOr<std::unique_ptr<HloLivenessAnalysis>> HloLivenessAnalysis::Run( VLOG(1) << "HloLivenessAnalysis::Run on module " << module.name(); XLA_VLOG_LINES(2, module.ToString()); - auto liveness_analysis = WrapUnique(new HloLivenessAnalysis(module)); + auto liveness_analysis = absl::WrapUnique(new HloLivenessAnalysis(module)); liveness_analysis->RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 7e4b883435..5269cad94d 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -15,15 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace testing { -using ::tensorflow::str_util::Join; - bool HloMatcher::MatchAndExplain( const HloInstruction* instruction, ::testing::MatchResultListener* listener) const { @@ -210,8 +208,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong lhs_contracting_dimensions (got {" - << Join(dim_nums.lhs_contracting_dimensions(), ",") << "} want {" - << lhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",") + << "} want {" << lhs_contracting_dim_ << "})"; return false; } @@ -219,8 +217,8 @@ bool HloDotWithContractingDimsMatcher::MatchAndExplain( dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) { *listener << instruction->ToString() << " has wrong rhs_contracting_dimensions (got {" - << Join(dim_nums.rhs_contracting_dimensions(), ",") << "} want {" - << rhs_contracting_dim_ << "})"; + << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",") + << "} want {" << rhs_contracting_dim_ << "})"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c577b4359a..5502e565b6 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { namespace testing { @@ -120,8 +120,7 @@ class HloShapeAndLayoutMatcher class HloShardingMatcher : public ::testing::MatcherInterface<const HloInstruction*> { public: - explicit HloShardingMatcher( - const tensorflow::gtl::optional<HloSharding>& sharding) + explicit HloShardingMatcher(const absl::optional<HloSharding>& sharding) : sharding_(sharding) {} bool MatchAndExplain(const HloInstruction* instruction, @@ -129,7 +128,7 @@ class HloShardingMatcher void DescribeTo(std::ostream* os) const override; private: - tensorflow::gtl::optional<HloSharding> sharding_; + absl::optional<HloSharding> sharding_; }; // Matches a Dot HLO instruction with specific LHS and RHS contracting @@ -189,6 +188,7 @@ HLO_MATCHER(Fusion); HLO_MATCHER(Ge); HLO_MATCHER(AfterAll); HLO_MATCHER(Gt); +HLO_MATCHER(Iota); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); @@ -307,7 +307,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Shape( return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape)); } inline ::testing::Matcher<const ::xla::HloInstruction*> Shape( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -317,7 +317,7 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout( new ::xla::testing::HloShapeAndLayoutMatcher(shape)); } inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout( - tensorflow::StringPiece shape) { + absl::string_view shape) { return ::testing::MakeMatcher(new ::xla::testing::HloShapeAndLayoutMatcher( ShapeUtil::ParseShapeString(shape).ValueOrDie())); } @@ -330,14 +330,14 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding( } // Matcher for Sharding from sharding string inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding( - tensorflow::StringPiece sharding) { + absl::string_view sharding) { return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( ParseSharding(sharding).ValueOrDie())); } // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() { return ::testing::MakeMatcher( - new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt)); + new ::xla::testing::HloShardingMatcher(absl::nullopt)); } inline ::testing::Matcher<const ::xla::HloInstruction*> Dot( diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 55ff073d3f..78167335c8 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -22,12 +22,13 @@ limitations under the License. #include <unordered_set> #include <utility> +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -274,7 +275,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique<HloModule>(proto.name(), module_config); + auto module = absl::make_unique<HloModule>(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -409,7 +410,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( string error_message = "The subcomputation to outline has multiple outputs:\n"; for (HloInstruction* output : outputs) { - tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n"); + absl::StrAppend(&error_message, output->ToString(), "\n"); } LOG(FATAL) << error_message; } @@ -507,7 +508,7 @@ std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const { std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; - auto module = MakeUnique<HloModule>(name_ + "-" + suffix, config_); + auto module = absl::make_unique<HloModule>(name_ + "-" + suffix, config_); HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); @@ -535,12 +536,11 @@ uint64 HloModule::RandomNew64() const { return rng_(); } -HloComputation* HloModule::GetComputationWithName( - tensorflow::StringPiece name) { +HloComputation* HloModule::GetComputationWithName(absl::string_view name) { auto computations_in_module = computations(); - auto it = c_find_if(computations_in_module, [&](HloComputation* computation) { - return computation->name() == name; - }); + auto it = absl::c_find_if( + computations_in_module, + [&](HloComputation* computation) { return computation->name() == name; }); return it == computations_in_module.end() ? nullptr : *it; } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index d2e726a0db..cf129b835d 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -24,6 +24,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_clone_context.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -142,7 +142,7 @@ class HloModule { // Returns the computation in this module that has the name `name`. Returns // null if there is no such computation. - HloComputation* GetComputationWithName(tensorflow::StringPiece name); + HloComputation* GetComputationWithName(absl::string_view name); // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 07a8c798db..9bfa3a5f45 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -18,15 +18,15 @@ limitations under the License. #include <atomic> #include <vector> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrAppend; +using absl::StrAppend; HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape, bool ignore_layouts) @@ -39,15 +39,14 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = - tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled()); + string key = absl::StrCat("profiling=", hlo_profiling_enabled()); StrAppend(&key, "::("); std::vector<string> params; for (const ShapeLayout& param_layout : entry_computation_layout_->parameter_layouts()) { params.push_back(param_layout.shape().DebugString()); } - StrAppend(&key, tensorflow::str_util::Join(params, ", "), ") => ", + StrAppend(&key, absl::StrJoin(params, ", "), ") => ", entry_computation_layout_->result_shape().SerializeAsString()); if (seed() != 0) { // TODO(b/32083678): force recompilation to reset global state. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 074e9c9070..3f1e1cc73e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -18,11 +18,11 @@ limitations under the License. #include <string> +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -72,15 +72,6 @@ class HloModuleConfig { return debug_options_.xla_hlo_profile(); } - // Sets/returns whether this is a "host module". Host modules are used to - // record the data- and control-flow dependencies of host side computation - // that communicates with compiled code. They are used for analysis and - // scheduling purposes, but no code is generated. - bool is_host_module() const { return is_host_module_; } - void set_is_host_module(bool is_host_module) { - is_host_module_ = is_host_module; - } - // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } @@ -113,7 +104,7 @@ class HloModuleConfig { private: // If you add new members, be sure to update compilation_cache_key. - tensorflow::gtl::optional<ComputationLayout> entry_computation_layout_; + absl::optional<ComputationLayout> entry_computation_layout_; // Whether this is a 'host module'. bool is_host_module_ = false; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 29024085c1..12ca2340a6 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -31,7 +31,7 @@ namespace xla { class HloModuleDCE : public HloPassInterface { public: ~HloModuleDCE() override {} - tensorflow::StringPiece name() const override { return "hlo-module-dce"; } + absl::string_view name() const override { return "hlo-module-dce"; } // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 10bf9ffd6c..9c01862a4b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -19,9 +19,10 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -59,7 +60,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { /* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>> HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) { - auto metadata = MakeUnique<HloModuleGroupMetadata>(modules); + auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); } @@ -131,6 +132,14 @@ Status HloModuleGroupMetadata::Build() { if (VLOG_IS_ON(4)) { DumpCollectedStats(); } + + for (HloModule* module : modules_) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, + TuplePointsToAnalysis::Run(module)); + points_to_analyses_[module] = std::move(points_to_analysis); + } + return Status::OK(); } @@ -163,7 +172,7 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const { ss << " " << hlo->name() << std::endl; } ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); + return FailedPrecondition("%s", ss.str()); } } } @@ -204,6 +213,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel( return channels_[channel_id_map_.at(channel_id)]; } +bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const { + return channel_id_map_.find(channel_id) != channel_id_map_.end(); +} + HloComputation* HloModuleGroupMetadata::PeerComputation( const HloInstruction* instruction) const { CHECK(IsChannelInstruction(instruction)); @@ -267,15 +280,14 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } -tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice( +absl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice( const HloInstruction& instruction) const { // The module group metadata can be created in both "single module, multiple // devices" and "multiple modules, no explicit devices" fashions. // The API returns an optional even though the current implementation always // returns a device, to account for cases where we cannot guess a device. // In such cases the VerifyChannelInstructions() will return proper errors. - tensorflow::gtl::optional<int64> device = - instruction.sharding_unique_device(); + absl::optional<int64> device = instruction.sharding_unique_device(); if (!device) { device = GetModuleId(instruction.parent()->parent()); } @@ -283,10 +295,7 @@ tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice( } int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { - return std::count_if(modules_.begin(), modules_.end(), - [](const HloModule* module) { - return !module->config().is_host_module(); - }); + return modules_.size(); } Status HloModuleGroupMetadata::RecordInstructions() { @@ -383,7 +392,7 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - tensorflow::MakeUnique<std::unordered_set<HloInstruction*>>()); + absl::make_unique<std::unordered_set<HloInstruction*>>()); auto companion_set = companion_sets_.back().get(); companion_set->insert(instruction1); companion_set->insert(instruction2); @@ -411,16 +420,16 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, Status HloModuleGroupMetadata::VerifyChannelInstructions() { for (const Channel& channel : channels_) { if (channel.send == nullptr) { - return FailedPrecondition("missing send for id : %lld", channel.id); + return FailedPrecondition("missing send for id : %d", channel.id); } if (channel.recv == nullptr) { - return FailedPrecondition("missing recv for id : %lld", channel.id); + return FailedPrecondition("missing recv for id : %d", channel.id); } if (channel.send_done == nullptr) { - return FailedPrecondition("missing send-done for id : %lld", channel.id); + return FailedPrecondition("missing send-done for id : %d", channel.id); } if (channel.recv_done == nullptr) { - return FailedPrecondition("missing recv-done for id : %lld", channel.id); + return FailedPrecondition("missing recv-done for id : %d", channel.id); } } @@ -436,33 +445,33 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { auto send_done_device = GetInstructionDevice(*channel.send_done); if (!send_device) { return FailedPrecondition("send instruction must have a device: %s", - channel.send->ToString().c_str()); + channel.send->ToString()); } if (!send_done_device) { return FailedPrecondition("send_done instruction must have a device: %s", - channel.send_done->ToString().c_str()); + channel.send_done->ToString()); } if (*send_device != *send_done_device) { return FailedPrecondition( - "send and send-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "send and send-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *send_device, *send_done_device); } auto recv_device = GetInstructionDevice(*channel.recv); auto recv_done_device = GetInstructionDevice(*channel.recv_done); if (!recv_done_device) { return FailedPrecondition("recv_done instruction must have a device: %s", - channel.recv_done->ToString().c_str()); + channel.recv_done->ToString()); } if (*recv_device != *recv_done_device) { return FailedPrecondition( - "recv and recv-done (channel=%lld) must be on the same device: %lld " - "vs. %lld", + "recv and recv-done (channel=%d) must be on the same device: %d " + "vs. %d", channel.id, *recv_device, *recv_done_device); } if (*send_device == *recv_device) { return FailedPrecondition( - "send and recv (channel=%lld) must be on different devices: %lld", + "send and recv (channel=%d) must be on different devices: %d", channel.id, *send_device); } } @@ -483,7 +492,7 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { !CheckCompanionPathsCompatibility( path, GetCompanionsPath(channel.recv_done))) { return FailedPrecondition( - "Nest companion paths do not match for channel %lld", channel.id); + "Nest companion paths do not match for channel %d", channel.id); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 1b256cd00e..768b0c7eb3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,14 +22,15 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -125,6 +126,9 @@ class HloModuleGroupMetadata { // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns if the given channel id exists in metadata. + bool HasChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. const std::vector<HloInstruction*>& GetAllReduceGroup( int64 all_reduce_id) const; @@ -156,7 +160,7 @@ class HloModuleGroupMetadata { // Retrieves the device an instruction is assigned to. Either from the // sharding information, or from the ordinal of the module the instruction // is in. - tensorflow::gtl::optional<int64> GetInstructionDevice( + absl::optional<int64> GetInstructionDevice( const HloInstruction& instruction) const; // Returns the number of modules for devices (excluding the host module). @@ -194,6 +198,10 @@ class HloModuleGroupMetadata { // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } + TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { + return points_to_analyses_.at(module).get(); + } + private: Status Build(); @@ -268,6 +276,9 @@ class HloModuleGroupMetadata { // The modules that this metadata was built from. const std::vector<HloModule*>& modules_; + + tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>> + points_to_analyses_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 0dc5676148..d70328c8a3 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,7 +22,10 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -30,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( add_unique_predecessor(control_predecessor); } } - if (instruction->opcode() == HloOpcode::kRecvDone) { + if (instruction->opcode() == HloOpcode::kRecvDone && + !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) { // Send is a remote predecessor of RecvDone. HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_predecessor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) { // Recv is a remote predecessor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( add_unique_successor(control_successor); } } - if (instruction->opcode() == HloOpcode::kRecv) { + if (instruction->opcode() == HloOpcode::kRecv && + !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) { // Send is a remote successor of Recv. const HloInstruction* recv_done = instruction->users().front(); CHECK(recv_done->opcode() == HloOpcode::kRecvDone); HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send; add_unique_successor(send); } - if (instruction->opcode() == HloOpcode::kSend) { + if (instruction->opcode() == HloOpcode::kSend && + !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) { // RecvDone is a remote successor of Send. HloInstruction* recv_done = metadata_.GetChannel(instruction->channel_id()).recv_done; @@ -264,8 +270,8 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( string cyclic_instructions; for (const auto& state : *visit_state) { if (state.second == VisitState::kVisiting) { - tensorflow::strings::StrAppend(&cyclic_instructions, - state.first->ToString(), "\n"); + absl::StrAppend(&cyclic_instructions, state.first->ToString(), + "\n"); } } // TODO(b/64305524): Improve the error message to print out the @@ -276,7 +282,7 @@ Status HloModuleGroupUtil::VisitTopologicalOrder( "following nodes. Note that the order of the nodes is arbitrary " "and that the list may include nodes that are not part of the " "cycle.\n%s", - predecessor->ToString().c_str(), cyclic_instructions.c_str()); + predecessor->ToString(), cyclic_instructions); } stack.push(predecessor); } @@ -332,7 +338,7 @@ HloModuleGroupUtil::ComputeReachability( TF_RETURN_IF_ERROR( VisitTopologicalOrder(&visit_states, visit_function, root)); } - auto reachability = MakeUnique<HloReachabilityMap>(post_order); + auto reachability = absl::make_unique<HloReachabilityMap>(post_order); for (HloInstruction* hlo : post_order) { reachability->FastSetReachabilityToUnion(GlobalPredecessors(hlo), hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 236f450086..209ad5e58c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index d1eaf35785..2d4e38589f 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -39,7 +39,7 @@ StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) { }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { - return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + return InvalidArgument("Unknown opcode: %s", opcode_name); } return it->second; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ec279867e5..e6bfb8025d 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -58,6 +58,7 @@ namespace xla { V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil") \ V(kClamp, "clamp") \ + V(kCollectivePermute, "collective-permute") \ V(kClz, "count-leading-zeros") \ V(kComplex, "complex") \ V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ @@ -85,7 +86,6 @@ namespace xla { V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ - V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIota, "iota") \ @@ -156,7 +156,7 @@ enum HloOpcodeProperty { // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); -// Returns a string representation of the opcode. +// Retrieves the opcode enum by name if the opcode exists. StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name); inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6c1e015f77..0581d5c404 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -254,6 +254,10 @@ bool HloOrdering::LiveRangeStrictlyBefore( } // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { + if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), + use.instruction)) { + continue; + } if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of " << a << " (" << use << ") not before " << b << " is defined"; @@ -302,22 +306,20 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector<string> pieces; pieces.push_back(name); for (auto* computation : module_->MakeNonfusionComputations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); const auto all = computation->MakeInstructionPostOrder(); for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s predecessors:", instruction->name().c_str())); + pieces.push_back( + absl::StrFormat(" %s predecessors:", instruction->name())); for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) @@ -368,8 +370,8 @@ string SequentialHloOrdering::ToString() const { std::vector<string> pieces; pieces.push_back("SequentialHloOrdering"); for (auto* computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); + pieces.push_back( + absl::StrFormat("computation %s order:", computation->name())); // Gather all instructions in the module sequence for this computation and // sort them by their position. std::vector<const HloInstruction*> instructions; @@ -384,11 +386,10 @@ string SequentialHloOrdering::ToString() const { return order_position_.at(a) < order_position_.at(b); }); for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); + pieces.push_back(absl::StrFormat(" %s", instruction->name())); } } - return tensorflow::str_util::Join(pieces, "\n"); + return absl::StrJoin(pieces, "\n"); } std::ostream& operator<<( diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ab57a8b07f..eae4508b24 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -15,6 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" @@ -24,21 +30,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace { -using ::tensorflow::StringPiece; -using ::tensorflow::gtl::optional; -using ::tensorflow::str_util::Join; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsInts; -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::nullopt; +using absl::optional; +using absl::StrAppend; +using absl::StrCat; +using absl::StrFormat; +using absl::StrJoin; const double kF16max = 65504; @@ -47,7 +49,7 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(StringPiece str, const HloModuleConfig& config) + explicit HloParser(absl::string_view str, const HloModuleConfig& config) : lexer_(str), config_(config) {} // Runs the parser. Returns false if an error occurred. @@ -57,14 +59,28 @@ class HloParser { std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); } // Returns the error information. - string GetError() const { return Join(error_, "\n"); } + string GetError() const { return StrJoin(error_, "\n"); } // Stand alone parsing utils for various aggregate data types. StatusOr<HloSharding> ParseShardingOnly(); StatusOr<Window> ParseWindowOnly(); StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); + // Stand-alone parsing utility for a single instruction worth of text. + Status ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name); + private: + // Locates an instruction with the given name in the instruction_pool_ or + // returns nullptr. + // + // If the missing_instruction_hook_ is registered and a "shape" is provided, + // the hook will be called and may satisfy the request for the given + // instruction. This is useful when we reify parameters as they're resolved; + // i.e. for ParseSingleInstruction. + std::pair<HloInstruction*, LocTy>* FindInstruction( + const string& name, const optional<Shape>& shape = nullopt); + // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); @@ -138,6 +154,7 @@ class HloParser { kFusionKind, kDistribution, kDomain, + kPrecisionList, }; struct AttrConfig { @@ -203,6 +220,7 @@ class HloParser { bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad); bool ParseSliceRanges(SliceRanges* result); + bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); @@ -221,6 +239,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); + bool ParsePrecision(PrecisionConfigProto::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -233,8 +252,8 @@ class HloParser { bool CanBeParamListToShape(); // Logs the current parsing line and the given message. Always returns false. - bool TokenError(StringPiece msg); - bool Error(LocTy loc, StringPiece msg); + bool TokenError(absl::string_view msg); + bool Error(LocTy loc, absl::string_view msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -265,24 +284,55 @@ class HloParser { std::vector<std::unique_ptr<HloComputation>> computations_; const HloModuleConfig config_; std::vector<string> error_; + + // Function that gets invoked when we try to resolve an instruction + // instruction_pool_ but fail to do so. + std::function<std::pair<HloInstruction*, LocTy>*(string, + const optional<Shape>&)> + missing_instruction_hook_; }; -bool HloParser::Error(LocTy loc, StringPiece msg) { +bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { + for (const auto& split : absl::StrSplit(s, delim)) { + int64 val; + if (!absl::SimpleAtoi(split, &val)) { + return false; + } + out->push_back(val); + } + return true; +} + +// Creates replica groups from the provided nested array. groups[i] represents +// the replica ids for group 'i'. +std::vector<ReplicaGroup> CreateReplicaGroups( + tensorflow::gtl::ArraySlice<std::vector<int64>> groups) { + std::vector<ReplicaGroup> replica_groups; + absl::c_transform(groups, std::back_inserter(replica_groups), + [](const std::vector<int64>& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + return replica_groups; +} + +bool HloParser::Error(LocTy loc, absl::string_view msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; const unsigned col = line_col.second; std::vector<string> error_lines; error_lines.push_back( StrCat("was parsing ", line, ":", col, ": error: ", msg)); - error_lines.push_back(std::string(lexer_.GetLine(loc))); + error_lines.emplace_back(lexer_.GetLine(loc)); error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); - error_.push_back(Join(error_lines, "\n")); + error_.push_back(StrJoin(error_lines, "\n")); VLOG(1) << "Error: " << error_.back(); return false; } -bool HloParser::TokenError(StringPiece msg) { +bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } @@ -291,6 +341,17 @@ bool HloParser::Run() { return ParseHloModule(); } +std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( + const string& name, const optional<Shape>& shape) { + std::pair<HloInstruction*, LocTy>* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); + // Potentially call the missing instruction hook. + if (instr == nullptr && missing_instruction_hook_ != nullptr) { + return missing_instruction_hook_(name, shape); + } + return instr; +} + // ::= 'HloModule' name computations bool HloParser::ParseHloModule() { if (lexer_.GetKind() != TokKind::kw_HloModule) { @@ -304,7 +365,7 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique<HloModule>(name, config_); + module_ = absl::make_unique<HloModule>(name, config_); return ParseComputations(); } @@ -357,7 +418,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = MakeUnique<HloComputation::Builder>(name); + auto builder = absl::make_unique<HloComputation::Builder>(name); LocTy shape_loc = nullptr; Shape shape; @@ -370,8 +431,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - std::pair<HloInstruction*, LocTy>* root_node = - tensorflow::gtl::FindOrNull(instruction_pool_, root_name); + std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. if (!root_name.empty() && root_node == nullptr) { @@ -469,6 +529,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; + optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -498,11 +562,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kIota: { + optional<tensorflow::int64> iota_dimension; + attrs["iota_dimension"] = {/*required=*/true, AttrTy::kInt64, + &iota_dimension}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateIota(shape)); + instruction = builder->AddInstruction( + HloInstruction::CreateIota(shape, *iota_dimension)); break; } // Unary ops. @@ -597,31 +665,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional<std::vector<std::vector<int64>>> tmp_groups; optional<HloComputation*> to_apply; optional<std::vector<int64>> replica_group_ids; optional<string> barrier; optional<int64> all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - attrs["replica_group_ids"] = { - /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "", all_reduce_id)); - } else { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "", - all_reduce_id)); + std::vector<ReplicaGroup> replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); } + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, replica_groups, + barrier ? *barrier : "", all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -629,21 +695,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<string> barrier; attrs["replica_groups"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &tmp_groups}; - attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } std::vector<ReplicaGroup> replica_groups; if (tmp_groups) { - c_transform(*tmp_groups, std::back_inserter(replica_groups), - [](const std::vector<int64>& ids) { - ReplicaGroup group; - *group.mutable_replica_ids() = {ids.begin(), ids.end()}; - return group; - }); + replica_groups = CreateReplicaGroups(*tmp_groups); } - instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( - shape, operands, replica_groups, barrier ? *barrier : "")); + instruction = builder->AddInstruction( + HloInstruction::CreateAllToAll(shape, operands, replica_groups)); + break; + } + case HloOpcode::kCollectivePermute: { + optional<std::vector<std::vector<int64>>> source_targets; + attrs["source_target_pairs"] = { + /*required=*/true, AttrTy::kBracedInt64ListList, &source_targets}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + std::vector<std::pair<int64, int64>> pairs(source_targets->size()); + for (int i = 0; i < pairs.size(); i++) { + if ((*source_targets)[i].size() != 2) { + return TokenError( + "expects 'source_target_pairs=' to be a list of pairs"); + } + pairs[i].first = (*source_targets)[i][0]; + pairs[i].second = (*source_targets)[i][1]; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCollectivePermute(shape, operands[0], pairs)); break; } case HloOpcode::kReshape: { @@ -1177,20 +1258,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } break; } - case HloOpcode::kHostCompute: { - optional<string> channel_name; - optional<tensorflow::int64> cost_estimate_ns; - attrs["channel_name"] = {/*required=*/true, AttrTy::kString, - &channel_name}; - attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, - &cost_estimate_ns}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( - shape, operands, *channel_name, *cost_estimate_ns)); - break; - } case HloOpcode::kDot: { optional<std::vector<tensorflow::int64>> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { @@ -1346,6 +1413,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } + if (operand_precision) { + PrecisionConfigProto precision_config; + *precision_config.mutable_operand_precision() = {operand_precision->begin(), + operand_precision->end()}; + instruction->set_precision_config(precision_config); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1509,14 +1582,14 @@ bool HloParser::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = MakeUnique<HloSharding>( + auto entry_sharding_ptr = absl::make_unique<HloSharding>( HloSharding::FromProto(*entry_sharding).ValueOrDie()); - auto exit_sharding_ptr = MakeUnique<HloSharding>( + auto exit_sharding_ptr = absl::make_unique<HloSharding>( HloSharding::FromProto(*exit_sharding).ValueOrDie()); domain->entry_metadata = - MakeUnique<ShardingMetadata>(std::move(entry_sharding_ptr)); + absl::make_unique<ShardingMetadata>(std::move(entry_sharding_ptr)); domain->exit_metadata = - MakeUnique<ShardingMetadata>(std::move(exit_sharding_ptr)); + absl::make_unique<ShardingMetadata>(std::move(exit_sharding_ptr)); } else { return TokenError(StrCat("unsupported domain kind: ", *kind)); } @@ -1536,11 +1609,9 @@ bool HloParser::ParseInstructionNames( if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } - std::pair<HloInstruction*, LocTy>* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name); if (!instr) { - return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + return TokenError(StrFormat("instruction '%s' is not defined", name)); } instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); @@ -1769,10 +1840,10 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, std::vector<tensorflow::int64> elems_seen_until_dim( elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", - Join(elems_seen_until_dim, ",", - [](string* out, const tensorflow::int64& num_elems) { - StrAppend(out, num_elems - 1); - }), + StrJoin(elems_seen_until_dim, ",", + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); + }), "]"); }; do { @@ -1782,17 +1853,17 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, case TokKind::kLbrace: { nest_level++; if (nest_level > rank) { - return TokenError(Printf( - "expects nested array in rank %lld, but sees larger", rank)); + return TokenError(absl::StrFormat( + "expects nested array in rank %d, but sees larger", rank)); } if (nest_level > 1) { elems_seen_per_dim[nest_level - 2]++; if (elems_seen_per_dim[nest_level - 2] > shape.dimensions(nest_level - 2)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees more", + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees more", shape.dimensions(nest_level - 2), - get_index_str(nest_level - 2).c_str())); + get_index_str(nest_level - 2))); } } lexer_.Lex(); @@ -1801,9 +1872,9 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, case TokKind::kRbrace: { nest_level--; if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { - return TokenError(Printf( - "expects %lld elements in the %sth element, but sees %lld", - shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + return TokenError(absl::StrFormat( + "expects %d elements in the %sth element, but sees %d", + shape.dimensions(nest_level), get_index_str(nest_level), elems_seen_per_dim[nest_level])); } elems_seen_per_dim[nest_level] = 0; @@ -1824,15 +1895,15 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, if (rank > 0) { if (nest_level != rank) { return TokenError( - Printf("expects nested array in rank %lld, but sees %lld", rank, - nest_level)); + absl::StrFormat("expects nested array in rank %d, but sees %d", + rank, nest_level)); } elems_seen_per_dim[rank - 1]++; if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { - return TokenError( - Printf("expects %lld elements on the minor-most dimension, but " - "sees more", - shape.dimensions(rank - 1))); + return TokenError(absl::StrFormat( + "expects %d elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); } } if (lexer_.GetKind() == TokKind::kw_true || @@ -1925,7 +1996,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, tensorflow::int64 rank = ShapeUtil::Rank(shape); - *literal = MakeUnique<Literal>(shape); + *literal = absl::make_unique<Literal>(shape); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of a sparse literal")) { @@ -1959,7 +2030,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, return Error( index_loc, StrCat("invalid multi-dimension index for shape with rank ", rank, - ": [", Join(index, ", "), "]")); + ": [", StrJoin(index, ", "), "]")); } } if (!ParseToken(TokKind::kColon, @@ -2020,6 +2091,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, // ::= operand (, operand)* // operand ::= (shape)? name bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { + CHECK(operands != nullptr); if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { return false; @@ -2030,9 +2102,10 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { do { LocTy loc = lexer_.GetLoc(); string name; + optional<Shape> shape; if (CanBeShape()) { - Shape shape; - if (!ParseShape(&shape)) { + shape.emplace(); + if (!ParseShape(&shape.value())) { return false; } } @@ -2040,8 +2113,8 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { return false; } std::pair<HloInstruction*, LocTy>* instruction = - tensorflow::gtl::FindOrNull(instruction_pool_, name); - if (!instruction) { + FindInstruction(name, shape); + if (instruction == nullptr) { return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction->first); @@ -2052,6 +2125,7 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands, const int expected_size) { + CHECK(operands != nullptr); LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; @@ -2085,8 +2159,8 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("sub-attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("sub-attribute %s is expected but not seen", + attr_it.first)); } } return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); @@ -2106,8 +2180,8 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return Error(loc, Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); + return Error(loc, StrFormat("attribute %s is expected but not seen", + attr_it.first)); } } return true; @@ -2123,7 +2197,7 @@ bool HloParser::ParseAttributeHelper( } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return Error(loc, Printf("attribute %s already exists", name.c_str())); + return Error(loc, StrFormat("attribute %s already exists", name)); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { @@ -2133,13 +2207,13 @@ bool HloParser::ParseAttributeHelper( } else { allowed_attrs = StrCat( "Allowed attributes: ", - Join(attrs, ", ", - [&](string* out, const std::pair<string, AttrConfig>& kv) { - StrAppend(out, kv.first); - })); + StrJoin(attrs, ", ", + [&](string* out, const std::pair<string, AttrConfig>& kv) { + StrAppend(out, kv.first); + })); } - return Error(loc, Printf("unexpected attribute \"%s\". %s", name.c_str(), - allowed_attrs.c_str())); + return Error(loc, StrFormat("unexpected attribute \"%s\". %s", name, + allowed_attrs)); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; @@ -2321,10 +2395,20 @@ bool HloParser::ParseAttributeHelper( case AttrTy::kDomain: { return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); } + case AttrTy::kPrecisionList: { + std::vector<PrecisionConfigProto::Precision> result; + if (!ParsePrecisionList(&result)) { + return false; + } + static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return Error(loc, Printf("error parsing attribute %s", name.c_str())); + return Error(loc, StrFormat("error parsing attribute %s", name)); } return true; } @@ -2439,20 +2523,24 @@ bool HloParser::ParseConvolutionDimensionNumbers( } string str = lexer_.GetStrVal(); - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // The str is expected to have 3 items, lhs, rhs, out, and it must look like // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector<string> lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { + std::vector<string> split1 = absl::StrSplit(str, "_"); + if (split1.size() != 2) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + std::vector<string> split2 = absl::StrSplit(split1[1], "->"); + if (split2.size() != 2) { LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " << str; } + absl::string_view lhs = split1[0]; + absl::string_view rhs = split2[0]; + absl::string_view out = split2[1]; - const tensorflow::int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + const tensorflow::int64 rank = lhs.length(); + if (rank != rhs.length() || rank != out.length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); } @@ -2467,8 +2555,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // lhs { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { + if (!is_unique(string(lhs))) { return TokenError( StrCat("expects unique lhs dimension numbers, but sees ", lhs)); } @@ -2485,14 +2572,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_input_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dbf] in lhs dimension numbers", rank - 1)); } } } // rhs { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { + if (!is_unique(string(rhs))) { return TokenError( StrCat("expects unique rhs dimension numbers, but sees ", rhs)); } @@ -2509,14 +2595,13 @@ bool HloParser::ParseConvolutionDimensionNumbers( dnums->set_kernel_spatial_dimensions(c - '0', i); } else { return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + StrFormat("expects [0-%dio] in rhs dimension numbers", rank - 1)); } } } // output { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { + if (!is_unique(string(out))) { return TokenError( StrCat("expects unique output dimension numbers, but sees ", out)); } @@ -2532,8 +2617,8 @@ bool HloParser::ParseConvolutionDimensionNumbers( } else if (c < '0' + rank && c >= '0') { dnums->set_output_spatial_dimensions(c - '0', i); } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + return TokenError(StrFormat( + "expects [0-%dbf] in output dimension numbers", rank - 1)); } } } @@ -2579,9 +2664,10 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { } const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return Error(loc, Printf("expects [start:limit:step] or [start:limit], " - "but sees %ld elements.", - range.size())); + return Error(loc, + StrFormat("expects [start:limit:step] or [start:limit], " + "but sees %d elements.", + range.size())); } } while (EatIfPresent(TokKind::kComma)); @@ -2593,6 +2679,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } +// precisionlist ::= start precision_elements end +// precision_elements +// ::= /*empty*/ +// ::= precision_val (delim precision_val)* +bool HloParser::ParsePrecisionList( + std::vector<PrecisionConfigProto::Precision>* result) { + auto parse_and_add_item = [&]() { + PrecisionConfigProto::Precision item; + if (!ParsePrecision(&item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2749,14 +2853,13 @@ bool HloParser::ParseDxD(const string& name, std::vector<tensorflow::int64>* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return Error(loc, - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, StrFormat("sub-attribute '%s=' already exists", name)); } // 1D if (lexer_.GetKind() == TokKind::kInt) { tensorflow::int64 number; if (!ParseInt64(&number)) { - return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, StrFormat("expects sub-attribute '%s=i'", name)); } result->push_back(number); return true; @@ -2764,9 +2867,8 @@ bool HloParser::ParseDxD(const string& name, // 2D or higher. if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { - return Error(loc, - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + if (!SplitToInt64s(str, 'x', result)) { + return Error(loc, StrFormat("expects sub-attribute '%s=ixj...'", name)); } lexer_.Lex(); return true; @@ -2784,10 +2886,9 @@ bool HloParser::ParseWindowPad( return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); } string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector<tensorflow::int64> low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + if (!SplitToInt64s(padding_dim_str, '_', &low_high) || low_high.size() != 2) { return Error(loc, "expects padding_low and padding_high separated by '_'"); @@ -2808,10 +2909,9 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { } LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { + for (const auto& padding_dim_str : absl::StrSplit(str, 'x')) { std::vector<tensorflow::int64> padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + if (!SplitToInt64s(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, "expects padding config pattern like 'low_high_interior' or " @@ -2863,9 +2963,8 @@ bool HloParser::ParseOpcode(HloOpcode* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToHloOpcode(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects opcode but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects opcode but sees: %s, error: %s", val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2879,7 +2978,7 @@ bool HloParser::ParseFftType(FftType* result) { } string val = lexer_.GetStrVal(); if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { - return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + return TokenError(StrFormat("expects fft type but sees: %s", val)); } lexer_.Lex(); return true; @@ -2893,9 +2992,9 @@ bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { string val = lexer_.GetStrVal(); auto status_or_result = StringToFusionKind(val); if (!status_or_result.ok()) { - return TokenError( - Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), - status_or_result.status().error_message().c_str())); + return TokenError(StrFormat("expects fusion kind but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -2911,8 +3010,25 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { auto status_or_result = StringToRandomDistribution(val); if (!status_or_result.ok()) { return TokenError( - Printf("expects random distribution but sees: %s, error: %s", - val.c_str(), status_or_result.status().error_message().c_str())); + StrFormat("expects random distribution but sees: %s, error: %s", val, + status_or_result.status().error_message())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { + VLOG(1) << "ParsePrecision"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToPrecision(val); + if (!status_or_result.ok()) { + return TokenError(StrFormat("expects precision but sees: %s, error: %s", + val, + status_or_result.status().error_message())); } *result = status_or_result.ValueOrDie(); lexer_.Lex(); @@ -3006,7 +3122,7 @@ StatusOr<HloSharding> HloParser::ParseShardingOnly() { lexer_.Lex(); OpSharding op_sharding; if (!ParseSharding(&op_sharding)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); @@ -3018,7 +3134,7 @@ StatusOr<Window> HloParser::ParseWindowOnly() { lexer_.Lex(); Window window; if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after window"); @@ -3031,7 +3147,7 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { lexer_.Lex(); ConvolutionDimensionNumbers dnums; if (!ParseConvolutionDimensionNumbers(&dnums)) { - return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument( @@ -3040,37 +3156,83 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } +Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name) { + TF_RET_CHECK(missing_instruction_hook_ == nullptr); + + // The missing instruction hook we register creates the shaped instruction on + // the fly as a parameter and returns it. + int64 parameter_count = 0; + missing_instruction_hook_ = + [this, builder, ¶meter_count]( + string name, + const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + StrCat("Operand ", name, + " had no shape in HLO text; cannot create parameter for " + "single-instruction module.")); + return nullptr; + } + HloInstruction* parameter = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_count++, *shape, name)); + instruction_pool_[name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(instruction_pool_, name); + }; + + // Prime the lexer. + lexer_.Lex(); + + // Parse the instruction with the registered hook. + if (!ParseInstruction(builder, root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + return Status::OK(); +} + } // namespace StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config) { + absl::string_view str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError()); } return parser.ConsumeHloModule(); } -StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str) { +StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) { HloModuleConfig config; return ParseHloString(str, config); } -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) { +StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( + absl::string_view str, absl::string_view name) { + HloModuleConfig config; + HloParser parser(str, config); + auto builder = absl::make_unique<HloComputation::Builder>(string(name)); + string root_name; + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); + std::unique_ptr<HloComputation> computation = builder->Build(); + auto module = absl::make_unique<HloModule>(string(name), config); + module->AddEntryComputation(std::move(computation)); + return std::move(module); +} + +StatusOr<HloSharding> ParseSharding(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseShardingOnly(); } -StatusOr<Window> ParseWindow(tensorflow::StringPiece str) { +StatusOr<Window> ParseWindow(absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseWindowOnly(); } StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str) { + absl::string_view str) { HloModuleConfig config; HloParser parser(str, config); return parser.ParseConvolutionDimensionNumbersOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3f3a51215e..0c64b50481 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_lexer.h" @@ -32,27 +33,31 @@ namespace xla { // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str, const HloModuleConfig& config); + absl::string_view str, const HloModuleConfig& config); + +// Parses the text for a single HLO operation into an HLO module with a function +// that runs that operation (with the same parameters) as its entry computation. +StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( + absl::string_view str, absl::string_view name = "single_op"); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr<std::unique_ptr<HloModule>> ParseHloString( - tensorflow::StringPiece str); +StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str); +StatusOr<HloSharding> ParseSharding(absl::string_view str); // Parses the result of window_util::ToString(const Window&). -StatusOr<Window> ParseWindow(tensorflow::StringPiece str); +StatusOr<Window> ParseWindow(absl::string_view str); // Parses the result of ConvolutionDimensionNumbersToString(), e.g. // "b0f_0io->b0f". StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( - tensorflow::StringPiece str); + absl::string_view str); // ParseHloString sharding from str. str is supposed to contain the body of the // sharding, i.e. just the rhs of the "sharding={...}" attribute string. -StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str); +StatusOr<HloSharding> ParseSharding(absl::string_view str); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 0d7919346b..ba07ec432e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -16,17 +16,19 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_parser.h" #include <string> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { - namespace { -using ::tensorflow::StringPiece; +namespace op = ::xla::testing::opcode_matchers; +using absl::string_view; struct TestData { string test_name; @@ -1049,7 +1051,7 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add } )" @@ -1067,7 +1069,7 @@ add { ENTRY CrossReplicaSumWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" @@ -1091,7 +1093,19 @@ R"(HloModule AllToAllWithSubgroups ENTRY AllToAllWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}, barrier="abc" + ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}} +} + +)" +}, +// collective-permute +{ +"CollectivePermute", +R"(HloModule CollectivePermute + +ENTRY CollectivePermute { + input = f32[128,32]{0,1} parameter(0) + ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} } )" @@ -1102,7 +1116,7 @@ ENTRY AllToAllWithSubgroups { R"(HloModule iota ENTRY Iota { - ROOT iota = f32[100]{0} iota() + ROOT iota = f32[100]{0} iota(), iota_dimension=0 } )" @@ -1125,8 +1139,8 @@ ENTRY Computation { class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface<TestData> { protected: - static void ExpectHasSubstr(StringPiece s, StringPiece expected) { - EXPECT_TRUE(tensorflow::str_util::StrContains(s, expected)) + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -1390,15 +1404,14 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); ExpectHasSubstr( - ParseHloString(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) + ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); @@ -1722,5 +1735,26 @@ ENTRY nontuple_infeed { "infeed must have a non-empty tuple shape"); } +TEST(HloParserSingleOpTest, SingleOp) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " + "f32[2,4]{1,0} %x)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { + const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; + StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT( + module.status().ToString(), + ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index 0cddf8fb8f..f1ad0f9b01 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -29,7 +29,7 @@ namespace xla { class HloPassInterface { public: virtual ~HloPassInterface() = default; - virtual tensorflow::StringPiece name() const = 0; + virtual absl::string_view name() const = 0; // Run the pass on the given HLO module. Return whether it modified the // module. diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index d8f1ab916b..6e4ed0de62 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,22 +17,23 @@ limitations under the License. #include <functional> +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { + +using absl::StrAppend; +using absl::StrCat; + void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; @@ -48,9 +49,9 @@ void DumpModuleProto(const HloModule& module, const string& dump_to, tensorflow::mutex_lock lock(mu); const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; - const string mod_name = SanitizeFileName(tensorflow::strings::Printf( - "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, - pipeline_name.c_str(), pass_name.c_str())); + const string mod_name = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, pipeline_name, pass_name)); TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), dump_to, mod_name)); @@ -68,7 +69,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { repeated_field.end()); if (!disabled_passes.empty()) { VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << tensorflow::str_util::Join(disabled_passes, ", "); + << absl::StrJoin(disabled_passes, ", "); } auto run_invariant_checkers = [this, @@ -90,7 +91,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { return Status::OK(); }; - string prefix = std::string(name()) + ": pipeline start"; + string prefix = StrCat(name(), ": pipeline start"); bool changed = false; string message; TF_RETURN_IF_ERROR( @@ -98,12 +99,12 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { const string xla_dump_per_pass_hlo_proto_to = module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), "pipeline_start"); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + "pipeline_start"); } for (auto& pass : passes_) { - if (disabled_passes.count(std::string(pass->name())) > 0) { + if (disabled_passes.count(string(pass->name())) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() << ", disabled by --xla_disable_hlo_passes"; continue; @@ -120,8 +121,8 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) { TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, - std::string(name()), std::string(pass->name())); + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), + string(pass->name())); } changed |= changed_this_pass; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index a42d7e59fe..1d41a4dac1 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -21,7 +21,7 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,7 +34,7 @@ namespace xla { class HloPassPipeline : public HloPassInterface { public: explicit HloPassPipeline(const string& name) : name_(name) {} - tensorflow::StringPiece name() const override { return name_; } + absl::string_view name() const override { return name_; } // Add a pass to the pipeline. It should be called with the arguments for the // pass constructor: diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index b9cca13870..c3cacd7ce6 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index cf0be30c7a..569d2e5d2d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,6 +20,10 @@ limitations under the License. #include <set> #include <string> +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -37,17 +41,13 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Potential optimizations: // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue // of candidates. @@ -88,7 +88,7 @@ bool CanBeRematerialized( // Type holding a unique identifier for each Buffer object. using BufferId = int64; -using BufferIdList = tensorflow::gtl::InlinedVector<BufferId, 3>; +using BufferIdList = absl::InlinedVector<BufferId, 3>; // We wrap HloInstruction* with an Item that holds auxiliary // per-instruction state. @@ -123,7 +123,7 @@ struct Item { int64 position; }; -using ItemList = tensorflow::gtl::InlinedVector<Item*, 3>; +using ItemList = absl::InlinedVector<Item*, 3>; // Class which maintains an ordered list of instructions with fast insertion // before arbitrary elements. @@ -206,11 +206,10 @@ class InstructionList { Item* to_insert, tensorflow::gtl::ArraySlice<Item*> before_instructions) { VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name() << " before {" - << tensorflow::str_util::Join(before_instructions, ", ", - [](string* out, Item* item) { - tensorflow::strings::StrAppend( - out, item->instruction->name()); - }) + << absl::StrJoin(before_instructions, ", ", + [](string* out, Item* item) { + absl::StrAppend(out, item->instruction->name()); + }) << "}"; // Find the minimal position number of any instruction in @@ -393,10 +392,9 @@ class MemoryUsageTracker { int64 unfinished_user_count; string ToString() const { - return tensorflow::strings::StrCat( - "Buffer ", id, " (defined by ", - defining_instruction->instruction->name(), ", size ", size, - " bytes)"); + return absl::StrCat("Buffer ", id, " (defined by ", + defining_instruction->instruction->name(), ", size ", + size, " bytes)"); } }; @@ -740,29 +738,27 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item, } string MemoryUsageTracker::ToString() const { - string output = tensorflow::strings::StrCat("MemoryUsageTracker for ", - computation_->name(), "\n"); - tensorflow::strings::StrAppend( - &output, "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", - memory_usage(), " bytes)"); + string output = + absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n"); + absl::StrAppend(&output, + "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (", + memory_usage(), " bytes)"); for (auto* item = instruction_list_.first(); item != nullptr; item = instruction_list_.next(item)) { const HloInstruction* instruction = item->instruction; string inprogress = item == in_progress_item_ ? " in-progress" : ""; string placed = item->placed ? " placed" : ""; - tensorflow::strings::StrAppend(&output, " ", instruction->name(), - inprogress, placed, "\n Defines:\n"); + absl::StrAppend(&output, " ", instruction->name(), inprogress, placed, + "\n Defines:\n"); for (BufferId buffer_id : item->buffers_defined) { const Buffer& buffer = buffers_[buffer_id]; string live = IsCurrentlyLive(buffer_id) ? " live" : ""; - tensorflow::strings::StrAppend(&output, " ", buffer.ToString(), live, - ", ", buffer.unfinished_user_count, - " unfinished uses\n"); + absl::StrAppend(&output, " ", buffer.ToString(), live, ", ", + buffer.unfinished_user_count, " unfinished uses\n"); } - tensorflow::strings::StrAppend(&output, " Uses:\n"); + absl::StrAppend(&output, " Uses:\n"); for (BufferId buffer_id : item->buffers_used) { - tensorflow::strings::StrAppend(&output, " ", - buffers_[buffer_id].ToString(), "\n"); + absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n"); } } return output; @@ -780,10 +776,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( defined_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); for (const Buffer& buffer : buffers_) { @@ -803,10 +798,9 @@ bool MemoryUsageTracker::Check() const { CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " - << tensorflow::str_util::Join( + << absl::StrJoin( used_buffers, ", ", [this](string* out, BufferId buffer_id) { - tensorflow::strings::StrAppend( - out, buffers_.at(buffer_id).ToString()); + absl::StrAppend(out, buffers_.at(buffer_id).ToString()); }); } for (const Buffer& buffer : buffers_) { @@ -1209,6 +1203,49 @@ StatusOr<bool> HloRematerialization::Run( VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); + XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + + // Create initial sequence of HLO instructions. + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( + *module, + [this](const BufferValue& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); + if (copy_insertion) { + // We run a separate pass of copy elision here because the sequential + // ordering from the HLO schedule allows for more copies to be eliminated. + // TODO(b/80249101): Instead of a separate copy elision pass, use the + // ordering from the HLO schedule directly for copy insertion. + + // First create a copy of the schedule which contains HloInstruction unique + // ids instead of HloInstruction*. This is necessary for updating the + // schedule below. + // TODO(b/113175018): Remove this when the HLO schedule is self-contained + // and can update itself. + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(*sequence); + + SequentialHloOrdering ordering(module, *sequence); + TF_RETURN_IF_ERROR( + copy_insertion->RemoveUnnecessaryCopies(ordering, module)); + + // RemoveUnnecessaryCopies only considers interference when determining + // whether it is legal to remove a copy. However, copies in the graph may be + // necessary for other reason such as preventing a constant from being live + // out of the graph. So run AddSpecialCaseCopies to re-insert these copies. + // TODO(b/80249101): Break copy insertion into several passes and run each + // one once in the regular HLO pipeline. + TF_RETURN_IF_ERROR(copy_insertion->AddSpecialCaseCopies(module)); + + // The passes above can add and remove copies, update the schedule to + // account for these transformations. Newly added instructions will be + // placed ASAP in the schedule. + TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + + TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( + SequentialHloOrdering(module, *sequence), module)); + } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1230,24 +1267,6 @@ StatusOr<bool> HloRematerialization::Run( << HumanReadableNumBytes(module_output_size) << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes); - XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, - [this](const BufferValue& buffer) { - return size_function_(buffer.shape()); - }, - scheduler_algorithm_)); - if (copy_insertion) { - // We run a separate pass of copy elision here because the sequential - // ordering from the HLO schedule allows for more copies to be eliminated. - // TODO(b/80249101): Instead of a separate copy elision pass, use the - // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(module, *sequence); - TF_RETURN_IF_ERROR( - copy_insertion->RemoveUnnecessaryCopies(ordering, module)); - } - // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1334,12 +1353,11 @@ StatusOr<bool> HloRematerialization::Run( XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString()); if (current_peak_memory > memory_limit_bytes) { - LOG(WARNING) << tensorflow::strings::Printf( - "Can't reduce memory use below %s (%lld bytes) by rematerialization; " - "only reduced to %s (%lld bytes)", - HumanReadableNumBytes(memory_limit_bytes).c_str(), memory_limit_bytes, - HumanReadableNumBytes(current_peak_memory).c_str(), - current_peak_memory); + LOG(WARNING) << absl::StrFormat( + "Can't reduce memory use below %s (%d bytes) by rematerialization; " + "only reduced to %s (%d bytes)", + HumanReadableNumBytes(memory_limit_bytes), memory_limit_bytes, + HumanReadableNumBytes(current_peak_memory), current_peak_memory); } return changed; diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b2725e2918..7bd8a4a544 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include <utility> +#include "absl/memory/memory.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +32,7 @@ limitations under the License. namespace xla { /*static*/ StatusOr<std::unique_ptr<HloModule>> -HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, +HloRunner::CreateModuleFromString(const absl::string_view hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); @@ -233,7 +233,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( int64 device = device_assignment(i, 0); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device)); - streams.push_back(MakeUnique<se::Stream>(executor)); + streams.push_back(absl::make_unique<se::Stream>(executor)); streams.back()->Init(); service_run_options.emplace_back(GetServiceRunOptionsForDevice( device, streams.back().get(), &device_assignment)); @@ -260,7 +260,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( num_threads += options.num_replicas; } if (num_threads > 0) { - pool = MakeUnique<tensorflow::thread::ThreadPool>( + pool = absl::make_unique<tensorflow::thread::ThreadPool>( tensorflow::Env::Default(), "infeed_outfeed", /*num_threads=*/num_threads); } @@ -291,7 +291,7 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated( VLOG(1) << "Starting outfeed on device " << device; for (int64 step = 1; options.infeed_steps < 0 || step <= options.infeed_steps; ++step) { - auto literal = MakeUnique<Literal>(); + auto literal = absl::make_unique<Literal>(); TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromOutfeed( executor, options.outfeed_shape, literal.get())); if (options.outfeed_values != nullptr) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 65537f07f5..cfc519063e 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -87,8 +87,7 @@ class HloRunner { // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static StatusOr<std::unique_ptr<HloModule>> CreateModuleFromString( - const tensorflow::StringPiece hlo_string, - const DebugOptions& debug_options); + const absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 27cc5361cd..0fc3b268c0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include <map> +#include <queue> #include <utility> #include <vector> @@ -28,16 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -using ::tensorflow::strings::HumanReadableNumBytes; - namespace xla { - namespace { +using ::tensorflow::strings::HumanReadableNumBytes; + // Class implementing a list scheduler of HLO instructions which produces a // sequence which minimizes memory usage by preferring to schedule the node that // frees bigger buffer and defines smaller outputs. @@ -582,4 +581,187 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( size_function, nullptr, empty_map); } +tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> id_sequence; + for (const auto& computation_sequence : sequence) { + for (const HloInstruction* instruction : computation_sequence.second) { + id_sequence[computation_sequence.first].push_back( + instruction->unique_id()); + } + } + return id_sequence; +} + +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence) { + // Map from unique ID to HloInstruction pointer for instructions in the + // module. + tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction; + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet<int> ids_in_schedule; + std::vector<HloComputation*> nonfusion_computations = + module.MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK( + id_to_instruction.insert({instruction->unique_id(), instruction}) + .second); + } + for (int id : id_sequence.at(computation)) { + ids_in_schedule.insert(id); + } + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // module, but not in schedule) which use X. If an instruction is not in the + // map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap<const HloInstruction*, + std::vector<const HloInstruction*>> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap<const HloInstruction*, int> + unscheduled_operand_count; + // For each computation, this is the set of newly added instructions which + // have no operands. These must be handled specially and are added to the + // beginning of the schedule. + tensorflow::gtl::FlatMap<const HloComputation*, + std::vector<const HloInstruction*>> + new_zero_operand_instructions; + for (const HloComputation* computation : nonfusion_computations) { + new_zero_operand_instructions[computation] = {}; + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + if (instruction->operands().empty()) { + new_zero_operand_instructions[computation].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + for (const HloComputation* computation : nonfusion_computations) { + std::vector<const HloInstruction*> old_computation_sequence = + std::move(sequence->at(computation)); + sequence->at(computation).clear(); + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue<const HloInstruction*> worklist; + for (const HloInstruction* instruction : + new_zero_operand_instructions.at(computation)) { + worklist.push(instruction); + } + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + sequence->at(computation).push_back(instruction); + std::vector<const HloInstruction*>* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : id_sequence.at(computation)) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. + continue; + } + const HloInstruction* instruction = it->second; + worklist.push(instruction); + schedule_worklist(); + } + } + + TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); + return Status::OK(); +} + +Status VerifySchedule( + const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence) { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(2, module.ToString()); + VLOG(2) << sequence; + + // Verify the set of computations in the sequence is exactly the set of + // computations in the module. + std::vector<HloComputation*> nonfusion_computations = + module.MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); + tensorflow::gtl::FlatSet<const HloComputation*> computations_in_module( + module.computations().begin(), module.computations().end()); + for (const auto& computation_sequence : sequence) { + TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position; + int pos = 0; + for (const HloInstruction* instruction : sequence.at(computation)) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 2b33ccc8bf..d06b8d9a5c 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -85,6 +85,43 @@ StatusOr<std::vector<const HloInstruction*>> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// Transforms the given schedule such that it is (again) a valid schedule for +// the module. This is used to update a schedule after the HLO module has been +// transformed in some way. In general, the only transformations to the module +// for which a schedule can be updated is the addition or removal of +// instructions to/from the module. Updating the schedule after new dependencies +// between existing instructions in the module is not supported and may result +// in an error status returned. +// +// Instructions in the module which also exist in the given schedule will remain +// in the same order in the updated schedule. Instructions which exist in the +// module but not in the given schedule will be placed as early as possible in +// the updated schedule. +// +// 'id_sequence' is a mirror of the given schedule 'sequence' but with +// HloInstruction ids rather than HloInstruction pointers. This should be +// constructed using ComputeIdSchedule below after the schedule is constructed +// but before the HLO module is transformed. +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence); + +// Constructs a copy of the given schedule but with HloInstruction unique ids +// rather than HloInstruction pointers. This is necessary for updating a +// schedule as HloInstruction points in the schedule may become invalid if +// instructions are removed from the module. Used by UpdateSchedule above.. +// TODO(b/113175018): Remove this function when HLO schedule is its own class. +tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); + +// Verifies that the given schedule is valid for the given module. Specifically, +// the schedule contains exactly the instructions in the module and every +// dependency in the module is satisfied in the schedule. +Status VerifySchedule(const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 9ec983c2bc..930801288a 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -28,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -244,9 +246,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The max mem doesn't change - // because the while body isn't live during the peak. - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); @@ -350,7 +352,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto module = CreateNewModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); // param != 0 // Needs 17 bytes @@ -408,12 +409,259 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations - EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } +TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(sequence); + std::vector<const HloInstruction*> entry_schedule = sequence.begin()->second; + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(entry_schedule, sequence.begin()->second); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(sequence); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), + hlo) != sequence.at(entry).end(); + }; + + EXPECT_EQ(sequence.at(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(sequence); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 4); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(sequence); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 3); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 2); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + tensorflow::gtl::FlatMap<const HloComputation*, std::vector<int>> + id_sequence = ComputeIdSchedule(sequence); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(body).size(), 7); + EXPECT_EQ(sequence.at(cond).size(), 4); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(body).size(), 1); + EXPECT_EQ(sequence.at(cond).size(), 5); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 0cba9ebbcb..980dae07ce 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -15,13 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; +using absl::StrCat; +using absl::StrJoin; HloSharding HloSharding::AssignDevice(int64 device_id) { return HloSharding(device_id); @@ -71,12 +72,9 @@ HloSharding HloSharding::SingleTuple(const Shape& tuple_shape, const HloSharding& sharding) { CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); CHECK(!sharding.IsTuple()) << sharding.ToString(); - int64 leaf_count = ShapeUtil::GetLeafCount(tuple_shape); + int64 leaf_count = RequiredLeaves(tuple_shape); std::vector<HloSharding> flattened_list; - flattened_list.reserve(leaf_count); - for (int64 i = 0; i < leaf_count; ++i) { - flattened_list.push_back(sharding); - } + flattened_list.resize(leaf_count, sharding); return HloSharding(flattened_list); } @@ -92,7 +90,7 @@ string HloSharding::ToString() const { for (const HloSharding& element : tuple_elements_) { parts.push_back(element.ToString()); } - return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + return StrCat("{", absl::StrJoin(parts, ", "), "}"); } if (replicated_) { @@ -101,8 +99,8 @@ string HloSharding::ToString() const { return StrCat( "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}"); } else { - return StrCat("{devices=[", Join(tile_assignment_.dimensions(), ","), "]", - Join(tile_assignment_, ","), "}"); + return StrCat("{devices=[", StrJoin(tile_assignment_.dimensions(), ","), + "]", StrJoin(tile_assignment_, ","), "}"); } } @@ -244,16 +242,16 @@ StatusOr<HloSharding> HloSharding::GetTupleSharding(const Shape& shape) const { return Tuple(ShapeTree<HloSharding>(shape, *this)); } -tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const { +absl::optional<int64> HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } - tensorflow::gtl::optional<int64> unique_device; + absl::optional<int64> unique_device; for (auto& tuple_sharding : tuple_elements_) { auto device = tuple_sharding.UniqueDevice(); if (!device || (unique_device && *device != *unique_device)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } unique_device = device; } @@ -262,7 +260,7 @@ tensorflow::gtl::optional<int64> HloSharding::UniqueDevice() const { if (!replicated_ && maximal_) { return static_cast<int64>(*tile_assignment_.begin()); } - return tensorflow::gtl::nullopt; + return absl::nullopt; } int64 HloSharding::GetUniqueDevice() const { @@ -439,14 +437,13 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, : sub_shape_tree.element(ShapeIndex({})); } -tensorflow::gtl::optional<HloSharding> HloSharding::ExtractSingleSharding() - const { +absl::optional<HloSharding> HloSharding::ExtractSingleSharding() const { if (!IsTuple()) { return *this; } for (int64 i = 1; i < tuple_elements_.size(); ++i) { if (tuple_elements_[0] != tuple_elements_[i]) { - return tensorflow::gtl::optional<HloSharding>(); + return absl::nullopt; } } return tuple_elements_.front(); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 894783e5d1..be51c3f55b 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -151,7 +151,7 @@ class HloSharding { // span a single device, the return value will be empty. // In order for a sharding to span a single device, every leaf sharding must // be maximal and not replicated, and the used device must match. - tensorflow::gtl::optional<int64> UniqueDevice() const; + absl::optional<int64> UniqueDevice() const; // Retrieves the unique device or fails with a CHECK. int64 GetUniqueDevice() const; @@ -182,7 +182,7 @@ class HloSharding { // be returned. If it is a tuple, and all the tuple elements are common, the // common element will be returned. Otherwise the optional will contain no // value. - tensorflow::gtl::optional<HloSharding> ExtractSingleSharding() const; + absl::optional<HloSharding> ExtractSingleSharding() const; bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && @@ -260,9 +260,9 @@ class HloSharding { bool maximal_; bool tuple_; Array<int64> tile_assignment_; - // Only non-empty when tuple_ is true, but because empty tuples are allowed - // may also be empty even then. This is a flattened list of all the leaf - // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + // Only non-empty when tuple_ is true. If a tuple is empty then one entry is + // present for the root. This is a flattened list of all the leaf shardings in + // a tuple shape, by pre-order walk (ShapeTree iterator order). std::vector<HloSharding> tuple_elements_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index a2c1d39d0d..6e9b96488c 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -23,6 +24,23 @@ namespace xla { namespace { +// AssignmentKind and kUnassignedDevice are used during tuple domain sharding +// propagation in order to distinguish among three cases: +// kUnassigned: no assignment has occurred +// kAssigned: at least an assignment has occurred +// kConflict: no assignment has occurred because of conflicting propagations, +// which occurs when multiple users of an instruction have different +// shardings. +enum class AssignmentKind { kUnassigned, kAssigned, kConflict }; + +// kUnassignedDevice can only be assigned to tuple leaf shardings to indicate +// absence of sharding information for that particular sub-sharding during +// sharding propagation. It is used to be able to express tuple shardings with +// partial information. At the end of the propagation the sharding of +// tuple-shaped instructions using kUnassignedDevice's is cleared. +// TODO(b/112883246): Centralized enum of reserved devices. +constexpr int64 kUnassignedDevice = -2; + struct PassThrough { PassThrough(HloInstruction* user, HloInstruction* operand) : user(user), operand(operand) {} @@ -117,13 +135,17 @@ Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, return Status::OK(); } -std::unique_ptr<HloSharding> CloneShardingForDomain( - const HloSharding& sharding) { - auto single_sharding = sharding.ExtractSingleSharding(); +// For tuple shardings if every element have the same sharsing then we want to +// treat them as single element sharsings to insert less domain separation as a +// domain can prevent some optimizations and we want to minimize that from +// happening. +std::shared_ptr<const HloSharding> CloneShardingForDomain( + std::shared_ptr<const HloSharding> sharding) { + auto single_sharding = sharding->ExtractSingleSharding(); if (!single_sharding) { - return MakeUnique<HloSharding>(sharding); + return sharding; } - return MakeUnique<HloSharding>(*single_sharding); + return std::make_shared<const HloSharding>(*single_sharding); } Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, @@ -142,108 +164,174 @@ Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, return Status::OK(); } -// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. -// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() -// sharding will be returned. -ShapeTree<HloSharding> GetTupleSharding(HloInstruction* tuple) { - if (tuple->has_sharding()) { - return tuple->sharding().GetAsShapeTree(tuple->shape()); +// Return the ShapeTree<HloSharding> of the user argument. The user argument +// is assumed to be a user of the instruction argument. +// If user is a tuple instruction, return the tuple subsharding corresponding to +// the operand matching the instruction argument, because that is the +// subsharding corresponding to instruction. +ShapeTree<HloSharding> GetShardingTreeFromUser( + const HloInstruction& instruction, const HloInstruction& user) { + if (user.opcode() == HloOpcode::kTuple) { + return user.sharding() + .GetSubSharding(user.shape(), {user.operand_index(&instruction)}) + .GetAsShapeTree(instruction.shape()); + } + return user.sharding().GetAsShapeTree(user.shape()); +} + +// Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice) +// then no assignment is made. Therefore kUnassignedDevice is never propagated. +// kConflict is returned if lhs is already assigned and rhs is assigned to a +// different device. +StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs, + const HloSharding& rhs) { + TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple()); + if (rhs.UsesDevice(kUnassignedDevice)) { + return AssignmentKind::kUnassigned; + } + if (lhs->UsesDevice(kUnassignedDevice)) { + *lhs = rhs; + return AssignmentKind::kAssigned; + } + return lhs->UniqueDevice() != rhs.UniqueDevice() + ? AssignmentKind::kConflict + : AssignmentKind::kUnassigned; +} + +// Assigns the whole rhs tree to lhs_tree, starting at lhs_it. +// In case of conflicting assignment AssignmentKind::kConflict is returned. In +// this case lhs_tree is partially assigned, up to the conflicting leaf. It is +// up to the caller to discard the partial assignment in case of conflict. +StatusOr<AssignmentKind> AssignTreeSharding( + ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it, + const ShapeTree<HloSharding>& rhs_tree) { + AssignmentKind assigned = AssignmentKind::kUnassigned; + auto rhs_it = rhs_tree.begin(); + for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end(); + ++lhs_it, ++rhs_it) { + // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it) + if (rhs_tree.IsLeaf(rhs_it->first)) { + TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first)); + TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned, + AssignLeafSharding(&lhs_it->second, rhs_it->second)); + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we return conflict to the caller. At this point + // partial assignments to lhs_tree may have been made already. It is up + // to the caller to discard the partial assignment in case of conflict. + return AssignmentKind::kConflict; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } } - return ShapeTree<HloSharding>(tuple->shape(), HloSharding::Replicate()); + TF_RET_CHECK(rhs_it == rhs_tree.end()); + return assigned; } -// Retrieves the sharding of operand, asked from a user instruction which is -// within domain. If operand is a kDomain, it means that sharding argument is -// the operand sharding, otherwise the operand's own sharding will be returned. -const HloSharding* GetOperandSharding(const HloInstruction* operand, +StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction, const DomainMetadata::Domain& domain, - const HloSharding& sharding) { - // Here the user of operand is within the domain instruction set, and since it - // is user of operand, we need to look into the enter_domains set. If this is - // not a kDomain within the user domains set, then return the operand - // sharding, if any. - if (operand->opcode() != HloOpcode::kDomain || - domain.enter_domains.count(const_cast<HloInstruction*>(operand)) == 0) { - return operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding& domain_sharding) { + if (instruction->users().empty()) { + // No sharding from users, use domain_sharding, after checking + // compatibility. + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()) && + ShapeUtil::GetLeafCount(instruction->shape()) == + domain_sharding.tuple_elements().size()); + instruction->set_sharding(domain_sharding); + return true; + } + AssignmentKind assigned = AssignmentKind::kUnassigned; + // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple + // subshardings can result in a final sharding assignment containing + // kUnassignedDevice leaves, in case some tuple indexes are not used, or are + // used by users that don't have a sharding. + // Non-tuple shardings are either assigned to a real sharding, or are not + // assigned at all. As such they will never get assigned to kUnassignedDevice. + // In any case, kUnassignedDevice is never propagated, from the implementation + // of AssignLeafSharding. + ShapeTree<HloSharding> sharding_tree( + instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice)); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(const_cast<HloInstruction*>(user)) > 0) { + // If a user is a domain and it is registered in the domain exits, then + // the instruction sharding is taken directly from the domain, and no + // further users need to be visited. + instruction->set_sharding(domain_sharding); + return true; + } + if (!user->has_sharding()) { + continue; + } + AssignmentKind sub_assigned = AssignmentKind::kUnassigned; + ShapeTree<HloSharding> user_sharding_tree = + GetShardingTreeFromUser(*instruction, *user); + if (ShapeUtil::IsTuple(instruction->shape())) { + // For tuple-shaped instructions collect individual tuple subshardings + // from the uses, and then combine them into the tuple sharding. + // If the user is a GTE its sharding concerns only the subtree of + // sharding_tree at index user->tuple_index, otherwise the whole + // sharding_tree is affected. + ShapeTree<HloSharding>::iterator sharding_tree_begin = + user->opcode() == HloOpcode::kGetTupleElement + ? sharding_tree.find({user->tuple_index()}) + : sharding_tree.begin(); + TF_ASSIGN_OR_RETURN( + sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin, + user_sharding_tree)); + } else { + // Non-tuple shape: assign common users sharding. + TF_RET_CHECK(user_sharding_tree.leaf_count() == 1) + << "Expected non-tuple user sharding"; + TF_ASSIGN_OR_RETURN( + sub_assigned, + AssignTreeSharding(&sharding_tree, sharding_tree.begin(), + user_sharding_tree)); + } + + if (sub_assigned == AssignmentKind::kConflict) { + // In case of conflict we don't assign any sharding. + return false; + } else if (sub_assigned == AssignmentKind::kAssigned) { + assigned = sub_assigned; + } + } + + if (assigned == AssignmentKind::kAssigned) { + if (ShapeUtil::IsTuple(instruction->shape())) { + instruction->set_sharding(HloSharding::Tuple(sharding_tree)); + } else { + TF_RET_CHECK(sharding_tree.leaf_count() == 1); + instruction->set_sharding(sharding_tree.leaf_begin()->second); + } + return true; } - // At this point operand is a kDomain of the currently processed domain, so we - // can refer to sharding as the domain sharding. - return &sharding; + return false; } // Tries to propagate the sharding information into the instructions that are -// part of the domain, in a post order manner (operand propagate to user). +// part of the domain, in a reverse post order manner (users propoagate to +// instruction). StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain, - const HloSharding& sharding) { + const HloSharding& domain_sharding) { int64 assigned = 0; - for (HloInstruction* instruction : domain.instructions) { + // domain.instructions are ordered in a post-order manner. As we do + // user->operand propagation we process instructions in reverse order. In so + // doing we are guaranteed to process all users before their operands. + for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend(); + ++it) { + HloInstruction* instruction = *it; if (instruction->has_sharding()) { continue; } - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - HloInstruction* tuple = instruction->mutable_operand(0); - const HloSharding* tuple_sharding = - GetOperandSharding(tuple, domain, sharding); - if (tuple_sharding != nullptr) { - if (tuple_sharding->IsTuple()) { - HloSharding sub_sharding = tuple_sharding->GetSubSharding( - tuple->shape(), {instruction->tuple_index()}); - VLOG(4) << " " << instruction->name() << " to sharding " - << sub_sharding; - instruction->set_sharding(sub_sharding); - } else { - SetSingleSharding(instruction, *tuple_sharding); - } - ++assigned; - } - } else if (instruction->opcode() == HloOpcode::kTuple) { - int64 tuple_assigned = 0; - ShapeTree<HloSharding> shape_tree = GetTupleSharding(instruction); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const HloSharding* operand_sharding = - GetOperandSharding(instruction->operand(i), domain, sharding); - if (operand_sharding != nullptr) { - HloSharding operand_subsharding = HloSharding::Replicate(); - if (operand_sharding == &sharding) { - operand_subsharding = - sharding.GetSubSharding(instruction->shape(), {i}); - operand_sharding = &operand_subsharding; - } - if (shape_tree.element({i}) != *operand_sharding) { - *shape_tree.mutable_element({i}) = *operand_sharding; - ++tuple_assigned; - } - } - } - if (tuple_assigned > 0) { - HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); - VLOG(4) << " " << instruction->name() << " to sharding " - << tuple_sharding; - instruction->set_sharding(tuple_sharding); - ++assigned; - } - } else { - // If all the operand of the given instruction has the same single device - // assignment, assign that device to this instruction as well. - const HloSharding* common_sharding = nullptr; - for (const HloInstruction* operand : instruction->operands()) { - const HloSharding* operand_sharding = - GetOperandSharding(operand, domain, sharding); - if (operand_sharding != nullptr) { - if (common_sharding != nullptr && - *common_sharding != *operand_sharding) { - common_sharding = nullptr; - break; - } - common_sharding = operand_sharding; - } - } - if (common_sharding != nullptr) { - VLOG(4) << " " << instruction->name() << " to sharding " - << *common_sharding; - instruction->set_sharding(*common_sharding); - ++assigned; - } + // Take the sharding from the users. + TF_ASSIGN_OR_RETURN( + bool instruction_assigned, + ApplyShardingFromUsers(instruction, domain, domain_sharding)); + if (instruction_assigned) { + ++assigned; + VLOG(4) << " " << instruction->name() << " to sharding " + << instruction->sharding(); } } return assigned; @@ -261,83 +349,40 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, return ApplyDomainSingleSharding(domain, *single_sharding); } VLOG(1) << "Assigning non-trivial sharding " << sharding; - for (;;) { - TF_ASSIGN_OR_RETURN(int64 assigned, - ApplyDomainShardingPass(domain, sharding)); - if (assigned == 0) { - break; - } - } + TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status()); + int64 unassigned = 0; for (HloInstruction* instruction : domain.instructions) { if (!instruction->has_sharding()) { LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); ++unassigned; + } else { + // Un-set sharding of tuples whose sub-sgardings are assigned to + // kUnassignedDevice. Indeed in case of doubt it is better to leave the + // entire tuple unassigned, and let the device placer decide for it. + if (instruction->sharding().UsesDevice(kUnassignedDevice)) { + TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())) + << "Only tuples can have kUnassignedDevice sub shardings"; + instruction->clear_sharding(); + } } } // Should we error out if unassigned > 0? return Status::OK(); } -// Creates a kDomain instruction to be placed between instruction and operand. -// The kDomain instruction will be created only if the sharding differ between -// the instruction and the operand. -std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, - HloInstruction* operand) { - const HloSharding* instruction_sharding = - instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* operand_sharding = - operand->has_sharding() ? &operand->sharding() : nullptr; - // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && operand_sharding == nullptr) { - return nullptr; - } - // No need for domain if they match. - if (instruction_sharding != nullptr && operand_sharding != nullptr && - ShardingMatches(*instruction_sharding, *operand_sharding)) { - return nullptr; - } - std::unique_ptr<HloSharding> real_instruction_sharding; - std::unique_ptr<HloSharding> real_operand_sharding; - if (instruction_sharding != nullptr) { - real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); - } - if (operand_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*operand_sharding); - } - VLOG(3) << "Creating domain:"; - VLOG(3) << " Instruction: " << instruction->name(); - VLOG(3) << " Operand: " << operand->name(); - VLOG(3) << " User side sharding: " - << (real_instruction_sharding != nullptr - ? real_instruction_sharding->ToString() - : "None"); - VLOG(3) << " Operand side sharding: " - << (real_operand_sharding != nullptr - ? real_operand_sharding->ToString() - : "None"); - - std::unique_ptr<DomainMetadata> operand_side_metadata = - MakeUnique<ShardingMetadata>(std::move(real_operand_sharding)); - std::unique_ptr<DomainMetadata> user_side_metadata = - MakeUnique<ShardingMetadata>(std::move(real_instruction_sharding)); - return HloInstruction::CreateDomain(operand->shape(), operand, - std::move(operand_side_metadata), - std::move(user_side_metadata)); -} - -StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding( +StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding( tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { // If we are here, all the instructions being passed had the same sharding // (or no sharding), by the means of the ShardingMatches() API. // As such, no kDomain was inserted, and here we are asked to extract the // original common sharding. // All the instructions passed to this API are part of the same computation. - const HloSharding* sharding = nullptr; + std::shared_ptr<const HloSharding> sharding; for (HloInstruction* instruction : instructions) { if (instruction->has_sharding()) { if (sharding == nullptr) { - sharding = &instruction->sharding(); + sharding = instruction->sharding_ptr(); } else { TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) << "Sharding " << *sharding << " does not match the one in " @@ -346,10 +391,10 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding( } } if (sharding == nullptr) { - return std::unique_ptr<HloSharding>(); + return std::shared_ptr<const HloSharding>(); } VLOG(4) << "Extracted sharding is " << *sharding; - return CloneShardingForDomain(*sharding); + return CloneShardingForDomain(sharding); } } // namespace @@ -357,9 +402,9 @@ StatusOr<std::unique_ptr<HloSharding>> ExtractOriginalCommonSharding( std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const { std::unique_ptr<HloSharding> sharding; if (sharding_ != nullptr) { - sharding = MakeUnique<HloSharding>(*sharding_); + sharding = absl::make_unique<HloSharding>(*sharding_); } - return MakeUnique<ShardingMetadata>(std::move(sharding)); + return absl::make_unique<ShardingMetadata>(std::move(sharding)); } bool ShardingMetadata::Matches(const DomainMetadata& other) const { @@ -403,7 +448,7 @@ Status ShardingMetadata::NormalizeShardingDomain( TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding)); } } else { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSharding> sharding, + TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding, ExtractOriginalCommonSharding(domain.instructions)); if (sharding != nullptr) { VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString(); @@ -415,9 +460,75 @@ Status ShardingMetadata::NormalizeShardingDomain( return Status::OK(); } -std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand) { - return CreateDomain(instruction, operand); +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction, + HloInstruction* root, + HloInstruction* operand) { + auto instruction_sharding = instruction->sharding_ptr(); + auto root_sharding = root->sharding_ptr(); + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && root_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { + return nullptr; + } + + if (instruction_sharding != nullptr) { + instruction_sharding = CloneShardingForDomain(instruction_sharding); + } + if (root_sharding != nullptr) { + root_sharding = CloneShardingForDomain(root_sharding); + } + + auto it = domain_cse_map_.find({operand, instruction_sharding}); + if (it != domain_cse_map_.end()) { + return it->second; + } + + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (instruction_sharding != nullptr ? instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (root_sharding != nullptr ? root_sharding->ToString() : "None"); + + HloInstruction* domain = + operand->parent()->AddInstruction(HloInstruction::CreateDomain( + operand->shape(), operand, + absl::make_unique<ShardingMetadata>(root_sharding), + absl::make_unique<ShardingMetadata>(instruction_sharding))); + domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding}, + domain); + return domain; +} + +bool ShardingDomainCreator::DomainCseMapKey::operator==( + const ShardingDomainCreator::DomainCseMapKey& other) const { + if (instruction != other.instruction) { + return false; + } + if (sharding == nullptr && other.sharding == nullptr) { + return true; + } + if (sharding == nullptr || other.sharding == nullptr) { + return false; + } + return *sharding == *other.sharding; +} + +size_t ShardingDomainCreator::DomainCseMapHasher::operator()( + const ShardingDomainCreator::DomainCseMapKey& key) const { + return tensorflow::Hash64Combine( + std::hash<const HloInstruction*>{}(key.instruction), + key.sharding ? key.sharding->Hash() + : static_cast<size_t>(0x297814aaad196e6dULL)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 5e01fc0e22..7a6b0d9abc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -27,12 +27,12 @@ namespace xla { // A DomainMetadata implementation that internally wraps a sharding attribute. class ShardingMetadata : public DomainMetadata { public: - explicit ShardingMetadata(std::unique_ptr<HloSharding> sharding) + explicit ShardingMetadata(std::shared_ptr<const HloSharding> sharding) : sharding_(std::move(sharding)) {} std::unique_ptr<DomainMetadata> Clone() const override; - tensorflow::StringPiece Kind() const override { return KindName(); } + absl::string_view Kind() const override { return KindName(); } bool Matches(const DomainMetadata& other) const override; @@ -40,7 +40,7 @@ class ShardingMetadata : public DomainMetadata { const HloSharding* sharding() const { return sharding_.get(); } - static tensorflow::StringPiece KindName() { return "sharding"; } + static absl::string_view KindName() { return "sharding"; } static StatusOr<const ShardingMetadata*> ToShardingMetadata( const DomainMetadata* metadata); @@ -55,15 +55,33 @@ class ShardingMetadata : public DomainMetadata { const DomainMetadata* metadata); private: - std::unique_ptr<HloSharding> sharding_; + std::shared_ptr<const HloSharding> sharding_; }; -// Given an HLO graph edge between instruction and one of its operands, creates -// a ShardingMetadata based kDomain instruction if the sharding between -// instruction and operand changes. Returns nullptr if there is no need for a -// domain separation. -std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand); +// If the sharding between root and instruction changes then returns a +// ShardingMetadata based kDomain instruction what can be used to separate +// operand and instruction. +// Returns nullptr if there is no need for a domain separation. +class ShardingDomainCreator { + public: + HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand); + + private: + // Map from instruction and user sharding to domain users to CSE identical + // domains. + struct DomainCseMapKey { + const HloInstruction* instruction; + std::shared_ptr<const HloSharding> sharding; + + bool operator==(const DomainCseMapKey& other) const; + }; + struct DomainCseMapHasher { + size_t operator()(const DomainCseMapKey& key) const; + }; + std::unordered_map<DomainCseMapKey, HloInstruction*, DomainCseMapHasher> + domain_cse_map_; +}; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 45fc300fca..2341f8ada0 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -115,6 +115,13 @@ TEST_F(HloShardingTest, Tile) { } } +// Tests that empty tuple is supported. +TEST_F(HloShardingTest, EmptySingleTuple) { + HloSharding sharding = HloSharding::SingleTuple(ShapeUtil::MakeTupleShape({}), + HloSharding::AssignDevice(0)); + EXPECT_TRUE(sharding.ExtractSingleSharding()); +} + TEST_F(HloShardingTest, NestedTuple) { // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index 2ef38821af..d1cf644f82 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -24,7 +24,7 @@ namespace xla { // one arbitrarily to use and delete the others. class HloSubcomputationUnification : public HloPassInterface { public: - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "subcomputation-unification"; } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index b78bfa0cdf..4876533449 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -21,28 +23,25 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" - -using ::tensorflow::GraphDef; -using ::tensorflow::NodeDef; -using ::tensorflow::TensorShapeProto; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; -using ::tensorflow::str_util::Join; namespace xla { namespace hlo_graph_dumper { namespace { +using absl::StrAppend; +using absl::StrCat; +using tensorflow::GraphDef; +using tensorflow::NodeDef; +using tensorflow::TensorShapeProto; + string GetOpDefName(const HloInstruction* instruction) { string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); - tensorflow::str_util::TitlecaseString(&name, "-"); + tensorflow::str_util::TitlecaseString(&name, "-"); // non-absl ok name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); if (instruction->opcode() == HloOpcode::kFusion) { string fusion_name = ToString(instruction->fusion_kind()); - StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + StrAppend(&name, absl::string_view(fusion_name).substr(1)); } return name; } @@ -166,7 +165,9 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); + "{", + absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","), + "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 7fd99fc930..e0c1326177 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,8 +18,10 @@ limitations under the License. #include <algorithm> #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -30,16 +32,13 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; +using absl::StrAppend; +using absl::StrCat; const Shape& HloPosition::shape() const { return ShapeUtil::GetSubshape(instruction->shape(), index); @@ -216,10 +215,11 @@ void HloValueSet::SortAndUniquifyValues() { } string HloValueSet::ToString() const { - return StrCat("HloValueSet: ", - Join(values_, ", ", [](string* result, const HloValue* value) { - result->append(value->ToShortString()); - })); + return StrCat( + "HloValueSet: ", + absl::StrJoin(values_, ", ", [](string* result, const HloValue* value) { + result->append(value->ToShortString()); + })); } bool HloValueSet::AssignUnionOf( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index ac1a663633..f1b29c2559 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,11 +15,13 @@ limitations under the License. #include <set> +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -115,6 +117,11 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { ShapeInference::InferAllToAllTupleShape(operand_shapes)); } +Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( + hlo->operand(0)->shape())); +} + Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), @@ -122,39 +129,32 @@ Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { reduce_precision->mantissa_bits())); } -namespace { - -Status CheckIsTokenOperand(const HloInstruction* instruction, - int64 operand_no) { +Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no) { const HloInstruction* token = instruction->operand(operand_no); if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) { return InternalError( - "Expected operand %lld to be token-shaped, actual shape is " + "Expected operand %d to be token-shaped, actual shape is " "%s:\n%s", - operand_no, ShapeUtil::HumanString(token->shape()).c_str(), - instruction->ToString().c_str()); + operand_no, StringifyShape(token->shape()), instruction->ToString()); } return Status::OK(); } -Status CheckOperandAndParameter(const HloInstruction* instruction, - int64 operand_number, - const HloComputation* computation, - int64 parameter_number) { +Status ShapeVerifier::CheckOperandAndParameter( + const HloInstruction* instruction, int64 operand_number, + const HloComputation* computation, int64 parameter_number) { const HloInstruction* operand = instruction->operand(operand_number); const HloInstruction* parameter = computation->parameter_instruction(parameter_number); - if (!ShapeUtil::Compatible(operand->shape(), parameter->shape())) { + if (!ShapesSame(operand->shape(), parameter->shape())) { return InternalError("Operand %s shape does not match parameter's %s in %s", - operand->ToString().c_str(), - parameter->ToString().c_str(), - instruction->ToString().c_str()); + operand->ToString(), parameter->ToString(), + instruction->ToString()); } return Status::OK(); } -} // namespace - Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -171,22 +171,16 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { // Outfeed has a separate shape field for the value which is outfed to the // host. The shape of the instruction itself is always a token. - if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), - outfeed->operand(0)->shape())) { + if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) { return InternalError( - "Expected outfeed shape to be compatible with operand's shape %s, " + "Expected outfeed shape to be equal to operand's shape %s, " "actual shape is %s:\n%s", - ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), - outfeed->ToString().c_str()); + StringifyShape(outfeed->operand(0)->shape()), + StringifyShape(outfeed->outfeed_shape()), outfeed->ToString()); } return CheckShape(outfeed, ShapeUtil::MakeTokenShape()); } -Status ShapeVerifier::HandleHostCompute(HloInstruction*) { - return Status::OK(); -} - bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape) { @@ -200,7 +194,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (instruction->operand_count() != 2) { return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } const Shape& shape_0 = instruction->operand(0)->shape(); @@ -208,14 +202,14 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) { return InternalError( "Expected scalar types for the two operands of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) { return InternalError( "Expected compatible element types for the result and the two operands" " of Rng instruction: %s", - instruction->ToString().c_str()); + instruction->ToString()); } PrimitiveType element_type = shape_0.element_type(); @@ -228,7 +222,7 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { "Element type not supported." " Expected element to be of floating point type, integral type or" " predicate type for RngUniform: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; @@ -237,13 +231,13 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { return InternalError( "Element type not supported." " Expected element to be FloatingPointType for RngNormal: %s", - instruction->ToString().c_str()); + instruction->ToString()); } break; default: return InternalError( "Invalid Rng distribution %s", - RandomDistribution_Name(instruction->random_distribution()).c_str()); + RandomDistribution_Name(instruction->random_distribution())); } return Status::OK(); @@ -262,8 +256,8 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { return InternalError( "Expected sort to have to have the same dimensions for the keys and " "the values. Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(sort->operand(0)->shape()).c_str(), - ShapeUtil::HumanString(sort->operand(1)->shape()).c_str()); + StringifyShape(sort->operand(0)->shape()), + StringifyShape(sort->operand(1)->shape())); } return CheckVariadicShape(sort); } @@ -272,10 +266,18 @@ Status ShapeVerifier::HandleConstant(HloInstruction* constant) { return CheckShape(constant, constant->literal().shape()); } -Status ShapeVerifier::HandleIota(HloInstruction* iota) { - return ShapeUtil::Rank(iota->shape()) == 1 - ? Status::OK() - : InternalError("Iota only supports arrays of rank 1."); +Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + auto* iota = Cast<HloIotaInstruction>(instruction); + const int64 rank = ShapeUtil::Rank(iota->shape()); + if (rank == 0) { + return InternalError("Iota does not support scalars."); + } + int64 iota_dimension = iota->iota_dimension(); + if (iota_dimension >= rank) { + return InternalError( + "The iota dimension cannot go beyond the operation rank."); + } + return Status::OK(); } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { @@ -337,7 +339,18 @@ Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { return Status::OK(); } -Status ShapeVerifier::HandleFusion(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleFusion(HloInstruction* fusion) { + for (HloInstruction* fused_param : fusion->fused_parameters()) { + int64 param_no = fused_param->parameter_number(); + if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) { + return InternalError( + "Shape mismatch between parameter number %d and its operand in " + "%s.", + param_no, fusion->ToString().c_str()); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) { @@ -419,12 +432,11 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0)); const Shape& conditional_shape = xla_while->while_condition()->root_instruction()->shape(); - if (!ShapeUtil::Compatible(conditional_shape, - ShapeUtil::MakeShape(PRED, {}))) { + if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) { return InternalError( "Conditional computation shape does not lead to a scalar predicate " "shape: %s", - ShapeUtil::HumanString(conditional_shape).c_str()); + StringifyShape(conditional_shape)); } // The shape of kWhile should match the shape of the body computation it // calls. @@ -555,7 +567,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { return InternalError( "Seen floating point types of different precisions in " "%s, but mixed precision is disallowed.", - instruction->ToString().c_str()); + instruction->ToString()); } return Status::OK(); })); @@ -602,53 +614,51 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } // Check if the output shape matches the expected shape. - bool compatible; + // // We treat BF16 and F32 as compatible types if mixed precision is allowed, // but only when the instruction defines the BF16/F32 buffer. - switch (instruction->opcode()) { - case HloOpcode::kTupleSelect: - // TupleSelect only defines the top-level buffer, which in this case is - // the tuple, so we cannot allow mixed precision. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - case HloOpcode::kGetTupleElement: - case HloOpcode::kTuple: - // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed - // precision is disallowed. - case HloOpcode::kConstant: - case HloOpcode::kBitcast: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kConvert: - case HloOpcode::kCustomCall: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kWhile: - // The above opcodes should match the expected shapes exactly. - compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); - break; - default: - if (allow_mixed_precision_) { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision( - instruction->shape(), inferred_shape); - } else { - compatible = - ShapeUtil::Compatible(instruction->shape(), inferred_shape); - } - } - if (!compatible) { + bool equal = [&] { + switch (instruction->opcode()) { + // The opcodes below can't have implicit layout conversions, nor can they + // implicitly transform f32 -> bf16. Fundamentally these are either + // reinterpreting existing data (e.g. kBitcast) or shuffling data around + // without modifying it (e.g. kGetTupleElement, kTupleSelect). + case HloOpcode::kBitcast: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return ShapesSame(instruction->shape(), inferred_shape); + + // We allow arbitrary layout and f32->bf16 transformations on all other + // instructions, although this may be made more strict pending discussion + // in b/112709536. + default: + if (allow_mixed_precision_) { + return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(), + inferred_shape); + } else { + return ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + }(); + if (!equal) { return InternalError( - "Expected instruction to have shape compatible with %s, actual " + "Expected instruction to have shape equal to %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(inferred_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); + StringifyShape(inferred_shape), StringifyShape(instruction->shape()), + instruction->ToString()); } return Status::OK(); } @@ -692,10 +702,10 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { string ComputationsToString( tensorflow::gtl::ArraySlice<HloComputation*> computations) { - return tensorflow::str_util::Join( - computations, ",", [](string* s, const HloComputation* computation) { - s->append(computation->name()); - }); + return absl::StrJoin(computations, ",", + [](string* s, const HloComputation* computation) { + s->append(computation->name()); + }); } // Verifies various invariants about the structure of the HLO: @@ -713,23 +723,23 @@ Status VerifyHloStructure(HloModule* module) { for (const HloComputation* computation : module->computations()) { if (computation->parent() == nullptr) { return InternalError("Computation %s has a null parent pointer", - computation->name().c_str()); + computation->name()); } if (computation->parent() != module) { return InternalError( "Computation %s parent() does not point to parent module", - computation->name().c_str()); + computation->name()); } for (const HloInstruction* instruction : computation->instructions()) { if (instruction->parent() == nullptr) { return InternalError("Instruction %s has a null parent pointer", - instruction->name().c_str()); + instruction->name()); } if (instruction->parent() != computation) { return InternalError( "Instruction %s parent() does not point to parent computation", - instruction->name().c_str()); + instruction->name()); } } } @@ -746,9 +756,8 @@ Status VerifyHloStructure(HloModule* module) { return InternalError( "Operand %d (%s) of instruction %s is in a different " "computation: %s vs %s", - i, operand->name().c_str(), instruction->name().c_str(), - operand->parent()->name().c_str(), - instruction->parent()->name().c_str()); + i, operand->name(), instruction->name(), + operand->parent()->name(), instruction->parent()->name()); } } } @@ -764,7 +773,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { "Instruction of fused computation does not match expected " "instruction " "%s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Fused root instruction and fused parameters must all be owned by the @@ -778,7 +787,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_root == instruction) { if (root_owned) { return InternalError("Root appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } root_owned = true; } @@ -786,7 +795,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return InternalError("Parameter appears more than once in %s.", - fusion->ToString().c_str()); + fusion->ToString()); } parameter_owned[i] = true; } @@ -794,20 +803,19 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } if (!root_owned) { return InternalError("Root not found in computation of %s.", - fusion->ToString().c_str()); + fusion->ToString()); } // Make sure all the parameter_owned entries are set for (int i = 0; i < parameter_owned.size(); i++) { if (!parameter_owned[i]) { return InternalError("Parameter %d not found in computation of %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } // Fused root must have no users. if (fused_root->user_count() != 0) { - return InternalError("Root of %s may not have users.", - fusion->ToString().c_str()); + return InternalError("Root of %s may not have users.", fusion->ToString()); } // All uses of fused instructions must be in the fusion computation, and @@ -817,54 +825,46 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { if (instruction != fused_root) { if (instruction->user_count() == 0) { return InternalError("Non-root instruction %s in %s must have users.", - instruction->ToString().c_str(), - fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } for (auto& user : instruction->users()) { if (fused_computation != user->parent()) { return InternalError( "Non-root instruction %s in %s may not have external users.", - instruction->ToString().c_str(), fusion->ToString().c_str()); + instruction->ToString(), fusion->ToString()); } } } } // Fused parameter instructions must be numbered contiguously and match up - // (shapes compatible) with their respective operand. + // (shapes equal) with their respective operand. CHECK_EQ(fusion->operands().size(), fused_parameters.size()); std::vector<bool> parameter_numbers(fused_parameters.size(), false); for (auto fused_param : fused_parameters) { int64 param_no = fused_param->parameter_number(); if (param_no < 0) { - return InternalError("Unexpected negative parameter number %lld in %s.", - param_no, fusion->ToString().c_str()); + return InternalError("Unexpected negative parameter number %d in %s.", + param_no, fusion->ToString()); } if (param_no >= fused_parameters.size()) { return InternalError( - "Unexpected parameter number %lld in %s: higher then number of " + "Unexpected parameter number %d in %s: higher then number of " "parameters %lu.", - param_no, fusion->ToString().c_str(), fused_parameters.size()); + param_no, fusion->ToString(), fused_parameters.size()); } if (parameter_numbers[param_no]) { return InternalError( - "Did not expect parameter number %lld more than once in %s.", - param_no, fusion->ToString().c_str()); + "Did not expect parameter number %d more than once in %s.", param_no, + fusion->ToString()); } parameter_numbers[param_no] = true; - if (!ShapeUtil::Compatible(fused_param->shape(), - fusion->operand(param_no)->shape())) { - return InternalError( - "Shape mismatch between parameter number %lld and its operand in " - "%s.", - param_no, fusion->ToString().c_str()); - } } // Make sure all the parameter_numbers entries were seen. for (int i = 0; i < parameter_numbers.size(); i++) { if (!parameter_numbers[i]) { return InternalError("Did not see parameter number %d in %s.", i, - fusion->ToString().c_str()); + fusion->ToString()); } } @@ -879,18 +879,18 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { auto* while_body = instruction->while_body(); if (while_cond->num_parameters() != 1) { return FailedPrecondition( - "While condition must have exactly 1 parameter; had %lld : %s", - while_cond->num_parameters(), while_cond->ToString().c_str()); + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); } if (while_body->num_parameters() != 1) { return FailedPrecondition( - "While body must have exactly 1 parameter; had %lld : %s", - while_body->num_parameters(), while_body->ToString().c_str()); + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); } if (instruction->operand_count() != 1) { return FailedPrecondition( - "While loop must have exactly one operand; had %lld : %s", - instruction->operand_count(), instruction->ToString().c_str()); + "While loop must have exactly one operand; had %d : %s", + instruction->operand_count(), instruction->ToString()); } return Status::OK(); } @@ -898,16 +898,14 @@ Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { if (instruction->true_computation()->num_parameters() != 1) { return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %lld", - instruction->true_computation()->name().c_str(), - instruction->ToString().c_str(), + "True computation %s of %s must have 1 parameter insted of %d", + instruction->true_computation()->name(), instruction->ToString(), instruction->true_computation()->num_parameters()); } if (instruction->false_computation()->num_parameters() != 1) { return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %lld", - instruction->false_computation()->name().c_str(), - instruction->ToString().c_str(), + "False computation %s of %s must have 1 parameter insted of %d", + instruction->false_computation()->name(), instruction->ToString(), instruction->false_computation()->num_parameters()); } return Status::OK(); @@ -920,11 +918,11 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) { return FailedPrecondition( "Implicit broadcast is not allowed in HLO." - "Found non-compatible shapes for instruction %s.\n" + "Found different shapes for instruction %s.\n" "output: %s\noperand: %s\n", - HloOpcodeString(instruction->opcode()).c_str(), - ShapeUtil::HumanString(out_shape).c_str(), - ShapeUtil::HumanString(operand_shape).c_str()); + HloOpcodeString(instruction->opcode()), + ShapeUtil::HumanString(out_shape), + ShapeUtil::HumanString(operand_shape)); } } return Status::OK(); @@ -955,7 +953,7 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { if (ShapeContainsToken(param->shape())) { return InternalError( "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape()).c_str()); + ShapeUtil::HumanString(param->shape())); } } return Status::OK(); @@ -967,9 +965,9 @@ Status CheckSameChannel(const HloInstruction* instr1, if (instr1->channel_id() != instr2->channel_id()) { return InternalError( "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); } return Status::OK(); } @@ -990,7 +988,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1, "Expected instructions to have the same is-host-transfer property: " "%s, " "%s ", - instr1->ToString().c_str(), instr2->ToString().c_str()); + instr1->ToString(), instr2->ToString()); } return Status::OK(); } @@ -1007,12 +1005,12 @@ Status VerifySendsAndRecvs(const HloModule& module) { host_channels.insert({sendrecv->channel_id(), sendrecv}); if (!it_inserted.second) { return FailedPrecondition( - "Channel %lld is used for multiple host send/recv instructions: " + "Channel %d is used for multiple host send/recv instructions: " "%s " "and " "%s", - sendrecv->channel_id(), sendrecv->ToString().c_str(), - it_inserted.first->second->ToString().c_str()); + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); } } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index c942fab08e..42e3027bf1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/shape_inference.h" namespace xla { @@ -27,9 +28,9 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: - explicit ShapeVerifier() : allow_mixed_precision_(false) {} - explicit ShapeVerifier(bool allow_mixed_precision) - : allow_mixed_precision_(allow_mixed_precision) {} + explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision) + : layout_sensitive_(layout_sensitive), + allow_mixed_precision_(allow_mixed_precision) {} Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; @@ -46,6 +47,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; + Status HandleCollectivePermute(HloInstruction* hlo) override; Status HandleReducePrecision(HloInstruction* reduce_precision) override; Status HandleInfeed(HloInstruction*) override; Status HandleOutfeed(HloInstruction*) override; @@ -63,7 +65,6 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; - Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -106,13 +107,42 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckVariadicShape(const HloInstruction* instruction); private: - // Return true if the shapes of the two operands have the same element type, - // and the result shape either has the same element type as the operand - // shapes or mixed precision is allowed and the result shape and the operand - // shapes have floating point element types. + // Helpers that switch on layout_sensitive_. + bool ShapesSame(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::Equal(a, b) + : ShapeUtil::Compatible(a, b); + } + bool ShapesSameIgnoringFpPrecision(const Shape& a, const Shape& b) { + return layout_sensitive_ ? ShapeUtil::EqualIgnoringFpPrecision(a, b) + : ShapeUtil::CompatibleIgnoringFpPrecision(a, b); + } + string StringifyShape(const Shape& s) { + return layout_sensitive_ ? ShapeUtil::HumanStringWithLayout(s) + : ShapeUtil::HumanString(s); + } + + // Checks that the given operand of the given instruction is of type TOKEN. + Status CheckIsTokenOperand(const HloInstruction* instruction, + int64 operand_no); + + // Checks that the shape of the given operand of the given instruction matches + // the given parameter of the given computation. + Status CheckOperandAndParameter(const HloInstruction* instruction, + int64 operand_number, + const HloComputation* computation, + int64 parameter_number); + + // Returns true if the shapes of the two operands have the same element type, + // and the result shape either has the same element type as the operand shapes + // or mixed precision is allowed and the result shape and the operand shapes + // have floating point element types. bool HasCompatibleElementTypes(const Shape& shape_0, const Shape& shape_1, const Shape& result_shape); + // If the verifier is layout-sensitive, shapes must be equal to what's + // expected. Otherwise, the shapes must simply be compatible. + bool layout_sensitive_; + // Whether the inputs and output of an instruction can contain both F32s and // BF16s. Tuples that include both F32s and BF16s are allowed regardless of // this flag. @@ -125,14 +155,10 @@ class HloVerifier : public HloPassInterface { public: using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; - // Uses standard shape inference. - explicit HloVerifier() - : shape_verifier_factory_( - [] { return MakeUnique<ShapeVerifier>(false); }) {} - - explicit HloVerifier(bool allow_mixed_precision) - : shape_verifier_factory_([allow_mixed_precision] { - return MakeUnique<ShapeVerifier>(allow_mixed_precision); + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { + return absl::make_unique<ShapeVerifier>(layout_sensitive, + allow_mixed_precision); }) {} // Uses custom shape verification. @@ -140,10 +166,9 @@ class HloVerifier : public HloPassInterface { : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; - tensorflow::StringPiece name() const override { return "verifier"; } + absl::string_view name() const override { return "verifier"; } - // Note: always returns false (no instructions are ever modified by this - // pass). + // Never returns true; no instructions are ever modified by this pass. StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index d764964f3c..fc1f81bdd2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -37,13 +37,15 @@ using ::testing::HasSubstr; class HloVerifierTest : public HloTestBase { public: HloVerifierTest() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; class HloVerifierTestAllowMixedPrecision : public HloTestBase { public: HloVerifierTestAllowMixedPrecision() - : HloTestBase(/*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloTestBase(/*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; TEST_F(HloVerifierTest, NullInstructionParent) { @@ -275,5 +277,84 @@ TEST_F(HloVerifierTest, RngElementTypeNotSupported) { EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported")); } +TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(F32).CloneToUnique())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) { + // This testcase can't be written using textual HLO, because it doesn't parse + // negative interior padding. That's probably a feature. :) + HloComputation::Builder builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {100}), "param")); + PaddingConfig padding_config; + padding_config.add_dimensions()->set_interior_padding(-1); + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {100}), param, + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(F32).CloneToUnique())), + padding_config)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("Interior padding cannot be negative")); +} + +// Simple module containing a convolution as the root. +static const char* const kConvHloString = R"( +HloModule module +ENTRY entry_computation { + param0 = f16[128,128,56,56] parameter(0) + param1 = f16[3,3,128,128] parameter(1) + zero_f16 = f16[] constant(0) + ROOT conv = f16[128,128,28,28] convolution(param0, param1), + window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01 +})"; + +TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_window_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive window dilation factor")); +} + +TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kConvHloString)); + auto* conv = module->entry_computation()->root_instruction(); + Window w = conv->window(); + w.mutable_dimensions(0)->set_base_dilation(-1); + conv->set_window(w); + + EXPECT_THAT(verifier().Run(module.get()).status().error_message(), + HasSubstr("non-positive base area dilation factor")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index bb5b40a8a8..e76b93107c 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -14,27 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/metric_table_report.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -using tensorflow::strings::Appendf; +using absl::StrAppend; +using absl::StrAppendFormat; +using absl::StrCat; +using absl::StrFormat; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; string HumanReadableProfileBuilder::ToString() const { string s; - Appendf(&s, "Execution profile for %s: (%s @ f_nom)\n", - computation_name_.c_str(), - HumanReadableElapsedTime(CyclesToSeconds(total_cycles_)).c_str()); + StrAppendFormat(&s, "Execution profile for %s: (%s @ f_nom)\n", + computation_name_, + HumanReadableElapsedTime(CyclesToSeconds(total_cycles_))); int64 cumulative_cycles = 0; auto print_op = [&](const OpInfo& op, bool is_total = false) { @@ -56,7 +56,7 @@ string HumanReadableProfileBuilder::ToString() const { if (op.bytes_accessed > op.cycles) { bytes_per_cycle = StrCat(HumanReadableNumBytes(bpc), "/cycle"); } else { - bytes_per_cycle = Printf("%.3fB/cycle", bpc); + bytes_per_cycle = StrFormat("%.3fB/cycle", bpc); } } @@ -77,27 +77,24 @@ string HumanReadableProfileBuilder::ToString() const { // columns in the output. cycles_percent_str = "100.% 100Σ"; } else { - cycles_percent_str = - Printf("%5.2f%% %2.0fΣ", cycles_percent, cumulative_cycles_percent); + cycles_percent_str = StrFormat("%5.2f%% %2.0fΣ", cycles_percent, + cumulative_cycles_percent); } double nsecs = op.cycles / clock_rate_ghz_; - Appendf( + StrAppendFormat( &s, - "%15lld cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " + "%15d cycles (%s) :: %12.1f usec %22s :: %18s :: %18s :: %14s :: " "%16s :: %s\n", - op.cycles, cycles_percent_str.c_str(), CyclesToMicroseconds(op.cycles), + op.cycles, cycles_percent_str, CyclesToMicroseconds(op.cycles), op.optimal_seconds < 0 ? "" - : Printf("(%12.1f optimal)", op.optimal_seconds * 1e6).c_str(), - op.flop_count <= 0 - ? "" - : HumanReadableNumFlops(op.flop_count, nsecs).c_str(), + : StrFormat("(%12.1f optimal)", op.optimal_seconds * 1e6), + op.flop_count <= 0 ? "" : HumanReadableNumFlops(op.flop_count, nsecs), op.transcendental_count <= 0 ? "" - : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs) - .c_str(), - bytes_per_sec.c_str(), bytes_per_cycle.c_str(), op.name.c_str()); + : HumanReadableNumTranscendentalOps(op.transcendental_count, nsecs), + bytes_per_sec, bytes_per_cycle, op.name); }; float optimal_seconds_sum = 0.0; diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.h b/tensorflow/compiler/xla/service/human_readable_profile_builder.h index 6f56c3aa82..925111fa1f 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.h +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.h @@ -18,8 +18,8 @@ limitations under the License. #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -29,10 +29,10 @@ namespace xla { // computation, suitable for consumption by humans. class HumanReadableProfileBuilder { public: - explicit HumanReadableProfileBuilder(tensorflow::StringPiece computation_name, + explicit HumanReadableProfileBuilder(absl::string_view computation_name, int64 total_cycles, double clock_rate_ghz) - : computation_name_(std::string(computation_name)), + : computation_name_(computation_name), total_cycles_(total_cycles), clock_rate_ghz_(clock_rate_ghz) { CHECK_GE(clock_rate_ghz, 1e-9); @@ -43,15 +43,13 @@ class HumanReadableProfileBuilder { // Adds an operation to the profile. If you don't know the number of // floating-point ops or bytes touched by the op, or if you don't know how // fast it would run optimally, pass -1 for that param. - void AddOp(tensorflow::StringPiece op_name, - tensorflow::StringPiece short_name, - tensorflow::StringPiece category, int64 cycles, int64 flop_count, + void AddOp(absl::string_view op_name, absl::string_view short_name, + absl::string_view category, int64 cycles, int64 flop_count, int64 transcendental_count, int64 bytes_accessed, float optimal_seconds) { - op_infos_.push_back({std::string(op_name), std::string(short_name), - std::string(category), cycles, flop_count, - transcendental_count, bytes_accessed, - optimal_seconds}); + op_infos_.push_back({string(op_name), string(short_name), string(category), + cycles, flop_count, transcendental_count, + bytes_accessed, optimal_seconds}); } // Gets the human-readable profile. diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index aa325dc8a3..85bb4a8b24 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -30,7 +30,7 @@ class ImplicitBroadcastRemover : public HloPassInterface { ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "implicit-broadcast-remover"; } diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc index f85d31d522..df88587492 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -26,6 +26,11 @@ namespace xla { namespace { class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + public: + ImplicitBroadcastRemoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: ImplicitBroadcastRemover remover_; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 8d17c03afc..43ef30d1eb 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -14,13 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" + +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -31,32 +34,30 @@ using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ReshapedArray = Analysis::ReshapedArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using absl::StrJoin; using tensorflow::gtl::ArraySlice; -using tensorflow::str_util::Join; } // namespace string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as<UnknownArray>(); - return tensorflow::strings::StrCat("%", - unknown_tensor->instruction().name()); + return absl::StrCat("%", unknown_tensor->instruction().name()); } case Array::kConstant: { if (print_constants) { string contents = root->as<ConstantArray>()->literal()->ToString(); - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, - ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + " ", contents, ")"); } - return tensorflow::strings::StrCat( - "(constant ", ShapeUtil::HumanString(root->shape()), ")"); + return absl::StrCat("(constant ", ShapeUtil::HumanString(root->shape()), + ")"); } case Array::kReshaped: { ReshapedArray* reshaped_array = root->as<ReshapedArray>(); - return tensorflow::strings::StrCat( + return absl::StrCat( "(reshape ", ToString(reshaped_array->operand(), print_constants), " to ", ShapeUtil::HumanString(reshaped_array->shape()), ")"); } @@ -67,11 +68,11 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { string name = root->kind() == Array::kScalarIndexedConstant ? "scalar-indexed-const" : "scalar-indexed"; - return tensorflow::strings::StrCat( + return absl::StrCat( "(", name, " ", ToString(indexed_array->source(), print_constants), " ", ToString(indexed_array->indices(), print_constants), " ", indexed_array->source_dim(), "->[", - Join(indexed_array->output_dims(), ","), "])"); + StrJoin(indexed_array->output_dims(), ","), "])"); } } } @@ -92,7 +93,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( // Depth first search over the DAG, invoking ComputeArrayFor in post order. // The HLO instructions already in the cache are considered leaves. - gtl::InlinedVector<const HloInstruction*, 4> stack; + absl::InlinedVector<const HloInstruction*, 4> stack; enum DfsState { kDiscovered, kVisited }; gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map; @@ -290,13 +291,13 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather( int64 source_dim = dim_numbers.start_index_map(0); std::vector<int64> output_dims; for (int64 i = 0, e = shape.dimensions_size(); i < e; i++) { - if (!c_binary_search(dim_numbers.offset_dims(), i)) { + if (!absl::c_binary_search(dim_numbers.offset_dims(), i)) { output_dims.push_back(i); } } if (auto* indexed = dynamic_cast<ScalarIndexedArray*>(source)) { - if (c_linear_search(indexed->output_dims(), source_dim)) { + if (absl::c_linear_search(indexed->output_dims(), source_dim)) { return FoldGatherOfGather(indexed, indices, source_dim, output_dims, shape); } @@ -314,7 +315,7 @@ namespace { // [values.begin()+index, values.end()) is equal to `product`. If there is no // such index, return -1. All integers in `values` must be positive. int64 FindSuffixWithProduct(ArraySlice<int64> values, int64 product) { - DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + DCHECK(absl::c_all_of(values, [](int64 value) { return value > 0; })); int64 current_product = 1; int64 i; @@ -377,8 +378,8 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( CHECK_NE(candidate_operand_dim, 0) << "result_dim = " << result_dim << ", result_subarray_size = " << result_subarray_size - << ", result_shape = [" << Join(result_shape, ",") << "]" - << ", operand_shape = [" << Join(operand_shape, ",") << "]"; + << ", result_shape = [" << StrJoin(result_shape, ",") << "]" + << ", operand_shape = [" << StrJoin(operand_shape, ",") << "]"; if (candidate_operand_dim != -1 && result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { @@ -388,26 +389,27 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( result_subarray_size *= result_shape[result_dim]; } - c_reverse(result); + absl::c_reverse(result); if (VLOG_IS_ON(3)) { std::vector<string> result_strings; - c_transform(result, std::back_inserter(result_strings), - [](ReshapePassthroughDimPair value) { - return tensorflow::strings::StrCat(value.result_dim, "->", - value.operand_dim); - }); - VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" - << Join(result_shape, ",") << "] passthrough indices are [" - << Join(result_strings, ",") << "] (legend: `result`->`operand`)"; + absl::c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return absl::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << StrJoin(operand_shape, ",") << "] to [" + << StrJoin(result_shape, ",") << "] passthrough indices are [" + << StrJoin(result_strings, ",") + << "] (legend: `result`->`operand`)"; } - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.result_dim < rhs.result_dim; })); - DCHECK(c_is_sorted( + DCHECK(absl::c_is_sorted( result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { return lhs.operand_dim < rhs.operand_dim; })); @@ -419,20 +421,20 @@ std::vector<ReshapePassthroughDimPair> ComputeReshapePassthroughDimPairs( // `passthrough_dims`. bool IsReshapePassthroughOperandDim( ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 dim) { - return c_any_of(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == dim; - }); + return absl::c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); } // Maps `operand_dim` which must be an passthrough operand dimension to its // corresponding passthrough result dimension based on `passthrough_dims`. int64 MapPassthroughOperandDimToResultDim( ArraySlice<ReshapePassthroughDimPair> passthrough_dims, int64 operand_dim) { - auto it = c_find_if(passthrough_dims, - [&](ReshapePassthroughDimPair passthrough_dim_pair) { - return passthrough_dim_pair.operand_dim == operand_dim; - }); + auto it = absl::c_find_if( + passthrough_dims, [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); CHECK(it != passthrough_dims.end()); return it->result_dim; } @@ -441,7 +443,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape, ArraySlice<int64> result_shape, int64 source_passthrough_dim) { VLOG(3) << "FindSourcePositionForPassthroughResultDim([" - << Join(operand_shape, ",") << "], [" << Join(result_shape, ",") + << StrJoin(operand_shape, ",") << "], [" << StrJoin(result_shape, ",") << "], " << source_passthrough_dim << ")"; int64 indexed_source_subarray_size = @@ -453,8 +455,8 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape, Shape StripDegenerateDimensions(const Shape& shape) { DimensionVector new_dims; - c_copy_if(shape.dimensions(), std::back_inserter(new_dims), - [](int64 dim) { return dim != 1; }); + absl::c_copy_if(shape.dimensions(), std::back_inserter(new_dims), + [](int64 dim) { return dim != 1; }); return ShapeUtil::MakeShape(shape.element_type(), new_dims); } }; // namespace @@ -530,7 +532,7 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims( // element is true iff the i'th component of the result index is an output // index. - gtl::InlinedVector<bool, 6> output_dims_bitvector( + absl::InlinedVector<bool, 6> output_dims_bitvector( operand->shape().dimensions_size()); for (int64 output_dim : operand->output_dims()) { output_dims_bitvector[output_dim] = true; @@ -552,8 +554,8 @@ StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::ReshapeToAddDegenerateDims( }(); DimensionVector new_result_shape_dims; - c_copy(operand->shape().dimensions(), - std::back_inserter(new_result_shape_dims)); + absl::c_copy(operand->shape().dimensions(), + std::back_inserter(new_result_shape_dims)); for (int64 degenerate_dim : degenerate_dims) { InsertAt(&new_result_shape_dims, degenerate_dim, 1); } @@ -694,8 +696,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( operand_dim); }; - if (!c_all_of(scalar_indexed->output_dims(), - is_reshape_passthrough_operand_dim)) { + if (!absl::c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { VLOG(3) << "Not all output dims are passthrough dims " << ToString(scalar_indexed); return nullptr; @@ -753,9 +755,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( if (source_dim_for_new_scalar_indexed_node == -1) { VLOG(3) << "Could not compute the source dim for the new scalar indexed " "node: scalar_indexed_source_shape = [" - << Join(scalar_indexed_source_shape.dimensions(), ",") + << StrJoin(scalar_indexed_source_shape.dimensions(), ",") << "] and new_scalar_indexed_source_shape = [" - << Join(new_scalar_indexed_source_shape, ",") << "]"; + << StrJoin(new_scalar_indexed_source_shape, ",") << "]"; return nullptr; } @@ -763,8 +765,8 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); - CHECK_EQ(c_accumulate(new_scalar_indexed_source_shape, 1LL, - std::multiplies<int64>()), + CHECK_EQ(absl::c_accumulate(new_scalar_indexed_source_shape, 1LL, + std::multiplies<int64>()), ShapeUtil::ElementsIn(scalar_indexed_source_shape)); CHECK(IsReshapePassthroughOperandDim( @@ -780,9 +782,9 @@ IndexedArrayAnalysis::FoldReshapeOfGatherNoDegenerateDims( }; std::vector<int64> output_dims_for_new_scalar_indexed_node; - c_transform(scalar_indexed->output_dims(), - std::back_inserter(output_dims_for_new_scalar_indexed_node), - map_passthrough_operand_dim_to_result_dim); + absl::c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, TakeOwnership(scalar_indexed->literal().Reshape( @@ -873,11 +875,12 @@ IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, ArraySlice<int64> broadcast_dims = broadcast_instr->dimensions(); auto is_broadcasted_dim = [&](int64 output_dim) { - return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + return absl::c_find(broadcast_dims, output_dim) == broadcast_dims.end(); }; // All of the output dims must be "broadcasted" dims for the other operand. - if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + if (!absl::c_all_of(scalar_indexed_const->output_dims(), + is_broadcasted_dim)) { return nullptr; } @@ -969,15 +972,15 @@ namespace { // Returns the non-contracting non-batch dimension (as per `contracting_dims` // and `batch_dims`) if there is exactly one, otherwise returns nullopt. -gtl::optional<int64> GetOnlyNonContractingNonBatchDim( +absl::optional<int64> GetOnlyNonContractingNonBatchDim( int64 rank, ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) { - gtl::optional<int64> result; + absl::optional<int64> result; for (int64 dim = 0; dim < rank; dim++) { if (!ArrayContains(contracting_dims, dim) && !ArrayContains(batch_dims, dim)) { if (result.has_value()) { - return gtl::nullopt; + return absl::nullopt; } result = dim; } @@ -994,10 +997,9 @@ gtl::optional<int64> GetOnlyNonContractingNonBatchDim( // `contracting_dims` and `batch_dims` are the contracting and batch dimensions // of whatever operand `indexed_array` is to the dot (LHS or RHS). bool CanFoldDotIntoIndexedArray( - tensorflow::StringPiece tag, - Analysis::ScalarIndexedConstantArray* indexed_array, + absl::string_view tag, Analysis::ScalarIndexedConstantArray* indexed_array, ArraySlice<int64> contracting_dims, ArraySlice<int64> batch_dims) { - gtl::optional<int64> non_contracting_non_batch_dim = + absl::optional<int64> non_contracting_non_batch_dim = GetOnlyNonContractingNonBatchDim(ShapeUtil::Rank(indexed_array->shape()), contracting_dims, batch_dims); if (!non_contracting_non_batch_dim.has_value()) { @@ -1132,7 +1134,7 @@ StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForDot( return nullptr; } -tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { +absl::string_view IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 675eb31d26..3fa7d749e1 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -371,7 +371,7 @@ class IndexedArrayAnalysis { // unconditionally add to the regular HLO pass pipeline. class IndexedArrayAnalysisPrinterPass : public HloPassInterface { public: - tensorflow::StringPiece name() const override; + absl::string_view name() const override; StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 97052edf7d..c34c32f7d3 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -22,6 +22,11 @@ limitations under the License. namespace xla { namespace { class IndexedArrayAnalysisTest : public HloVerifiedTestBase { + public: + IndexedArrayAnalysisTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { @@ -634,9 +639,9 @@ ENTRY main { AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( (scalar-indexed-const (constant f32[3,4] f32[3,4] { - { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, - { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, - { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } + { 0.761594, 0.964028, 0.995055, 0.999329 }, + { 0.761594, 0.995055, 0.964028, 0.999329 }, + { 0.999329, 0.995055, 0.964028, 0.761594 } }) %indices 0->[0]))"); } diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h index a523811f6c..efa8ed3abc 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/inliner.h @@ -27,7 +27,7 @@ namespace xla { class Inliner : public HloPassInterface { public: ~Inliner() override = default; - tensorflow::StringPiece name() const override { return "inline"; } + absl::string_view name() const override { return "inline"; } // Run inlining on the given computation. Returns whether the computation was // changed. diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 32937b33b3..5695bc2420 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f33942d679..83313c7ec1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -21,6 +21,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" @@ -121,6 +122,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDomain: @@ -130,7 +132,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: - case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kLog1p: case HloOpcode::kMap: @@ -189,13 +190,13 @@ bool InstructionFusion::CanFuseOnAllPaths( if (consumer == producer) { return true; } - if (!consumer->IsFusable()) { + if (!consumer->IsFusible()) { return false; } for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter - // whether it's fusable. + // whether it's fusible. if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } @@ -205,7 +206,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for - // producer to be fusable into consumer on all paths. + // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { @@ -216,7 +217,7 @@ bool InstructionFusion::CanFuseOnAllPaths( } InstructionFusion::HloInstructionSet -InstructionFusion::ComputeGloballyUnfusable( +InstructionFusion::ComputeGloballyUnfusible( tensorflow::gtl::ArraySlice<HloInstruction*> post_order) { // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers @@ -270,19 +271,19 @@ InstructionFusion::ComputeGloballyUnfusable( // all of its consumers on all paths. // // That means, that for: - // A --> B (fusable) - // \-> C (non-fusable) + // A --> B (fusible) + // \-> C (non-fusible) // A will be not allowed to be fused into B, as it cannot be fused into C. // // Similarly, for: // A -------------> B // \-> C -> D -/ // If: - // - A is fusable into B and C, and D is fusable into B - // - C is *not* fusable into D + // - A is fusible into B and C, and D is fusible into B + // - C is *not* fusible into D // A will be not allowed to be fused into B, as it cannot be fused via // all paths. - if (producer->IsFusable() && + if (producer->IsFusible() && CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } @@ -318,7 +319,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusible(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -341,7 +342,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { // consistent. post_order_index.erase(instruction); - if (!instruction->IsFusable() && + if (!instruction->IsFusible() && instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -413,7 +414,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) { for (int64 i : sorted_operand_numbers) { HloInstruction* operand = instruction->mutable_operand(i); - if (!operand->IsFusable()) { + if (!operand->IsFusible()) { continue; } @@ -497,7 +498,7 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( bool InstructionFusion::MultiOutputFusionCreatesCycle( HloInstruction* producer, HloInstruction* consumer) { - return c_any_of( + return absl::c_any_of( consumer->operands(), [&](const HloInstruction* consumer_operand) { // The fusion algorithm traverses the HLO graph in reverse post order. // Thus `cosumers` is visited before its operands (including diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index f73ca9adf7..9802d4cfc1 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -36,7 +36,7 @@ class InstructionFusion : public HloPassInterface { bool may_duplicate = true) : is_expensive_(is_expensive), may_duplicate_(may_duplicate) {} ~InstructionFusion() override = default; - tensorflow::StringPiece name() const override { return "fusion"; } + absl::string_view name() const override { return "fusion"; } // Run instruction fusion on the given computation. Returns whether the // computation was changed (instructions were fused). @@ -122,7 +122,7 @@ class InstructionFusion : public HloPassInterface { // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - HloInstructionSet ComputeGloballyUnfusable( + HloInstructionSet ComputeGloballyUnfusible( tensorflow::gtl::ArraySlice<HloInstruction*> post_order); // Used to determine if an HLO is expensive. Expensive operations will not be diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index 9e7a15f033..da1ad90959 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -158,7 +158,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusable) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusible) { HloComputation::Builder builder(TestName()); auto shape = ShapeUtil::MakeShape(F32, {16, 16}); auto param0 = @@ -216,7 +216,7 @@ TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 1) << module->ToString(); } -TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { +TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { // Make sure we do not duplicate the add, as we cannot fuse through the rng. // // p0 -> add -------------------------> sub @@ -309,7 +309,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { EXPECT_EQ(Count(*module, HloOpcode::kAdd), 2) << module->ToString(); // A variant of the above that allows the algorithm to put add2 into the set - // of unfusable ops to short-circuit the decision whether add1 should be fused + // of unfusible ops to short-circuit the decision whether add1 should be fused // into sub2. // // /---------------\ diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8652599dc6..581f8d2e92 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -12,12 +12,11 @@ cc_library( srcs = ["interpreter_transfer_manager.cc"], hdrs = ["interpreter_transfer_manager.h"], deps = [ - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform_id", "//tensorflow/core:lib", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains per-platform transfer manager registration ) @@ -32,8 +31,6 @@ cc_library( "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", @@ -54,6 +51,7 @@ cc_library( "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", "//tensorflow/stream_executor", + "@com_google_absl//absl/memory", ], alwayslink = True, # Contains compiler registration ) @@ -79,7 +77,6 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", @@ -91,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/memory", ], ) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 9f8f4bda87..bb69cb9c47 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -18,7 +18,7 @@ limitations under the License. #include <string> #include <utility> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -69,8 +69,8 @@ StatusOr<std::unique_ptr<Executable>> InterpreterCompiler::RunBackend( // Create executable from only the Hlo module. std::unique_ptr<Executable> executable = - xla::MakeUnique<InterpreterExecutable>(std::move(hlo_module), - xla::MakeUnique<HloEvaluator>()); + absl::make_unique<InterpreterExecutable>( + std::move(hlo_module), absl::make_unique<HloEvaluator>()); return std::move(executable); } @@ -103,11 +103,11 @@ HloCostAnalysis::ShapeSizeFunction InterpreterCompiler::ShapeSizeBytesFunction() static bool InitModule() { xla::Compiler::RegisterCompilerFactory( se::interpreter::kXlaInterpreterPlatformId, []() { - return xla::MakeUnique<xla::interpreter::InterpreterCompiler>(); + return absl::make_unique<xla::interpreter::InterpreterCompiler>(); }); xla::ComputationPlacer::RegisterComputationPlacer( se::interpreter::kXlaInterpreterPlatformId, - []() { return xla::MakeUnique<xla::ComputationPlacer>(); }); + []() { return absl::make_unique<xla::ComputationPlacer>(); }); return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 8d40c08d55..2259dc1083 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -21,8 +21,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index d27cd7502f..7955ee5cf3 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -17,7 +17,7 @@ limitations under the License. #include <memory> -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -31,7 +31,7 @@ InterpreterTransferManager::InterpreterTransferManager() static std::unique_ptr<xla::TransferManager> CreateInterpreterTransferManager() { - return xla::MakeUnique<xla::InterpreterTransferManager>(); + return absl::make_unique<xla::InterpreterTransferManager>(); } static bool InitModule() { diff --git a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h index 2b44f30821..b732230fdd 100644 --- a/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" #include "tensorflow/core/platform/macros.h" @@ -33,4 +33,4 @@ class InterpreterTransferManager : public GenericTransferManager { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_TRANSFER_MANAGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERPRETER_INTERPRETER_TRANSFER_MANAGER_H_ diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc index 42c2c28997..c9b40d3c61 100644 --- a/tensorflow/compiler/xla/service/interpreter/platform.cc +++ b/tensorflow/compiler/xla/service/interpreter/platform.cc @@ -17,13 +17,14 @@ limitations under the License. #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" #include "tensorflow/stream_executor/device_options.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status_macros.h" -#include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/multi_platform_manager.h" #include "tensorflow/stream_executor/platform.h" @@ -70,15 +71,15 @@ port::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor( port::StatusOr<std::unique_ptr<StreamExecutor>> XlaInterpreterPlatform::GetUncachedExecutor( const StreamExecutorConfig& config) { - auto executor = MakeUnique<StreamExecutor>( - this, MakeUnique<XlaInterpreterExecutor>(config.plugin_config)); + auto executor = absl::make_unique<StreamExecutor>( + this, absl::make_unique<XlaInterpreterExecutor>(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ port::error::INTERNAL, - port::Printf( + absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; + config.ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 805fdb2d5b..5e5c93e3a2 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -26,9 +26,12 @@ limitations under the License. #include <string> #include <tuple> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -49,20 +52,11 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" namespace xla { -// For now moving only one API here, but we should have a single top level -// anonymous namespace, instead of three or four spread all over this file. -namespace { - -} // namespace - std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -77,9 +71,8 @@ BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, } string BufferLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s", - buffer_->ToString().c_str(), - LayoutUtil::HumanString(layout_).c_str()); + return absl::StrFormat("BufferLayoutConstraint %s: %s", buffer_->ToString(), + LayoutUtil::HumanString(layout_)); } OperandLayoutConstraint::OperandLayoutConstraint( @@ -98,15 +91,14 @@ OperandLayoutConstraint::OperandLayoutConstraint( } string OperandLayoutConstraint::ToString() const { - return tensorflow::strings::Printf( - "OperandLayoutConstraint %s, operand %lld: %s", - instruction_->name().c_str(), operand_no_, - shape_layout_.ToString().c_str()); + return absl::StrFormat("OperandLayoutConstraint %s, operand %d: %s", + instruction_->name(), operand_no_, + shape_layout_.ToString()); } string ResultLayoutConstraint::ToString() const { - return tensorflow::strings::Printf("ResultLayoutConstraint: %s", - shape_layout_.ToString().c_str()); + return absl::StrFormat("ResultLayoutConstraint: %s", + shape_layout_.ToString()); } LayoutConstraints::LayoutConstraints( @@ -137,7 +129,7 @@ PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( } auto& buffer_set = buffer_sets_cache_ - .emplace(instruction, MakeUnique<PointsToSet::BufferSet>()) + .emplace(instruction, absl::make_unique<PointsToSet::BufferSet>()) .first->second; const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); points_to_set.ForEachElement( @@ -174,8 +166,7 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Layout of buffer %s cannot be constrained because buffer is not " "array-shaped, has shape: %s", - buffer.ToString().c_str(), - ShapeUtil::HumanString(buffer.shape()).c_str()); + buffer.ToString(), ShapeUtil::HumanString(buffer.shape())); } TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); @@ -191,9 +182,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", - buffer.ToString().c_str(), - LayoutUtil::HumanString(curr_constraint.layout()).c_str(), - LayoutUtil::HumanString(layout).c_str()); + buffer.ToString(), LayoutUtil::HumanString(curr_constraint.layout()), + LayoutUtil::HumanString(layout)); } iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } else { @@ -227,11 +217,11 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, } if (curr_shape_layout->mandatory()) { return FailedPrecondition( - "Operand %lld of instruction %s already has a layout constraint " + "Operand %d of instruction %s already has a layout constraint " "%s, cannot add incompatible constraint %s", - operand_no, instruction->name().c_str(), - curr_shape_layout->shape_layout().ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + operand_no, instruction->name(), + curr_shape_layout->shape_layout().ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } } @@ -240,9 +230,9 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, // layouts beyond this immediate use and is complicated to handle. if (OperandBufferForwarded(instruction, operand_no)) { return FailedPrecondition( - "Cannot constraint layout of operand %lld of instruction %s " + "Cannot constraint layout of operand %d of instruction %s " "because instruction forwards operand's LogicalBuffer(s)", - operand_no, instruction->name().c_str()); + operand_no, instruction->name()); } auto key = std::make_pair(instruction, operand_no); @@ -284,8 +274,8 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, return FailedPrecondition( "Result of computation %s already has the layout constraint %s, " "cannot add incompatible constraint %s", - computation_->name().c_str(), curr_shape_layout->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + computation_->name(), curr_shape_layout->ToString(), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // New constraint matches existing constraint. Nothing to do. return Status::OK(); @@ -307,9 +297,8 @@ Status LayoutConstraints::SetInstructionLayout( if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) { return FailedPrecondition( "Instruction %s of shape %s cannot be assigned incompatible layout %s", - instruction->name().c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str()); + instruction->name(), ShapeUtil::HumanString(instruction->shape()), + ShapeUtil::HumanStringWithLayout(shape_with_layout)); } // Create a BufferLayoutConstraint for each array shape in the output of the @@ -368,31 +357,27 @@ const ShapeLayout* LayoutConstraints::ResultLayout() const { string LayoutConstraints::ToString() const { string output; - tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ", - computation_->name(), ":\n"); + absl::StrAppend(&output, "LayoutConstraints for computation ", + computation_->name(), ":\n"); for (auto* instruction : computation_->MakeInstructionPostOrder()) { - tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(), - "\n"); + absl::StrAppend(&output, " ", instruction->ToShortString(), "\n"); for (int64 i = 0; i < instruction->operand_count(); ++i) { if (OperandLayout(instruction, i) != nullptr) { - tensorflow::strings::StrAppend( - &output, " operand (", i, - "): ", OperandLayout(instruction, i)->ToString(), "\n"); + absl::StrAppend(&output, " operand (", i, + "): ", OperandLayout(instruction, i)->ToString(), "\n"); } } for (const LogicalBuffer* buffer : points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { if (BufferLayout(*buffer) != nullptr) { - tensorflow::strings::StrAppend( - &output, " ", buffer->ToString(), " : ", - LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); + absl::StrAppend(&output, " ", buffer->ToString(), " : ", + LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n"); } } } if (ResultLayout() != nullptr) { - tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(), - "\n"); + absl::StrAppend(&output, " => ", ResultLayout()->ToString(), "\n"); } return output; } @@ -763,7 +748,7 @@ Status CheckParameterLayout(HloInstruction* parameter, return InternalError( "parameter instruction %s does not match layout of computation " "shape: %s", - parameter->ToString().c_str(), parameter_layout.ToString().c_str()); + parameter->ToString(), parameter_layout.ToString()); } return Status::OK(); } @@ -774,8 +759,8 @@ Status CheckConstantLayout(HloInstruction* constant) { constant->shape())) { return InternalError( "constant instruction %s does not match the layout of its literal %s", - constant->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str()); + constant->ToString(), + ShapeUtil::HumanStringWithLayout(constant->literal().shape())); } return Status::OK(); } @@ -908,13 +893,10 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str(), - buffer->ToString().c_str(), - ShapeUtil::HumanStringWithLayout(instruction_subshape) - .c_str(), - ShapeUtil::HumanStringWithLayout(buffer->shape()) - .c_str()); + instruction->name(), absl::StrJoin(index, ","), + buffer->ToString(), + ShapeUtil::HumanStringWithLayout(instruction_subshape), + ShapeUtil::HumanStringWithLayout(buffer->shape())); } } } @@ -998,17 +980,18 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( CHECK(ShapeUtil::IsArray(instruction->shape())); CHECK(ShapeUtil::IsArray(operand->shape())); - if (instruction->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && + if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == - ShapeUtil::Rank(instruction->shape())) { - // Assign operands the same layout as the instruction, so that + ShapeUtil::Rank(instruction->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + // Propagate the result layout to the operand layout if the instruction + // requires the same layout out for the result and the operand. + // + // For elementwise operations, using the same layout for the operands and + // the result also has the following benefits: // 1) the elementwise operation can reuse its operand's buffer, and // 2) the input and output elements can reuse the same linear index. - // - // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit - // from assigning the same layout to input and output. - return MakeUnique<Layout>(output_layout); + return absl::make_unique<Layout>(output_layout); } if (instruction->opcode() == HloOpcode::kReshape) { @@ -1031,13 +1014,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique<Layout>(operand_shape.layout()); + return absl::make_unique<Layout>(operand_shape.layout()); } if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { *operand_shape.mutable_layout() = output_layout; if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { - return MakeUnique<Layout>(output_layout); + return absl::make_unique<Layout>(output_layout); } } auto aligned_operand_shape = @@ -1046,7 +1029,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( auto operand_layout = aligned_operand_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand_shape)); - return MakeUnique<Layout>(operand_layout); + return absl::make_unique<Layout>(operand_layout); } } @@ -1062,7 +1045,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); - return MakeUnique<Layout>(operand_layout); + return absl::make_unique<Layout>(operand_layout); } return nullptr; @@ -1076,11 +1059,11 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( CHECK(ShapeUtil::IsArray(user->shape()) && ShapeUtil::IsArray(operand->shape())); - if (user->IsElementwiseOnOperand(operand_no) && - !ShapeUtil::IsScalar(operand->shape()) && - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) { + if (!ShapeUtil::IsScalar(operand->shape()) && + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && + InstructionRequiresInputLayoutEqualToOutputLayout(user)) { // Assign users the same layout as the operand. - return MakeUnique<Layout>(operand_layout); + return absl::make_unique<Layout>(operand_layout); } if (user->opcode() == HloOpcode::kReshape) { @@ -1103,13 +1086,13 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique<Layout>(output_shape.layout()); + return absl::make_unique<Layout>(output_shape.layout()); } if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { *output_shape.mutable_layout() = operand_layout; if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { - return MakeUnique<Layout>(operand_layout); + return absl::make_unique<Layout>(operand_layout); } } auto aligned_user_shape = @@ -1118,7 +1101,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( auto user_layout = aligned_user_shape.value().layout(); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(user_layout, output_shape)); - return MakeUnique<Layout>(user_layout); + return absl::make_unique<Layout>(user_layout); } } @@ -1134,7 +1117,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); - return MakeUnique<Layout>(user_layout); + return absl::make_unique<Layout>(user_layout); } return nullptr; @@ -1385,7 +1368,7 @@ StatusOr<Layout> InferArrayLayout( // This should not happen because we've assigned layouts to all // instructions preceding this one. return InternalError("LogicalBuffer %s does not have a layout", - source_buffer->ToString().c_str()); + source_buffer->ToString()); } if (first_buffer_layout == nullptr) { @@ -1400,9 +1383,8 @@ StatusOr<Layout> InferArrayLayout( return FailedPrecondition( "Array at index {%s} in instruction %s aliases buffers %s " "and %s which have different layouts", - tensorflow::str_util::Join(index, ",").c_str(), - instruction->name().c_str(), source_buffers[0]->ToString().c_str(), - source_buffer->ToString().c_str()); + absl::StrJoin(index, ","), instruction->name(), + source_buffers[0]->ToString(), source_buffer->ToString()); } } @@ -1570,7 +1552,7 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // present in the IR before layout assignment is a bug. return InternalError( "Unexpected bitcast operation seen during layout assignment: %s.", - instruction->ToString().c_str()); + instruction->ToString()); } if (instruction->opcode() != HloOpcode::kInfeed) { LayoutUtil::ClearLayout(instruction->mutable_shape()); @@ -1822,6 +1804,107 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) { return true; } +bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction) { + switch (instruction->opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kAdd: + case HloOpcode::kAnd: + case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCeil: + case HloOpcode::kClamp: + case HloOpcode::kClz: + case HloOpcode::kComplex: + case HloOpcode::kConcatenate: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCos: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kAllToAll: + case HloOpcode::kCollectivePermute: + case HloOpcode::kCustomCall: + case HloOpcode::kDivide: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kEq: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFft: + case HloOpcode::kFloor: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLe: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kLt: + case HloOpcode::kMap: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: + case HloOpcode::kXor: + case HloOpcode::kPad: + case HloOpcode::kPower: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kReduceWindow: + case HloOpcode::kRemainder: + case HloOpcode::kReverse: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSelect: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kSubtract: + case HloOpcode::kTanh: + case HloOpcode::kTupleSelect: + case HloOpcode::kWhile: + return true; + case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kCall: + case HloOpcode::kConstant: + case HloOpcode::kConvolution: + case HloOpcode::kCopy: + case HloOpcode::kDomain: + case HloOpcode::kDot: + case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kIota: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReduce: + case HloOpcode::kReshape: + case HloOpcode::kRng: + case HloOpcode::kScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kAfterAll: + case HloOpcode::kTrace: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return false; + } +} + Status LayoutAssignment::Init() { computation_layouts_.clear(); *entry_computation_layout_ = saved_entry_computation_layout_; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index f9e8dbea2f..cf545031d3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -297,12 +297,17 @@ class LayoutAssignment : public HloPassInterface { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} - tensorflow::StringPiece name() const override { return "layout-assignment"; } + absl::string_view name() const override { return "layout-assignment"; } // Assign layouts to the given module. Returns whether the module was changed // (any layouts were changed). StatusOr<bool> Run(HloModule* module) override; + // Returns true if the instruction requires that operands with the same rank + // as the output have to have the same layout as the output. + virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( + const HloInstruction* instruction); + protected: // These methods, invoked by PropagateConstraints, propagate a layout // constraint to its neighbors (i.e. operands and users) in order to minimize diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index a16fa75e30..7505d7a5b3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -59,7 +59,7 @@ class LayoutAssignmentTest : public HloTestBase { EXPECT_IS_OK(layout_assignment.Run(module).status()); } - std::vector<int64> LayoutOf(HloModule* module, tensorflow::StringPiece name) { + std::vector<int64> LayoutOf(HloModule* module, absl::string_view name) { auto minor_to_major = FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector<int64>(minor_to_major.begin(), minor_to_major.end()); @@ -861,5 +861,115 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopySliceOperandToAvoidImplicitLayoutChange + + ENTRY CopySliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + slice0 = f32[3,4] slice(par1), slice={[1:4],[1:5]} + ROOT add0 = f32[3,4]{1,0} add(par0,slice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto slice = FindInstruction(module.get(), "slice0"); + EXPECT_EQ(slice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyDSliceOperandToAvoidImplicitLayoutChange + + ENTRY CopyDSliceOperandToAvoidImplicitLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[4,5]{0,1} parameter(1) + par2 = s32[2] parameter(2) + dslice0 = f32[3,4] dynamic-slice(par1, par2), dynamic_slice_sizes={3,4} + ROOT add0 = f32[3,4]{1,0} add(par0,dslice0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto dslice = FindInstruction(module.get(), "dslice0"); + EXPECT_EQ(dslice->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { + const char* module_str = R"( + HloModule CopyConcatOperandToAvoidImplicitLayoutChange + + ENTRY CopyConcatOperandToAvoidImplicitLayoutChange { + par0 = f32[3,8]{1,0} parameter(0) + par1 = f32[3,5]{0,1} parameter(1) + par2 = f32[3,3]{1,0} parameter(2) + concat0 = f32[3,8] concatenate(f32[3,5] par1, f32[3,3] par2), + dimensions={1} + ROOT add0 = f32[3,8]{1,0} add(par0,concat0) + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + auto concat = FindInstruction(module.get(), "concat0"); + EXPECT_EQ(concat->operand(0), copy); + EXPECT_TRUE( + LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); +} + +TEST_F(LayoutAssignmentTest, + ConvolutionOperandWithImplicitLayoutChangeNotCopied) { + const char* module_str = R"( + HloModule ConvolutionOperandWithImplicitLayoutChangeNotCopied + + ENTRY ConvolutionOperandWithImplicitLayoutChangeNotCopied { + par0 = f32[128,3,230,230]{2,3,1,0} parameter(0) + par1 = f32[7,7,3,64]{3,2,0,1} parameter(1) + ROOT convolution0 = f32[128,64,112,112]{3,2,1,0} convolution(par0, par1), + window={size=7x7 stride=2x2}, dim_labels=bf01_01io->bf01, + feature_group_count=1 + } + )"; + + auto module = ParseHloString(module_str).ValueOrDie(); + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + auto copy = FindInstruction(module.get(), "copy.1"); + EXPECT_EQ(copy, nullptr); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index cdd3daf73b..be12d7c90c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -69,6 +70,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", "@llvm//:support", "@llvm//:target", @@ -88,6 +90,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -103,6 +107,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -120,6 +125,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/strings:str_format", "@llvm//:core", ], ) @@ -133,9 +139,7 @@ cc_library( ":llvm_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "@llvm//:core", @@ -193,6 +197,8 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", "//tensorflow/compiler/xla/service/gpu:partition_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@llvm//:core", ], ) @@ -219,7 +225,7 @@ cc_library( deps = [ ":llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@llvm//:core", ], ) @@ -230,6 +236,7 @@ cc_library( hdrs = ["buffer_assignment_util.h"], deps = [ "//tensorflow/compiler/xla/service:buffer_assignment", + "@com_google_absl//absl/strings", ], ) @@ -242,3 +249,12 @@ cc_library( "@llvm//:core", ], ) + +cc_library( + name = "ir_builder_mixin", + srcs = [], + hdrs = ["ir_builder_mixin.h"], + deps = [ + "@llvm//:core", + ], +) diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index fe9eab93aa..8d9fa99d82 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace llvm_ir { diff --git a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc index 4eb5d9fb47..bdce4a171b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h" +#include "absl/strings/str_cat.h" namespace xla { namespace llvm_ir { @@ -48,7 +49,7 @@ string ConstantBufferAllocationToGlobalName( c = '_'; } } - return tensorflow::strings::StrCat("buffer_for_", instr_name); + return absl::StrCat("buffer_for_", instr_name); } const Literal& LiteralForConstantAllocation( diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc index 27fbb11e2e..ad350613dd 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc @@ -40,7 +40,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( const Shape& update_shape, const ElementGenerator& start_indices_generator, bool is_signed, ElementGenerator update_array_generator, const IrArray& output_array, const gpu::LaunchDimensions* launch_dimensions, - tensorflow::StringPiece name, llvm::IRBuilder<>* b) { + absl::string_view name, llvm::IRBuilder<>* b) { const Shape& output_shape = output_array.GetShape(); // Read start indices from start_indices_generator. @@ -101,8 +101,7 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice<IrArray> operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b) { + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b) { VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; // No need to use operand_arrays[0], the input array of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h index 3502577d23..e1631a62ae 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h @@ -65,8 +65,7 @@ inline bool CanEmitFusedDynamicUpdateSliceInPlace( // modify the input/output buffer without touching any of the other elements. Status EmitDynamicUpdateSliceInPlace( tensorflow::gtl::ArraySlice<IrArray> operand_arrays, - const IrArray& output_array, tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + const IrArray& output_array, absl::string_view name, llvm::IRBuilder<>* b); // Given a loop-fusion node whose root is a dynamic-update-slice op whose // array-to-be-updated and output share the same buffer slice, emits diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 72ede377e1..6d637cad6d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -98,7 +98,7 @@ Status FusedIrEmitter::HandleGetTupleElement( return Unimplemented( "GetTupleElement fusion currently only supports" " parameter operands, but found operand: %s", - operand->name().c_str()); + operand->name()); } // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 2b6caee6aa..6971220022 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -342,9 +342,9 @@ llvm::Value* IrArray::Index::Linearize( return logical_linear_index; } -llvm::Value* IrArray::EmitArrayElementAddress( - const IrArray::Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { +llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, + llvm::IRBuilder<>* b, + absl::string_view name) const { if (ShapeUtil::IsScalar(*shape_)) { // Special handling of scalars: a scalar pretends to have the same value for // every index, thus effectively implementing broadcasting of its value @@ -402,7 +402,7 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name) const { + absl::string_view name) const { llvm::Value* element_address = EmitArrayElementAddress(index, b, name); llvm::LoadInst* load = b->CreateLoad(element_address); AnnotateLoadStoreInstructionWithMetadata(load); diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 28ca793e3e..e913c109b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -19,12 +19,13 @@ limitations under the License. #include <map> #include <vector> +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -81,7 +82,7 @@ class IrArray { } } CHECK_NE(index_type_, nullptr); - CHECK(c_all_of(multidim, [&](llvm::Value* v) { + CHECK(absl::c_all_of(multidim, [&](llvm::Value* v) { return index_type_ == v->getType(); })); } @@ -240,7 +241,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitArrayElementAddress(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Attach metadata this IrArray instance knows about to "instruction". void AnnotateLoadStoreInstructionWithMetadata( @@ -254,7 +255,7 @@ class IrArray { // The optional name is useful for debugging when looking at // the emitted LLVM IR. llvm::Value* EmitReadArrayElement(const Index& index, llvm::IRBuilder<>* b, - tensorflow::StringPiece name = "") const; + absl::string_view name = "") const; // Emit IR to write the given value to the array element at the given index. void EmitWriteArrayElement(const Index& index, llvm::Value* value, diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h new file mode 100644 index 0000000000..abc06fb7b4 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h @@ -0,0 +1,400 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ + +#include "llvm/IR/IRBuilder.h" + +namespace xla { + +// Mixin class that injects more ergonomic versions of llvm::IRBuilder methods +// into a class. Intended to be used as a CRTP base class, like: +// +// class MyIrEmitter : public IrBuilderMixin<MyIrEmitter> { +// llvm::IRBuilder<>* builder() { return builder_; } +// +// void EmitFoo(HloInstruction* foo) { +// Add(Mul(...), FPToUI(...)); +// } +// }; + +template <typename Derived> +class IrBuilderMixin { + protected: + template <class... Args> + llvm::Value* Add(Args&&... args) { + return mixin_builder()->CreateAdd(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::LoadInst* AlignedLoad(Args&&... args) { + return mixin_builder()->CreateAlignedLoad(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::StoreInst* AlignedStore(Args&&... args) { + return mixin_builder()->CreateAlignedStore(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::AllocaInst* Alloca(Args&&... args) { + return mixin_builder()->CreateAlloca(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* And(Args&&... args) { + return mixin_builder()->CreateAnd(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* AtomicCmpXchg(Args&&... args) { + return mixin_builder()->CreateAtomicCmpXchg(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* AtomicRMW(Args&&... args) { + return mixin_builder()->CreateAtomicRMW(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* BitCast(Args&&... args) { + return mixin_builder()->CreateBitCast(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Br(Args&&... args) { + return mixin_builder()->CreateBr(std::forward<Args>(args)...); + } + + llvm::CallInst* Call(llvm::Value* callee, + llvm::ArrayRef<llvm::Value*> args = llvm::None, + const llvm::Twine& name = "", + llvm::MDNode* fp_math_tag = nullptr) { + return mixin_builder()->CreateCall(callee, args, name, fp_math_tag); + } + + template <class... Args> + llvm::BranchInst* CondBr(Args&&... args) { + return mixin_builder()->CreateCondBr(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ConstInBoundsGEP1_32(Args&&... args) { + return mixin_builder()->CreateConstInBoundsGEP1_32( + std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FAdd(Args&&... args) { + return mixin_builder()->CreateFAdd(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FMul(Args&&... args) { + return mixin_builder()->CreateFMul(std::forward<Args>(args)...); + } + + llvm::Value* GEP(llvm::Value* ptr, llvm::ArrayRef<llvm::Value*> idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateGEP(ptr, idx_list, name); + } + + template <class... Args> + llvm::Value* ICmpEQ(Args&&... args) { + return mixin_builder()->CreateICmpEQ(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ICmpNE(Args&&... args) { + return mixin_builder()->CreateICmpNE(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ICmpULE(Args&&... args) { + return mixin_builder()->CreateICmpULE(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ICmpULT(Args&&... args) { + return mixin_builder()->CreateICmpULT(std::forward<Args>(args)...); + } + + llvm::Value* InBoundsGEP(llvm::Value* ptr, + llvm::ArrayRef<llvm::Value*> idx_list, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInBoundsGEP(ptr, idx_list, name); + } + + llvm::Value* ExtractValue(llvm::Value* agg, llvm::ArrayRef<unsigned> idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateExtractValue(agg, idxs, name); + } + + llvm::Value* InsertValue(llvm::Value* agg, llvm::Value* val, + llvm::ArrayRef<unsigned> idxs, + const llvm::Twine& name = "") { + return mixin_builder()->CreateInsertValue(agg, val, idxs, name); + } + + template <class... Args> + llvm::Value* IntToPtr(Args&&... args) { + return mixin_builder()->CreateIntToPtr(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::LoadInst* Load(Args&&... args) { + return mixin_builder()->CreateLoad(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::CallInst* MemCpy(Args&&... args) { + return mixin_builder()->CreateMemCpy(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Mul(Args&&... args) { + return mixin_builder()->CreateMul(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* NSWAdd(Args&&... args) { + return mixin_builder()->CreateNSWAdd(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* NSWMul(Args&&... args) { + return mixin_builder()->CreateNSWMul(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* NSWSub(Args&&... args) { + return mixin_builder()->CreateNSWSub(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Or(Args&&... args) { + return mixin_builder()->CreateOr(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* PointerCast(Args&&... args) { + return mixin_builder()->CreatePointerCast(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* PtrToInt(Args&&... args) { + return mixin_builder()->CreatePtrToInt(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* SDiv(Args&&... args) { + return mixin_builder()->CreateSDiv(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Select(Args&&... args) { + return mixin_builder()->CreateSelect(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* SRem(Args&&... args) { + return mixin_builder()->CreateSRem(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::StoreInst* Store(Args&&... args) { + return mixin_builder()->CreateStore(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* UDiv(Args&&... args) { + return mixin_builder()->CreateUDiv(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* URem(Args&&... args) { + return mixin_builder()->CreateURem(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* VectorSplat(Args&&... args) { + return mixin_builder()->CreateVectorSplat(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ZExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateZExtOrTrunc(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* AShr(Args&&... args) { + return mixin_builder()->CreateAShr(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FCmpOEQ(Args&&... args) { + return mixin_builder()->CreateFCmpOEQ(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FCmpOLT(Args&&... args) { + return mixin_builder()->CreateFCmpOLT(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FCmpONE(Args&&... args) { + return mixin_builder()->CreateFCmpONE(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FCmpUNE(Args&&... args) { + return mixin_builder()->CreateFCmpUNE(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FDiv(Args&&... args) { + return mixin_builder()->CreateFDiv(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FNeg(Args&&... args) { + return mixin_builder()->CreateFNeg(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FPCast(Args&&... args) { + return mixin_builder()->CreateFPCast(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FPToSI(Args&&... args) { + return mixin_builder()->CreateFPToSI(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FPToUI(Args&&... args) { + return mixin_builder()->CreateFPToUI(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FPTrunc(Args&&... args) { + return mixin_builder()->CreateFPTrunc(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FRem(Args&&... args) { + return mixin_builder()->CreateFRem(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* FSub(Args&&... args) { + return mixin_builder()->CreateFSub(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ICmpSGE(Args&&... args) { + return mixin_builder()->CreateICmpSGE(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* ICmpSLT(Args&&... args) { + return mixin_builder()->CreateICmpSLT(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* IntCast(Args&&... args) { + return mixin_builder()->CreateIntCast(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* LShr(Args&&... args) { + return mixin_builder()->CreateLShr(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* MemSet(Args&&... args) { + return mixin_builder()->CreateMemSet(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Neg(Args&&... args) { + return mixin_builder()->CreateNeg(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Not(Args&&... args) { + return mixin_builder()->CreateNot(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::PHINode* PHI(Args&&... args) { + return mixin_builder()->CreatePHI(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* RetVoid(Args&&... args) { + return mixin_builder()->CreateRetVoid(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* SExtOrTrunc(Args&&... args) { + return mixin_builder()->CreateSExtOrTrunc(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Shl(Args&&... args) { + return mixin_builder()->CreateShl(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* SIToFP(Args&&... args) { + return mixin_builder()->CreateSIToFP(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Sub(Args&&... args) { + return mixin_builder()->CreateSub(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Trunc(Args&&... args) { + return mixin_builder()->CreateTrunc(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* UIToFP(Args&&... args) { + return mixin_builder()->CreateUIToFP(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Unreachable(Args&&... args) { + return mixin_builder()->CreateUnreachable(std::forward<Args>(args)...); + } + + template <class... Args> + llvm::Value* Xor(Args&&... args) { + return mixin_builder()->CreateXor(std::forward<Args>(args)...); + } + + private: + llvm::IRBuilder<>* mixin_builder() { + return static_cast<Derived*>(this)->builder(); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_BUILDER_MIXIN_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index b79567369a..bd0139f85b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value*, bool)>& for_body_generator) { return If(b_->CreateICmpSLT(start, end), [&]() -> Status { @@ -30,7 +30,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<Status(llvm::Value*, llvm::Value*)>& for_body_generator) { @@ -56,7 +56,7 @@ Status KernelSupportLibrary::For( } Status KernelSupportLibrary::If( - tensorflow::StringPiece name, llvm::Value* condition, + absl::string_view name, llvm::Value* condition, const std::function<Status()>& true_block_generator, const std::function<Status()>& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, name, b_); @@ -70,7 +70,7 @@ Status KernelSupportLibrary::If( void KernelSupportLibrary::EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, + absl::string_view kernel_name, KernelSupportLibrary::ArgumentVector arguments, const std::function<void(KernelSupportLibrary::ArgumentVector)>& kernel_body_generator) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index b00f903d56..b152cf9275 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ #include <string> +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { // A thin wrapper around llvm_loop.h to make code generating structured control @@ -49,13 +49,13 @@ class KernelSupportLibrary { // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator); void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { @@ -67,7 +67,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + Status For(absl::string_view name, int64 start, int64 end, int64 step, const std::function<Status(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { @@ -77,7 +77,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), @@ -99,13 +99,13 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<Status(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& for_body_generator); - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, const std::function<void(llvm::Value* ind_var, @@ -119,7 +119,7 @@ class KernelSupportLibrary { })); } - Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + Status For(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function<Status(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& @@ -129,7 +129,7 @@ class KernelSupportLibrary { peel_first_iteration, for_body_generator); } - void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + void ForReturnVoid(absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, bool peel_first_iteration, const std::function<void(llvm::Value* ind_var, llvm::Value* is_first_iteration)>& @@ -140,7 +140,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, start, end, step, @@ -151,7 +151,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + absl::string_view name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, start, end, step, @@ -162,8 +162,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, start, end, llvm::ConstantInt::get(start->getType(), step), /*peel_first_iteration=*/false, @@ -173,8 +172,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, + absl::string_view name, llvm::Value* start, llvm::Value* end, int64 step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, start, end, llvm::ConstantInt::get(start->getType(), step), @@ -182,7 +180,7 @@ class KernelSupportLibrary { } Status For( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<Status(llvm::Value* ind_var)>& for_body_generator) { return For(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -190,7 +188,7 @@ class KernelSupportLibrary { } void ForReturnVoid( - tensorflow::StringPiece name, int64 start, int64 end, int64 step, + absl::string_view name, int64 start, int64 end, int64 step, const std::function<void(llvm::Value* ind_var)>& for_body_generator) { ForReturnVoid(name, /*start=*/b_->getInt64(start), /*end=*/b_->getInt64(end), @@ -203,7 +201,7 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - Status If(tensorflow::StringPiece name, llvm::Value* condition, + Status If(absl::string_view name, llvm::Value* condition, const std::function<Status()>& true_block_generator, const std::function<Status()>& false_block_generator = []() -> Status { return Status::OK(); }); @@ -222,7 +220,7 @@ class KernelSupportLibrary { IfReturnVoid("", condition, true_block_generator, false_block_generator); } - void IfReturnVoid(tensorflow::StringPiece name, llvm::Value* condition, + void IfReturnVoid(absl::string_view name, llvm::Value* condition, const std::function<void()>& true_block_generator, const std::function<void()>& false_block_generator = []() { }) { @@ -259,13 +257,13 @@ class KernelSupportLibrary { // Currently we only support at most one nullptr value in `arguments`. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, ArgumentVector arguments, + absl::string_view kernel_name, ArgumentVector arguments, const std::function<void(ArgumentVector)>& kernel_body_generator); // Thin wrappers around the more general EmitAndCallOutlinedKernel above. static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& kernel_body_generator) { @@ -278,7 +276,7 @@ class KernelSupportLibrary { static void EmitAndCallOutlinedKernel( bool enable_fast_math, bool optimize_for_size, llvm::IRBuilder<>* b, - tensorflow::StringPiece kernel_name, llvm::Value* arg0, llvm::Value* arg1, + absl::string_view kernel_name, llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, llvm::Value* arg3, const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, llvm::Value*)>& kernel_body_generator) { @@ -296,4 +294,4 @@ class KernelSupportLibrary { }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc index 35b3941272..cb4d1db997 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.cc @@ -55,10 +55,10 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs, } } // namespace -tensorflow::gtl::optional<std::vector<int64> > FindTranspose021( - const Shape& a, const Shape& b) { +absl::optional<std::vector<int64> > FindTranspose021(const Shape& a, + const Shape& b) { if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } std::vector<int64> perm(a.dimensions().size()); @@ -88,7 +88,7 @@ tensorflow::gtl::optional<std::vector<int64> > FindTranspose021( return dims_021; } - return tensorflow::gtl::nullopt; + return absl::nullopt; } IrArray::Index GetUnreducedOutputIndex( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h index ccb9b8ba3e..8bd06c42c3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h @@ -36,8 +36,8 @@ namespace llvm_ir { // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the // reduced shape of `b` or the 0-2-1 shape. -tensorflow::gtl::optional<std::vector<int64> > FindTranspose021(const Shape& a, - const Shape& b); +absl::optional<std::vector<int64> > FindTranspose021(const Shape& a, + const Shape& b); // Return the unreduced output index corresponding to the given reduced output // index. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index ba7f94834c..9f3329e7f0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/strings/str_cat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" @@ -25,19 +26,17 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, +ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) - : prefix_(std::string(prefix)), - suffix_(std::string(suffix)), + : prefix_(prefix), + suffix_(suffix), start_index_(start_index), end_index_(end_index), step_(step), @@ -46,9 +45,9 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, - UnrollMode unroll_mode, bool prevent_vectorization) { + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode, + bool prevent_vectorization) { std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index, end_index, step, unroll_mode, prevent_vectorization)); @@ -168,16 +167,16 @@ std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) { return result; } -string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { +string ForLoop::GetQualifiedName(absl::string_view name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } -llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, +llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b) { return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b); } -std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, +std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode, @@ -186,12 +185,9 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, unroll_mode, prevent_vectorization); } -std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* stride, - UnrollMode unroll_mode, - bool prevent_vectorization) { +std::unique_ptr<ForLoop> ForLoopNest::AddLoop( + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, + llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); @@ -216,7 +212,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -227,7 +223,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); @@ -238,7 +234,7 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index, } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { std::vector<int64> dimensions(ShapeUtil::Rank(shape)); std::iota(dimensions.begin(), dimensions.end(), 0); return AddLoopsForShapeOnDimensions(shape, dimensions, suffix); @@ -246,14 +242,14 @@ IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::StringPiece suffix) { + absl::string_view suffix) { llvm_ir::IrArray::Index index(index_type_, shape.dimensions_size()); for (int64 dimension : dimensions) { std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop( /*start_index=*/0, /*end_index=*/shape.dimensions(dimension), /*suffix=*/ - llvm_ir::IrName(suffix, tensorflow::strings::StrCat(dimension))); + llvm_ir::IrName(suffix, absl::StrCat(dimension))); index[dimension] = loop->GetIndVarValue(); } return index; @@ -261,7 +257,7 @@ IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions( IrArray::Index ForLoopNest::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix) { + absl::string_view name_suffix) { // Prepares the dimension list we will use to emit the loop nest. Outermost // loops are added first. Add loops in major-to-minor order, and skip the // 'dimension_to_skip' dimension. diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index a4fed5c8dc..0a406bd90b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -19,15 +19,15 @@ limitations under the License. #include <memory> #include <string> +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -78,7 +78,7 @@ class ForLoop { // `unroll_mode` specifies the desired LLVM unrolling behavior for generated // loop. static std::unique_ptr<ForLoop> EmitForLoop( - tensorflow::StringPiece prefix, llvm::Value* start_index, + absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -133,19 +133,18 @@ class ForLoop { // Allow ForLoopNest to call this private constructor. friend class ForLoopNest; - ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, + ForLoop(absl::string_view prefix, absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* b); - llvm::BasicBlock* CreateLoopBB(tensorflow::StringPiece name, - llvm::IRBuilder<>* b); + llvm::BasicBlock* CreateLoopBB(absl::string_view name, llvm::IRBuilder<>* b); // Creates a name for an LLVM construct, appending prefix_ and suffix_, if // they are set. - string GetQualifiedName(tensorflow::StringPiece name); + string GetQualifiedName(absl::string_view name); // Return a list of metadata nodes that should be associated with the // llvm::Loop for this `ForLoop`. @@ -182,9 +181,9 @@ class ForLoopNest { SetIndexType(index_ty); } - ForLoopNest(tensorflow::StringPiece name, llvm::IRBuilder<>* b, + ForLoopNest(absl::string_view name, llvm::IRBuilder<>* b, llvm::Type* index_ty = nullptr) - : name_(std::string(name)), + : name_(name), outer_loop_preheader_bb_(nullptr), outer_loop_exit_bb_(nullptr), inner_loop_body_bb_(nullptr), @@ -197,14 +196,14 @@ class ForLoopNest { // been added then emit loop inside the body of the last added loop. // unroll_mode is used to emit metadata that controls LLVM unrolling. std::unique_ptr<ForLoop> AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr<ForLoop> AddLoop( - tensorflow::StringPiece suffix, llvm::Value* start_index, + absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -213,13 +212,13 @@ class ForLoopNest { // end index are constant. std::unique_ptr<ForLoop> AddLoop( int64 start_index, int64 end_index, int64 stride, - tensorflow::StringPiece suffix, + absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr<ForLoop> AddLoop( - int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + int64 start_index, int64 end_index, absl::string_view suffix, UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, bool prevent_vectorization = false); @@ -234,8 +233,7 @@ class ForLoopNest { // within the shape. One possible order for that sequence would be: // // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2) - IrArray::Index AddLoopsForShape(const Shape& shape, - tensorflow::StringPiece suffix); + IrArray::Index AddLoopsForShape(const Shape& shape, absl::string_view suffix); // Add a loop for each dimension in "dimensions". "suffix" is the // name suffix of the indvar and basic blocks in this new loop nest. @@ -245,7 +243,7 @@ class ForLoopNest { // dimension that is not in "dimensions". IrArray::Index AddLoopsForShapeOnDimensions( const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions, - tensorflow::StringPiece suffix); + absl::string_view suffix); // Emits a series of nested loops for iterating over an operand array. Loops // are constructed in major to minor dimension layout order. No loop is @@ -256,7 +254,7 @@ class ForLoopNest { // basic blocks) constructed by this method. IrArray::Index EmitOperandArrayLoopNest(const llvm_ir::IrArray& operand_array, int64 dimension_to_skip, - tensorflow::StringPiece name_suffix); + absl::string_view name_suffix); // Convenience methods which return particular basic blocks of the outermost // or innermost loops. These methods return nullptr if no loops have been diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e6126881af..f0db2a3761 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -19,6 +19,8 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/byte_order.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -61,7 +61,7 @@ string AsString(const std::string& str) { return string(str.data(), str.length()); } -llvm::StringRef AsStringRef(tensorflow::StringPiece str) { +llvm::StringRef AsStringRef(absl::string_view str) { return llvm::StringRef(str.data(), str.size()); } @@ -262,15 +262,17 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment) { return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, b, alignment); } -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment) { +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment) { llvm::IRBuilder<>::InsertPoint insert_point = b->saveIP(); llvm::Function* function = b->GetInsertBlock()->getParent(); b->SetInsertPoint(&function->getEntryBlock(), @@ -285,7 +287,7 @@ llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( } llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b) { return llvm::BasicBlock::Create( /*Context=*/b->getContext(), @@ -294,27 +296,25 @@ llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, /*InsertBefore*/ insert_before); } -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else) { llvm_ir::LlvmIfData if_data; if_data.if_block = b->GetInsertBlock(); if_data.true_block = - CreateBasicBlock(nullptr, tensorflow::strings::StrCat(name, "-true"), b); + CreateBasicBlock(nullptr, absl::StrCat(name, "-true"), b); if_data.false_block = - emit_else ? CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-false"), b) + emit_else ? CreateBasicBlock(nullptr, absl::StrCat(name, "-false"), b) : nullptr; // Add a terminator to the if block, if necessary. if (if_data.if_block->getTerminator() == nullptr) { b->SetInsertPoint(if_data.if_block); - if_data.after_block = CreateBasicBlock( - nullptr, tensorflow::strings::StrCat(name, "-after"), b); + if_data.after_block = + CreateBasicBlock(nullptr, absl::StrCat(name, "-after"), b); b->CreateBr(if_data.after_block); } else { if_data.after_block = if_data.if_block->splitBasicBlock( - b->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + b->GetInsertPoint(), AsStringRef(absl::StrCat(name, "-after"))); } // Our basic block should now end with an unconditional branch. Remove it; @@ -413,14 +413,14 @@ string IrName(string a) { return a; } -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b) { +string IrName(absl::string_view a, absl::string_view b) { if (!a.empty() && !b.empty()) { - return IrName(tensorflow::strings::StrCat(a, ".", b)); + return IrName(absl::StrCat(a, ".", b)); } - return IrName(tensorflow::strings::StrCat(a, b)); + return IrName(absl::StrCat(a, b)); } -string IrName(const HloInstruction* a, tensorflow::StringPiece b) { +string IrName(const HloInstruction* a, absl::string_view b) { return IrName(a->name(), b); } @@ -556,7 +556,7 @@ std::map<int, llvm::MDNode*> MergeMetadata( return result; } -static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { +static string GetProcessUniqueIrFileName(absl::string_view prefix) { static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); @@ -584,18 +584,16 @@ Status DumpIRToDirectory(const string& directory_name, // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously // dumped from the same process in such cases. string unique_and_safe_file_name = GetProcessUniqueIrFileName( - tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", - optimized ? "with" : "no", "-opt")); + absl::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); string ir_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, ".ll")); // For some models the embedded constants can be huge, so also dump the module // with the constants stripped to get IR that is easier to manipulate. string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( - directory_name, - tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + directory_name, absl::StrCat(unique_and_safe_file_name, "-noconst.ll")); TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( directory_name, ir_file_name, DumpModuleToString(llvm_module))); @@ -607,8 +605,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module) { + absl::string_view name, llvm::Module* module) { llvm::Function* function = llvm::Function::Create(function_type, linkage, AsStringRef(name), module); function->setCallingConv(llvm::CallingConv::C); @@ -638,7 +635,7 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { fake_argv_storage.push_back(""); for (const auto& it : options) { // Skip options the XLA backend itself consumes. - if (!tensorflow::str_util::StartsWith(it.first, "xla_")) { + if (!absl::StartsWith(it.first, "xla_")) { if (it.second.empty()) { fake_argv_storage.push_back(it.first); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 0958398534..dde50e19d1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" @@ -47,11 +47,11 @@ namespace llvm_ir { // Convert a std::string (used by LLVM's interfaces) to string. string AsString(const std::string& str); -// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both -// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a +// Convert a absl::string_view to a llvm::StringRef. Note: both +// absl::string_view and llvm::StringRef are non-owning pointers into a // string in memory. This method is used to feed strings to LLVM // & Clang APIs that expect llvm::StringRef. -llvm::StringRef AsStringRef(tensorflow::StringPiece str); +llvm::StringRef AsStringRef(absl::string_view str); template <typename T> llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) { @@ -88,8 +88,8 @@ string DumpModuleToString(const llvm::Module& module); // - removing all '%'s. // string IrName(string a); -string IrName(tensorflow::StringPiece a, tensorflow::StringPiece b); -string IrName(const HloInstruction* a, tensorflow::StringPiece b = ""); +string IrName(absl::string_view a, absl::string_view b); +string IrName(const HloInstruction* a, absl::string_view b = ""); // Removes special characters from a function name. // @@ -164,21 +164,23 @@ llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, // This can be useful to avoid e.g. executing an alloca every time // through a loop. llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b, int alignment = 0); // As EmitAllocaAtFunctionEntry, but allocates element_count entries // instead of a single element. -llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount( - llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name, - llvm::IRBuilder<>* b, int alignment = 0); +llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type, + llvm::Value* element_count, + absl::string_view name, + llvm::IRBuilder<>* b, + int alignment = 0); // Creates a basic block with the same context and function as for the // builder. Inserts at the end of the function if insert_before is // null. llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before, - tensorflow::StringPiece name, + absl::string_view name, llvm::IRBuilder<>* b); // Struct with data on a conditional branch in a diamond shape created @@ -210,7 +212,7 @@ struct LlvmIfData { // Currently the insertion point of the builder must be a well-formed // block with a terminator. If you need to use this for a // non-terminated block, just make the function able to do that too. -LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, +LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name, llvm::IRBuilder<>* b, bool emit_else = true); // Emits a compare operation between "lhs" and "rhs" with the given predicate, @@ -285,8 +287,7 @@ Status DumpIRToDirectory(const string& directory_name, llvm::Function* CreateFunction(llvm::FunctionType* function_type, llvm::GlobalValue::LinkageTypes linkage, bool enable_fast_math, bool optimize_for_size, - tensorflow::StringPiece name, - llvm::Module* module); + absl::string_view name, llvm::Module* module); // Extracts the xla_backend_extra_options from `config` and passes those that // don't start with xla_ to LLVM. diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 36f5fa1952..1553b4fc91 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -18,13 +18,13 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -86,7 +86,7 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, } std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type) { CHECK_NE(index_type, nullptr); if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. @@ -105,7 +105,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( std::unique_ptr<ForLoop> loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + /*suffix=*/absl::StrFormat("dim.%d", dimension)); array_index[dimension] = loop->GetIndVarValue(); } @@ -122,7 +122,7 @@ std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock( return {array_index}; } -Status LoopEmitter::EmitLoop(tensorflow::StringPiece loop_name, +Status LoopEmitter::EmitLoop(absl::string_view loop_name, llvm::Type* index_type) { if (index_type == nullptr) { index_type = b_->getInt64Ty(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index c4f5c82086..57d9d8bbc6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -69,10 +69,10 @@ class LoopEmitter { } virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock( - tensorflow::StringPiece loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type); // Emits a complete loop nest for every element in the given shape. - Status EmitLoop(tensorflow::StringPiece loop_name = "", + Status EmitLoop(absl::string_view loop_name = "", llvm::Type* index_type = nullptr); protected: diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc index e546f5cc4a..00dd3f1638 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" @@ -29,8 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -42,7 +42,7 @@ namespace { void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, const IrArray::Index& compare_keys_index, const IrArray& keys_array, - const tensorflow::gtl::optional<IrArray>& values_array, + const absl::optional<IrArray>& values_array, llvm::IRBuilder<>* b) { // if (is_smaller_index && // compare_keys[dimension_to_sort] < dimension_to_sort_bound) @@ -87,8 +87,8 @@ void EmitCompareLoop(int64 dimension_to_sort, const IrArray::Index& keys_index, } // namespace Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional<IrArray>& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional<IrArray>& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions) { const Shape& keys_shape = keys_array.GetShape(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h index 8458744c6b..527ed10374 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/sort_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/sort_util.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_ +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -31,8 +31,8 @@ namespace llvm_ir { // implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr, // the inner compare loop will not be parallelized. Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array, - const tensorflow::gtl::optional<IrArray>& values_array, - tensorflow::StringPiece name, llvm::Value* xor_mask, + const absl::optional<IrArray>& values_array, + absl::string_view name, llvm::Value* xor_mask, llvm::IRBuilder<>* b, const gpu::LaunchDimensions* launch_dimensions); } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 5e02096ee5..768105d9e1 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,10 +19,12 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/execution_options_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -37,7 +39,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -73,7 +74,7 @@ namespace { // If the parameter number is invalid for this computation, nullopt is // returned. When the return value has_value(), nullptr will never be // the held value. -tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata( +absl::optional<const OpMetadata*> ParameterMetadata( const XlaComputation& computation, int parameter_number) { for (const HloComputationProto& comp : computation.proto().computations()) { if (comp.id() == computation.proto().entry_computation_id()) { @@ -81,14 +82,14 @@ tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata( if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) && instr.parameter_number() == parameter_number) { if (!instr.has_metadata()) { - return tensorflow::gtl::nullopt; + return absl::nullopt; } return &instr.metadata(); } } } } - return tensorflow::gtl::nullopt; + return absl::nullopt; } ExecutionOptions CreateExecutionOptions( @@ -149,7 +150,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", + "Invalid number of arguments for computation: expected %d, got %u.", program_shape.parameters_size(), argument_layouts.size()); } @@ -158,7 +159,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) { - tensorflow::gtl::optional<const OpMetadata*> metadata = + absl::optional<const OpMetadata*> metadata = ParameterMetadata(computation, /*parameter_number=*/i); auto metadata_string = [&metadata]() -> string { if (!metadata.has_value()) { @@ -167,16 +168,15 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( CHECK(metadata.value() != nullptr); const OpMetadata& m = *metadata.value(); if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); + return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line()); } return ""; }; return InvalidArgument( "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); + metadata_string(), + ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(argument_shape)); } } if (build_options.result_layout() != nullptr) { @@ -214,7 +214,7 @@ StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer( TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); if (replica_number >= buffers.size()) { return InvalidArgument( - "replica_number %d out of range; must be less than num_replicas = %zu.", + "replica_number %d out of range; must be less than num_replicas = %u.", replica_number, buffers.size()); } return buffers[replica_number]; diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc index c742d35a7b..e1f56727bd 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.cc +++ b/tensorflow/compiler/xla/service/logical_buffer.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -34,11 +34,10 @@ LogicalBuffer::~LogicalBuffer() {} string LogicalBuffer::ToString() const { string color_string; if (has_color()) { - color_string = tensorflow::strings::StrCat(" @", color().value()); + color_string = absl::StrCat(" @", color().value()); } - return tensorflow::strings::StrCat(instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "](#", id(), color_string, ")"); + return absl::StrCat(instruction_->name(), "[", absl::StrJoin(index_, ","), + "](#", id(), color_string, ")"); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index d631fb5ee4..eaa09591b7 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -17,6 +17,7 @@ limitations under the License. #include <utility> +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" @@ -89,7 +90,7 @@ void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index) { CHECK_EQ(logical_buffers_.size(), next_buffer_id_); logical_buffers_.emplace_back( - MakeUnique<LogicalBuffer>(instruction, index, next_buffer_id_)); + absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_)); output_buffers_[std::make_pair(instruction, index)] = logical_buffers_.back().get(); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index 0019cd7254..4c8cb7d379 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,10 +19,10 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -48,9 +48,7 @@ class MultiOutputFusion : public HloPassInterface { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} - tensorflow::StringPiece name() const override { - return "multi_output_fusion"; - } + absl::string_view name() const override { return "multi_output_fusion"; } // Run multi-output fusion on the given module. Returns whether the module // was changed. @@ -104,17 +102,17 @@ class MultiOutputFusion : public HloPassInterface { // InstructionFusion instead. virtual bool DoProducerConsumerMultiOutputFusion(); - private: - // Update the internal data structures after instr1 and instr2 are fused into - // one fusion instruction. - void Update(HloInstruction* instr1, HloInstruction* instr2); - // Optimization fuel is a compiler debugging technique that makes an // optimization pass stop what it is doing after having made N changes to the // program, where N is the fuel. By varying N, this can be used to find the // first single change that makes a test fail. int64 fuel_; + private: + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + // Computation for the pass. HloComputation* computation_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index f6e7578a89..bd8fb17a23 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/name_uniquer.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -52,8 +53,8 @@ NameUniquer::NameUniquer(const string& separator) { return result; } -string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { - string root = GetSanitizedName(prefix.empty() ? "name" : std::string(prefix)); +string NameUniquer::GetUniqueName(absl::string_view prefix) { + string root = GetSanitizedName(prefix.empty() ? "name" : string(prefix)); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. @@ -63,20 +64,22 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { if (separator_index != string::npos && (separator_index > 0) && (separator_index < root.size() - 1)) { string after_suffix = root.substr(separator_index + 1); - if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. root = root.substr(0, separator_index); + } else { + // absl::SimpleAtoi may modify numeric_suffix even if it returns false. + numeric_suffix = 0; } } SequentialIdGenerator& id_generator = generated_names_[root]; numeric_suffix = id_generator.RegisterId(numeric_suffix); if (numeric_suffix == 0) { - return has_numeric_suffix ? tensorflow::strings::StrCat(root, separator_, 0) - : root; + return has_numeric_suffix ? absl::StrCat(root, separator_, 0) : root; } - tensorflow::strings::StrAppend(&root, separator_, numeric_suffix); + absl::StrAppend(&root, separator_, numeric_suffix); return root; } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 4423d61069..6dd89c240f 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,8 +18,8 @@ limitations under the License. #include <string> +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" @@ -38,7 +38,7 @@ class NameUniquer { // Get a sanitized unique name in a string, with an optional prefix for // convenience. - string GetUniqueName(tensorflow::StringPiece prefix = ""); + string GetUniqueName(absl::string_view prefix = ""); // Sanitizes and returns the name. Unallowed characters will be replaced with // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index ac6ea4c72f..4869db79e7 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace xla { @@ -622,7 +622,7 @@ template <typename Previous> class HloInstructionPatternNameImpl { public: explicit HloInstructionPatternNameImpl(const Previous& previous, - tensorflow::StringPiece name) + absl::string_view name) : previous_(previous), name_(name) {} bool Match(const ::xla::HloInstruction* inst) const { @@ -631,7 +631,7 @@ class HloInstructionPatternNameImpl { private: Previous previous_; - tensorflow::StringPiece name_; + absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction @@ -784,7 +784,7 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction has the given name. HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>> - WithName(tensorflow::StringPiece name) const { + WithName(absl::string_view name) const { return HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>>( HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_); @@ -918,6 +918,7 @@ Op(::xla::HloInstruction** matched_inst) { } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) +XLA_NULLOP_PATTERN(Iota) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 39fe3c7835..ae1e13d8a6 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -19,20 +19,19 @@ limitations under the License. #include <string> #include <utility> +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { -using tensorflow::str_util::Lowercase; - // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; @@ -43,7 +42,7 @@ constexpr char kInterpreter[] = "interpreter"; namespace { string CanonicalPlatformName(const string& name) { - string platform_str = Lowercase(name); + string platform_str = absl::AsciiStrToLower(name); // "cpu" and "host" mean the same thing. if (platform_str == "cpu") { platform_str = "host"; @@ -94,12 +93,12 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", - platforms_string.c_str()); + platforms_string); } /* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() { @@ -110,21 +109,21 @@ PlatformUtil::GetSupportedPlatforms() { return platforms[0]; } else if (platforms.size() == 2) { for (int i = 0; i < 2; i++) { - if (Lowercase(platforms[i]->Name()) == kInterpreter && - Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + if (absl::AsciiStrToLower(platforms[i]->Name()) == kInterpreter && + absl::AsciiStrToLower(platforms[1 - i]->Name()) != kInterpreter) { return platforms[1 - i]; } } } // Multiple platforms present and we can't pick a reasonable default. - string platforms_string = tensorflow::str_util::Join( + string platforms_string = absl::StrJoin( platforms, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform (except for the " "interpreter platform) found: %s", - platforms_string.c_str()); + platforms_string); } /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatform( @@ -132,11 +131,11 @@ PlatformUtil::GetSupportedPlatforms() { string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) == platform_str) { + if (absl::AsciiStrToLower(platform->Name()) == platform_str) { return platform; } } - return InvalidArgument("platform %s not found", platform_name.c_str()); + return InvalidArgument("platform %s not found", platform_name); } /*static*/ StatusOr<se::Platform*> PlatformUtil::GetPlatformExceptFor( @@ -146,23 +145,23 @@ PlatformUtil::GetSupportedPlatforms() { TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); std::vector<se::Platform*> matched; for (se::Platform* platform : platforms) { - if (Lowercase(platform->Name()) != platform_name) { + if (absl::AsciiStrToLower(platform->Name()) != platform_name) { matched.push_back(platform); } } if (matched.empty()) { return InvalidArgument("unable to find platform that is not %s", - platform_name.c_str()); + platform_name); } if (matched.size() == 1) { return matched[0]; } - string matched_string = tensorflow::str_util::Join( + string matched_string = absl::StrJoin( matched, ", ", [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "found multiple platforms %s, but expected one platform except for %s", - matched_string.c_str(), platform_name.c_str()); + matched_string, platform_name); } // Returns whether the device underlying the given StreamExecutor is supported @@ -193,7 +192,7 @@ static bool IsDeviceSupported(se::StreamExecutor* executor) { PlatformUtil::GetStreamExecutors(se::Platform* platform) { int device_count = platform->VisibleDeviceCount(); if (device_count <= 0) { - return NotFound("no %s devices found", platform->Name().c_str()); + return NotFound("no %s devices found", platform->Name()); } if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware @@ -232,7 +231,7 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", - platform->Name().c_str()); + platform->Name()); } return stream_executors; } diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index afde3cf95c..256b231e3a 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -59,7 +59,7 @@ class ReducePrecisionInsertion : public HloPassInterface { ~ReducePrecisionInsertion() override{}; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "reduce-precision-insertion"; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index ca86c5d13e..4df746fca9 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -38,6 +38,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include <algorithm> + +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -374,7 +376,7 @@ StatusOr<bool> TryReshapeMoveOnCandidates( removed = false; for (auto operand : nontrivial_operands) { - if (c_any_of(operand->users(), [&](HloInstruction* user) { + if (absl::c_any_of(operand->users(), [&](HloInstruction* user) { return !reshape_candidates->count(user); })) { for (auto* user : operand->users()) { diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1f59e3b314..1e86a0823a 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-mover"; } + absl::string_view name() const override { return "reshape-mover"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index ccb9fb3e3a..a395dd5333 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" @@ -28,13 +28,18 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/str_util.h" - -namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloVerifiedTestBase; + +namespace op = xla::testing::opcode_matchers; + +class ReshapeMoverTest : public HloVerifiedTestBase { + public: + ReshapeMoverTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} +}; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 45ca731153..2077b57c05 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/scatter_expander.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" @@ -92,7 +93,7 @@ static StatusOr<HloInstruction*> PermuteScatterAndWindowDims( permutation.reserve(updates_rank); for (int64 i = 0; i < updates_rank; ++i) { - bool is_scatter_dim = !c_binary_search(update_window_dims, i); + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); if (is_scatter_dim) { permutation.push_back(i); } @@ -290,7 +291,7 @@ StatusOr<HloInstruction*> ScatterExpander::ExpandScatter( return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " "supported. This error occurred for %s.", - scatter->ToString().c_str()); + scatter->ToString()); } // Canonicalize the scatter_indices, after which the size of its most-major diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 8f735e877d..14f062c89c 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -22,7 +22,7 @@ namespace xla { class ScatterExpander : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "scatter_expander"; } + absl::string_view name() const override { return "scatter_expander"; } StatusOr<bool> Run(HloModule* module) override; private: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 1dbf540d13..e10c1d9927 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -20,10 +20,12 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -46,8 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -55,13 +55,12 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/ptr_util.h" -using ::tensorflow::strings::Printf; -using ::tensorflow::strings::StrCat; - namespace xla { - namespace { +using absl::StrCat; +using absl::StrFormat; + // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, @@ -148,19 +147,19 @@ Service::Service(const ServiceOptions& options, CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) << "Requested more replicas than there are devices."; } - LOG(INFO) << Printf( + LOG(INFO) << StrFormat( "XLA service %p executing computations on platform %s. Devices:", this, - execute_backend_->platform()->Name().c_str()); + execute_backend_->platform()->Name()); for (int i = 0; i < execute_backend_->device_count(); ++i) { if (execute_backend_->device_ordinal_supported(i)) { se::StreamExecutor* executor = execute_backend_->stream_executor(i).ValueOrDie(); const auto& description = executor->GetDeviceDescription(); - LOG(INFO) << Printf(" StreamExecutor device (%d): %s, %s", i, - description.name().c_str(), - description.platform_version().c_str()); + LOG(INFO) << StrFormat(" StreamExecutor device (%d): %s, %s", i, + description.name(), + description.platform_version()); } else { - LOG(INFO) << Printf(" StreamExecutor device (%d) not supported", i); + LOG(INFO) << StrFormat(" StreamExecutor device (%d) not supported", i); } } } else { @@ -200,8 +199,8 @@ Status Service::ValidateResultShape(const Shape& client_shape, return InvalidArgument( "Shape used to set computation result layout %s is not compatible " "with result shape %s", - ShapeUtil::HumanStringWithLayout(client_shape).c_str(), - ShapeUtil::HumanString(result_shape).c_str()); + ShapeUtil::HumanStringWithLayout(client_shape), + ShapeUtil::HumanString(result_shape)); } return Status::OK(); } @@ -231,9 +230,9 @@ Service::ResolveAndValidateArguments( return InvalidArgument( "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, shaped_buffer->platform()->Name().c_str(), + i, shaped_buffer->platform()->Name(), shaped_buffer->device_ordinal(), - execute_backend_->device_name(replica_device_ordinal).c_str()); + execute_backend_->device_name(replica_device_ordinal)); } replicated_arguments[replica].push_back(shaped_buffer); } @@ -245,11 +244,11 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice<const Shape*> argument_shapes, const ExecutionOptions* execution_options) { - auto config = MakeUnique<HloModuleConfig>(program_shape); + auto config = absl::make_unique<HloModuleConfig>(program_shape); ComputationLayout* computation_layout = config->mutable_entry_computation_layout(); if (program_shape.parameters_size() != argument_shapes.size()) { - return InvalidArgument("computation takes %d parameters, but %zu given", + return InvalidArgument("computation takes %d parameters, but %u given", program_shape.parameters_size(), argument_shapes.size()); } @@ -261,8 +260,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig( return InvalidArgument( "Argument does not match shape of computation parameter %d: want " "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); + i, ShapeUtil::HumanString(program_shape.parameters(i)), + ShapeUtil::HumanString(*argument_shapes[i])); } TF_RETURN_IF_ERROR( computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( @@ -314,7 +313,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( std::vector<std::unique_ptr<HloModuleConfig>> module_configs, Backend* backend, std::vector<std::vector<se::StreamExecutor*>> executors, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); + VLOG(1) << StrFormat("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. std::vector<std::unique_ptr<HloSnapshot>> hlo_snapshots; @@ -326,12 +325,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables( if (directory_path.empty() && execution_directory_path.empty()) { continue; } - auto hlo_snapshot = MakeUnique<HloSnapshot>(); + auto hlo_snapshot = absl::make_unique<HloSnapshot>(); *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = *module_protos[i]; if (!directory_path.empty()) { - string filename = - Printf("computation_%lld__%s", module_protos[i]->id(), - module_protos[i]->entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_protos[i]->id(), + module_protos[i]->entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -409,7 +407,8 @@ Service::ExecuteParallelAndRegisterResult( streams.push_back(std::move(stream)); if (replica == 0 && profile != nullptr) { - timers.push_back(MakeUnique<se::Timer>(streams.back()->parent())); + timers.push_back( + absl::make_unique<se::Timer>(streams.back()->parent())); streams.back() ->InitTimer(timers.back().get()) .ThenStartTimer(timers.back().get()); @@ -453,8 +452,8 @@ Service::ExecuteParallelAndRegisterResult( for (int64 i = 0; i < streams.size(); ++i) { Status block_status = streams[i]->BlockHostUntilDone(); if (!block_status.ok()) { - return InternalError("failed to complete execution for stream %lld: %s", - i, block_status.error_message().c_str()); + return InternalError("failed to complete execution for stream %d: %s", i, + block_status.error_message()); } } @@ -579,7 +578,7 @@ StatusOr<std::vector<se::StreamExecutor*>> Service::GetExecutors( if (requests_size > 1 && execution_options.device_handles_size() > 1) { return InvalidArgument( "Parallel requests with multiple device handles is not supported. " - "Found %lld parallel requests, with request %lld containing %d device " + "Found %d parallel requests, with request %d containing %d device " "handles.", requests_size, request_index, execution_options.device_handles_size()); } @@ -744,8 +743,8 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, } if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( - "Requested device count (%lld) exceeds the number of available devices " - "on the target (%lld)", + "Requested device count (%d) exceeds the number of available devices " + "on the target (%d)", arg->device_count(), available_device_count); } @@ -795,12 +794,12 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr<HloModuleConfig> module_config, Backend* backend, se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf( + VLOG(1) << StrFormat( "BuildExecutable on service %p with serialized module proto: %s", this, - module_proto.name().c_str()); + module_proto.name()); // Dump computation proto state if flag is set. - auto hlo_snapshot = MakeUnique<HloSnapshot>(); + auto hlo_snapshot = absl::make_unique<HloSnapshot>(); const string& directory_path = module_config->debug_options().xla_dump_computations_to(); const string& execution_directory_path = @@ -808,8 +807,8 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable( if (!directory_path.empty() || !execution_directory_path.empty()) { *hlo_snapshot->mutable_hlo()->mutable_hlo_module() = module_proto; if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s", module_proto.id(), - module_proto.entry_computation_name().c_str()); + string filename = StrFormat("computation_%d__%s", module_proto.id(), + module_proto.entry_computation_name()); TF_RETURN_IF_ERROR( Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot)); } @@ -954,7 +953,7 @@ namespace { // shape and DeviceMemoryBase values of the clone are identical to the original. std::unique_ptr<ShapedBuffer> CloneShapedBufferOnDevice( const ShapedBuffer& shaped_buffer, int device_ordinal) { - auto clone = MakeUnique<ShapedBuffer>( + auto clone = absl::make_unique<ShapedBuffer>( shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), shaped_buffer.platform(), device_ordinal); clone->buffers() = shaped_buffer.buffers(); @@ -1009,8 +1008,7 @@ Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, "%s", StrCat("The replica_id=", arg->replica_id(), " on TransferToInfeedRequest not in range [0, replica_count=", - replica_count, ").") - .c_str()); + replica_count, ").")); } se::StreamExecutor* executor; @@ -1036,8 +1034,7 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg, const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( - "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " - "%lld)", + "The replica_id=%d on TransferFromOutfeedRequest not in range [0, %d)", arg->replica_id(), replica_count); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index cc1ec1704e..f5217c5a11 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -21,6 +21,11 @@ limitations under the License. #include <set> #include <string> +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -28,32 +33,26 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" -using tensorflow::str_util::Join; -using tensorflow::strings::Printf; - namespace xla { - namespace { +using absl::StrFormat; +using absl::StrJoin; + // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) { return std::set<int64>(slice.begin(), slice.end()).size() == slice.size(); } -Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) { +Status ExpectArray(const Shape& shape, absl::string_view op_type) { if (!ShapeUtil::IsArray(shape)) { return InvalidArgument("Expected array argument for %s, but got %s.", - std::string(op_type).c_str(), - ShapeUtil::HumanString(shape).c_str()); + string(op_type), ShapeUtil::HumanString(shape)); } return Status::OK(); } @@ -65,7 +64,7 @@ Status VerifyReducerShape( int64 inputs) { if (reducer_shape.parameters_size() != inputs * 2) { return InvalidArgument( - "Reduction function must take %lld parameters, but " + "Reduction function must take %d parameters, but " "takes %d parameter(s).", inputs * 2, reducer_shape.parameters_size()); } @@ -75,7 +74,7 @@ Status VerifyReducerShape( if (ShapeUtil::IsArray(accumulator_shape)) { if (inputs != 1) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but " + "Reduction function must produce a tuple with %d elements, but " "produces a scalar", inputs); } @@ -83,8 +82,8 @@ Status VerifyReducerShape( } else if (ShapeUtil::IsTuple(accumulator_shape)) { if (ShapeUtil::TupleElementCount(accumulator_shape) != inputs) { return InvalidArgument( - "Reduction function must produce a tuple with %lld elements, but has " - "%lld elements", + "Reduction function must produce a tuple with %d elements, but has " + "%d elements", inputs, ShapeUtil::TupleElementCount(accumulator_shape)); } for (const Shape& element_shape : accumulator_shape.tuple_shapes()) { @@ -94,7 +93,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must produce a scalar or tuple of scalars, but has " "shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } for (const Shape* element_shape : accumulator_subshapes) { @@ -102,7 +101,7 @@ Status VerifyReducerShape( return InvalidArgument( "Reduction function must return a scalar or tuple of scalars but " "returns shape: %s", - ShapeUtil::HumanString(accumulator_shape).c_str()); + ShapeUtil::HumanString(accumulator_shape)); } } @@ -113,19 +112,19 @@ Status VerifyReducerShape( if (!ShapeUtil::Compatible(*accumulator_subshapes[i], reducer_shape.parameters(i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "result shape: %s vs %s", - i, ShapeUtil::HumanString(reducer_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + i, ShapeUtil::HumanString(reducer_shape.parameters(i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } // Check that init_value's shapes are suitable for reducer_shape. if (!ShapeUtil::CompatibleIgnoringFpPrecision(*accumulator_subshapes[i], *init_value_shapes[i])) { return InvalidArgument( - "Reduction function's accumulator shape at index %lld differs from " + "Reduction function's accumulator shape at index %d differs from " "the init_value shape: %s vs %s", - i, ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str(), - ShapeUtil::HumanString(*init_value_shapes[i]).c_str()); + i, ShapeUtil::HumanString(*accumulator_subshapes[i]), + ShapeUtil::HumanString(*init_value_shapes[i])); } // Check that the inputs can be passed in as the non-accumulator arguments. const Shape input_element_shape = @@ -133,11 +132,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( input_element_shape, reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape differs from the " + "Reduction function's %d-th parameter shape differs from the " "input type element type: %s vs %s", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(input_element_shape).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(input_element_shape)); } // Check that the accumulator and inputs to the reducer function match. // If the accumulator is scalar, it must have the same type as the inputs @@ -147,11 +146,11 @@ Status VerifyReducerShape( if (!ShapeUtil::CompatibleIgnoringFpPrecision( *accumulator_subshapes[i], reducer_shape.parameters(inputs + i))) { return InvalidArgument( - "Reduction function's %lld-th parameter shape must " + "Reduction function's %d-th parameter shape must " "match the result shape, but got %s vs %s.", inputs + i, - ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)).c_str(), - ShapeUtil::HumanString(*accumulator_subshapes[i]).c_str()); + ShapeUtil::HumanString(reducer_shape.parameters(inputs + i)), + ShapeUtil::HumanString(*accumulator_subshapes[i])); } } @@ -164,7 +163,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, bool allow_negative_padding) { if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) { return InvalidArgument( - "Window has dimension %d but base shape has dimension %lld.", + "Window has dimension %d but base shape has dimension %d.", window.dimensions_size(), ShapeUtil::Rank(base_shape)); } @@ -173,29 +172,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, const auto& dim = window.dimensions(i); if (dim.size() <= 0) { return InvalidArgument("Window %s has a non-positive dimension.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.stride() <= 0) { return InvalidArgument("Window %s has a non-positive stride.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_low() < 0) { return InvalidArgument("Window %s has a negative low padding.", - window.DebugString().c_str()); + window.DebugString()); } if (!allow_negative_padding && dim.padding_high() < 0) { return InvalidArgument("Window %s has a negative high padding.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.base_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive base area dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } if (dim.window_dilation() < 1) { return InvalidArgument( "Window %s has a non-positive window dilation factor.", - window.DebugString().c_str()); + window.DebugString()); } const int64 dilated_base = window_util::DilatedBound( @@ -233,11 +232,12 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, switch (opcode) { case HloOpcode::kFloor: case HloOpcode::kCeil: + case HloOpcode::kRoundNearestAfz: if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating for floor/ceil " - "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating for %s operation; " + "got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kCos: @@ -250,9 +250,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, if (!ShapeUtil::ElementIsFloating(shape) && !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( - "Expected element type in shape to be floating or complex for " - "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "Expected element type in shape to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; case HloOpcode::kReal: @@ -264,19 +264,47 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, } else { return InvalidArgument( "Expected element type in shape to be floating or complex for " - "real/imag operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } case HloOpcode::kAbs: if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( shape, primitive_util::ComplexComponentType(shape.element_type())); + } else if (ShapeUtil::ElementIsSigned(shape)) { + return shape; + } else { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } - return shape; case HloOpcode::kClz: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to Clz " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kNegate: - case HloOpcode::kRoundNearestAfz: + if (!ShapeUtil::ElementIsIntegral(shape) && + !ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be integral, floating or " + "complex for %s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: + if (!ShapeUtil::ElementIsSigned(shape) && + !ShapeUtil::ElementIsComplex(shape)) { + return InvalidArgument( + "Expected element type in shape to be signed or complex for " + "%s operation; got %s.", + HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); + } return shape; case HloOpcode::kNot: @@ -285,7 +313,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return shape; @@ -295,14 +323,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, "Expected element type in shape to be floating " "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(shape.element_type()).c_str()); + PrimitiveType_Name(shape.element_type())); } return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - HloOpcodeString(opcode).c_str()); + HloOpcodeString(opcode)); } } @@ -313,7 +341,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return InvalidArgument("Concatenate expects at least one argument."); } if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) { - return InvalidArgument("Concatenate dimension out of bounds: %lld.", + return InvalidArgument("Concatenate dimension out of bounds: %d.", dimension); } const Shape* arg_shape = nullptr; @@ -327,17 +355,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { return InvalidArgument( - "Cannot concatenate arrays with different ranks: %lld (%s) vs %lld " + "Cannot concatenate arrays with different ranks: %d (%s) vs %d " "(%s).", - ShapeUtil::Rank(*arg_shape), - ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), - ShapeUtil::HumanString(*shape).c_str()); + ShapeUtil::Rank(*arg_shape), ShapeUtil::HumanString(*arg_shape), + ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "Cannot concatenate arrays with different element types: %s vs %s.", - PrimitiveType_Name(arg_shape->element_type()).c_str(), - PrimitiveType_Name(shape->element_type()).c_str()); + PrimitiveType_Name(arg_shape->element_type()), + PrimitiveType_Name(shape->element_type())); } for (int64 dimension_number = 0; dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) { @@ -350,9 +377,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Cannot concatenate arrays that differ in dimensions other than " "the one being concatenated (the other array dimensions must be " - "the same): %s vs %s in dimension %lld.", - ShapeUtil::HumanString(*arg_shape).c_str(), - ShapeUtil::HumanString(*shape).c_str(), dimension); + "the same): %s vs %s in dimension %d.", + ShapeUtil::HumanString(*arg_shape), ShapeUtil::HumanString(*shape), + dimension); } } element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); @@ -384,8 +411,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, !primitive_util::IsComplexType(new_element_type)) { return Unimplemented( "Conversion from complex to real type %s => %s is not implemented.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -394,8 +421,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Convert does not allow non-arrays, so cannot convert from %s to %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -407,8 +434,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, if (primitive_util::IsComplexType(old_element_type) != primitive_util::IsComplexType(new_element_type)) { return InvalidArgument("Conversion from complex to real type %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (!ShapeUtil::IsArray(operand_shape) || !primitive_util::IsArrayType(new_element_type)) { @@ -417,15 +444,15 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, // are valid. For now we just reject them, though. return InvalidArgument( "Cannot convert from or to tuple type; requested conversion: %s => %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + ShapeUtil::HumanString(operand_shape), + PrimitiveType_Name(new_element_type)); } if (primitive_util::BitWidth(old_element_type) != primitive_util::BitWidth(new_element_type)) { return InvalidArgument( "Cannot bitcast types with different bit-widths: %s => %s.", - PrimitiveType_Name(old_element_type).c_str(), - PrimitiveType_Name(new_element_type).c_str()); + PrimitiveType_Name(old_element_type), + PrimitiveType_Name(new_element_type)); } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); @@ -438,7 +465,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "Expected element type in shape to be floating point for " "ReducePrecision operation; got %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (exponent_bits < 1) { // One exponent bit is necessary to distinguish 0 from infinity. Having @@ -470,21 +497,29 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return InvalidArgument( "The rank of the operand and the padding configuration do not match: " "%s vs %s.", - ShapeUtil::HumanString(operand_shape).c_str(), - padding_config.ShortDebugString().c_str()); + ShapeUtil::HumanString(operand_shape), + padding_config.ShortDebugString()); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, padding_value_shape)) { return InvalidArgument( "The element types of the operands to Pad do not match."); } + if (absl::c_any_of(padding_config.dimensions(), + [](const PaddingConfig::PaddingConfigDimension& p) { + return p.interior_padding() < 0; + })) { + return InvalidArgument("Interior padding cannot be negative: %s", + padding_config.ShortDebugString()); + } + std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape)); for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) { - dimensions[i] = operand_shape.dimensions(i) + - padding_config.dimensions(i).edge_padding_low() + - padding_config.dimensions(i).edge_padding_high() + + const auto& p = padding_config.dimensions(i); + dimensions[i] = operand_shape.dimensions(i) + p.edge_padding_low() + + p.edge_padding_high() + std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) * - padding_config.dimensions(i).interior_padding(); + p.interior_padding(); } return ShapeUtil::MakeShape( ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), @@ -538,7 +573,7 @@ Status ValidateDotDimensionNumbers( !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is out of range in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that dimension numbers are unique. @@ -556,7 +591,7 @@ Status ValidateDotDimensionNumbers( if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { return InvalidArgument("A dimension number is not unique in Dot: %s.", - dimension_numbers.DebugString().c_str()); + dimension_numbers.DebugString()); } // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. @@ -601,14 +636,13 @@ Status ValidateDotDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot")); auto fail = [lhs, rhs](const string& addendum) -> Status { - string message = tensorflow::strings::Printf( - "Cannot infer shape for dot operation: %s <dot> %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + string message = + StrFormat("Cannot infer shape for dot operation: %s <dot> %s.", + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); if (!addendum.empty()) { message += " " + addendum; } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); }; // Check if both element types are the same. @@ -704,9 +738,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - HloOpcodeString(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } } return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), @@ -721,14 +754,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // the user to provide an explicit broadcast dimension in this case. // See b/25177275 for more details. return InvalidArgument("Automatic shape inference not supported: %s and %s", - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) { return InvalidArgument( "Size of broadcast_dimensions has to match lower-rank operand's " "rank; " - " lower-rank operand's rank is %lld, size of broadcast_dimensions is " - "%zu.", + " lower-rank operand's rank is %d, size of broadcast_dimensions is " + "%u.", ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size()); } @@ -778,12 +811,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 dimension_to_match = broadcast_dimensions.at(i); if (dimension_to_match < 0) { return InvalidArgument( - "Broadcast dimension number (%lld) cannot be negative.", + "Broadcast dimension number (%d) cannot be negative.", dimension_to_match); } if (dimension_to_match >= larger_shape.dimensions_size()) { return InvalidArgument( - "Broadcast dimension number (%lld) too large; higher-rank " + "Broadcast dimension number (%d) too large; higher-rank " "operand has rank %d.", dimension_to_match, larger_shape.dimensions_size()); } @@ -795,16 +828,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (small_dimension_size != large_dimension_size && small_dimension_size != 1 && large_dimension_size != 1) { return InvalidArgument( - "Broadcast dimension %d mismatch: %lld != %lld; %s and %s.", i, + "Broadcast dimension %d mismatch: %d != %d; %s and %s.", i, small_dimension_size, large_dimension_size, - ShapeUtil::HumanString(smaller_shape).c_str(), - ShapeUtil::HumanString(larger_shape).c_str()); + ShapeUtil::HumanString(smaller_shape), + ShapeUtil::HumanString(larger_shape)); } // Make sure the broadcast dimensions are listed in a strictly increasing // order. if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) { return InvalidArgument( - "Broadcast dimensions order is wrong: %lld comes after %lld.", + "Broadcast dimensions order is wrong: %d comes after %d.", dimension_to_match, broadcast_dimensions.at(i - 1)); } @@ -823,8 +856,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + HloOpcodeString(operation), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -874,20 +907,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), - Join(broadcast_dimensions, ", ").c_str()); + HloOpcodeString(opcode), ShapeUtil::HumanString(lhs), + ShapeUtil::HumanString(rhs), StrJoin(broadcast_dimensions, ", ")); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR( - ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ", - HloOpcodeString(opcode)))); - TF_RETURN_IF_ERROR( - ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ", - HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + lhs, absl::StrCat("lhs of binary operation ", HloOpcodeString(opcode)))); + TF_RETURN_IF_ERROR(ExpectArray( + rhs, absl::StrCat("rhs of binary operation ", HloOpcodeString(opcode)))); switch (opcode) { case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -909,7 +939,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected element type in shape to be floating for complex compose " "operation; got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(opcode, lhs, rhs, @@ -928,7 +958,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected pred or integral type in argument to and/or operation; " "got %s.", - PrimitiveType_Name(lhs.element_type()).c_str()); + PrimitiveType_Name(lhs.element_type())); } return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); @@ -946,8 +976,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), - rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode), lhs.ShortDebugString(), + rhs.ShortDebugString()); } } @@ -970,8 +1000,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kTupleSelect: return InferTupleSelectShape(lhs, rhs, ehs); default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } @@ -1010,8 +1039,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Sort keys and values dimensions must match. " "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), - ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), + ShapeUtil::HumanString(*operand_shapes[1])); } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); @@ -1019,8 +1048,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument("Unexpected number of operands for sort"); } default: - return InvalidArgument("Unknown operation %s.", - HloOpcodeString(opcode).c_str()); + return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode)); } } @@ -1058,7 +1086,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s.", - Join(pieces, ", ").c_str()); + StrJoin(pieces, ", ")); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -1066,7 +1094,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions.size() != arg_shape->dimensions_size()) { return InvalidArgument( "Map applied to a subset of dimensions currently not supported: " - "arg_dimension_size: %d, requested_map_dimensions_size: %zu.", + "arg_dimension_size: %d, requested_map_dimensions_size: %u.", arg_shape->dimensions_size(), dimensions.size()); } @@ -1075,7 +1103,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers; got: %s.", - Join(dimensions, ", ").c_str()); + StrJoin(dimensions, ", ")); } } @@ -1083,7 +1111,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( "Map applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu.", + "arity: %d, arguments: %u.", to_apply.parameters_size(), arg_shapes.size()); } @@ -1092,7 +1120,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsScalar(output_shape)) { return InvalidArgument( "Mapped computation's result has to be a scalar; got: %s.", - ShapeUtil::HumanString(output_shape).c_str()); + ShapeUtil::HumanString(output_shape)); } for (int i = 0; i < to_apply.parameters_size(); ++i) { @@ -1102,7 +1130,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter has to be a scalar; " "got parameter %d shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape)); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, @@ -1110,8 +1138,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s.", - i, ShapeUtil::HumanString(parameter_shape).c_str(), - ShapeUtil::HumanString(*arg_shape).c_str()); + i, ShapeUtil::HumanString(parameter_shape), + ShapeUtil::HumanString(*arg_shape)); } } @@ -1140,35 +1168,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-training to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-training to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-training to be at least 1; got %lld.", + "batch-norm-training to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-training must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1176,7 +1204,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-training must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1185,8 +1213,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1195,8 +1223,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1206,16 +1234,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } @@ -1250,35 +1278,35 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-inference to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (feature_index < 0) { return InvalidArgument( "Expected feature_index of batch-norm-inference to " - "be a non-negative number, got %lld.", + "be a non-negative number, got %d.", feature_index); } if (ShapeUtil::Rank(operand_shape) < 1) { return InvalidArgument( "Expected the rank of operand to " - "batch-norm-inference to be at least 1; got %lld.", + "batch-norm-inference to be at least 1; got %d.", ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(offset_shape) != 1) { return InvalidArgument( "Offset input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(offset_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-inference must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } @@ -1286,7 +1314,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-inference must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, @@ -1296,8 +1324,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of offset factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(offset_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(offset_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1307,8 +1335,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of scale factor is %s " "and the shape of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1318,8 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of mean is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, @@ -1329,8 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "batch-norm-inference, " "but the shape of variance is %s " "and the shape of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(variance_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(variance_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1340,32 +1368,32 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { return InvalidArgument( "The size of offset factor should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(offset_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of mean is %lld " - "and the feature count is %lld.", + "but the size of mean is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(variance_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(variance_shape, 0), feature_count); } @@ -1395,36 +1423,36 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Expected feature_index of batch-norm-grad to be " "smaller than the rank of operand_shape; " - "got feature_index %lld, and rank %lld.", + "got feature_index %d, and rank %d.", feature_index, ShapeUtil::Rank(operand_shape)); } if (ShapeUtil::Rank(operand_shape) != ShapeUtil::Rank(output_grad_shape)) { return InvalidArgument( "Expected operand_shape of batch-norm-grad to have the same rank as" - " output_grad_shape; got rank(oprand_shape) %lld, and" - " rank(output_grad_shape) %lld.", + " output_grad_shape; got rank(oprand_shape) %d, and" + " rank(output_grad_shape) %d.", ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(output_grad_shape)); } if (ShapeUtil::Rank(mean_shape) != 1) { return InvalidArgument( "Mean input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(mean_shape)); } if (ShapeUtil::Rank(scale_shape) != 1) { return InvalidArgument( "Scale input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(scale_shape)); } if (ShapeUtil::Rank(var_shape) != 1) { return InvalidArgument( "Var input of batch-norm-grad must have" - " rank 1, but has rank %lld.", + " rank 1, but has rank %d.", ShapeUtil::Rank(var_shape)); } @@ -1432,14 +1460,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The operand to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::ElementIsFloating(output_grad_shape)) { return InvalidArgument( "The output_grad to batch-norm-grad must have a floating point " "element type, but the shape is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, @@ -1448,8 +1476,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " "and the element type of operand is %s.", - PrimitiveType_Name(output_grad_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(output_grad_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, @@ -1458,8 +1486,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " "and the element type of operand is %s.", - PrimitiveType_Name(scale_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(scale_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, @@ -1468,8 +1496,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, @@ -1478,8 +1506,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " "and the element type of operand is %s.", - PrimitiveType_Name(mean_shape.element_type()).c_str(), - PrimitiveType_Name(operand_shape.element_type()).c_str()); + PrimitiveType_Name(mean_shape.element_type()), + PrimitiveType_Name(operand_shape.element_type())); } const int64 feature_count = operand_shape.dimensions(feature_index); @@ -1490,24 +1518,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::GetDimension(mean_shape, 0) != feature_count) { return InvalidArgument( "The size of mean should be the same as feature count," - "but the size of offset factor is %lld " - "and the feature count is %lld.", + "but the size of offset factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(mean_shape, 0), feature_count); } if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { return InvalidArgument( "The size of scale factor should be the same as feature count," - "but the size of scale factor is %lld " - "and the feature count is %lld.", + "but the size of scale factor is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(scale_shape, 0), feature_count); } if (ShapeUtil::GetDimension(var_shape, 0) != feature_count) { return InvalidArgument( "The size of variance should be the same as feature count," - "but the size of variance is %lld " - "and the feature count is %lld.", + "but the size of variance is %d " + "and the feature count is %d.", ShapeUtil::GetDimension(var_shape, 0), feature_count); } @@ -1517,8 +1545,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::GetDimension(output_grad_shape, i)) { return InvalidArgument( "The bounds of operand shape should be the same as output_grad's," - "but the bound of operand_shape at dimension %lld is %lld " - "and the bound of output_grad_shape is %lld.", + "but the bound of operand_shape at dimension %d is %d " + "and the bound of output_grad_shape is %d.", i, ShapeUtil::GetDimension(operand_shape, i), ShapeUtil::GetDimension(output_grad_shape, i)); } @@ -1537,15 +1565,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s.", - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs)); } if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" "Window: %s", - window.DebugString().c_str()); + window.DebugString()); } const int num_spatial_dims = dnums.input_spatial_dimensions_size(); @@ -1553,19 +1580,19 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" "Window: %s\nDimension numbers: %s.", - window.DebugString().c_str(), dnums.DebugString().c_str()); + window.DebugString(), dnums.DebugString()); } const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } if (ShapeUtil::Rank(rhs) != num_dims) { return InvalidArgument( "The RHS argument to a convolution should have rank %d; lhs: %s.", - num_dims, ShapeUtil::HumanString(lhs).c_str()); + num_dims, ShapeUtil::HumanString(lhs)); } TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -1602,26 +1629,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (input_dnums != expected_dnums) { return InvalidArgument( "Input dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (window_dnums != expected_dnums) { return InvalidArgument( "Window dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } if (output_dnums != expected_dnums) { return InvalidArgument( "Output dimensions of convolution must contain each dimension exactly " "once: %s.", - dnums.DebugString().c_str()); + dnums.DebugString()); } std::vector<int64> input_spatial_dims(num_spatial_dims); @@ -1642,13 +1669,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( - "Expected LHS feature dimension (value %lld) to match RHS " - "input feature dimension * feature_group_count (value %lld); " + "Expected LHS feature dimension (value %d) to match RHS " + "input feature dimension * feature_group_count (value %d); " "got <conv>(%s, %s)\n" "Dimension numbers: {%s}.", input_features, kernel_input_features * feature_group_count, - ShapeUtil::HumanString(lhs).c_str(), - ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str()); + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); } std::vector<int64> window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1660,8 +1687,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, "RHS shape: %s\n\t" "Window: {%s}\n\t" "Dimension numbers: {%s}.", - ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(), - dnums.ShortDebugString().c_str()); + ShapeUtil::HumanString(rhs), window.ShortDebugString(), + dnums.ShortDebugString()); } Shape base_shape = @@ -1687,29 +1714,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const tensorflow::gtl::ArraySlice<int64> fft_length) { const int64 fft_rank = fft_length.size(); if (fft_rank < 1 || fft_rank > 3) { - return InvalidArgument("FFT only supports ranks 1-3; got %lld.", fft_rank); + return InvalidArgument("FFT only supports ranks 1-3; got %d.", fft_rank); } -#define RET_CHECK_RANK(x) \ - if (x.dimensions_size() < fft_rank) { \ - return InvalidArgument( \ - "FFT of rank %lld requires input of at least " \ - "same rank; got input of rank %d", \ - fft_rank, x.dimensions_size()); \ +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %d requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ } switch (fft_type) { case FFT: case IFFT: if (in.element_type() != C64) { return InvalidArgument("%s requires C64 input type, found %s.", - FftType_Name(fft_type).c_str(), - PrimitiveType_Name(in.element_type()).c_str()); + FftType_Name(fft_type), + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); return in; case RFFT: { if (in.element_type() != F32) { return InvalidArgument("RFFT requires F32 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); for (int i = 0; i < fft_rank; i++) { @@ -1717,7 +1744,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "RFFT requires innermost dimensions match fft_length but " - "dimension %lld is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1731,7 +1758,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case IRFFT: { if (in.element_type() != C64) { return InvalidArgument("IRFFT requires C64 input type, found %s.", - PrimitiveType_Name(in.element_type()).c_str()); + PrimitiveType_Name(in.element_type())); } RET_CHECK_RANK(in); Shape result = ShapeUtil::ComplexComponentShape(in); @@ -1740,7 +1767,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[i]) { return InvalidArgument( "IRFFT requires all but one innermost dimensions match " - "fft_length, but dimension %lld is %lld and should be %lld.", + "fft_length, but dimension %d is %d and should be %d.", in.dimensions_size() - fft_rank + i, in.dimensions(in.dimensions_size() - fft_rank + i), fft_length[i]); @@ -1750,7 +1777,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, fft_length[fft_rank - 1] / 2 + 1) { return InvalidArgument( "IRFFT requires innermost dimension matches fft_length/2+1, but " - "dimension %d is %lld and should be %lld.", + "dimension %d is %d and should be %d.", in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), fft_length[fft_rank - 1] / 2 + 1); } @@ -1786,18 +1813,18 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RET_CHECK(split_count > 0); if (split_dimension >= ShapeUtil::Rank(shape) || split_dimension < 0) { return InvalidArgument( - "AllToAll split_dimension %lld is out-of-bounds in shape %s.", - split_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll split_dimension %d is out-of-bounds in shape %s.", + split_dimension, ShapeUtil::HumanString(shape)); } if (concat_dimension >= ShapeUtil::Rank(shape) || concat_dimension < 0) { return InvalidArgument( - "AllToAll concat_dimension %lld is out-of-bounds in shape %s.", - concat_dimension, ShapeUtil::HumanString(shape).c_str()); + "AllToAll concat_dimension %d is out-of-bounds in shape %s.", + concat_dimension, ShapeUtil::HumanString(shape)); } if (shape.dimensions(split_dimension) % split_count != 0) { return InvalidArgument( - "AllToAll split dimension size %lld must be dividable by split_count " - "%lld.", + "AllToAll split dimension size %d must be dividable by split_count " + "%d.", shape.dimensions(split_dimension), split_count); } std::vector<int64> new_dimensions(shape.dimensions().begin(), @@ -1817,14 +1844,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "HLO all-to-all has operands with different shapes: the 0th " "operand shape %s, but the %dth operand has shape %s.", - ShapeUtil::HumanString(*operand_shapes[0]).c_str(), i, - ShapeUtil::HumanString(*operand_shapes[i]).c_str()); + ShapeUtil::HumanString(*operand_shapes[0]), i, + ShapeUtil::HumanString(*operand_shapes[i])); } } return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ StatusOr<Shape> ShapeInference::InferCollectivePermuteShape( + const Shape& shape) { + TF_RET_CHECK(ShapeUtil::IsArray(shape)); + return shape; +} + /* static */ StatusOr<Shape> ShapeInference::InferReduceShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce, @@ -1847,9 +1880,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameDimensions(*reduced_args[0], *reduced_args[i])) { return InvalidArgument( "All reduced tensors must have the sime dimension. Tensor 0 has " - "shape %s, Tensor %lld has shape %s", - ShapeUtil::HumanString(*reduced_args[0]).c_str(), i, - ShapeUtil::HumanString(*reduced_args[i]).c_str()); + "shape %s, Tensor %d has shape %s", + ShapeUtil::HumanString(*reduced_args[0]), i, + ShapeUtil::HumanString(*reduced_args[i])); } } @@ -1859,9 +1892,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const Shape& arg = *reduced_args[0]; for (int64 dimension : dimensions_to_reduce) { if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) { - return InvalidArgument( - "Reducing out-of-bounds dimension %lld in shape %s.", dimension, - ShapeUtil::HumanString(arg).c_str()); + return InvalidArgument("Reducing out-of-bounds dimension %d in shape %s.", + dimension, ShapeUtil::HumanString(arg)); } } @@ -1934,16 +1966,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select function's first parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(0)), + ShapeUtil::HumanString(operand_element_shape)); } if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, select_shape.parameters(1))) { return InvalidArgument( "Select function's second parameter shape currently must " "match the operand element shape, but got %s vs %s.", - ShapeUtil::HumanString(select_shape.parameters(1)).c_str(), - ShapeUtil::HumanString(operand_element_shape).c_str()); + ShapeUtil::HumanString(select_shape.parameters(1)), + ShapeUtil::HumanString(operand_element_shape)); } // Check if the scatter function has a proper shape as a reduction. @@ -1961,8 +1993,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s).", - ShapeUtil::HumanString(source_shape).c_str(), - ShapeUtil::HumanString(window_result_shape).c_str()); + ShapeUtil::HumanString(source_shape), + ShapeUtil::HumanString(window_result_shape)); } return operand_shape; } @@ -1975,29 +2007,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "%s in slice operation; argument shape: %s; starts: {%s}; limits: " "{%s}; strides: {%s}.", - message.c_str(), ShapeUtil::HumanString(arg).c_str(), - Join(starts, ",").c_str(), Join(limits, ",").c_str(), - Join(strides, ",").c_str()); + message, ShapeUtil::HumanString(arg), StrJoin(starts, ","), + StrJoin(limits, ","), StrJoin(strides, ",")); }; TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice")); - VLOG(2) << tensorflow::strings::Printf( - "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), - Join(limits, ", ").c_str()); + VLOG(2) << StrFormat("slicing shape %s starts={%s} limits={%s}", + ShapeUtil::HumanString(arg), StrJoin(starts, ", "), + StrJoin(limits, ", ")); if (starts.size() != limits.size()) { - return error(Printf("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size())); + return error(StrFormat("slice start and limit sizes differ: %u vs %u", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return error(Printf("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size())); + return error(StrFormat("slice start and strides sizes differ: %u vs %u", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { return InvalidArgument( - "Slice index count does not match argument rank: %zu vs %lld.", + "Slice index count does not match argument rank: %u vs %d.", starts.size(), ShapeUtil::Rank(arg)); } @@ -2007,27 +2037,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, int64 limit_index = limits[dimension]; int64 stride = strides[dimension]; if (start_index < 0) { - return InvalidArgument("Negative start index to slice: %lld.", - start_index); + return InvalidArgument("Negative start index to slice: %d.", start_index); } if (limit_index > arg.dimensions(dimension)) { return error( - Printf("limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension))); - } - VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, - start_index); - VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, - limit_index); + StrFormat("limit index (%d) must be less than or equal to dimension " + "size (%d)", + limit_index, arg.dimensions(dimension))); + } + VLOG(2) << StrFormat("starts[%d] = %d", dimension, start_index); + VLOG(2) << StrFormat("limits[%d] = %d", dimension, limit_index); if (start_index > limit_index) { return error( - Printf("limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index)); + StrFormat("limit index (%d) must be greater or equal to " + "start index (%d) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { - return InvalidArgument("Stride (%lld) must be positive.", stride); + return InvalidArgument("Stride (%d) must be positive.", stride); } sizes.push_back((limit_index - start_index + stride - 1) / stride); } @@ -2042,15 +2069,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR( ExpectArray(start_indices_shape, "start indices of dynamic slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - Join(slice_sizes, ", ").c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), StrJoin(slice_sizes, ", ")); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic slice start indices of rank %lld must be rank1.", + "Dynamic slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2062,16 +2088,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic slice start number of dimensions %lld (%s) must match rank " - "%lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic slice start number of dimensions %d (%s) must match rank " + "%d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( - "Dynamic slice index count does not match argument rank: %zu vs %lld.", + "Dynamic slice index count does not match argument rank: %u vs %d.", slice_sizes.size(), ShapeUtil::Rank(operand_shape)); } @@ -2079,16 +2104,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 input_dim_size = operand_shape.dimensions(dim); const int64 slice_dim_size = slice_sizes[dim]; if (slice_dim_size < 0) { - return InvalidArgument("Negative size index to dynamic slice: %lld.", + return InvalidArgument("Negative size index to dynamic slice: %d.", slice_dim_size); } if (slice_dim_size > input_dim_size) { return InvalidArgument( - "Slice dim size %lld greater than dynamic slice dimension: %lld.", + "Slice dim size %d greater than dynamic slice dimension: %d.", slice_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim, - slice_dim_size); + VLOG(2) << StrFormat("slice_sizes[%d] = %d", dim, slice_dim_size); } return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes); @@ -2104,16 +2128,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape, "start indices of dynamic update slice")); - VLOG(2) << tensorflow::strings::Printf( + VLOG(2) << StrFormat( "updating slice of shape %s at dynamic start_indices %s with update " "shape %s", - ShapeUtil::HumanString(operand_shape).c_str(), - ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::HumanString(update_shape).c_str()); + ShapeUtil::HumanString(operand_shape), + ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::HumanString(update_shape)); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( - "Dynamic update slice start indices of rank %lld must be rank1.", + "Dynamic update slice start indices of rank %d must be rank1.", ShapeUtil::Rank(start_indices_shape)); } @@ -2125,17 +2149,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 start_num_dims = start_indices_shape.dimensions(0); if (ShapeUtil::Rank(operand_shape) != start_num_dims) { return InvalidArgument( - "Dynamic update slice start number of dimensions %lld (%s) must match " - "rank %lld of slice input (%s).", - start_num_dims, ShapeUtil::HumanString(start_indices_shape).c_str(), - ShapeUtil::Rank(operand_shape), - ShapeUtil::HumanString(operand_shape).c_str()); + "Dynamic update slice start number of dimensions %d (%s) must match " + "rank %d of slice input (%s).", + start_num_dims, ShapeUtil::HumanString(start_indices_shape), + ShapeUtil::Rank(operand_shape), ShapeUtil::HumanString(operand_shape)); } if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) { return InvalidArgument( "Dynamic update slice update rank does not match argument rank: " - "%lld vs %lld.", + "%d vs %d.", ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } @@ -2144,8 +2167,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s.", - PrimitiveType_Name(operand_shape.element_type()).c_str(), - PrimitiveType_Name(update_shape.element_type()).c_str()); + PrimitiveType_Name(operand_shape.element_type()), + PrimitiveType_Name(update_shape.element_type())); } for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) { @@ -2153,16 +2176,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const int64 update_dim_size = update_shape.dimensions(dim); if (update_dim_size < 0) { return InvalidArgument( - "Size index %lld to dynamic update slice must be >= 0.", + "Size index %d to dynamic update slice must be >= 0.", update_dim_size); } if (update_dim_size > input_dim_size) { return InvalidArgument( - "Update dim size %lld greater than dynamic slice dimension: %lld.", + "Update dim size %d greater than dynamic slice dimension: %d.", update_dim_size, input_dim_size); } - VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim, - update_dim_size); + VLOG(2) << StrFormat("update_sizes[%d] = %d", dim, update_dim_size); } return operand_shape; @@ -2177,8 +2199,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, for (int64 dimension : dimensions) { if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) { return InvalidArgument( - "One of the reverse dimensions (%lld) is out-of-bounds in shape %s.", - dimension, ShapeUtil::HumanString(operand_shape).c_str()); + "One of the reverse dimensions (%d) is out-of-bounds in shape %s.", + dimension, ShapeUtil::HumanString(operand_shape)); } } return operand_shape; @@ -2189,14 +2211,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::IsTuple(arg)) { return InvalidArgument( "Cannot infer shape: attempting to index into non-tuple: %s.", - ShapeUtil::HumanString(arg).c_str()); + ShapeUtil::HumanString(arg)); } if (index >= arg.tuple_shapes_size()) { return InvalidArgument( - "Cannot infer shape: attempt to index out of tuple bounds: %lld " + "Cannot infer shape: attempt to index out of tuple bounds: %d " ">= %d in shape %s.", - index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str()); + index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg)); } return arg.tuple_shapes(index); @@ -2216,17 +2238,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } auto shape_string = [&]() { - return tensorflow::strings::Printf( - "Condition: %s; body: %s; init: %s.", - ShapeUtil::HumanString(condition).c_str(), - ShapeUtil::HumanString(body).c_str(), - ShapeUtil::HumanString(init).c_str()); + return StrFormat( + "Condition: %s; body: %s; init: %s.", ShapeUtil::HumanString(condition), + ShapeUtil::HumanString(body), ShapeUtil::HumanString(init)); }; // Check the shapes of computation parameters and return types. if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) { return InvalidArgument("Condition must return a boolean; got %s.", - shape_string().c_str()); + shape_string()); } if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) || !ShapeUtil::Compatible(body.result(), body.parameters(0)) || @@ -2234,7 +2254,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "The parameter of condition and body, the result of the body, and init " "must all have the same shape; got %s.", - shape_string().c_str()); + shape_string()); } return init; @@ -2246,7 +2266,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, const ProgramShape& false_computation) { if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { return InvalidArgument("Predicate must be a boolean; got %s.", - ShapeUtil::HumanString(predicate).c_str()); + ShapeUtil::HumanString(predicate)); } if (true_computation.parameters_size() != 1) { @@ -2255,15 +2275,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { auto true_shape_string = [&]() { - return tensorflow::strings::Printf( - "true_operand: %s; true_computation: %s", - ShapeUtil::HumanString(true_operand).c_str(), - ShapeUtil::HumanString(true_computation).c_str()); + return StrFormat("true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand), + ShapeUtil::HumanString(true_computation)); }; return InvalidArgument( "true_operand must match the shape of the only parameter of " "true_computation: got %s.", - true_shape_string().c_str()); + true_shape_string()); } if (false_computation.parameters_size() != 1) { @@ -2272,28 +2291,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { auto false_shape_string = [&]() { - return tensorflow::strings::Printf( - "false_operand: %s; false_computation: %s", - ShapeUtil::HumanString(false_operand).c_str(), - ShapeUtil::HumanString(false_computation).c_str()); + return StrFormat("false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand), + ShapeUtil::HumanString(false_computation)); }; return InvalidArgument( "false_operand must match the shape of the only parameter of " "false_computation: got %s.", - false_shape_string().c_str()); + false_shape_string()); } if (!ShapeUtil::Compatible(true_computation.result(), false_computation.result())) { auto shape_string = [&]() { - return tensorflow::strings::Printf( + return StrFormat( "true_computation result: %s; false_computation result: %s.", - ShapeUtil::HumanString(true_computation.result()).c_str(), - ShapeUtil::HumanString(false_computation.result()).c_str()); + ShapeUtil::HumanString(true_computation.result()), + ShapeUtil::HumanString(false_computation.result())); }; return InvalidArgument( "the result of true_computation and false_computation must have the " "same shape: got %s.", - shape_string().c_str()); + shape_string()); } return true_computation.result(); } @@ -2303,7 +2321,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast")); for (int64 size : broadcast_sizes) { if (size < 0) { - return InvalidArgument("Broadcast with negative dimension size %lld.", + return InvalidArgument("Broadcast with negative dimension size %d.", size); } } @@ -2328,11 +2346,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "Reshape operation has mismatched element counts: from=%lld (%s) " - "to=%lld (%s).", - ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), ShapeUtil::ElementsIn(inferred_shape), - ShapeUtil::HumanString(inferred_shape).c_str()); + ShapeUtil::HumanString(inferred_shape)); } std::vector<int64> indices(ShapeUtil::Rank(operand)); @@ -2343,7 +2361,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } return inferred_shape; @@ -2378,9 +2396,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("Clamp with different operand types: %s, %s, %s.", - ShapeUtil::HumanString(min).c_str(), - ShapeUtil::HumanString(operand).c_str(), - ShapeUtil::HumanString(max).c_str()); + ShapeUtil::HumanString(min), + ShapeUtil::HumanString(operand), + ShapeUtil::HumanString(max)); } if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || ShapeUtil::IsScalar(min)) && @@ -2397,9 +2415,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return ShapeUtil::ChangeElementType(min, operand.element_type()); } } - return Unimplemented( - "%s, %s <clamp> %s is not implemented.", min.ShortDebugString().c_str(), - max.ShortDebugString().c_str(), operand.ShortDebugString().c_str()); + return Unimplemented("%s, %s <clamp> %s is not implemented.", + min.ShortDebugString(), max.ShortDebugString(), + operand.ShortDebugString()); } // TODO(b/36794510): Make broadcast semantics more consistent, by supporting @@ -2410,13 +2428,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "Select's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || ShapeUtil::IsScalar(pred)) { @@ -2429,7 +2446,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Select operation with non-scalar predicate with dimensionality " " different from the other operands: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } } @@ -2440,18 +2457,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (!ShapeUtil::Compatible(on_true, on_false)) { return InvalidArgument( "Operands to tuple-select must be the same shape; got %s and %s.", - ShapeUtil::HumanString(on_true).c_str(), - ShapeUtil::HumanString(on_false).c_str()); + ShapeUtil::HumanString(on_true), ShapeUtil::HumanString(on_false)); } if (pred.element_type() != PRED) { return InvalidArgument( "TupleSelect's pred operand must have PRED element type; got %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } if (!ShapeUtil::IsScalar(pred)) { return InvalidArgument( "TupleSelect operation with non-scalar predicate: %s.", - ShapeUtil::HumanString(pred).c_str()); + ShapeUtil::HumanString(pred)); } return on_true; } @@ -2463,15 +2479,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); string argument_shapes = - Join(arg_shapes, ", ", [](string* out, const Shape* shape) { - tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); + StrJoin(arg_shapes, ", ", [](string* out, const Shape* shape) { + absl::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( "Call applied function arity must match number of arguments; got: " - "arity: %d, arguments: %zu; computation signature: %s; argument " + "arity: %d, arguments: %u; computation signature: %s; argument " "shapes: [%s].", - to_apply.parameters_size(), arg_shapes.size(), - computation_signature.c_str(), argument_shapes.c_str()); + to_apply.parameters_size(), arg_shapes.size(), computation_signature, + argument_shapes); } // All arguments must be compatible with the program shape. @@ -2482,8 +2498,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InvalidArgument( "Call parameter must match argument; got parameter %d shape: %s, " "argument shape: %s.", - i, ShapeUtil::HumanString(param_shape).c_str(), - ShapeUtil::HumanString(arg_shape).c_str()); + i, ShapeUtil::HumanString(param_shape), + ShapeUtil::HumanString(arg_shape)); } } @@ -2494,17 +2510,17 @@ static Status ValidateGatherDimensionNumbers( const Shape& input_shape, tensorflow::gtl::ArraySlice<int64> start_indices_shape, const GatherDimensionNumbers& dim_numbers) { - if (!c_is_sorted(dim_numbers.offset_dims())) { + if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.offset_dims()) != + if (absl::c_adjacent_find(dim_numbers.offset_dims()) != dim_numbers.offset_dims().end()) { return InvalidArgument( "Output window dimensions in gather op must not repeat; got: %s.", - Join(dim_numbers.offset_dims(), ", ").c_str()); + StrJoin(dim_numbers.offset_dims(), ", ")); } const int64 output_offset_dim_count = dim_numbers.offset_dims_size(); @@ -2515,9 +2531,9 @@ static Status ValidateGatherDimensionNumbers( int64 offset_dim = dim_numbers.offset_dims(i); if (offset_dim < 0 || offset_dim >= output_shape_rank) { return InvalidArgument( - "Offset dimension %d in gather op is out of bounds; got %lld, but " + "Offset dimension %d in gather op is out of bounds; got %d, but " "should " - "have been in [0,%lld).", + "have been in [0,%d).", i, offset_dim, output_shape_rank); } } @@ -2526,8 +2542,8 @@ static Status ValidateGatherDimensionNumbers( start_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Gather op has %d elements in start_index_map and the " - "bound of dimension index_vector_dim=%lld of start_indices is " - "%lld. These two numbers must be equal.", + "bound of dimension index_vector_dim=%d of start_indices is " + "%d. These two numbers must be equal.", dim_numbers.start_index_map_size(), dim_numbers.index_vector_dim(), start_indices_shape[dim_numbers.index_vector_dim()]); } @@ -2537,7 +2553,7 @@ static Status ValidateGatherDimensionNumbers( if (operand_dim_for_start_index_i < 0 || operand_dim_for_start_index_i >= input_shape.dimensions_size()) { return InvalidArgument( - "Invalid start_index_map; domain is [0, %d), got: %d->%lld.", + "Invalid start_index_map; domain is [0, %d), got: %d->%d.", input_shape.dimensions_size(), i, operand_dim_for_start_index_i); } } @@ -2546,36 +2562,37 @@ static Status ValidateGatherDimensionNumbers( dim_numbers.start_index_map().begin(), dim_numbers.start_index_map().end()); - c_sort(sorted_start_index_map); + absl::c_sort(sorted_start_index_map); - if (c_adjacent_find(sorted_start_index_map) != sorted_start_index_map.end()) { + if (absl::c_adjacent_find(sorted_start_index_map) != + sorted_start_index_map.end()) { return InvalidArgument( "Repeated dimensions are not allowed in start_index_map; " "got: %s.", - Join(dim_numbers.start_index_map(), ", ").c_str()); + StrJoin(dim_numbers.start_index_map(), ", ")); } for (int64 collapsed_dim : dim_numbers.collapsed_slice_dims()) { if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( "Invalid collapsed_slice_dims set in gather op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", input_shape.dimensions_size(), collapsed_dim); } } - if (!c_is_sorted(dim_numbers.collapsed_slice_dims())) { + if (!absl::c_is_sorted(dim_numbers.collapsed_slice_dims())) { return InvalidArgument( "collapsed_slice_dims in gather op must be sorted; got: %s", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.collapsed_slice_dims()) != + if (absl::c_adjacent_find(dim_numbers.collapsed_slice_dims()) != dim_numbers.collapsed_slice_dims().end()) { return InvalidArgument( "Repeated dimensions not allowed in collapsed_slice_dims in gather op; " "got: %s.", - Join(dim_numbers.collapsed_slice_dims(), ", ").c_str()); + StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } return Status::OK(); @@ -2593,7 +2610,7 @@ static Status ValidateGatherDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) { return InvalidArgument( "Gather indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(start_indices_shape).c_str()); + ShapeUtil::HumanString(start_indices_shape)); } // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if @@ -2606,15 +2623,15 @@ static Status ValidateGatherDimensionNumbers( return InvalidArgument( "Gather index leaf dimension must be within [0, rank(start_indices) + " "1). rank(start_indices) is %d and gather index leaf dimension is " - "%lld.", + "%d.", start_indices_shape.dimensions_size(), gather_dim_numbers.index_vector_dim()); } std::vector<int64> expanded_start_indices_shape; expanded_start_indices_shape.reserve(start_indices_shape.dimensions_size()); - c_copy(start_indices_shape.dimensions(), - std::back_inserter(expanded_start_indices_shape)); + absl::c_copy(start_indices_shape.dimensions(), + std::back_inserter(expanded_start_indices_shape)); if (expanded_start_indices_shape.size() == gather_dim_numbers.index_vector_dim()) { expanded_start_indices_shape.push_back(1); @@ -2637,8 +2654,8 @@ static Status ValidateGatherDimensionNumbers( "All components of the offset index in a gather op must either be a " "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " "output_slice_sizes=%s, collapsed_slice_dims=%s.", - slice_sizes.size(), Join(gather_dim_numbers.offset_dims(), ",").c_str(), - Join(gather_dim_numbers.collapsed_slice_dims(), ",").c_str()); + slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); } for (int i = 0; i < slice_sizes.size(); i++) { @@ -2647,7 +2664,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_size < 0 || slice_size > corresponding_input_size) { return InvalidArgument( "Slice size at index %d in gather op is out of range, must be " - "within [0, %lld), got %lld.", + "within [0, %d), got %d.", i, corresponding_input_size + 1, slice_size); } } @@ -2656,7 +2673,7 @@ static Status ValidateGatherDimensionNumbers( if (slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)] != 1) { return InvalidArgument( "Gather op can only collapse slice dims with bound 1, but bound is " - "%lld for index %lld at position %d.", + "%d for index %d at position %d.", slice_sizes[gather_dim_numbers.collapsed_slice_dims(i)], gather_dim_numbers.collapsed_slice_dims(i), i); } @@ -2670,10 +2687,11 @@ static Status ValidateGatherDimensionNumbers( output_dim_bounds.reserve(result_rank); for (int64 i = 0; i < result_rank; i++) { int64 current_bound; - bool is_window_index = c_binary_search(gather_dim_numbers.offset_dims(), i); + bool is_window_index = + absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { - while (c_binary_search(gather_dim_numbers.collapsed_slice_dims(), - offset_dims_seen)) { + while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen)) { offset_dims_seen++; } current_bound = slice_sizes[offset_dims_seen++]; @@ -2697,44 +2715,44 @@ Status ValidateScatterDimensionNumbers( tensorflow::gtl::ArraySlice<int64> scatter_indices_shape, const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { // Validate update_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.update_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.update_window_dims())) { return InvalidArgument( "update_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.update_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.update_window_dims()) != dim_numbers.update_window_dims().end()) { return InvalidArgument( "update_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.update_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.update_window_dims(), ", ")); } const int64 updates_rank = ShapeUtil::Rank(updates_shape); for (int64 window_dim : dim_numbers.update_window_dims()) { if (window_dim < 0 || window_dim >= updates_rank) { return InvalidArgument( "Invalid update_window_dims set in scatter op; valid range is [0, " - "%lld). got: %lld.", + "%d). got: %d.", updates_rank, window_dim); } } // Validate inserted_window_dims in ScatterDimensionNumbers. - if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + if (!absl::c_is_sorted(dim_numbers.inserted_window_dims())) { return InvalidArgument( "inserted_window_dims in scatter op must be sorted; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } - if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + if (absl::c_adjacent_find(dim_numbers.inserted_window_dims()) != dim_numbers.inserted_window_dims().end()) { return InvalidArgument( "inserted_window_dims in scatter op must not repeat; got: %s.", - Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + StrJoin(dim_numbers.inserted_window_dims(), ", ")); } for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid inserted_window_dims set in scatter op; valid range is [0, " - "%d), got: %lld.", + "%d), got: %d.", operand_shape.dimensions_size(), inserted_dim); } } @@ -2744,7 +2762,7 @@ Status ValidateScatterDimensionNumbers( scatter_indices_shape[dim_numbers.index_vector_dim()]) { return InvalidArgument( "Scatter op has %d elements in scatter_dims_to_operand_dims and the " - "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "bound of dimension index_vector_dim=%d of scatter_indices is %d. " "These two numbers must be equal.", dim_numbers.scatter_dims_to_operand_dims_size(), dim_numbers.index_vector_dim(), @@ -2757,20 +2775,20 @@ Status ValidateScatterDimensionNumbers( scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { return InvalidArgument( "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " - "got: %d->%lld.", + "got: %d->%d.", operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); } } std::vector<int64> sorted_scatter_dims_to_operand_dims( dim_numbers.scatter_dims_to_operand_dims().begin(), dim_numbers.scatter_dims_to_operand_dims().end()); - c_sort(sorted_scatter_dims_to_operand_dims); - if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + absl::c_sort(sorted_scatter_dims_to_operand_dims); + if (absl::c_adjacent_find(sorted_scatter_dims_to_operand_dims) != sorted_scatter_dims_to_operand_dims.end()) { return InvalidArgument( "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " "got: %s.", - Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } return Status::OK(); @@ -2791,7 +2809,7 @@ Status ValidateScatterDimensionNumbers( if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { return InvalidArgument( "Scatter indices parameter must be an integral tensor; got %s.", - ShapeUtil::HumanString(scatter_indices_shape).c_str()); + ShapeUtil::HumanString(scatter_indices_shape)); } if (scatter_indices_shape.dimensions_size() < @@ -2800,7 +2818,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Scatter index leaf dimension must be within [0, rank(scatter_indices)" " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " - "is %lld.", + "is %d.", scatter_indices_shape.dimensions_size(), scatter_dim_numbers.index_vector_dim()); } @@ -2822,7 +2840,7 @@ Status ValidateScatterDimensionNumbers( int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + scatter_dim_numbers.update_window_dims_size(); if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { - return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + return InvalidArgument("Updates tensor must be of rank %d; got %d.", expected_updates_rank, ShapeUtil::Rank(updates_shape)); } @@ -2848,7 +2866,7 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the window dimensions of updates must not exceed the " "bounds of the corresponding dimensions of operand. For dimension " - "%lld, updates bound is %lld, operand bound is %lld.", + "%d, updates bound is %d, operand bound is %d.", update_window_dim, updates_shape.dimensions(update_window_dim), max_update_slice_sizes[i]); } @@ -2857,7 +2875,7 @@ Status ValidateScatterDimensionNumbers( int64 scatter_dims_seen = 0; for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { bool is_update_window_dim = - c_binary_search(scatter_dim_numbers.update_window_dims(), i); + absl::c_binary_search(scatter_dim_numbers.update_window_dims(), i); if (is_update_window_dim) { continue; } @@ -2869,8 +2887,8 @@ Status ValidateScatterDimensionNumbers( return InvalidArgument( "Bounds of the scatter dimensions of updates must be same as the " "bounds of the corresponding dimensions of scatter indices. For " - "scatter dimension %lld, updates bound is %lld, scatter_indices " - "bound is %lld.", + "scatter dimension %d, updates bound is %d, scatter_indices " + "bound is %d.", i, updates_shape.dimensions(i), expanded_scatter_indices_shape[scatter_dims_seen]); } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 4974ac9916..235b1a4cf3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -136,6 +136,9 @@ class ShapeInference { static StatusOr<Shape> InferAllToAllTupleShape( tensorflow::gtl::ArraySlice<const Shape*> operand_shapes); + // Infers the shape of a collective permute operation. + static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape); + // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. // diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 7d7dcac10b..921a984589 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,20 +18,19 @@ limitations under the License. #include <string> #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::strings::Appendf; - ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const se::Platform* platform, int device_ordinal) @@ -76,7 +75,7 @@ void ShapedBuffer::clear() { } string ShapedBuffer::ToString() const { - string s = tensorflow::strings::StrCat( + string s = absl::StrCat( "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), ", on-device shape=" + @@ -92,9 +91,9 @@ string ShapedBuffer::ToString() const { shape_str = ShapeUtil::HumanStringWithLayout(subshape); } const se::DeviceMemoryBase& memory = buffer(index); - Appendf(&s, " %s%p (%lld bytes) : %s\n", - string(index.size() * 2, ' ').c_str(), memory.opaque(), - memory.size(), shape_str.c_str()); + absl::StrAppendFormat(&s, " %s%p (%d bytes) : %s\n", + string(index.size() * 2, ' '), memory.opaque(), + memory.size(), shape_str); }); return s; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer_test.cc b/tensorflow/compiler/xla/service/shaped_buffer_test.cc index 0fc2436679..d69e6362e9 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer_test.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -34,7 +35,7 @@ TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) { xla::StreamExecutorMemoryAllocator allocator(platform, executors); const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {}); const int kDeviceOrdinal = 0; - auto scoped_buffer = tensorflow::MakeUnique<xla::ScopedShapedBuffer>( + auto scoped_buffer = absl::make_unique<xla::ScopedShapedBuffer>( shape, shape, &allocator, kDeviceOrdinal); std::unique_ptr<xla::ShapedBuffer> buffer = std::move(scoped_buffer); buffer = nullptr; diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc index 8cbaac7b37..dd53c7531b 100644 --- a/tensorflow/compiler/xla/service/source_map_util.cc +++ b/tensorflow/compiler/xla/service/source_map_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/source_map_util.h" +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -26,11 +27,10 @@ Status InvalidParameterArgumentV(const OpMetadata& op_metadata, string message; tensorflow::strings::Appendv(&message, format, args); if (!op_metadata.source_file().empty()) { - tensorflow::strings::Appendf(&message, " (%s:%d)", - op_metadata.source_file().c_str(), - op_metadata.source_line()); + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); } - return InvalidArgument("%s", message.c_str()); + return InvalidArgument("%s", message); } } // namespace diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h index 18e2651abb..c5a7e17cb4 100644 --- a/tensorflow/compiler/xla/service/source_map_util.h +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ -#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ +#include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -24,23 +25,40 @@ namespace xla { namespace source_map_util { // Creates an INVALID_ARGUMENT status with the given format string. +template <typename... Args> +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const absl::FormatSpec<Args...>& format, + const Args&... args) { + string message = absl::StrFormat(format, args...); + if (!op_metadata.source_file().empty()) { + absl::StrAppendFormat(&message, " (%s:%d)", op_metadata.source_file(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message); +} + +// Creates an INVALID_ARGUMENT status with the given format string. // // Also, attempts to extract the OpMetadata for parameter_number on executable // and append it to the status message for source mapping to user code. // // executable may be nullptr, but parameter_number should not be out of bounds // or a CHECK-failure may occur. +template <typename... Args> Status InvalidParameterArgument(Executable* executable, int parameter_number, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(3, 4); - -// As above, but takes the parameter metadata directly instead of extracting it -// from the executable. -Status InvalidParameterArgument(const OpMetadata& op_metadata, - const char* format, ...) - TF_PRINTF_ATTRIBUTE(2, 3); + const absl::FormatSpec<Args...>& format, + const Args&... args) { + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + return InvalidParameterArgument(metadata, format, args...); + } + return InvalidArgument(format, args...); +} } // namespace source_map_util } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SOURCE_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index c0582c6a2d..5d1cd1c442 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/stream_pool.h" -#include "tensorflow/compiler/xla/ptr_util.h" +#include "absl/memory/memory.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -35,7 +35,7 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { if (!stream) { // Create a new stream. - stream = MakeUnique<se::Stream>(executor); + stream = absl::make_unique<se::Stream>(executor); stream->Init(); VLOG(1) << stream->DebugStreamPointers() << " StreamPool created new stream"; diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 32d368a904..b8d2d546e5 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -18,6 +18,8 @@ limitations under the License. #include <string> #include <utility> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -27,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/notification.h" -using ::tensorflow::strings::StrCat; +using absl::StrCat; namespace xla { /* static */ tensorflow::mutex @@ -61,7 +63,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferLiteralFromDevice( if (!s.ok()) { return s; } - return MakeUnique<Literal>(std::move(literal)); + return absl::make_unique<Literal>(std::move(literal)); } Status TransferManager::TransferLiteralFromDevice( @@ -120,7 +122,7 @@ StatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice( if (!s.ok()) { return s; } - return MakeUnique<Literal>(std::move(literal)); + return absl::make_unique<Literal>(std::move(literal)); } Status TransferManager::TransferArrayToDevice( @@ -147,7 +149,7 @@ Status TransferManager::TransferArrayToDeviceAsync( if (dest.size() < GetByteSizeRequirement(on_device_shape)) { return FailedPrecondition( "Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", dest.size(), GetByteSizeRequirement(on_device_shape)); } ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, @@ -164,12 +166,12 @@ void TransferManager::TransferArrayFromDevice( auto error = StrCat("Shape ", ShapeUtil::HumanString(shape), " has a differently shaped representation on-device: ", ShapeUtil::HumanString(HostShapeToDeviceShape(shape))); - return done(FailedPrecondition("%s", error.c_str())); + return done(FailedPrecondition("%s", error)); } if (source.size() < GetByteSizeRequirement(shape)) { return done( FailedPrecondition("Allocation on device not large enough for array: " - "%lld < %lld", + "%d < %d", source.size(), GetByteSizeRequirement(shape))); } ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, @@ -201,7 +203,7 @@ void TransferManager::TransferArrayFromDevice( return NotFound( "could not find registered transfer manager for platform %s -- check " "target linkage", - platform->Name().c_str()); + platform->Name()); } if (it->second.manager == nullptr) { @@ -252,7 +254,7 @@ Status TransferManager::TransferBufferFromDevice( if (source.size() < size) { return FailedPrecondition( "Source allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", source.size(), size); } stream->ThenMemcpy(destination, source, size); @@ -265,7 +267,7 @@ Status TransferManager::TransferBufferToDevice( if (destination->size() < size) { return FailedPrecondition( "Destination allocation on device not large enough for data tranfer: " - "%lld < %lld", + "%d < %d", destination->size(), size); } stream->ThenMemcpy(destination, source, size); @@ -276,9 +278,8 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer( const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal) { if (!LayoutUtil::HasLayout(on_host_shape)) { - return InvalidArgument( - "Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape)); } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 475a2e5c14..f77690a462 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -152,6 +152,26 @@ class TransferManager { const Shape& on_host_shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // The given ShapedBuffer holds a handle to allocated memory, but it is not + // in the general case legal to immediately copy or access that allocated + // memory because queued operations on the device may alias that memory. + // Memory ordering is enforced by the Stream's happens-before relationship + // which allows eager deallocation and reallocation of buffers host-side even + // if the device hasn't finished with them. + // + // In certain cases, it can be known that a ShapedBuffer does not have any + // conflicting accesses on the device and thus is eligible to be accessed at + // any time from the host. + // + // This function returns true if device_buffer can be accessed immediately + // without waiting for the Stream's previously enqueued items. This only + // returns true if all subbuffers in device_buffer can be accessed + // immediately. + virtual bool CanShapedBufferBeAccessedNow( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { + return false; + } + ///// // The TransferManager class also serves as a point to register objects for // the various platforms. diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 49e1f87319..530f40e4b2 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot( dot->shape(), new_lhs, new_rhs, new_dim_numbers); + new_dot->set_precision_config(dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); + new_conv->set_precision_config(convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 71e8446452..3e5aa2db60 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -49,7 +49,7 @@ class TransposeFolding : public HloPassInterface { explicit TransposeFolding( TransposableGemmOperandsFn transposable_gemm_operands, TransposableConvOperandsFn transposable_conv_operands); - tensorflow::StringPiece name() const override { return "transpose-folding"; } + absl::string_view name() const override { return "transpose-folding"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0447807a41..cf00ca102b 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -19,6 +19,10 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -26,17 +30,13 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { string BufferAlias::ToString() const { - return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[", - tensorflow::str_util::Join(index_, ","), - "])"); + return absl::StrCat("BufferAlias(", instruction_->name(), "[", + absl::StrJoin(index_, ","), "])"); } std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) { @@ -441,7 +441,7 @@ PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet( PerInstruction* pi = PerInst(instruction); CHECK(pi->points_to_set == nullptr) << "instruction should not have been present in the map."; - auto set = MakeUnique<PointsToSet>(&instruction->shape()); + auto set = absl::make_unique<PointsToSet>(&instruction->shape()); pi->points_to_set = std::move(set); // Return *set using the iterator returned by emplace. return *pi->points_to_set; @@ -462,21 +462,20 @@ Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { return FailedPrecondition( "LogicalBuffer %s is ill-defined: instruction %s does not define a " "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + buffer.ToString(), buffer.instruction()->name()); } } if (buffer.id() < 0 || buffer.id() >= logical_buffer_analysis_->num_logical_buffers()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: invalid id %lld", - buffer.ToString().c_str(), buffer.id()); + return FailedPrecondition("LogicalBuffer %s is ill-defined: invalid id %d", + buffer.ToString(), buffer.id()); } if (GetBuffer(buffer.id()).instruction() != buffer.instruction() || GetBuffer(buffer.id()).index() != buffer.index()) { return FailedPrecondition( "LogicalBuffer %s is ill-defined: buffer with same id differs: %s", - buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str()); + buffer.ToString(), GetBuffer(buffer.id()).ToString()); } return Status::OK(); @@ -495,8 +494,7 @@ StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt( if (buffers.size() != 1 || buffers[0]->instruction() != instruction) { return FailedPrecondition( "instruction %s does not define buffer at index {%s}", - instruction->name().c_str(), - tensorflow::str_util::Join(index, ",").c_str()); + instruction->name(), absl::StrJoin(index, ",")); } return buffers[0]; } @@ -557,13 +555,12 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( } string TuplePointsToAnalysis::ToString() const { - string output = tensorflow::strings::Printf( - "TuplePointsToSet for module %s:\n", module_->name().c_str()); + string output = + absl::StrFormat("TuplePointsToSet for module %s:\n", module_->name()); for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = computation == module_->entry_computation() ? "entry " : ""; - tensorflow::strings::StrAppend(&output, entry, "computation ", - computation->name(), ":\n"); + absl::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); @@ -575,12 +572,11 @@ string TuplePointsToAnalysis::ToString() const { } } - tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n"); + absl::StrAppend(&output, "LogicalBuffers:\n"); for (const auto& b : logical_buffer_analysis_->logical_buffers()) { - tensorflow::strings::StrAppend(&output, " buffer ", b->ToString(), ":\n"); + absl::StrAppend(&output, " buffer ", b->ToString(), ":\n"); for (const BufferAlias& alias : logical_buffer_aliases_.at(b->id())) { - tensorflow::strings::StrAppend(&output, " alias ", alias.ToString(), - "\n"); + absl::StrAppend(&output, " alias ", alias.ToString(), "\n"); } } return output; @@ -589,20 +585,18 @@ string TuplePointsToAnalysis::ToString() const { void TuplePointsToAnalysis::InstructionToString( const HloInstruction* instruction, string* output) const { const string prefix = instruction->IsFused() ? " " : ""; - tensorflow::strings::StrAppend(output, prefix, " instruction ", - instruction->ToShortString(), ":\n"); + absl::StrAppend(output, prefix, " instruction ", + instruction->ToShortString(), ":\n"); const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement([&prefix, &output]( const ShapeIndex& index, const PointsToSet::BufferList& points_to) { - tensorflow::strings::StrAppend( - output, prefix, " {", tensorflow::str_util::Join(index, ","), "}: ", - tensorflow::str_util::Join( - points_to, ", ", - [](string* out, const LogicalBuffer* source) { - out->append(source->ToString()); - }), - "\n"); + absl::StrAppend(output, prefix, " {", absl::StrJoin(index, ","), "}: ", + absl::StrJoin(points_to, ", ", + [](string* out, const LogicalBuffer* source) { + out->append(source->ToString()); + }), + "\n"); }); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 686bb05328..62c7bb685d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -23,6 +23,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -109,7 +110,7 @@ class PointsToSet { // Add a tuple source instruction for the given index. void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); - using BufferList = tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>; + using BufferList = absl::InlinedVector<const LogicalBuffer*, 1>; // Return the list of logical buffers for the subshape at index. const BufferList& element(const ShapeIndex& index) const { @@ -203,7 +204,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // logical buffer The buffer alias set is the inverse of the points-to set. // That is, LogicalBuffer B is in the points-to set of instruction I at index // N iff instruction I, index N is a BufferAlias of B. - using BufferAliasVector = tensorflow::gtl::InlinedVector<BufferAlias, 1>; + using BufferAliasVector = absl::InlinedVector<BufferAlias, 1>; const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; // Returns the number of logical buffers in the module @@ -226,8 +227,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { // instructions produce a single buffer (the top-level buffer), some produce // no buffers (eg bitcast), and some produce more than one buffer (eg, // tuple-shaped parameters). - using BufferDefinitionVector = - tensorflow::gtl::InlinedVector<const LogicalBuffer*, 1>; + using BufferDefinitionVector = absl::InlinedVector<const LogicalBuffer*, 1>; const BufferDefinitionVector& GetBuffersDefinedByInstruction( const HloInstruction* instruction) const; diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 7509501883..8c91d6e69d 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -30,7 +30,7 @@ class TupleSimplifier : public HloPassInterface { TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} - tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + absl::string_view name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index af2cb6dc2a..7e4ac92a7c 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -18,8 +18,8 @@ limitations under the License. namespace xla { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; +using absl::nullopt; +using absl::optional; // Finds and returns the non-constant operand in instr. // diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index bf59813e8c..bf497f4892 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/core/lib/gtl/optional.h" namespace xla { @@ -25,8 +25,8 @@ namespace xla { // nullopt otherwise. max_value_returned limits the number of steps that are // evaluated while trying to brute force a loop trip count, trip counts larger // than max_value_returned result in nullopt. -tensorflow::gtl::optional<int64> ComputeWhileLoopTripCount( - HloInstruction *while_op, int64 max_value_returned = 128); +absl::optional<int64> ComputeWhileLoopTripCount(HloInstruction *while_op, + int64 max_value_returned = 128); } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 62af45128a..aab1180662 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -32,7 +33,7 @@ static Status ReplaceUsesWhileKeepingLoopInvariance( std::vector<HloInstruction*> users; users.reserve(old_instr->user_count()); - c_copy(old_instr->users(), std::back_inserter(users)); + absl::c_copy(old_instr->users(), std::back_inserter(users)); for (auto* user : users) { for (int64 i = 0, e = user->operand_count(); i < e; i++) { @@ -108,10 +109,10 @@ StatusOr<bool> WhileLoopConstantSinking::Run(HloModule* module) { // // This will let us sink the constant into the outer while first and then // into the inner while in a single run of this pass. - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 21fb8568a8..2dba7d7f75 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -54,7 +54,7 @@ class WhileLoopConstantSinking : public HloPassInterface { public: ~WhileLoopConstantSinking() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 09ddcffb22..f4098f28b3 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -14,18 +14,19 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { +using absl::InlinedVector; using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -using tensorflow::gtl::InlinedVector; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -65,8 +66,8 @@ static void CreateLoopInvariantCopy( }; InlinedVector<HloInstruction*, 4> new_operands; - c_transform(old_instruction->operands(), std::back_inserter(new_operands), - get_new_operand); + absl::c_transform(old_instruction->operands(), + std::back_inserter(new_operands), get_new_operand); HloInstruction* new_instruction = parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( @@ -197,7 +198,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( op->opcode() == HloOpcode::kConstant; }; - if (!c_all_of(instruction->operands(), is_invariant)) { + if (!absl::c_all_of(instruction->operands(), is_invariant)) { continue; } @@ -257,10 +258,10 @@ StatusOr<bool> WhileLoopInvariantCodeMotion::Run(HloModule* module) { bool changed = false; std::vector<HloInstruction*> while_instrs; for (auto* comp : module->computations()) { - c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kWhile; - }); + absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); } for (HloInstruction* while_instr : while_instrs) { diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8e6cc87875..2cdf20ce80 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -38,7 +38,7 @@ class WhileLoopInvariantCodeMotion : public HloPassInterface { : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "while-loop-invariant-code-motion"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 32e69c335b..e14014b961 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,6 +28,10 @@ namespace op = xla::testing::opcode_matchers; class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { public: + WhileLoopInvariantCodeMotionTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + // Makes a computation which has one parameter, of the given shape, and always // returns PRED[]{true}. This is useful as a dummy loop condition. HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index dd8697e680..6a7bfe3f12 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,17 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/optional.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; +using absl::optional; // Determines whether the given instruction is a send/recv node, or has a // subcomputation which contains a send/recv node. @@ -237,12 +236,11 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" - << tensorflow::str_util::Join( - user->users(), ", ", - [&](string* out, const HloInstruction* instr) { - tensorflow::strings::StrAppend( - out, instr->ToString(print_no_metadata)); - }) + << absl::StrJoin(user->users(), ", ", + [&](string* out, const HloInstruction* instr) { + absl::StrAppend( + out, instr->ToString(print_no_metadata)); + }) << "}"; replacements.emplace(user, nullptr); diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 3d3e1d60f2..78024f14dc 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -33,9 +33,7 @@ namespace xla { class WhileLoopSimplifier : public HloPassInterface { public: ~WhileLoopSimplifier() override {} - tensorflow::StringPiece name() const override { - return "simplify-while-loops"; - } + absl::string_view name() const override { return "simplify-while-loops"; } StatusOr<bool> Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index 2e1571943e..cfe4104f6d 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -15,11 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/strings/str_util.h" namespace xla { namespace { @@ -27,6 +28,11 @@ namespace { namespace op = xla::testing::opcode_matchers; class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + WhileLoopSimplifierTest() + : HloVerifiedTestBase(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/false) {} + protected: // Makes an HloModule that contains a loop with `num_iters` iteration. void MakeModuleWithSimpleLoop(int num_iters); @@ -64,10 +70,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoop(int num_iters) { } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } @@ -103,10 +107,8 @@ void WhileLoopSimplifierTest::MakeModuleWithSimpleLoopTupleElementLoopBound( } )"; - string hlo_string = tensorflow::str_util::StringReplace( - hlo_string_template, "{{LOOP_BOUND}}", - tensorflow::strings::StrCat(42 + num_iters), - /*replace_all=*/true); + string hlo_string = absl::StrReplaceAll( + hlo_string_template, {{"{{LOOP_BOUND}}", absl::StrCat(42 + num_iters)}}); ParseAndVerifyModule(hlo_string); } diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index 1ef17b9d7d..e8f76ff745 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -14,15 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_creation_utils.h" #include "tensorflow/compiler/xla/service/tuple_util.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace xla { -using tensorflow::strings::StrCat; +using absl::StrCat; static StatusOr<HloComputation*> WidenWhileCondition( HloComputation* narrow_condition, const Shape& wide_shape) { @@ -206,7 +207,7 @@ static StatusOr<HloInstruction*> MakeInitTupleFromInitValues( HloInstruction* zero = computation->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0))); init_values_with_indvar.push_back(zero); - c_copy(init_values, std::back_inserter(init_values_with_indvar)); + absl::c_copy(init_values, std::back_inserter(init_values_with_indvar)); return computation->AddInstruction( HloInstruction::CreateTuple(init_values_with_indvar)); } @@ -215,8 +216,9 @@ static Shape MakeLoopStateShape(const WhileUtil::LoopStateTy& init_values) { std::vector<Shape> loop_state_shape_components; loop_state_shape_components.reserve(init_values.size() + 1); loop_state_shape_components.push_back(ShapeUtil::MakeShape(S32, {})); - c_transform(init_values, std::back_inserter(loop_state_shape_components), - [](HloInstruction* instr) { return instr->shape(); }); + absl::c_transform(init_values, + std::back_inserter(loop_state_shape_components), + [](HloInstruction* instr) { return instr->shape(); }); return ShapeUtil::MakeTupleShape(loop_state_shape_components); } diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 2ccb919acf..5e69419333 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" @@ -206,7 +207,7 @@ ENTRY main { auto is_while = [](const HloInstruction* instr) { return instr->opcode() == HloOpcode::kWhile; }; - EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); + EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index 8763e588c4..a7f0e207eb 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -24,7 +24,7 @@ namespace xla { class ZeroSizedHloElimination : public HloPassInterface { public: StatusOr<bool> Run(HloModule* module) override; - tensorflow::StringPiece name() const override { + absl::string_view name() const override { return "zero_sized_hlo_elimination"; } }; |