diff options
Diffstat (limited to 'tensorflow/compiler')
360 files changed, 11553 insertions, 4862 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7a0932d44d..10fa33ab5e 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -25,6 +25,7 @@ test_suite( ":test_graph_tfmatmul_test", ":test_graph_tfmatmulandadd_test", ":test_graph_tfsplits_test", + ":test_graph_tftop_k_test", ":tfcompile_test", ], ) @@ -42,6 +43,7 @@ py_binary( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:session", "//tensorflow/python:training", @@ -66,6 +68,7 @@ genrule( "test_graph_tfmatmul.pb", "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", + "test_graph_tftop_k.pb", ], # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any # GPUs which might be present. This is important because builds may run @@ -208,6 +211,17 @@ tf_library( ], ) +tf_library( + name = "test_graph_tftop_k", + testonly = 1, + config = "test_graph_tftop_k.config.pbtxt", + cpp_class = "TopKComp", + graph = "test_graph_tftop_k.pb", + tags = [ + "manual", + ], +) + tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], @@ -226,6 +240,7 @@ tf_cc_test( ":test_graph_tfmatmulandadd", ":test_graph_tfmatmulandadd_with_profiling", ":test_graph_tfsplits", + ":test_graph_tftop_k", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9ec7df163b..64b861a730 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib @@ -46,7 +47,7 @@ def tfadd(_): def tfadd_with_ckpt(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -61,7 +62,7 @@ def tfadd_with_ckpt(out_dir): def tfadd_with_ckpt_saver(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') - y = variables.Variable(constant_op.constant([0]), name='y_saved') + y = variables.VariableV1(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') init_op = variables.initialize_all_variables() @@ -142,6 +143,12 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tftop_k(_): + x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') + output = nn_ops.top_k(x, 2, name='values') + array_ops.identity(output[1], name='indices') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -163,6 +170,7 @@ def main(_): write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tftop_k, FLAGS.out_dir) if __name__ == '__main__': diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt new file mode 100644 index 0000000000..6b4ac2d7cb --- /dev/null +++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt @@ -0,0 +1,13 @@ +# Text form of tensorflow.tf2xla.Config proto. +feed { + id { node_name: "x" } + shape { + dim { size: 5 } + } +} +fetch { + id { node_name: "values" } +} +fetch { + id { node_name: "indices" } +} diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 7ac90fb8a9..f10852c785 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, TopK) { + Eigen::ThreadPool tp(1); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + TopKComp fn; + + fn.set_thread_pool(&device); + // x = [4, 1, 4, 4, 3] + fn.arg0(0) = 4; + fn.arg0(1) = 1; + fn.arg0(2) = 4; + fn.arg0(3) = 4; + fn.arg0(4) = 3; + + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + const int32 expected_values[] = {4, 4}; + const int32 expected_indices[] = {0, 2}; + EXPECT_EQ(expected_values[0], fn.result0(0)); + EXPECT_EQ(expected_values[1], fn.result0(1)); + EXPECT_EQ(expected_indices[0], fn.result1(0)); + EXPECT_EQ(expected_indices[1], fn.result1(1)); +} + TEST(TFCompileTest, AssertEqAndReturnDiff) { // Assert is converted into a no-op in XLA, so there is no failure even if the // two args are different. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 792b7fe14a..859c84bb91 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -273,6 +273,7 @@ def tf_library( "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 1001c57f3d..661b444a42 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -26,6 +26,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -50,7 +51,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", ], @@ -62,7 +63,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda([ ":jit_compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", ]), @@ -76,7 +77,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/jit/legacy_flags:xla_device_flags", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -94,7 +95,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep @@ -111,7 +112,7 @@ cc_library( deps = [ ":jit_compilation_passes", ":xla_device", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep @@ -257,6 +258,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -280,7 +282,7 @@ cc_library( deps = [ ":common", ":compilation_passes", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -322,6 +324,8 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -341,7 +345,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -359,7 +363,7 @@ tf_cc_test( cc_library( name = "compilation_passes", srcs = [ - "build_xla_launch_ops_pass.cc", + "build_xla_ops_pass.cc", "deadness_analysis.cc", "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", @@ -369,7 +373,7 @@ cc_library( "partially_decluster_pass.cc", ], hdrs = [ - "build_xla_launch_ops_pass.h", + "build_xla_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "encapsulate_xla_computations_pass.h", @@ -382,12 +386,16 @@ cc_library( ":shape_inference_helpers", ":union_find", ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", @@ -399,6 +407,8 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -459,7 +469,7 @@ tf_cc_test( "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -470,6 +480,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -477,6 +488,7 @@ tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ + "build_xla_ops_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", @@ -485,6 +497,7 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":node_matchers", ":xla_cluster_util", ":xla_gpu_device", "//tensorflow/cc:cc_ops", @@ -493,7 +506,7 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:sendrecv_ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:test_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", @@ -506,6 +519,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler/optimizers/data:graph_utils", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -524,7 +538,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/compiler/jit/kernels:xla_launch_op", + "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu", @@ -628,6 +642,15 @@ tf_cc_test( ], ) +tf_custom_op_py_library( + name = "xla_ops_py", + kernels = ["//tensorflow/compiler/jit/ops:xla_ops"], + visibility = [ + ":friends", + ], + deps = ["//tensorflow/compiler/jit/ops:xla_ops_wrapper_py"], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc deleted file mode 100644 index b17ff589e2..0000000000 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/dump_graph.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/common_runtime/optimization_registry.h" -#include "tensorflow/core/framework/graph_def_util.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/public/version.h" - -namespace tensorflow { - -static Status BuildLaunchNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, const DataTypeVector& result_dtypes, - Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("XlaLaunch"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - AddNodeAttr("Tresults", result_dtypes, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); - - Status status; - *node = graph->AddNode(def, &status); - return status; -} - -static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) { - VLOG(2) << "Replacing " << node->name() << " with XlaLaunch"; - - int num_constant_args, num_resource_args; - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, &num_constant_args)); - TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, &num_resource_args)); - - if (num_constant_args < 0 || num_resource_args < 0 || - num_constant_args + num_resource_args > node->num_inputs()) { - return errors::InvalidArgument( - "Invalid number of constant/resource arguments to XLA kernel."); - } - const int num_nonconst_args = - node->num_inputs() - num_constant_args - num_resource_args; - - DataTypeVector const_dtypes(node->input_types().begin(), - node->input_types().begin() + num_constant_args); - DataTypeVector arg_dtypes( - node->input_types().begin() + num_constant_args, - node->input_types().begin() + num_constant_args + num_nonconst_args); - - // Build a XlaLaunch operator to execute the function body. - Node* launch_node; - TF_RETURN_IF_ERROR(BuildLaunchNode( - graph->NewName(node->name()), node->type_string(), node->def().attr(), - node->requested_device(), const_dtypes, num_resource_args, arg_dtypes, - node->output_types(), graph, &launch_node)); - launch_node->set_assigned_device_name(node->assigned_device_name()); - - // Copy incoming edges to the launch node. - for (const Edge* edge : node->in_edges()) { - if (edge->IsControlEdge()) { - graph->AddControlEdge(edge->src(), launch_node); - } else { - graph->AddEdge(edge->src(), edge->src_output(), launch_node, - edge->dst_input()); - } - } - - // Copy outgoing edges to the launch node. - std::vector<const Edge*> out_edges(node->out_edges().begin(), - node->out_edges().end()); - for (const Edge* edge : out_edges) { - Node* dst = edge->dst(); - int src_output = edge->src_output(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - if (edge->IsControlEdge()) { - graph->AddControlEdge(launch_node, dst); - } else { - graph->AddEdge(launch_node, src_output, dst, dst_input); - } - } - graph->RemoveNode(node); - - return Status::OK(); -} - -Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) { - Graph* graph = options.graph->get(); - - for (Node* n : graph->op_nodes()) { - // In all cases, only try to compile computational nodes. - if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { - continue; - } - - // Only compile nodes that are marked for compilation by the - // compilation-marking pass (via 'attr_name'). - if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n)); - } - } - - if (VLOG_IS_ON(1)) { - dump_graph::DumpGraphToFile("build_xla_launch_ops", *graph, - options.flib_def); - } - return Status::OK(); -} -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc new file mode 100644 index 0000000000..5974696b77 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "absl/algorithm/container.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace { +void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector<const Edge*> out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + // TODO(sanjoy): This does not update NodeDef inputs. To be able to update + // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up + // the NodeDef inputs to the function call nodes. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); + g->RemoveEdge(edge); + } +} + +struct XlaClusterInfo { + std::vector<Output> constant_inputs; + std::vector<Output> non_constant_inputs; + std::vector<Output> resource_inputs; + NameAttrList function; +}; + +Output IncomingEdgeAsOutput(const Edge* e) { + return Output(e->src(), e->src_output()); +} + +Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { + int num_constant_inputs, num_resource_inputs; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs)); + + if (num_constant_inputs < 0 || num_resource_inputs < 0 || + num_constant_inputs + num_resource_inputs > n->num_inputs()) { + return errors::InvalidArgument( + "Invalid number of constant/resource arguments to XLA kernel."); + } + + int num_non_constant_inputs = + n->num_inputs() - num_constant_inputs - num_resource_inputs; + + std::vector<const Edge*> input_edges_vector; + TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector)); + absl::Span<const Edge*> input_edges(input_edges_vector); + + absl::c_transform(input_edges.subspan(0, num_constant_inputs), + std::back_inserter(result->constant_inputs), + IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs, num_non_constant_inputs), + std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput); + + absl::c_transform( + input_edges.subspan(num_constant_inputs + num_non_constant_inputs, + num_resource_inputs), + std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput); + + result->function.set_name(n->type_string()); + *result->function.mutable_attr() = n->def().attr(); + return Status::OK(); +} + +Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { + for (const Edge* e : from->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), to); + } + } + + return Status::OK(); +} + +Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { + Status status; + Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(n->name()) + .WithDevice(n->requested_device()) + .WithAssignedDevice(n->assigned_device_name()); + + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + + ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), + /*constants=*/cluster_info.constant_inputs, + /*args=*/cluster_info.non_constant_inputs, + /*resources=*/cluster_info.resource_inputs, + cluster_info.function); + TF_RETURN_IF_ERROR( + CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); + + std::vector<Output> xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); + + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); + g->RemoveNode(n); + + return Status::OK(); +} +} // namespace + +Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { + Graph* graph = options.graph->get(); + + for (Node* n : graph->op_nodes()) { + // In all cases, only try to compile computational nodes. + if (n->IsSend() || n->IsRecv() || n->IsControlFlow()) { + continue; + } + + // Only compile nodes that are marked for compilation by the + // compilation-marking pass (via 'attr_name'). + if (IsXlaCompiledKernel(*n)) { + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n)); + } + } + + if (VLOG_IS_ON(1)) { + dump_graph::DumpGraphToFile("build_xla_ops", *graph, options.flib_def); + } + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h index 1dfea93f02..1dd38fa951 100644 --- a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -13,19 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ -#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -class BuildXlaLaunchOpsPass : public GraphOptimizationPass { +// Adds _XlaCompile and _XlaRun operations to the TF graph that compiles and +// executes (using XLA) TF function calls marked with "_XlaCompiledKernel". +class BuildXlaOpsPass : public GraphOptimizationPass { public: Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_ +#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_OPS_PASS_H_ diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc new file mode 100644 index 0000000000..9d56db7b6b --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::NodeWith; +using ::tensorflow::testing::matchers::Op; + +Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) { + auto graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : graph->nodes()) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = &graph; + BuildXlaOpsPass pass; + TF_RETURN_IF_ERROR(pass.Run(opt_options)); + *result = std::move(graph); + return Status::OK(); +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, int num_constant_args, + int num_resource_args, Node** result) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); + Status s; + *result = graph->AddNode(call_node, &s); + return s; +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + return MakeXlaCompiledKernel(graph, callee_name, node_name, + /*num_constant_args=*/0, /*num_resource_args=*/0, + result); +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +TEST(BuildXlaOps, ControlDepsPreserved) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr<Graph> graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); +} + +TEST(BuildXlaOps, CleanFailureOnBogusAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK( + MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr<Graph> graph; + Status failure_status = BuildXlaOps(root, &graph); + ASSERT_FALSE(failure_status.ok()); + EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/create_xla_launch_op.cc b/tensorflow/compiler/jit/create_xla_launch_op.cc index 56b034a30b..6f1ff85f24 100644 --- a/tensorflow/compiler/jit/create_xla_launch_op.cc +++ b/tensorflow/compiler/jit/create_xla_launch_op.cc @@ -16,7 +16,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 9128b48da3..b7ae7fbeb3 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -14,11 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/deadness_analysis.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" // ALGORITHM OVERVIEW @@ -296,7 +299,7 @@ class SymbolPredicate : public Predicate { template <typename FunctionTy> /*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) { - gtl::FlatSet<Predicate*> visited; + absl::flat_hash_set<Predicate*> visited; std::vector<Predicate*> stack; stack.push_back(p); @@ -383,6 +386,8 @@ class PredicateFactory { } Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and); + Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops, + Predicate::Kind pred_kind); // Predicate instances are interned, meaning that there is only a single // instance of a Predicate object with a given content. This makes checking @@ -417,24 +422,53 @@ class PredicateFactory { } }; - gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>, - HashSignatureForAndOr> + absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>, + HashSignatureForAndOr> interned_and_or_instances_; - gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>> + absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>> interned_not_instances_; - gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>> + absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>> interned_and_rec_instances_; - gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>, - HashSignatureForSymbol> + absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>, + HashSignatureForSymbol> interned_symbol_instances_; }; +Predicate* PredicateFactory::MakeInternedAndOr( + std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) { + std::stable_sort( + simplified_ops.begin(), simplified_ops.end(), + [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + + auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); + if (it != interned_and_or_instances_.end()) { + return it->second.get(); + } + + simplified_ops.shrink_to_fit(); + // NB! Because we'll use a non-owning reference to simplified_ops in the + // key for interned_and_or_instances_ we need to be careful to std::move() + // it all the way through. + absl::Span<Predicate* const> operands_slice = simplified_ops; + std::unique_ptr<Predicate> new_pred = + pred_kind == Predicate::Kind::kAnd + ? Make<AndPredicate>(std::move(simplified_ops)) + : Make<OrPredicate>(std::move(simplified_ops)); + + Predicate* new_pred_ptr = new_pred.get(); + interned_and_or_instances_.emplace( + SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred)); + return new_pred_ptr; +} + // Common code to create AndPredicate or OrPredicate instances. Predicate* PredicateFactory::MakeAndOrImpl( absl::Span<Predicate* const> operands, bool is_and) { Predicate::Kind pred_kind = is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr; - gtl::FlatSet<Predicate*> simplified_ops_set; + Predicate::Kind other_pred_kind = + is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd; + absl::flat_hash_set<Predicate*> simplified_ops_set; std::vector<Predicate*> simplified_ops; for (Predicate* op : operands) { // Simplify A&A => A and A|A => A. @@ -459,7 +493,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } // Simplify "A&~A=>False" and "A|~A=>True". - gtl::FlatSet<Predicate*> negated_ops; + absl::flat_hash_set<Predicate*> negated_ops; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kNot) { negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand()); @@ -472,30 +506,63 @@ Predicate* PredicateFactory::MakeAndOrImpl( } } - std::stable_sort( - simplified_ops.begin(), simplified_ops.end(), - [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); }); + // If all ops contain the same subop, then factor it out thanks to the + // distributive property. Such as: + // - (A & B) | (A & C) | (A & D) => A & (B | C | D) + // - (A | B) & (A | C) & (A | D) => A | (B & C & D) + // + // First find any predicates contained in all subops. + std::vector<Predicate*> common_inner_operands; + absl::flat_hash_set<Predicate*> common_inner_operands_set; + for (Predicate* op : simplified_ops) { + if (op->kind() != other_pred_kind) { + common_inner_operands.clear(); + break; + } - auto it = interned_and_or_instances_.find({pred_kind, simplified_ops}); - if (it == interned_and_or_instances_.end()) { - simplified_ops.shrink_to_fit(); - // NB! Because we'll use a non-owning reference to simplified_ops in the - // key for interned_and_or_instances_ we need to be careful to std::move() - // it all the way through. - absl::Span<Predicate* const> operands_slice = simplified_ops; - std::unique_ptr<Predicate> new_pred = - is_and ? Make<AndPredicate>(std::move(simplified_ops)) - : Make<OrPredicate>(std::move(simplified_ops)); + if (common_inner_operands.empty()) { + common_inner_operands.insert(common_inner_operands.end(), + op->GetOperands().begin(), + op->GetOperands().end()); + } else { + std::vector<Predicate*> sub_ops_intersection; + common_inner_operands.clear(); + absl::c_copy_if(op->GetOperands(), + std::back_inserter(common_inner_operands), + [&](Predicate* sub_op) { + return common_inner_operands_set.count(sub_op) == 1; + }); + } + if (common_inner_operands.empty()) break; + common_inner_operands_set.clear(); + common_inner_operands_set.insert(common_inner_operands.begin(), + common_inner_operands.end()); + } - Predicate* new_pred_ptr = new_pred.get(); - CHECK(interned_and_or_instances_ - .emplace(SignatureForAndOr(pred_kind, operands_slice), - std::move(new_pred)) - .second); - return new_pred_ptr; - } else { - return it->second.get(); + if (common_inner_operands.empty()) { + return MakeInternedAndOr(std::move(simplified_ops), pred_kind); } + + // For all predicates that can be factored out, remove them and recreate the + // subops. + std::vector<Predicate*> factored_ops; + for (Predicate* op : simplified_ops) { + std::vector<Predicate*> new_sub_op_ops; + absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops), + [&](Predicate* sub_op) { + return std::find(common_inner_operands.begin(), + common_inner_operands.end(), + sub_op) == common_inner_operands.end(); + }); + factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and)); + } + + Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and); + std::vector<Predicate*> outer_ops; + outer_ops.push_back(new_inner_op); + outer_ops.insert(outer_ops.end(), common_inner_operands.begin(), + common_inner_operands.end()); + return MakeAndOrImpl(outer_ops, !is_and); } class DeadnessAnalysisImpl : public DeadnessAnalysis { @@ -507,12 +574,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo); bool HasInputsWithMismatchingDeadness(const Node& node) override; void Print() const override; - gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const; + absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString() + const; private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind); + Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector<Predicate*>* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -549,7 +618,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector<bool>* should_revisit); const Graph& graph_; - gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_; + absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_; PredicateFactory predicate_factory_; bool vlog_; }; @@ -558,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds( - Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) { - std::vector<Predicate*> incoming_preds; +Status DeadnessAnalysisImpl::GetInputPreds( + Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, + std::vector<Predicate*>* result) { + result->clear(); for (const Edge* in_edge : n->in_edges()) { bool should_process = edge_kind == EdgeKind::kDataAndControl || @@ -569,17 +639,27 @@ std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds( if (should_process) { auto it = predicate_map_.find(InputEdgeToTensorId(in_edge)); - CHECK(it != predicate_map_.end()) << n->name(); - incoming_preds.push_back(it->second); + if (it == predicate_map_.end()) { + GraphCycles graph_cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles)); + + // If we didn't return with an error above then the graph is probably + // fine and we have a bug in deadness analysis. + return errors::Internal("Could not find input ", in_edge->DebugString(), + " to ", n->name(), + " when visiting the graph in post-order. Most " + "likely indicates a bug in deadness analysis."); + } + result->push_back(it->second); } } - return incoming_preds; + return Status::OK(); } Status DeadnessAnalysisImpl::HandleSwitch(Node* n, std::vector<bool>* should_revisit) { - std::vector<Predicate*> input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector<Predicate*> input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge)); Predicate* true_switch = predicate_factory_.MakeSymbolPredicate( @@ -608,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -const Edge* FindUniqueBackedge(Node* merge) { +Status CreateMultipleNextIterationInputsError(Node* merge) { + std::vector<string> backedges; + for (const Edge* backedge : merge->in_edges()) { + if (backedge->src()->IsNextIteration()) { + backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src()))); + } + } + return errors::InvalidArgument( + "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge), + ": \n", absl::StrJoin(backedges, "\n"), + "\nMerge nodes can have at most one incoming NextIteration edge."); +} + +Status FindUniqueBackedge(Node* merge, const Edge** result) { + *result = nullptr; CHECK(merge->IsMerge()); - const Edge* result = nullptr; for (const Edge* e : merge->in_edges()) { if (e->src()->IsNextIteration()) { - CHECK_EQ(result, nullptr) - << "Multiple backedges to " << merge->DebugString(); - result = e; + if (*result != nullptr) { + return CreateMultipleNextIterationInputsError(merge); + } + *result = e; } } - return result; + return Status::OK(); } // If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step @@ -697,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return Status::OK(); } + std::vector<Predicate*> input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); + // We're visiting this merge for the first time and it is a acyclic merge. - Predicate* input_data_pred = predicate_factory_.MakeOrPredicate( - GetIncomingPreds(n, EdgeKind::kDataOnly)); + Predicate* input_data_pred = + predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); return Status::OK(); @@ -710,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // of an unvisited backedge. Try to pattern match the predicate expression // for that backedge (which should be visited now) into an and recurrence // for the merge node. - if (const Edge* unique_backedge = FindUniqueBackedge(n)) { + const Edge* unique_backedge; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge)); + if (unique_backedge) { if (Predicate* step = DeduceStepPredicate( &predicate_factory_, it->second, predicate_map_[InputEdgeToTensorId(unique_backedge)])) { @@ -741,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, std::vector<bool>* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. - std::vector<Predicate*> input_preds = - GetIncomingPreds(n, EdgeKind::kDataAndControl); + std::vector<Predicate*> input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); input_preds.push_back(predicate_factory_.MakeSymbolPredicate( TensorId(n->name(), 0), /*must_be_true=*/false)); SetPredicate(n, {0, Graph::kControlSlot}, @@ -754,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, Status DeadnessAnalysisImpl::HandleGeneric(Node* n, std::vector<bool>* should_revisit) { // Generally nodes are alive iff all their inputs are alive. - Predicate* pred = predicate_factory_.MakeAndPredicate( - GetIncomingPreds(n, EdgeKind::kDataAndControl)); + std::vector<Predicate*> input_preds; + TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); + Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds); for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) { SetPredicate(n, output_idx, pred, should_revisit); } @@ -912,9 +1012,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {} return Status::OK(); } -gtl::FlatMap<TensorId, string, TensorId::Hasher> +absl::flat_hash_map<TensorId, string, TensorId::Hasher> DeadnessAnalysisImpl::PredicateMapAsString() const { - gtl::FlatMap<TensorId, string, TensorId::Hasher> result; + absl::flat_hash_map<TensorId, string, TensorId::Hasher> result; std::vector<TensorId> tensor_ids; for (const auto& kv_pair : predicate_map_) { CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second); diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 3df2679c62..354782374a 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -16,15 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ #define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. -using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>; +using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); // Returns a map describing the predicate each Tensor was mapped to. For diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 28a56044d5..617e31488c 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) { EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node())); } -TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { - // This demonstrates one of the weaknesses in the current approach -- since we - // only do some basic simplifications we can't see that "(A|B)&C" == - // "(A&C)|(B&C)". +TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) { + // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true + Scope root = Scope::NewRootScope().ExitOnError(); + + ops::Switch sw_0 = CreateSwitch(root, "A"); + ops::Switch sw_1 = CreateSwitch(root, "B"); + Output add0 = + ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true); + Output add1 = + ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false); + ops::Merge or2(root.WithOpName("or2"), {add0, add1}); + Output add3 = + ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false); + ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true}); + + std::unique_ptr<DeadnessAnalysis> result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true"); +} + +TEST(DeadnessAnalysisTest, AndOrDistributive) { + // (A|B)&C == (A&C)|(B&C) Scope root = Scope::NewRootScope().ExitOnError(); ops::Switch sw_0 = CreateSwitch(root, "0"); @@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) { std::unique_ptr<DeadnessAnalysis> result; TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); - EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node())); + EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node())); } TEST(DeadnessAnalysisTest, Ternary) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index e0632ff7e4..da27f837e8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/public/session_options.h" @@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) { namespace { bool AreAllParentsGuaranteedConst( - const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) { + const Node& n, + const absl::flat_hash_set<const Node*>& runtime_const_nodes) { if (n.type_string() == "GuaranteeConst") { // If the current node is itself a cast-to-const, no need // to look at the incoming edges. @@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst( void MarkGuaranteedConstants( const Graph& graph, const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) { - gtl::FlatSet<const Node*> guaranteed_const_nodes; + absl::flat_hash_set<const Node*> guaranteed_const_nodes; std::vector<const Node*> srcs; srcs.reserve(src_arg_pairs.size()); for (const auto& src_arg : src_arg_pairs) { @@ -748,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { graph_->set_versions(graph_in->versions()); } + // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is + // determined. In case of hard placement, ensure all the encapsulated nodes + // have the same requested device, which in turn will be the requested device + // for the entire encapsulated subgraph. In case of soft placement, use a + // deterministic approach to fill in the requested device. Handle co-location + // constraints similarly if they exist. if (device_.empty()) { device_ = node->assigned_device_name().empty() ? node->requested_device() @@ -1357,28 +1364,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( Status Encapsulator::GetFunctionNameAttr( Node const* node, string* attr, string* outside_compilation_attr) const { - Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no group_attribute. - attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - } - bool has_group_attr = s.ok(); - s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, - outside_compilation_attr); - if (s.code() == error::Code::NOT_FOUND) { - // Return empty attr if there's no outside_compilation attribute. - outside_compilation_attr->clear(); - } else { - TF_RETURN_IF_ERROR(s); - if (!has_group_attr) { - return errors::InvalidArgument( - "Node ", node->name(), " has ", outside_compilation_attribute_, - " attribute but no ", group_attribute_, " attribute."); + AttrSlice attrs = node->attrs(); + attr->clear(); + outside_compilation_attr->clear(); + bool found_group_attribute = false; + bool found_outside_compilation_attribute = false; + for (const auto& node_attr : attrs) { + if (node_attr.first == group_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *attr = node_attr.second.s(); + found_group_attribute = true; + } else if (node_attr.first == outside_compilation_attribute_) { + TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string")); + *outside_compilation_attr = node_attr.second.s(); + found_outside_compilation_attribute = true; } + if (found_group_attribute && found_outside_compilation_attribute) break; + } + + if (found_outside_compilation_attribute && !found_group_attribute) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } else { + return Status::OK(); } - return Status::OK(); } bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index 97ef8cd3cb..2ce6fa73fc 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -15,13 +15,13 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) { } // Adds the control inputs of `node` to `*deps`. -void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.in_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->src()); @@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) { } // Adds the control outputs of `node` to `*deps`. -void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) { +void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) { for (const Edge* edge : node.out_edges()) { if (edge->IsControlEdge()) { deps->insert(edge->dst()); @@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Data and control inputs to the new XlaLaunch node. std::vector<std::pair<Node*, int>> data_inputs(num_inputs); - gtl::FlatSet<Node*> control_inputs; + absl::flat_hash_set<Node*> control_inputs; DataTypeVector arg_types(num_args); AddControlInputs(*launch, &control_inputs); @@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Outputs. const int num_outputs = launch->output_types().size(); - gtl::FlatSet<Node*> control_outputs; + absl::flat_hash_set<Node*> control_outputs; std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs); DataTypeVector output_types(num_outputs); @@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors, // Target the XLA CPU/GPU backends. VLOG(2) << "Replacing with XlaLaunch"; + VLOG(2) << "Device is " << launch->requested_device(); def.set_op("XlaLaunch"); + def.set_device(launch->requested_device()); AddNodeAttr("Tconstants", DataTypeVector{}, &def); AddNodeAttr("Targs", arg_types, &def); AddNodeAttr("Nresources", num_variables, &def); diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index f643fb0cfe..22531a4ace 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -55,6 +55,7 @@ static std::unique_ptr<Graph> MakeOuterGraph( .Input(u.node()->name(), 0, DT_RESOURCE) .Input(v.node()->name(), 0, DT_RESOURCE) .Input(w.node()->name(), 0, DT_RESOURCE) + .Device("/gpu:0") .Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0") .Attr("_variable_start_index", 4) .Finalize(&def)); @@ -107,10 +108,11 @@ static std::unique_ptr<Graph> MakeBodyGraph() { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1); - + add_attrs(b_identity.node()); auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT); add_attrs(read_u.node()); auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT); @@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) { auto add_attrs = [](Node* node) { node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0"); + node->set_requested_device("/gpu:0"); }; auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b); @@ -317,8 +320,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) { NameAttrList function; function.set_name("launch0"); auto launch = ops::XlaLaunch( - scope.WithOpName("launch0"), std::initializer_list<Input>{}, - std::initializer_list<Input>{a, b, c, d}, + scope.WithOpName("launch0").WithDevice("/gpu:0"), + std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d}, std::initializer_list<Input>{u, v, w}, DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function); diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 3770eea6d0..085c0e5adb 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" @@ -55,6 +55,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, // Must run after EncapsulateSubgraphsPass. REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, - BuildXlaLaunchOpsPass); + BuildXlaOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 253a5d2547..26cb3af9d6 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -7,9 +7,9 @@ package( ) cc_library( - name = "xla_launch_op", - srcs = ["xla_launch_op.cc"], - hdrs = ["xla_launch_op.h"], + name = "xla_ops", + srcs = ["xla_ops.cc"], + hdrs = ["xla_ops.h"], deps = [ "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:xla_compilation_cache", @@ -26,6 +26,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/kernels:variable_ops", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc deleted file mode 100644 index b6f2f632f7..0000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ /dev/null @@ -1,276 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" - -#include "tensorflow/compiler/jit/defs.h" -#include "tensorflow/compiler/jit/xla_launch_util.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/common_runtime/function.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/node_def_util.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/kernels/variable_ops.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector<int>& constants, - const std::vector<int>& resources, - const NameAttrList& function) - : OpKernel(ctx), - constants_(constants), - resources_(resources), - device_type_(ctx->device_type()), - function_(function) { - if (device_type_ == DeviceType(DEVICE_CPU)) { - platform_id_ = se::host::kHostPlatformId; - } else if (device_type_ == DeviceType(DEVICE_GPU)) { - platform_id_ = ctx->device() - ->tensorflow_gpu_device_info() - ->stream->parent() - ->platform() - ->id(); - } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata_).ok()) { - use_multiple_streams_ = xla_device_metadata_->UseMultipleStreams(); - platform_id_ = xla_device_metadata_->platform()->id(); - } -} - -Status XlaLocalLaunchBase::BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache) { - if (xla_device_metadata_) { - *cache = new XlaCompilationCache(xla_device_metadata_->client(), - xla_device_metadata_->jit_device_type()); - return Status::OK(); - } - - auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_); - if (!platform.ok()) { - return platform.status(); - } - xla::LocalClientOptions client_options; - client_options.set_platform(platform.ValueOrDie()); - client_options.set_intra_op_parallelism_threads( - ctx->device()->tensorflow_cpu_worker_threads()->num_threads); - auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); - if (!client.ok()) { - return client.status(); - } - const XlaOpRegistry::DeviceRegistration* registration; - if (!XlaOpRegistry::GetCompilationDevice(device_type_.type(), - ®istration)) { - return errors::InvalidArgument("No JIT device registered for ", - device_type_.type()); - } - *cache = new XlaCompilationCache( - client.ValueOrDie(), DeviceType(registration->compilation_device_name)); - return Status::OK(); -} - -void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { - VLOG(1) << "XlaLocalLaunchOpBase::Compute " - << Canonicalize(function_.name(), AttrSlice(&function_.attr())); - // We store information about the JIT-compiled XLA computation - // in the ResourceMgr. - ResourceMgr* rm = ctx->resource_manager(); - OP_REQUIRES(ctx, rm, errors::Internal("No resource manager.")); - - se::Stream* stream = - ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - - XlaCompilationCache* cache; - OP_REQUIRES_OK(ctx, rm->LookupOrCreate<XlaCompilationCache>( - rm->default_container(), "xla_cache", &cache, - [this, ctx](XlaCompilationCache** cache) { - return BuildCompilationCache(ctx, cache); - })); - // Hold the reference to the JIT during evaluation. (We could probably - // free it sooner because the ResourceMgr will retain a reference, but - // this is more obviously correct.) - core::ScopedUnref cache_ref(cache); - - std::map<int, OptionalTensor> variables = - SnapshotResourceVariables(ctx, resources_); - - xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client()); - - XlaAllocator local_xla_allocator(client->backend().platform(), - ctx->device()->GetAllocator({})); - xla::DeviceMemoryAllocator* xla_allocator; - // If we are on an XlaDevice, use the underlying XLA platform's allocator - // directly. We could use the StreamExecutor's allocator which may - // theoretically be more correct, but XLA returns a nice OOM message in a - // Status and StreamExecutor does not. - // - // Importantly we can't use ctx->device()->GetAllocator() as the allocator - // (which local_xla_allocator above uses) as on an XlaDevice, this is a - // dummy allocator that returns XlaTensor objects. The XlaCompiler needs a - // real allocator to allocate real buffers. - if (xla_device_metadata_) { - xla_allocator = client->backend().memory_allocator(); - } else { - xla_allocator = &local_xla_allocator; - } - - XlaCompiler::Options options; - options.client = client; - if (ctx->op_device_context() != nullptr) { - options.device_ordinal = - ctx->op_device_context()->stream()->parent()->device_ordinal(); - } - options.device_type = cache->device_type(); - options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); - options.graph_def_version = ctx->function_library()->graph_def_version(); - options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); - options.device_allocator = xla_allocator; - if (xla_device_metadata_) { - options.shape_representation_fn = - xla_device_metadata_->shape_representation_fn(); - } - - const XlaCompiler::CompilationResult* kernel; - xla::LocalExecutable* executable; - - std::map<int, Tensor> constant_args; - for (int i : constants_) { - constant_args.insert({i, ctx->input(i)}); - } - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - // If we resolve constants we never emit them on the device, meaning that if - // they are needed by a following computation the host has to transfer - // them. Not resolving constants is expected to be faster than resolving - // constants. - compile_options.resolve_compile_time_constants = true; - // Optimization: where possible, have the computation return a naked array - // rather than a one-element tuple. - compile_options.always_return_tuple = false; - - OP_REQUIRES_OK( - ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, compile_options)); - - VLOG(1) << "Executing XLA Computation..."; - - XlaComputationLaunchContext launch_context( - client, xla_allocator, - /*allocate_xla_tensors=*/xla_device_metadata_ != nullptr, - use_multiple_streams_); - launch_context.PopulateInputs(ctx, kernel, variables); - - // Execute the computation. - VLOG(2) << "Executing computation."; - xla::ExecutableRunOptions run_options; - run_options.set_stream(stream); - run_options.set_allocator(xla_allocator); - run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); - run_options.set_rng_seed(GetXLARandomSeed()); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); - - auto run_result = executable->Run(launch_context.arguments(), run_options); - OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; - - OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie())); - VLOG(1) << "Done"; -} - -namespace { - -// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that -// in error case, it returns RET instead of void. -#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ - do { \ - ::tensorflow::Status _s(__VA_ARGS__); \ - if (!TF_PREDICT_TRUE(_s.ok())) { \ - (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ - return RET; \ - } \ - } while (0) - -// Helper static functions to construct parameters for -// XlaLocalLaunchBase constructor from OpKernelConstruction. -std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), - ctx->GetAttr("Tconstants", &constant_types)); - std::vector<int> constants(constant_types.size()); - std::iota(constants.begin(), constants.end(), 0); - return constants; -} - -std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { - DataTypeVector constant_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), - ctx->GetAttr("Tconstants", &constant_types)); - - DataTypeVector arg_types; - OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), - ctx->GetAttr("Targs", &arg_types)); - - int num_resources; - OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), - ctx->GetAttr("Nresources", &num_resources)); - - std::vector<int> resources(num_resources); - std::iota(resources.begin(), resources.end(), - constant_types.size() + arg_types.size()); - return resources; -} - -NameAttrList FunctionAttr(OpKernelConstruction* ctx) { - const NameAttrList* func; - OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); - return *func; -} - -#undef OP_REQUIRES_OK_RETURN -} // namespace - -XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) - : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), - FunctionAttr(ctx)) {} - -XlaLocalLaunchOp::~XlaLocalLaunchOp() { - VLOG(1) << "XlaLocalLaunchOp destroyed"; -} - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); - -REGISTER_KERNEL_BUILDER(Name("XlaLaunch") - .Device(DEVICE_GPU) - .HostMemory("constants") - .HostMemory("resources"), - XlaLocalLaunchOp); - -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h deleted file mode 100644 index e0f10e9817..0000000000 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ -#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ - -#include "tensorflow/compiler/jit/xla_compilation_cache.h" -#include "tensorflow/compiler/jit/xla_device.h" -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/util/stream_executor_util.h" - -namespace tensorflow { - -// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. -// The only difference is that it does not require arguments to follow -// the "constants, then regular args, then resources" order. -// It takes vectors of constant and resource arguments explicitly. -// It does not have corresponding OpDef because it is never present -// in the GraphDef. -// Currently, it is used by eager runtime. FunctionLibraryRuntime creates -// this kernel when asked to create a kernel for an XLA-compiled function. -class XlaLocalLaunchBase : public OpKernel { - public: - XlaLocalLaunchBase(OpKernelConstruction* ctx, - const std::vector<int>& constants, - const std::vector<int>& resources, - const NameAttrList& function); - XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; - XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; - ~XlaLocalLaunchBase() override = default; - - void Compute(OpKernelContext* ctx) override; - - protected: - // Builds a XlaCompilationCache class suitable for the current device. - Status BuildCompilationCache(OpKernelContext* ctx, - XlaCompilationCache** cache); - - // Indexes of compile-time constant inputs - std::vector<int> constants_; - // Indexes of resource inputs - std::vector<int> resources_; - - DeviceType device_type_; - NameAttrList function_; - se::Platform::Id platform_id_ = nullptr; - bool use_multiple_streams_ = false; - const XlaDevice::Metadata* xla_device_metadata_ = nullptr; -}; - -// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph -// which will be compiled and executed using XLA. The XlaLocalLaunchOp is -// responsible for handling interactions with the TensorFlow executor. -// Once all inputs are present, and their shapes are known, the op can -// use a 'XlaCompilationCache' to compile and execute code which is specific -// to the shapes of input Tensors. -// XlaLocalLaunchOp uses xla::LocalClient::Compile() and -// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device -// memory. -class XlaLocalLaunchOp : public XlaLocalLaunchBase { - public: - explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); - ~XlaLocalLaunchOp() override; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc new file mode 100644 index 0000000000..accc86a86d --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -0,0 +1,500 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/kernels/xla_ops.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/variable_ops.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +namespace { + +Status PlatformInfoFromContext(OpKernelConstruction* ctx, + XlaPlatformInfo* result) { + DeviceType device_type = ctx->device_type(); + se::Platform::Id platform_id = nullptr; + const XlaDevice::Metadata* xla_device_metadata = nullptr; + std::unique_ptr<XlaAllocator> xla_allocator; + xla::DeviceMemoryAllocator* device_allocator = nullptr; + + if (ctx->device_type() == DeviceType(DEVICE_CPU)) { + platform_id = se::host::kHostPlatformId; + } else if (ctx->device_type() == DeviceType(DEVICE_GPU)) { + platform_id = ctx->device() + ->tensorflow_gpu_device_info() + ->stream->parent() + ->platform() + ->id(); + } else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) { + // If we are on an XlaDevice, use the underlying XLA platform's allocator + // directly. We could use the StreamExecutor's allocator which may + // theoretically be more correct, but XLA returns a nice OOM message in a + // Status and StreamExecutor does not. + // + // Importantly we can't use ctx->device()->GetAllocator() as the allocator + // (which xla_allocator above uses) as on an XlaDevice, this is a dummy + // allocator that returns XlaTensor objects. The XlaCompiler needs a real + // allocator to allocate real buffers. + + platform_id = xla_device_metadata->platform()->id(); + device_allocator = + xla_device_metadata->client()->backend().memory_allocator(); + } + + if (!device_allocator) { + TF_ASSIGN_OR_RETURN(se::Platform* const platform, + se::MultiPlatformManager::PlatformWithId(platform_id)); + xla_allocator = absl::make_unique<XlaAllocator>( + platform, ctx->device()->GetAllocator({})); + } + + *result = XlaPlatformInfo(device_type, platform_id, xla_device_metadata, + std::move(xla_allocator), device_allocator); + + return Status::OK(); +} + +// A closure describing how to run a compiled version of a TensorFlow function. +// +// It may seem unusual to stick the resource variable snapshots in this class. +// This is necessary: we need to use the snapshots observed by the compiler as +// the initial values for the resource variables (and cannot snapshot them again +// during execution) because otherwise we risk observing a different snapshot +// with shapes different from what we compiled for. +class XlaExecutableClosure { + public: + explicit XlaExecutableClosure( + xla::LocalClient* client, xla::LocalExecutable* executable, + const XlaCompiler::CompilationResult* compilation_result, + std::map<int, OptionalTensor> resource_var_snapshots, + int num_constant_args) + : client_(client), + executable_(executable), + compilation_result_(compilation_result), + resource_var_snapshots_(std::move(resource_var_snapshots)), + num_constant_args_(num_constant_args) {} + + XlaExecutableClosure(XlaExecutableClosure&&) = default; + XlaExecutableClosure& operator=(XlaExecutableClosure&&) = default; + + xla::LocalClient* client() const { return client_; } + xla::LocalExecutable* executable() const { return executable_; } + const XlaCompiler::CompilationResult* compilation_result() const { + return compilation_result_; + } + const std::map<int, OptionalTensor>& resource_var_snapshots() const { + return resource_var_snapshots_; + } + int num_constant_args() const { return num_constant_args_; } + + private: + xla::LocalClient* client_; + xla::LocalExecutable* executable_; + const XlaCompiler::CompilationResult* compilation_result_; + std::map<int, OptionalTensor> resource_var_snapshots_; + int num_constant_args_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); +}; + +// This maintains a mapping from a globally unique ID to XlaExecutableClosure +// instances. +class XlaExecutableClosureStore { + public: + XlaExecutableClosureStore() : key_counter_(0) {} + + using KeyT = string; + + KeyT Produce(XlaExecutableClosure result) { + mutex_lock l(mutex_); + KeyT key = absl::StrCat(key_counter_++); + bool insert_successful = closures_.emplace(key, std::move(result)).second; + DCHECK(insert_successful); + (void)insert_successful; + return key; + } + + XlaExecutableClosure Consume(const KeyT& key) { + mutex_lock l(mutex_); + auto it = closures_.find(key); + DCHECK(it != closures_.end()); + XlaExecutableClosure value = std::move(it->second); + closures_.erase(it); + return value; + } + + static XlaExecutableClosureStore* Global() { + static XlaExecutableClosureStore* instance = new XlaExecutableClosureStore; + return instance; + } + + private: + mutex mutex_; + int64 key_counter_ GUARDED_BY(mutex_); + absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_); + + TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore); +}; + +} // namespace + +XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function) + : OpKernel(ctx), + constants_(constants), + resources_(resources), + function_(function) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +static Status BuildCompilationCache(OpKernelContext* ctx, + const XlaPlatformInfo& platform_info, + XlaCompilationCache** cache) { + if (platform_info.xla_device_metadata()) { + *cache = new XlaCompilationCache( + platform_info.xla_device_metadata()->client(), + platform_info.xla_device_metadata()->jit_device_type()); + return Status::OK(); + } + + auto platform = + se::MultiPlatformManager::PlatformWithId(platform_info.platform_id()); + if (!platform.ok()) { + return platform.status(); + } + xla::LocalClientOptions client_options; + client_options.set_platform(platform.ValueOrDie()); + client_options.set_intra_op_parallelism_threads( + ctx->device()->tensorflow_cpu_worker_threads()->num_threads); + auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options); + if (!client.ok()) { + return client.status(); + } + const XlaOpRegistry::DeviceRegistration* registration; + if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(), + ®istration)) { + return errors::InvalidArgument("No JIT device registered for ", + platform_info.device_type().type()); + } + *cache = new XlaCompilationCache( + client.ValueOrDie(), DeviceType(registration->compilation_device_name)); + return Status::OK(); +} + +static Status CompileToLocalExecutable( + OpKernelContext* ctx, const NameAttrList& function, + const XlaPlatformInfo& platform_info, absl::Span<const int> resources, + absl::Span<const int> constants, xla::LocalClient** client, + std::map<int, OptionalTensor>* variables, + const XlaCompiler::CompilationResult** kernel, + xla::LocalExecutable** executable) { + // We store information about the JIT-compiled XLA computation + // in the ResourceMgr. + ResourceMgr* rm = ctx->resource_manager(); + if (!rm) { + return errors::Internal("No resource manager."); + } + + XlaCompilationCache* cache; + TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>( + rm->default_container(), "xla_cache", &cache, + [&](XlaCompilationCache** cache) { + return BuildCompilationCache(ctx, platform_info, cache); + })); + // Hold the reference to the JIT during evaluation. (We could probably + // free it sooner because the ResourceMgr will retain a reference, but + // this is more obviously correct.) + core::ScopedUnref cache_ref(cache); + + *variables = SnapshotResourceVariables(ctx, resources); + *client = static_cast<xla::LocalClient*>(cache->client()); + + XlaCompiler::Options options; + options.client = *client; + if (ctx->op_device_context() != nullptr) { + options.device_ordinal = + ctx->op_device_context()->stream()->parent()->device_ordinal(); + } + options.device_type = cache->device_type(); + options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + options.graph_def_version = ctx->function_library()->graph_def_version(); + options.allow_cpu_custom_calls = + (platform_info.platform_id() == se::host::kHostPlatformId); + options.device_allocator = platform_info.allocator(); + if (platform_info.xla_device_metadata()) { + options.shape_representation_fn = + platform_info.xla_device_metadata()->shape_representation_fn(); + } + + std::map<int, Tensor> constant_args; + for (int i : constants) { + constant_args.insert({i, ctx->input(i)}); + } + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + // If we resolve constants we never emit them on the device, meaning that if + // they are needed by a following computation the host has to transfer + // them. Not resolving constants is expected to be faster than resolving + // constants. + compile_options.resolve_compile_time_constants = true; + // Optimization: where possible, have the computation return a naked array + // rather than a one-element tuple. + compile_options.always_return_tuple = false; + + return cache->Compile(options, function, constant_args, *variables, ctx, + compile_options, kernel, executable); +} + +void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { + VLOG(1) << "XlaLocalLaunchOpBase::Compute " + << Canonicalize(function_.name(), AttrSlice(&function_.attr())); + + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map<int, OptionalTensor> variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + + VLOG(1) << "Executing XLA Computation..."; + + XlaComputationLaunchContext launch_context( + client, platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + platform_info_.UseMultipleStreams()); + launch_context.PopulateInputs(ctx, kernel, variables, + /*missing_ctx_input_prefix=*/0); + + // Execute the computation. + VLOG(2) << "Executing computation."; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = executable->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time: " << elapsed << "us"; + + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); + VLOG(1) << "Done"; +} + +namespace { + +// OP_REQUIRES_OK_RETURN is the same as OP_REQUIRES_OK except that +// in error case, it returns RET instead of void. +#define OP_REQUIRES_OK_RETURN(CTX, RET, ...) \ + do { \ + ::tensorflow::Status _s(__VA_ARGS__); \ + if (!TF_PREDICT_TRUE(_s.ok())) { \ + (CTX)->CtxFailureWithWarning(__FILE__, __LINE__, _s); \ + return RET; \ + } \ + } while (0) + +// Helper static functions to construct parameters for +// XlaLocalLaunchBase constructor from OpKernelConstruction. +std::vector<int> ConstantsVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + std::vector<int> constants(constant_types.size()); + std::iota(constants.begin(), constants.end(), 0); + return constants; +} + +std::vector<int> ResourcesVector(OpKernelConstruction* ctx) { + DataTypeVector constant_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Tconstants", &constant_types)); + + DataTypeVector arg_types; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Targs", &arg_types)); + + int num_resources; + OP_REQUIRES_OK_RETURN(ctx, std::vector<int>(), + ctx->GetAttr("Nresources", &num_resources)); + + std::vector<int> resources(num_resources); + std::iota(resources.begin(), resources.end(), + constant_types.size() + arg_types.size()); + return resources; +} + +NameAttrList FunctionAttr(OpKernelConstruction* ctx) { + const NameAttrList* func; + OP_REQUIRES_OK_RETURN(ctx, NameAttrList(), ctx->GetAttr("function", &func)); + return *func; +} + +#undef OP_REQUIRES_OK_RETURN +} // namespace + +XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx) + : XlaLocalLaunchBase(ctx, ConstantsVector(ctx), ResourcesVector(ctx), + FunctionAttr(ctx)) {} + +XlaLocalLaunchOp::~XlaLocalLaunchOp() { + VLOG(1) << "XlaLocalLaunchOp destroyed"; +} + +XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx) + : OpKernel(ctx), + constants_(ConstantsVector(ctx)), + resources_(ResourcesVector(ctx)), + function_(FunctionAttr(ctx)) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaCompileOp::Compute(OpKernelContext* ctx) { + xla::LocalClient* client; + const XlaCompiler::CompilationResult* kernel; + xla::LocalExecutable* executable; + std::map<int, OptionalTensor> variables; + + OP_REQUIRES_OK( + ctx, CompileToLocalExecutable(ctx, function_, platform_info_, resources_, + constants_, &client, &variables, &kernel, + &executable)); + + // Each execution of an XlaCompile op creates a new XlaExecutableClosure, even + // if it didn't have to compile the cluster because of a compilation-cache + // hit. This is because we at least need new snapshots of the resource + // variables. + XlaExecutableClosureStore::KeyT key = + XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure( + client, executable, kernel, std::move(variables), constants_.size())); + + Allocator* cpu_allocator = [&] { + AllocatorAttributes host_alloc_attrs; + host_alloc_attrs.set_gpu_compatible(true); + host_alloc_attrs.set_on_host(true); + return ctx->device()->GetAllocator(host_alloc_attrs); + }(); + + Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({})); + compilation_key.flat<string>()(0) = key; + + Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({})); + compilation_successful.flat<bool>()(0) = true; + + ctx->set_output(0, compilation_key); + ctx->set_output(1, compilation_successful); +} + +XlaRunOp::XlaRunOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, PlatformInfoFromContext(ctx, &platform_info_)); +} + +void XlaRunOp::Compute(OpKernelContext* ctx) { + Tensor key_tensor = ctx->input(ctx->num_inputs() - 1); + const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0); + + XlaExecutableClosure closure = + XlaExecutableClosureStore::Global()->Consume(key); + + XlaComputationLaunchContext launch_context( + closure.client(), platform_info_.allocator(), + /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), + /*use_multiple_streams=*/platform_info_.UseMultipleStreams()); + + // We're missing the must-be-constant inputs, tell `PopulateInputs` + // about this. We don't actually need these inputs because they've + // already been baked into the compiled kernel. + launch_context.PopulateInputs( + ctx, closure.compilation_result(), closure.resource_var_snapshots(), + /*missing_ctx_input_prefix=*/closure.num_constant_args()); + + se::Stream* stream = + ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; + xla::ExecutableRunOptions run_options; + run_options.set_stream(stream); + run_options.set_allocator(platform_info_.allocator()); + run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device()); + run_options.set_rng_seed(GetXLARandomSeed()); + Env* env = Env::Default(); + auto start_time = env->NowMicros(); + + auto run_result = + closure.executable()->Run(launch_context.arguments(), run_options); + OP_REQUIRES(ctx, run_result.ok(), run_result.status()); + + auto elapsed = env->NowMicros() - start_time; + VLOG(2) << "Elapsed time in computation: " << elapsed << "us"; + + OP_REQUIRES_OK( + ctx, + launch_context.PopulateOutputs( + ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/closure.num_constant_args())); +} + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch").Device(DEVICE_CPU), XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("XlaLaunch") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaLocalLaunchOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaCompile").Device(DEVICE_CPU), XlaCompileOp); +REGISTER_KERNEL_BUILDER(Name("_XlaCompile") + .Device(DEVICE_GPU) + .HostMemory("constants") + .HostMemory("resources"), + XlaCompileOp); + +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_CPU), XlaRunOp); +REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE_GPU), XlaRunOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_ops.h b/tensorflow/compiler/jit/kernels/xla_ops.h new file mode 100644 index 0000000000..489d26eb30 --- /dev/null +++ b/tensorflow/compiler/jit/kernels/xla_ops.h @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ +#define TENSORFLOW_COMPILER_JIT_KERNELS_XLA_OPS_H_ + +#include "tensorflow/compiler/jit/xla_compilation_cache.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/stream_executor_util.h" + +namespace tensorflow { + +// Holds some information about the platform on which an +// XlaLaunch/_XlaCompile/_XlaRun op must run on. +class XlaPlatformInfo { + public: + XlaPlatformInfo() : device_type_("") {} + explicit XlaPlatformInfo(const DeviceType device_type, + se::Platform::Id platform_id, + const XlaDevice::Metadata* xla_device_metadata, + std::unique_ptr<XlaAllocator> xla_allocator, + xla::DeviceMemoryAllocator* device_allocator) + : device_type_(device_type), + platform_id_(platform_id), + xla_device_metadata_(xla_device_metadata), + xla_allocator_(std::move(xla_allocator)), + device_allocator_(device_allocator) { + CHECK((device_allocator_ != nullptr) ^ (xla_allocator_.get() != nullptr)); + } + + XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; + + bool UseMultipleStreams() const { + return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); + } + + xla::DeviceMemoryAllocator* allocator() const { + return device_allocator_ ? device_allocator_ : xla_allocator_.get(); + } + DeviceType device_type() const { return device_type_; } + + // This is equal to xla_device_metadata()->platform()->id() if + // xla_device_metadata() is not nullptr. + se::Platform::Id platform_id() const { return platform_id_; } + + // This may be null if the op this XlaPlatformInfo is for was not placed on an + // XLA device. + const XlaDevice::Metadata* xla_device_metadata() const { + return xla_device_metadata_; + } + bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } + + private: + DeviceType device_type_; + se::Platform::Id platform_id_; + + // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the + // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the + // XlaLaunch/_XlaCompile/_XlaRun OpKernel. + const XlaDevice::Metadata* xla_device_metadata_; + + // If the op associated with this XlaPlatformInfo is placed on an XLA device + // then device_allocator_ is the xla::Backend's memory allocator and + // xla_allocator_ is null. If the op is placed on a regular CPU or GPU device + // then device_allocator_ is null and xla_allocator_ points to an appropriate + // XlaAllocator instance. + std::unique_ptr<XlaAllocator> xla_allocator_; + xla::DeviceMemoryAllocator* device_allocator_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); +}; + +// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp. +// The only difference is that it does not require arguments to follow +// the "constants, then regular args, then resources" order. +// It takes vectors of constant and resource arguments explicitly. +// It does not have corresponding OpDef because it is never present +// in the GraphDef. +// Currently, it is used by eager runtime. FunctionLibraryRuntime creates +// this kernel when asked to create a kernel for an XLA-compiled function. +class XlaLocalLaunchBase : public OpKernel { + public: + XlaLocalLaunchBase(OpKernelConstruction* ctx, + const std::vector<int>& constants, + const std::vector<int>& resources, + const NameAttrList& function); + XlaLocalLaunchBase(const XlaLocalLaunchBase&) = delete; + XlaLocalLaunchBase& operator=(const XlaLocalLaunchBase&) = delete; + ~XlaLocalLaunchBase() override = default; + + void Compute(OpKernelContext* ctx) override; + + protected: + // Indexes of compile-time constant inputs + std::vector<int> constants_; + // Indexes of resource inputs + std::vector<int> resources_; + + NameAttrList function_; + XlaPlatformInfo platform_info_; +}; + +// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph +// which will be compiled and executed using XLA. The XlaLocalLaunchOp is +// responsible for handling interactions with the TensorFlow executor. +// Once all inputs are present, and their shapes are known, the op can +// use a 'XlaCompilationCache' to compile and execute code which is specific +// to the shapes of input Tensors. +// XlaLocalLaunchOp uses xla::LocalClient::Compile() and +// xla::LocalExecutable::Run(), and passes arguments into/out of XLA in device +// memory. +class XlaLocalLaunchOp : public XlaLocalLaunchBase { + public: + explicit XlaLocalLaunchOp(OpKernelConstruction* ctx); + ~XlaLocalLaunchOp() override; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp); +}; + +class XlaCompileOp : public OpKernel { + public: + explicit XlaCompileOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + // Indexes of compile-time constant inputs + std::vector<int> constants_; + // Indexes of resource inputs + std::vector<int> resources_; + + NameAttrList function_; + + XlaPlatformInfo platform_info_; +}; + +class XlaRunOp : public OpKernel { + public: + explicit XlaRunOp(OpKernelConstruction* ctx); + + void Compute(OpKernelContext* ctx) override; + + private: + XlaPlatformInfo platform_info_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_KERNELS_XLA_LAUNCH_OP_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index e6cc6e52ae..4f0c370e65 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -21,6 +21,7 @@ limitations under the License. #include <unordered_map> #include <unordered_set> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -365,10 +365,13 @@ bool IsXlaFusable(const NodeDef& node) { return elementwise_ops->count(node.op()) > 0; } +// Nodes that XLA can compile are put in `candidates`. Nodes put in +// `isolated_nodes` must either be unclustered or be put in trivial single-node +// clusters. Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn, - OrderedNodeSet* candidates) { + OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) { OptimizerOptions opts; std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -411,6 +414,8 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); + VLOG(4) << "Device type for " << node->name() << ": " + << device_type.type_string(); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { // is_compilable_fn has already logged the reason if it returned false. @@ -439,19 +444,56 @@ Status FindCompilationCandidates( << node->type_string(); continue; } - if (compile_time_const_nodes[node->id()] && - !registration->requires_compilation) { + if (compile_time_const_nodes[node->id()]) { const OpDef* op_def; TF_RETURN_IF_ERROR( graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { - // We need to be able to constant fold the nodes in - // compile_time_const_nodes given constant inputs (required by XLA) and - // therefore can't auto-cluster stateful ops since these can never be - // constant folded. - VLOG(2) << "Rejecting " << node->name() - << ": must-be-constant stateful op"; - continue; + // It is easiest to demonstrate the problem we're trying to solve with + // an example. Say we have this graph: + // + // shape = RandomUniformInt(); + // reshape = Reshape(input, shape) + // + // Both RandomUniformInt and Reshape are compilable by XLA so, absent + // any other reason, we will try to put both shape and reshape in the + // same cluster. However, since XLA only supports statically shaped + // values, it will expect to be able to constant fold `shape` to get a + // static shape for `reshape`. This is a problem because side-effecting + // ops like RandomUniformInt() cannot be constant folded. We fix this + // by putting `shape` and `reshape` in different clusters, which results + // in us recompiling `reshape`'s cluster for every new value of `shape`, + // making `reshape` statically sized within each compilation. We + // simplify the solution even further by disallowing operations like + // `shape` from being part of *any* non-trivial cluster. They're either + // not compiled by XLA altogether or, if assigned to an XLA_* device + // with "must compile" semantics, compiled into a trivial single-op + // cluster. This approach leaves some room for improvement, and we can + // consider implementing a more aggressive data-flow-analysis based + // solution in the future if needed. + // + // One ugly problem we have to contend with: certain sets of ops *have* + // to be in the same cluster because values flowing between them have + // types that can't be live-in or live-out of a cluster. These ops are: + // + // - TensorArray ops operating on the same TensorArray instance. + // - Stack ops operating on the same Stack instance. + // + // To work around this we avoid isolating these specific ops. Because + // of this concession it is unsound to auto-cluster them because then + // we'd create clusters we could not compile (because we can't constant + // fold, say, a TensorArrayRead or a StackPopV2). But we don't + // auto-cluster these operations today so we're good for now. + const XlaResourceOpInfo* op_info = + GetResourceOpInfoForOp(node->type_string()); + bool is_tensor_array_or_stack_op = + op_info && op_info->resource_kind() != XlaResourceKind::kVariable; + if (!is_tensor_array_or_stack_op) { + VLOG(2) << "Isolating " << node->name() + << ": must-be-constant stateful op"; + isolated_nodes->insert(node); + // Keep going and execute all the other checks. + } } } // We don't auto-cluster functional control flow nodes containing resource @@ -807,11 +849,12 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); OrderedNodeSet compilation_candidates; + absl::flat_hash_set<Node*> isolated_nodes; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env : Env::Default(), - is_compilable_fn, &compilation_candidates)); + is_compilable_fn, &compilation_candidates, &isolated_nodes)); if (compilation_candidates.empty()) { VLOG(2) << "No compilable candidates"; @@ -856,6 +899,11 @@ Status MarkForCompilationPass::RunImpl( "Found control flow node in clustering worklist: ", node_from->type_string()); } + + if (isolated_nodes.count(node_from)) { + continue; + } + string from_scope; string to_scope; for (int to : cycles.Successors(from)) { @@ -873,6 +921,9 @@ Status MarkForCompilationPass::RunImpl( node_to->assigned_device_name()) { continue; } + if (isolated_nodes.count(node_to)) { + continue; + } // Look for an _XlaScope on both nodes. If both nodes have a // scope and the scopes do not match, do not cluster along this // edge. This restriction is overridden if the global_jit_level is ON. If @@ -931,6 +982,11 @@ Status MarkForCompilationPass::RunImpl( // Names for each cluster. std::unordered_map<int, string> cluster_names; + if (flags->tf_xla_clustering_debug) { + dump_graph::DumpGraphToFile("before_mark_for_compilation", **options.graph, + options.flib_def); + } + // Mark clusters for compilation that: // * are placed on a device that requires compilation (an XlaDevice), // * are explicitly marked for compilation (_XlaCompile=true), or diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index c59770a4c8..2a80c745e3 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" @@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) { return ids; } -gtl::FlatMap<string, std::vector<string>> GetClusterSets( +absl::flat_hash_map<string, std::vector<string>> GetClusterSets( const Graph& g, std::vector<string>* cluster_names = nullptr) { CHECK(cluster_names == nullptr || cluster_names->empty()); - gtl::FlatMap<string, std::vector<string>> cluster_sets; + absl::flat_hash_map<string, std::vector<string>> cluster_sets; for (const auto& p : GetClusters(g)) { cluster_sets[p.second].push_back(p.first); } @@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR", @@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph); ASSERT_EQ(cluster_sets.size(), 1); std::vector<string> expected_clustered_nodes = {"AssignmentW", @@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) { TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::vector<string> cluster_names; - gtl::FlatMap<string, std::vector<string>> cluster_sets = + absl::flat_hash_map<string, std::vector<string>> cluster_sets = GetClusterSets(*graph, &cluster_names); ASSERT_EQ(cluster_sets.size(), 2); @@ -894,5 +895,71 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { EXPECT_EQ(clusters["fn_call"], ""); } +TEST(XlaCompilationTest, RandomShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + + Scope root = Scope::NewRootScope().ExitOnError(); + Output shape_shape = + ops::Const(root.WithOpName("test/shape_shape"), {2}, {1}); + Output shape = + ops::RandomUniformInt(root.WithOpName("test/shape_rng"), shape_shape, + ops::Const(root.WithOpName("test/minval"), 1), + ops::Const(root.WithOpName("test/maxval"), 20)); + Output reshape_input = + ops::Placeholder(root.WithOpName("test/reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), reshape_input, shape); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/shape_rng"], ""); + EXPECT_NE(clusters["test/reshape"], ""); + EXPECT_NE(clusters["test/shape_rng"], clusters["test/reshape"]); +} + +TEST(XlaCompilationTest, TensorArrayShapeOnXlaDevice) { + absl::string_view xla_gpu_device = + "/job:worker/replica:0/task:0/device:XLA_GPU:0"; + Scope root = Scope::NewRootScope().ExitOnError(); + ops::TensorArray tensor_array(root.WithOpName("test/tensor_array"), 1, + DT_INT32); + Output zero = ops::Const(root.WithOpName("test/zero"), 0); + ops::TensorArrayWrite tensor_array_write( + root.WithOpName("test/write"), tensor_array.handle, zero, + ops::Const(root.WithOpName("test/forty_two"), 42.0f), tensor_array.flow); + Output tensor_array_read = + ops::TensorArrayRead(root.WithOpName("test/read"), tensor_array.handle, + zero, tensor_array_write.flow_out, DT_INT32); + Output reshape = + ops::Reshape(root.WithOpName("test/reshape"), + ops::Placeholder(root.WithOpName("placeholder"), DT_FLOAT), + tensor_array_read); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + + for (Node* n : graph->nodes()) { + if (absl::StartsWith(n->name(), /*prefix=*/"test/")) { + n->set_assigned_device_name(string(xla_gpu_device)); + } + } + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_NE(clusters["test/read"], ""); + EXPECT_EQ(clusters["test/read"], clusters["test/reshape"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 65669877f7..d56d0f8ccf 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -14,18 +14,35 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { /*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def, SessionOptions* session_options) { - // Assign all nodes to the CPU device. + // Assign all unassigned nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } } + // Call AddDevices to register the XLA devices. + // + // It may be worth refactoring out XlaOpRegistry::RegisterCompilationDevice to + // make this more direct, but probably not worth it solely for this test. + std::vector<Device*> devices; + TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(*session_options, "", &devices)); + + auto delete_devices = gtl::MakeCleanup([&] { + for (Device* d : devices) { + delete d; + } + }); + GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.session_options = session_options; diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 13804c6a05..f72224545b 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -4,9 +4,17 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + cc_library( name = "xla_ops", srcs = ["xla_ops.cc"], deps = ["//tensorflow/core:framework"], alwayslink = 1, ) + +tf_gen_op_wrapper_py( + name = "xla_ops_wrapper_py", + out = "xla_ops.py", + deps = ["//tensorflow/compiler/jit/ops:xla_ops"], +) diff --git a/tensorflow/compiler/jit/ops/xla_ops.cc b/tensorflow/compiler/jit/ops/xla_ops.cc index 1a29c3caab..bcd1a29b1f 100644 --- a/tensorflow/compiler/jit/ops/xla_ops.cc +++ b/tensorflow/compiler/jit/ops/xla_ops.cc @@ -51,4 +51,43 @@ REGISTER_OP("XlaClusterOutput") "Operator that connects the output of an XLA computation to other " "consumer graph nodes."); +REGISTER_OP("_XlaCompile") + .Input("constants: Tconstants") + .Attr("Tconstants: list(type) >= 0") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Input("resources: Nresources * resource") + .Attr("Nresources: int >= 0") + .Output("key: string") + .Output("compilation_successful: bool") + .Attr("function: func") + // The compilation cache is stateful. + .SetIsStateful() + .Doc(R"(XLA Compile Op. For use by the XLA JIT only. + +Compiles a TensorFlow function into an XLA LocalExecutable and returns a key +that _XlaRun can use to look up the LocalExecutable and execute it. + +key: A key that can be used to look up the local executable compiled by the + node and associated metadata. + +compilation_successful: True iff the compilation was successful. Always true +for now. +)"); + +REGISTER_OP("_XlaRun") + .Input("args: Targs") + .Attr("Targs: list(type) >= 0") + .Output("results: Tresults") + .Attr("Tresults: list(type) >= 0") + .Input("key: string") + // XLA random-number generation ops are stateful. + // TODO(phawkins): create stateful and non-stateful variants of _XlaRun. + .SetIsStateful() + .Doc(R"(XLA Run Op. For use by the XLA JIT only. + +Executes a TensorFlow function previously compiled into a LocalExecutable by an +_XlaCompile op. +)"); + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index 10fc9e85d9..b1f9e9088f 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -15,17 +15,18 @@ limitations under the License. #include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace tensorflow { namespace { -Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, +Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set<Node*>* result, absl::Span<Node* const> post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to @@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) { GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), /*edge_filter=*/NotBackedge); - gtl::FlatSet<Node*> nodes_to_partially_decluster; + absl::flat_hash_set<Node*> nodes_to_partially_decluster; TF_RETURN_IF_ERROR( FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index 35872daa65..0feb73a89e 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -60,9 +60,9 @@ class FakeBinaryOp : public OpKernel { void Compute(OpKernelContext* ctx) override { CHECK(false); } }; -class FakeResourceVarUpdateOp : public OpKernel { +class FakeResourceUpdateOp : public OpKernel { public: - explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + explicit FakeResourceUpdateOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* ctx) override { CHECK(false); } @@ -74,10 +74,9 @@ REGISTER_KERNEL_BUILDER(Name("FakeBinary") .HostMemory("host_out"), FakeBinaryOp); -REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") - .Device(DEVICE_CPU) - .HostMemory("something_else"), - FakeResourceVarUpdateOp); +REGISTER_KERNEL_BUILDER( + Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"), + FakeResourceUpdateOp); Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { FixupSourceAndSinkEdges(graph->get()); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 56e35c0059..e039d46ec8 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -82,6 +82,7 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" @@ -89,8 +90,6 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/util/ptr_util.h" @@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) { // point. class ResourceOpSet { private: - using Impl = gtl::FlatSet<ResourceOp>; + using Impl = absl::flat_hash_set<ResourceOp>; public: ResourceOpSet() = default; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 3aa9e9c7ed..0471995015 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -228,37 +228,38 @@ Status XlaCompilationCache::Compile( const XlaCompiler::Options& options, const NameAttrList& function, const std::map<int, Tensor>& constant_args, const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { return CompileImpl(options, function, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, false); + compile_options, /*compile_single_op=*/false, + out_compilation_result, out_executable); } Status XlaCompilationCache::CompileSingleOp( const XlaCompiler::Options& options, const std::map<int, Tensor>& constant_args, const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options) { + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { const NodeDef& def = ctx->op_kernel().def(); NameAttrList name; name.set_name(def.op()); *name.mutable_attr() = def.attr(); - return CompileImpl(options, name, constant_args, variable_args, ctx, - compilation_result, executable, compile_options, true); + return CompileImpl( + options, name, constant_args, variable_args, ctx, compile_options, + /*compile_single_op=*/true, out_compilation_result, out_executable); } Status XlaCompilationCache::CompileImpl( const XlaCompiler::Options& options, const NameAttrList& function, const std::map<int, Tensor>& constant_args, const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op) { - CHECK_NE(executable, nullptr); + const XlaCompiler::CompileOptions& compile_options, bool compile_single_op, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable) { + DCHECK_NE(out_executable, nullptr); VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -357,8 +358,8 @@ Status XlaCompilationCache::CompileImpl( } } TF_RETURN_IF_ERROR(entry->compilation_status); - *compilation_result = &entry->compilation_result; - *executable = entry->executable.get(); + *out_compilation_result = &entry->compilation_result; + *out_executable = entry->executable.get(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index 10ad87e38c..75c7758f73 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -68,9 +68,9 @@ class XlaCompilationCache : public ResourceBase { const std::map<int, Tensor>& constant_args, const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // As above, but calls XlaCompiler::CompileSingleOp instead of // XlaCompiler::CompileFunction. @@ -78,9 +78,9 @@ class XlaCompilationCache : public ResourceBase { const XlaCompiler::Options& options, const std::map<int, Tensor>& constant_args, const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options); + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } @@ -89,15 +89,14 @@ class XlaCompilationCache : public ResourceBase { private: // Common implementation of Compile and CompileSingleOp. - Status CompileImpl(const XlaCompiler::Options& options, - const NameAttrList& function, - const std::map<int, Tensor>& constant_args, - const std::map<int, OptionalTensor>& variable_args, - OpKernelContext* ctx, - const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable, - const XlaCompiler::CompileOptions& compile_options, - bool compile_single_op); + Status CompileImpl( + const XlaCompiler::Options& options, const NameAttrList& function, + const std::map<int, Tensor>& constant_args, + const std::map<int, OptionalTensor>& variable_args, OpKernelContext* ctx, + const XlaCompiler::CompileOptions& compile_options, + bool compile_single_op, + const XlaCompiler::CompilationResult** out_compilation_result, + xla::LocalExecutable** out_executable); // Takes `result` which has been compiled from a Tensorflow subgraph to a // XLA computation already, and generates an XLA LocalExecutable `executable`. @@ -152,7 +151,7 @@ class XlaCompilationCache : public ResourceBase { }; mutex compile_cache_mu_; - gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ + absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ GUARDED_BY(compile_cache_mu_); struct CompileStats { @@ -165,7 +164,7 @@ class XlaCompilationCache : public ResourceBase { mutex compile_stats_mu_; // Maps cluster names to compilation statistics for said cluster. - gtl::FlatMap<string, CompileStats> compile_stats_ + absl::flat_hash_map<string, CompileStats> compile_stats_ GUARDED_BY(compile_stats_mu_); TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 3ba48e8c31..79976c85df 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -34,6 +34,7 @@ std::map<int, OptionalTensor> GetVariables(OpKernelContext* ctx) { OptionalTensor& optional = variables[i]; optional.name = handle.name(); if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); optional.present = true; optional.value = *variable->tensor(); @@ -58,7 +59,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables); + launch_context.PopulateInputs(ctx, result, variables, + /*missing_ctx_input_prefix=*/0); se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -79,7 +81,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, TF_RETURN_IF_ERROR(run_result.status()); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( - ctx, result, run_result.ConsumeValueOrDie())); + ctx, result, run_result.ConsumeValueOrDie(), + /*missing_ctx_input_prefix=*/0)); return Status::OK(); } @@ -177,7 +180,7 @@ Status XlaCompileOnDemandOp::Compile( std::map<int, OptionalTensor> variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, compile_options); + compile_options, result, executable); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7e159e3171..003c1d8081 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_CPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "Host" (CPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/legacy_flags/xla_device_flags.h" #include "tensorflow/compiler/jit/xla_compile_on_demand_op.h" #include "tensorflow/compiler/jit/xla_device.h" @@ -65,10 +65,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 7> kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 12> kAllXlaCpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_CPU, XlaCompileOp, kAllXlaCpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_CPU, XlaRunOp, kAllXlaCpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 51797def04..0824c4644e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -373,7 +373,7 @@ Status XlaDevice::FillContextMap(const Graph* graph, void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { VLOG(2) << "XlaDevice::Compute " << op_kernel->name() << ":" << op_kernel->type_string(); - TracingDevice::Compute(op_kernel, context); + op_kernel->Compute(context); } void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, @@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } +void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { + mutex_lock lock(mu_); + sync_on_completion_ = sync_on_completion; +} + +bool XlaDevice::RequiresSyncOnCompletion() const { + mutex_lock lock(mu_); + return sync_on_completion_; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 92891ffa8c..0f06b3fc80 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice { // information for GPU and TPU devices. Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + // Instructs this XlaDevice to return 'sync_on_completion' for + // RequiresSyncOnCompletion(). + void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + + bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + private: xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) @@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); - mutex mu_; + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. @@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr<thread::ThreadPool> thread_pool_; + + // True if the device requires XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ GUARDED_BY(mu_) = false; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 49c8582682..6967ad1f03 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -65,6 +65,16 @@ class XlaAssignVariableOp : public AsyncOpKernel { .HostMemory("resources"), \ KERNEL); +#define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaCompile") \ + .Device(DEVICE) \ + .HostMemory("constants") \ + .HostMemory("resources"), \ + KERNEL); + +#define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \ + REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL); + #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \ @@ -90,9 +100,15 @@ class XlaAssignVariableOp : public AsyncOpKernel { Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ ResourceHandleOp<Var>); \ REGISTER_KERNEL_BUILDER( \ + Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"), \ + ResourceHandlesOp<Var>); \ + REGISTER_KERNEL_BUILDER( \ Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ ReadVariableOp); \ REGISTER_KERNEL_BUILDER( \ + Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"), \ + ReadVariablesOp); \ + REGISTER_KERNEL_BUILDER( \ Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"), \ DestroyResourceOp); \ REGISTER_KERNEL_BUILDER(Name("Shape") \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ef4466f005..60979556a3 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -16,7 +16,7 @@ limitations under the License. // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs // operators using XLA via the XLA "CUDA" (GPU) backend. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -74,11 +74,14 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 8> kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array<DataType, 13> kAllXlaGpuTypes = { + {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64, + DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4574559674..19e681af0c 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -15,7 +15,7 @@ limitations under the License. // Registers the XLA_INTERPRETER device which exposes the XLA Interpreter. -#include "tensorflow/compiler/jit/kernels/xla_launch_op.h" +#include "tensorflow/compiler/jit/kernels/xla_ops.h" #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -72,6 +72,10 @@ static bool OpFilter(KernelDef* kdef) { return true; } REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp, kExecAllTypes); +REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp, + kExecAllTypes); +REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes); + REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes); REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter); diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index affeab4a8c..4f6fc4e068 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -42,13 +42,14 @@ using xla::ShapedBuffer; } // anonymous namespace std::map<int, OptionalTensor> SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector<int>& variables) { + OpKernelContext* ctx, absl::Span<const int> variables) { std::map<int, OptionalTensor> snapshot; for (int i : variables) { Var* variable = nullptr; ResourceHandle handle = HandleFromInput(ctx, i); OptionalTensor& tensor = snapshot[i]; if (LookupResource(ctx, handle, &variable).ok()) { + core::ScopedUnref scoped_unref(variable); tf_shared_lock lock(*variable->mu()); tensor.name = handle.name(); tensor.present = true; @@ -133,7 +134,8 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map<int, OptionalTensor>& variables) { + const std::map<int, OptionalTensor>& variables, + int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; // Build ShapedBuffers that point directly to the Tensor buffers. @@ -145,12 +147,13 @@ void XlaComputationLaunchContext::PopulateInputs( const Tensor* t; for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) { int arg_num = kernel->input_mapping[i]; + DCHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = kernel->xla_input_shapes[i]; if (variables.count(arg_num)) { t = &(variables.at(arg_num).value); CHECK(t); } else { - t = &(ctx->input(arg_num)); + t = &(ctx->input(arg_num - missing_ctx_input_prefix)); } if (use_multiple_streams_) { @@ -187,7 +190,7 @@ void XlaComputationLaunchContext::PopulateInputs( Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - ScopedShapedBuffer output) { + ScopedShapedBuffer output, int missing_ctx_input_prefix) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; @@ -275,6 +278,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " << DataTypeString(type); if (type == DT_RESOURCE) { + TF_RET_CHECK(kernel->outputs[i].input_index >= 0) + << "Invalid input for outputs " << i; ctx->set_output(i, ctx->input(kernel->outputs[i].input_index)); } else { se::DeviceMemoryBase buffer = output.buffer({output_num}); @@ -313,7 +318,8 @@ Status XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + int actual_input_index = write.input_index - missing_ctx_input_prefix; + if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) { return errors::Internal("Invalid input index for variable write."); } @@ -323,7 +329,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>( - ctx, HandleFromInput(ctx, write.input_index), &variable, + ctx, HandleFromInput(ctx, actual_input_index), &variable, [&write](Var** ptr) { *ptr = new Var(write.type); return Status::OK(); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 7ac275fab8..326d70a027 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { class XlaAllocator; @@ -43,7 +44,7 @@ class XlaAllocator; // resource variable is not initialized, the corresponding OptionalTensor // will have its `present` field set to false. std::map<int, OptionalTensor> SnapshotResourceVariables( - OpKernelContext* ctx, const std::vector<int>& variables); + OpKernelContext* ctx, absl::Span<const int> variables); // Adapter class that wraps a Tensorflow allocator as an XLA allocator. // Assumes that the Tensorflow allocator permits asynchronous deallocation: @@ -88,14 +89,24 @@ class XlaComputationLaunchContext { // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. All elements in kernel's + // input_mapping must be greater than or equal to `missing_ctx_input_prefix` + // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - const std::map<int, OptionalTensor>& variables); + const std::map<int, OptionalTensor>& variables, + int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. + // + // Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are + // missing and adjusts input indices accordingly. Status PopulateOutputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + xla::ScopedShapedBuffer output, + int missing_ctx_input_prefix); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 97ed554171..ba2401ed26 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -895,6 +895,22 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "tensor_list_ops_test", + size = "small", + srcs = ["tensor_list_ops_test.py"], + # TensorList ops are not implemented in the on-demand compilation model yet. + disabled_backends = "cpu_ondemand", + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:list_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + ], +) + +tf_xla_py_test( name = "ternary_ops_test", size = "small", srcs = ["ternary_ops_test.py"], @@ -978,7 +994,7 @@ tf_xla_py_test( name = "gather_test", size = "medium", srcs = ["gather_test.py"], - tags = ["noasan"], # times out, http://b/78599043 + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -1029,6 +1045,19 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "permute_test", + size = "small", + srcs = ["permute_test.py"], + deps = [ + "//tensorflow/compiler/tests:xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:nn_ops", + ], +) + +tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], @@ -1105,6 +1134,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1198,6 +1228,19 @@ tf_xla_py_test( ) tf_xla_py_test( + name = "quantized_ops_test", + size = "small", + srcs = ["quantized_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( name = "xla_ops_test", size = "medium", srcs = ["xla_ops_test.py"], diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 4155342787..68f52e796c 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 17280e445b..1b39d53dc0 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase): equality_test=self.ListsAreClose) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testBinary( gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), @@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - if dtype not in self.complex_types: # min/max not supported for complex + # min/max not supported for complex + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.maximum, np.array([1, 2], dtype=dtype), @@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[70], [14]], dtype=dtype)) # Complex support for squared_difference is incidental, see b/68205550 - if dtype not in self.complex_types: + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.squared_difference, np.array([1, 2], dtype=dtype), @@ -559,6 +560,13 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) + if dtype in [np.float32, np.float64]: + nums = np.arange(-10, 10, .25, dtype=dtype).reshape(80, 1) + divs = np.arange(-3, 3, .25, dtype=dtype).reshape(1, 24) + np_result = np.true_divide(nums, divs) + np_result[:, divs[0] == 0] = 0 + self._testBinary(gen_math_ops.div_no_nan, nums, divs, expected=np_result) + if dtype not in self.complex_types: # floordiv unsupported for complex. self._testBinary( gen_math_ops.floor_div, @@ -567,7 +575,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testDivision(dtype) def testFloatDivision(self): @@ -588,7 +596,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 1, -1, 0], dtype=dtype)) def testIntRemainder(self): - for dtype in self.int_types: + for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) def testFloatRemainder(self): @@ -1437,6 +1445,13 @@ class BinaryOpsTest(xla_test.XLATestCase): np.array([4, 0], dtype=np.int32), expected=np.zeros([4, 0], dtype=dtype)) + x = np.arange(3).reshape((3, 1, 1, 1)).astype(dtype) + self._testBinary( + array_ops.broadcast_to, + x, + np.array((3, 7, 8, 9), dtype=np.int32), + expected=np.tile(x, (1, 7, 8, 9))) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index a76f136736..1d3979b21b 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -2,6 +2,10 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) def all_backends(): b = ["cpu"] + plugins.keys() @@ -58,14 +62,14 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64", ] elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_QUINT8,DT_INT8,DT_QINT8,DT_INT32,DT_QINT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", ] - backend_tags += ["requires-gpu-sm35"] + backend_tags += tf_cuda_tests_tags() elif backend in plugins: backend_args += [ "--test_device=" + plugins[backend]["device"], diff --git a/tensorflow/compiler/tests/dense_layer_test.py b/tensorflow/compiler/tests/dense_layer_test.py index 0af74c2d8f..9390870e07 100644 --- a/tensorflow/compiler/tests/dense_layer_test.py +++ b/tensorflow/compiler/tests/dense_layer_test.py @@ -45,17 +45,21 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def XlaLaunchOpCount(labels): - """Count how many XlaLaunch labels are present.""" - return sum("XlaLaunch(" in x for x in labels) +class DenseLayerTest(test.TestCase): + def countXlaOps(self, labels): + """Count how many XlaCompile/XlaRun labels are present.""" + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + return xla_run_count -class DenseLayerTest(test.TestCase): def testDenseLayerAutoJit(self): """Tests dense layer compilation in auto-jit mode. - Dense layer should be compiled into a single XlaLaunch op in auto-jit mode. + Dense layer should be compiled into a single XlaCompile/XlaRun op pair in + auto-jit mode. """ os.environ["TF_XLA_FLAGS"] = ( @@ -77,14 +81,14 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) self.assertFalse(InLabels(labels, "MatMult")) def testDenseLayerJitScopeDefinedShape(self): """Tests that the dense layer node is properly compiled in jit scope. Dense layer with static shape input tensor should be compiled into a single - XlaLaunch op by XLA. + XlaCompile/XlaRun op pair by XLA. """ with self.cached_session() as sess: @@ -101,7 +105,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(1, XlaLaunchOpCount(labels)) + self.assertEqual(1, self.countXlaOps(labels)) # No need to check whether ListDiff is compiled or not because ListDiff op # is not used when input tensor shape is fully defined. @@ -111,7 +115,8 @@ class DenseLayerTest(test.TestCase): Dense layer uses shape op to get shape of input tensor if its shape is not fully defined. XLA does not cluster shape op with other operators. But in experimental_jit_scope, XLA is forced to compile shape op into its own - cluster, causing dense layer to be split into TWO XlaLaunch ops. + cluster, causing dense layer to be split into TWO XlaCompile/XlaRun op + pairs. """ with self.cached_session() as sess: @@ -128,7 +133,7 @@ class DenseLayerTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = GetRunMetadataLabels(run_metadata) - self.assertEqual(2, XlaLaunchOpCount(labels)) + self.assertEqual(2, self.countXlaOps(labels)) self.assertFalse(InLabels(labels, "MatMult")) diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 8c018cccb8..374942a0b3 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -29,6 +29,11 @@ from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn from tensorflow.python.platform import test +DATA_FORMATS = ( + ("_data_format_NHWC", "NHWC"), + ("_data_format_NCHW", "NCHW"), +) + class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): @@ -65,12 +70,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): grad_offset = np.sum(grad_y, axis=(0, 1, 2)) return grad_x, grad_scale, grad_offset - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] @@ -170,30 +170,15 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearning(self, data_format): self._testLearning(False, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testLearningWithGradientChecker(self, data_format): self._testLearning(True, data_format) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientTraining(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. @@ -241,12 +226,7 @@ class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) - @parameterized.named_parameters( - ("_data_format_NHWC", "NHWC"), - ("_data_format_NCHW", "NCHW"), - ("_data_format_HWNC", "HWNC"), - ("_data_format_HWCN", "HWCN"), - ) + @parameterized.named_parameters(*DATA_FORMATS) def testGradientInference(self, data_format): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 089d95daab..a38e1edafe 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -51,7 +51,7 @@ class GatherTest(xla_test.XLATestCase): indices_tf = constant_op.constant(indices) gather_t = array_ops.gather(params, indices_tf) gather_val = session.run(gather_t, feed_dict={params: params_np}) - np_val = params_np[indices] + np_val = constant_op.constant(params_np[indices]) self.assertAllEqual(np_val, gather_val) def testScalar2D(self): @@ -65,7 +65,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant(2) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, 2, axis=axis) + expected = constant_op.constant( + np.take(params_np, 2, axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): @@ -80,7 +81,8 @@ class GatherTest(xla_test.XLATestCase): indices = constant_op.constant([0, 1, 0, 2]) gather_t = array_ops.gather(params, indices, axis=axis) gather_val = session.run(gather_t, feed_dict={params: params_np}) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testSimpleTwoD32_Int64Indices(self): @@ -103,7 +105,8 @@ class GatherTest(xla_test.XLATestCase): params: params_np, indices: indices_np }) - expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + expected = constant_op.constant( + np.take(params_np, [0, 1, 0, 2], axis=axis), dtype) self.assertAllEqual(expected, gather_val) def testHigherRank(self): @@ -119,7 +122,8 @@ class GatherTest(xla_test.XLATestCase): tf_indices = constant_op.constant(indices, dtype=dtypes.int32) gather = array_ops.gather(tf_params, tf_indices, axis=axis) gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, axis=axis) + gather_np = constant_op.constant( + np.take(params, indices, axis=axis), dtype) self.assertAllEqual(gather_np, gather_value) def testIndicesWithDifferentDimensions(self): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 6fe5a66e0e..68fdb5caf4 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -605,10 +605,6 @@ class ResizeBilinearTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): def testNMS128From1024(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -644,10 +640,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) def testNMS3From6Boxes(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -693,10 +685,6 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): # Three boxes are selected based on IOU. # One is filtered out by score threshold. - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - with compat.forward_compatibility_horizon(2018, 8, 8): boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] @@ -736,6 +724,49 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + def testNMS3Then1WithScoreMaxThresh(self): + # Three boxes are selected based on IOU. + # One is filtered out by score threshold. + # One is filtered out by max_output_size. + + with compat.forward_compatibility_horizon(2018, 8, 8): + boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + boxes_np = np.array(boxes_data, dtype=np.float32) + + scores_data = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + scores_np = np.array(scores_data, dtype=np.float32) + max_output_size = 1 + iou_threshold_np = np.array(0.5, dtype=np.float32) + score_threshold_np = np.array(0.4, dtype=np.float32) + + with self.cached_session() as sess: + boxes = array_ops.placeholder(boxes_np.dtype, shape=boxes_np.shape) + scores = array_ops.placeholder(scores_np.dtype, shape=scores_np.shape) + iou_threshold = array_ops.placeholder(iou_threshold_np.dtype, + iou_threshold_np.shape) + score_threshold = array_ops.placeholder(score_threshold_np.dtype, + score_threshold_np.shape) + with self.test_scope(): + selected_indices = image_ops.non_max_suppression_padded( + boxes=boxes, + scores=scores, + max_output_size=max_output_size, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pad_to_max_output_size=True) + inputs_feed = { + boxes: boxes_np, + scores: scores_np, + iou_threshold: iou_threshold_np, + score_threshold: score_threshold_np + } + (indices_tf, num_valid) = sess.run( + selected_indices, feed_dict=inputs_feed) + + self.assertEqual(indices_tf.size, max_output_size) + self.assertEqual(num_valid, 1) + self.assertAllClose(indices_tf[:num_valid], [3]) if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 0839fb123e..de68ff0e32 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -77,11 +77,11 @@ def InLabels(labels, substr): return any([substr in x for x in labels]) -def MetadataHasXlaLaunch(run_metadata): - """Returns true if there is a XlaLaunch kernel in run_metadata's timeline.""" +def MetadataHasXlaOp(run_metadata): + """Returns true if there are XlaRun kernels in run_metadata's timeline.""" # TODO(phawkins): find a less hacky way to test whether a kernel ran. - return InLabels(RunMetadataLabels(run_metadata), "XlaLaunch") + return InLabels(RunMetadataLabels(run_metadata), "XlaRun") class JitLaunchTest(test.TestCase): @@ -90,9 +90,10 @@ class JitLaunchTest(test.TestCase): # Verifies that the outputs match and that XLA was invoked. 'fn' must take # the same number of tensors as arguments that are in 'args', and must return # a tuple of output tensors. - # If 'require_kernel_launch' is True, then we verify that a XlaLaunch node - # actually ran. However, it is sometimes possible for XlaLaunch ops to be - # constant-folded away, so the check is optional. + # + # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun + # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun + # ops to be constant-folded away, so the check is optional. def _compare(self, fn, args, require_kernel_launch=True, noinline=None): with session_lib.Session(config=NoRewriteSessionConfig()) as sess: placeholders = [] @@ -115,7 +116,7 @@ class JitLaunchTest(test.TestCase): print("Compiled Result {}".format(compiled)) if require_kernel_launch: - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) direct = sess.run(direct_op, feeds) print("Direct Result {}".format(direct)) @@ -149,10 +150,10 @@ class JitLaunchTest(test.TestCase): y = math_ops.add(x, x) return y, y - # Exercises compling a function (say, Foo) which calls another - # function (say, Bar) which is not inlined. When the compiler compiles - # Foo, it needs to symbolic execute Bar correctly regardless whether - # Bar is inlined or not. + # Exercises compiling a function (say, Foo) which calls another function + # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs + # to symbolically execute Bar correctly regardless of whether Bar is inlined + # or not. # TODO(b/36139787): Re-enable this test when noinline works again. # Tests compiled=True and noinline=True. @@ -259,7 +260,7 @@ class JitLaunchTest(test.TestCase): # TODO(phawkins): really we would like to test that there were exactly # two kernel launches. However, we have no reliable way to determine # that. - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) expected = np.square(np.dot(dx, dw) + db) self.assertAllClose(expected, output, rtol=1e-1) @@ -289,7 +290,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) def testIgnoredArguments(self): @@ -313,7 +314,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(28, out) def testLoops(self): @@ -331,7 +332,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(95), rtol=1e-1) def testCond(self): @@ -356,7 +357,7 @@ class XlaCompilationTest(test.TestCase): run_metadata=run_metadata, options=config_pb2.RunOptions( trace_level=config_pb2.RunOptions.FULL_TRACE)) - self.assert_(MetadataHasXlaLaunch(run_metadata)) + self.assert_(MetadataHasXlaOp(run_metadata)) self.assertAllClose(result, np.float32(6), rtol=1e-1) def testNestedFunction(self): @@ -441,14 +442,16 @@ class XlaCompilationTest(test.TestCase): self.assertFalse(InLabels(labels, "Log")) self.assertTrue(InLabels(labels, "Reciprocal")) self.assertTrue(InLabels(labels, "Mul")) - self.assertFalse(InLabels(labels, "XlaLaunch")) + self.assertFalse(InLabels(labels, "XlaCompile")) + self.assertFalse(InLabels(labels, "XlaRun")) - # Compile the backprop. One XlaLaunch. + # Compile the backprop. One XlaCompile/XlaRun pair. labels = _Run(compiled=True) self.assertFalse(InLabels(labels, "Log")) self.assertFalse(InLabels(labels, "Reciprocal")) self.assertFalse(InLabels(labels, "Mul")) - self.assertTrue(InLabels(labels, "XlaLaunch")) + self.assertTrue(InLabels(labels, "XlaCompile")) + self.assertTrue(InLabels(labels, "XlaRun")) class ElementWiseFusionTest(test.TestCase): @@ -482,9 +485,12 @@ class ElementWiseFusionTest(test.TestCase): trace_level=config_pb2.RunOptions.FULL_TRACE)) labels = RunMetadataLabels(run_metadata) - count = sum("XlaLaunch(" in x for x in labels) - return output, count + xla_compile_count = sum("XlaCompile(" in x for x in labels) + xla_run_count = sum("XlaRun(" in x for x in labels) + self.assertEqual(xla_compile_count, xla_run_count) + + return output, xla_run_count def testElementWiseClustering(self): arg0 = np.random.rand(2, 2).astype(np.float32) diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py index 43c469d032..73b3638e80 100644 --- a/tensorflow/compiler/tests/lstm.py +++ b/tensorflow/compiler/tests/lstm.py @@ -117,7 +117,7 @@ def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq): def RandomVar(shape, name=None): """Returns a variable of the given shape initialized to random values.""" - return variables.Variable( + return variables.VariableV1( random_ops.random_uniform(shape), dtype=dtypes.float32, name=name) diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py index f985c5d2d9..38cb2f83ef 100644 --- a/tensorflow/compiler/tests/nullary_ops_test.py +++ b/tensorflow/compiler/tests/nullary_ops_test.py @@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase): output.run() def testConstants(self): - constants = [ - np.float32(42), - np.array([], dtype=np.float32), - np.array([1, 2], dtype=np.float32), - np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32), - np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], - dtype=np.float32), - np.array([[[]], [[]]], dtype=np.float32), - np.array([[[[1]]]], dtype=np.float32), - ] - for c in constants: - self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + for dtype in self.numeric_types: + constants = [ + dtype(42), + np.array([], dtype=dtype), + np.array([1, 2], dtype=dtype), + np.array([7, 7, 7, 7, 7], dtype=dtype), + np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) + + def testComplexConstants(self): + for dtype in self.complex_types: + constants = [ + dtype(42 + 3j), + np.array([], dtype=dtype), + np.ones([50], dtype=dtype) * (3 + 4j), + np.array([1j, 2 + 1j], dtype=dtype), + np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype), + np.array([[[1, 2], [3, 4 + 6j], [5, 6]], + [[10 + 7j, 20], [30, 40], [50, 60]]], + dtype=dtype), + np.array([[[]], [[]]], dtype=dtype), + np.array([[[[1 + 3j]]]], dtype=dtype), + ] + for c in constants: + self._testNullary(lambda c=c: constant_op.constant(c), expected=c) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py new file mode 100644 index 0000000000..dbb9274df4 --- /dev/null +++ b/tensorflow/compiler/tests/permute_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the DataFormatVecPermute operator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test + + +class XlaPermuteOpTest(xla_test.XLATestCase): + + def _runPermuteAndCompare(self, x, src_format, dst_format, expected): + with self.cached_session() as session: + with self.test_scope(): + placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape) + param = {placeholder: x} + output = nn_ops.data_format_vec_permute( + placeholder, src_format=src_format, dst_format=dst_format) + result = session.run(output, param) + self.assertAllEqual(result, expected) + + def testNHWCToNCHW(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9]) + + def testNCHWToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4]) + + def testNHWCToHWNC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3]) + + def testHWNCToNHWC(self): + x = np.array([7, 4, 9, 3], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3]) + + def testNHWCToNCHW2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "NCHW", + [[7, 4], [5, 1], [9, 3], [4, 5]]) + + def testNHWCToHWNC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NHWC", "HWNC", + [[9, 3], [4, 5], [7, 4], [5, 1]]) + + def testHWNCToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "HWNC", "NHWC", + [[4, 5], [7, 4], [9, 3], [5, 1]]) + + def testNCHWToNHWC2D(self): + x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32) + self._runPermuteAndCompare(x, "NCHW", "NHWC", + [[7, 4], [4, 5], [5, 1], [9, 3]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/quantized_ops_test.py b/tensorflow/compiler/tests/quantized_ops_test.py new file mode 100644 index 0000000000..80c338513b --- /dev/null +++ b/tensorflow/compiler/tests/quantized_ops_test.py @@ -0,0 +1,48 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for quantized operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +class QuantizedOpsTest(xla_test.XLATestCase): + + # Verify that quantized types can be clustered by XLA. + def testQuantizedTypeRoundtrip(self): + with self.cached_session() as session: + for dtype in self.quantized_tf_types: + in_values = np.array([1, 2, 3, 4, 5, 6]) + expected = [[1, 2], [3, 4], [5, 6]] + with self.test_scope(): + p = array_ops.placeholder(dtype=dtypes.int32) + x = math_ops.cast(p, dtype) + x = array_ops.reshape(x, [3, 2]) + + value = session.run(x, {p: in_values}) + self.assertAllEqual(value, expected) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 6e18344117..36ef6ed5fe 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): - return set(self.numeric_types) - set(self.complex_types) + return set(self.numeric_types) - set( + self.complex_types) - {np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. @@ -68,9 +69,8 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): for dtype in self._random_types(): @@ -92,13 +92,13 @@ class RandomOpsTest(xla_test.XLATestCase): def rng(dtype): return random_ops.truncated_normal(shape=[2], dtype=dtype) - # TODO(b/34339814): implement inverse erf support for non-F32 types. - self._testRngIsNotConstant(rng, dtypes.float32) + for dtype in self._random_types() & self.float_types: + self._testRngIsNotConstant(rng, dtype) def testTruncatedNormalIsInRange(self): count = 10000000 - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + # TODO(b/34339814): make this test work with 16 bit float types. + for dtype in self._random_types() & {dtypes.float32, dtypes.float64}: with self.cached_session() as sess: with self.test_scope(): x = random_ops.truncated_normal(shape=[count], dtype=dtype) @@ -144,9 +144,6 @@ class RandomOpsTest(xla_test.XLATestCase): self.assertAllClose(actual_variance, expected_variance, rtol=2*1e-3) def testShuffle1d(self): - # TODO(b/26783907): this test requires the CPU backend to implement sort. - if self.device in ["XLA_CPU"]: - return with self.cached_session() as sess: with self.test_scope(): x = math_ops.range(1 << 16) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index bddda6f302..dc119fb0f8 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,7 @@ limitations under the License. #include <random> #include <unordered_map> +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" @@ -63,7 +64,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, Tensor tensor(dtype, TensorShape(shape)); switch (dtype) { case DT_FLOAT: { - gtl::FlatSet<float> already_generated; + absl::flat_hash_set<float> already_generated; std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); test::FillFn<float>(&tensor, [&](int i) -> float { float generated; @@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_DOUBLE: { - gtl::FlatSet<double> already_generated; + absl::flat_hash_set<double> already_generated; std::uniform_real_distribution<double> distribution(-1.0, 1.0); test::FillFn<double>(&tensor, [&](int i) -> double { double generated; @@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_COMPLEX64: { - gtl::FlatSet<std::pair<float, float>> already_generated; + absl::flat_hash_set<std::pair<float, float>> already_generated; std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); test::FillFn<complex64>(&tensor, [&](int i) { complex64 generated; @@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT32: { - gtl::FlatSet<int32> already_generated; + absl::flat_hash_set<int32> already_generated; std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20); test::FillFn<int32>(&tensor, [&](int i) -> int32 { int32 generated; @@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_INT64: { - gtl::FlatSet<int64> already_generated; + absl::flat_hash_set<int64> already_generated; std::uniform_int_distribution<int64> distribution(-(1LL << 40), 1LL << 40); test::FillFn<int64>(&tensor, [&](int i) -> int64 { @@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values, break; } case DT_BOOL: { - gtl::FlatSet<bool> already_generated; + absl::flat_hash_set<bool> already_generated; std::bernoulli_distribution distribution; test::FillFn<bool>(&tensor, [&](int i) -> bool { bool generated; @@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) { do { dims = RandomDims(1); size = TensorShape(dims).num_elements(); - } while (size * size < tf_xla_max_tensor_size); + } while (size * size > tf_xla_max_tensor_size); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type)); }); diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 60c2337743..abc822ef36 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): def testSeqLength(self): for dtype in self.all_types: - for seq_dtype in self.int_types: + for seq_dtype in self.all_types & {np.int32, np.int64}: self._testBasic(dtype, seq_dtype) diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 51c04b5c47..57f0ab7a9e 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -48,22 +48,30 @@ class XlaSortOpTest(xla_test.XLATestCase): self.assertAllClose(v, result, rtol=1e-3) def testSort(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - - supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): x = np.arange(101, dtype=dtype) np.random.shuffle(x) self._assertOpOutputMatchesExpected( xla.sort, [x], expected=[np.arange(101, dtype=dtype)]) - def testTopK(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return + def testKeyValueSort(self): + supported_types = set( + [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) + for key_type in supported_types.intersection(self.numeric_types): + for value_type in supported_types.intersection(self.numeric_types): + x = np.arange(101, dtype=key_type) + np.random.shuffle(x) + y = (-x).astype(value_type) + self._assertOpOutputMatchesExpected( + xla.key_value_sort, [x, y], + expected=[ + np.arange(101, dtype=key_type), + -np.arange(101, dtype=value_type) + ]) + def testTopK(self): supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -89,10 +97,6 @@ class XlaSortOpTest(xla_test.XLATestCase): expected=[x[indices].astype(dtype), indices]) def testTopK2D(self): - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - supported_types = set( [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32]) for dtype in supported_types.intersection(self.numeric_types): @@ -122,10 +126,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: @@ -144,10 +144,6 @@ class XlaSortOpTest(xla_test.XLATestCase): def testTopKInfinities(self): """Tests that positive and negative infinity sort correctly.""" - # TODO(b/26783907): The Sort HLO is not implemented on CPU or GPU. - if self.device in ["XLA_CPU", "XLA_GPU"]: - return - # Only bfloat16 is implemented. bfloat16 = dtypes.bfloat16.as_numpy_dtype if bfloat16 not in self.numeric_types: diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 1bea7d9355..e8741bc468 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -34,7 +34,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): """Test cases for stateless random-number generator operators.""" def _random_types(self): - return [dtypes.float32] + return self.float_types & {dtypes.float32, dtypes.float64} def testDeterminism(self): # Stateless values should be equal iff the seeds are equal (roughly) @@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): with self.cached_session() as sess, self.test_scope(): for dtype in self._random_types(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) - x = stateless.stateless_random_uniform( + x = stateless.stateless_random_normal( shape=[10000], seed=seed_t, dtype=dtype) y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) self.assertTrue(np.all(np.isfinite(y))) @@ -124,8 +124,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): self.assertTrue(self._anderson_darling(y) < 2.492) def testTruncatedNormalIsInRange(self): - # TODO(b/34339814): implement inverse erf support for non-F32 types. - for dtype in [dtypes.float32]: + for dtype in self._random_types(): with self.cached_session() as sess, self.test_scope(): seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) n = 10000000 @@ -159,7 +158,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase): # Department of Scientific Computing website. Florida State University. expected_mean = mu + (normal_pdf(alpha) - normal_pdf(beta)) / z * sigma actual_mean = np.mean(y) - self.assertAllClose(actual_mean, expected_mean, atol=2e-4) + self.assertAllClose(actual_mean, expected_mean, atol=5e-4) expected_median = mu + probit( (normal_cdf(alpha) + normal_cdf(beta)) / 2.) * sigma diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py new file mode 100644 index 0000000000..5c079d595c --- /dev/null +++ b/tensorflow/compiler/tests/tensor_list_ops_test.py @@ -0,0 +1,96 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ops which manipulate lists of tensors via bridge.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +def scalar_shape(): + return ops.convert_to_tensor([], dtype=dtypes.int32) + + +class ListOpsTest(xla_test.XLATestCase): + + def testElementShape(self): + with self.cached_session() as sess, self.test_scope(): + dim = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(dim, 15), num_elements=20, + element_dtype=dtypes.float32) + e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32) + e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64) + self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15)) + self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15)) + + def testPushPop(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(7, 15))) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(2.0, shape=(7, 15))) + l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15))) + self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15))) + + def testPushPopSeparateLists(self): + with self.cached_session() as sess, self.test_scope(): + num = array_ops.placeholder(dtypes.int32) + l = list_ops.tensor_list_reserve( + element_shape=scalar_shape(), + num_elements=num, + element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0)) + l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0)) + l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0)) + _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32) + l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32) + result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20}) + self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]]) + + def testEmptyTensorList(self): + dim = 7 + with self.cached_session() as sess, self.test_scope(): + p = array_ops.placeholder(dtypes.int32) + l = list_ops.empty_tensor_list( + element_shape=(p, 15), element_dtype=dtypes.float32) + l = list_ops.tensor_list_push_back( + l, constant_op.constant(1.0, shape=(dim, 15))) + _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Use TensorListReserve instead"): + self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 55a992195f..98a07709c6 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -122,8 +122,7 @@ class TernaryOpsTest(xla_test.XLATestCase): expected=np.array([[2], [5]], dtype=dtype)) def testClipByValue(self): - # TODO(b/78258593): enable integer types here too. - for dtype in self.float_types: + for dtype in self.numeric_types - self.complex_types: test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5b0e57f83f..77f6eee0cf 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array( @@ -158,9 +158,6 @@ class UnaryOpsTest(xla_test.XLATestCase): def testFloatOps(self): for dtype in self.float_types: - # TODO(b/77694432): Half test failed on CPU, last ran on 04-06-2018. - if dtype == np.float16 and self.device == "XLA_CPU": - continue x = np.arange(-0.90, 0.90, 0.25) self._assertOpOutputMatchesExpected( math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype)) @@ -633,7 +630,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) def testNumericOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 1e600c44e9..4cf88fc523 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dtype=dtype)) def testNeg(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.uint8, np.int8}: self._assertOpOutputMatchesExpected( xla.neg, args=(np.array([1, 2, 3], dtype=dtype),), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 88827cb53b..98a41981cf 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -97,10 +97,23 @@ class XLATestCase(test.TestCase): ]) self._numeric_tf_types = set( self.int_tf_types | self._float_tf_types | self.complex_tf_types) - - self._all_types = set( - [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.quantized_tf_types = set( + dtype for dtype in self._all_tf_types if dtype.is_quantized) + + # Quantized types don't have a numpy equivalent, include them in + # all_tf_types but not in all_types. + # TODO(b/115960798): Parametrize tests on TF types instead of numpy types + # and remove all_types. + self._all_types = set(dtype.as_numpy_dtype + for dtype in self._all_tf_types + if not dtype.is_quantized) self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self.signed_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if not dtype.is_unsigned) + self.unsigned_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ba1e3b2b4f..3f631f91ec 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -635,6 +635,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ops", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -649,6 +650,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14..adcdb6c8f7 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -30,14 +30,15 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", - out_ops_file = "ops/xla_jit_op", + include_internal_ops = 1, + out_ops_file = "ops/xla_jit_ops", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( name = "xla_jit_ops", - srcs = ["ops/xla_jit_op.cc"], - hdrs = ["ops/xla_jit_op.h"], + srcs = ["ops/xla_jit_ops.cc"], + hdrs = ["ops/xla_jit_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 922ae7c79a..027ca6d2d2 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g, std::vector<bool>* compile_time_const_arg_indices, std::vector<bool>* compile_time_const_nodes, std::function<bool(const Edge&)> edge_filter) { - // Operators that don't look at the data of their inputs, just the shapes. - const std::unordered_set<string> metadata_ops = { - "Rank", - "Shape", - "ShapeN", - "Size", - }; - std::vector<bool> compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g, if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. - if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + if (XlaOpRegistry::IsMetadataOp(node->type_string())) { + return; + } // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index f792c52032..0362682bd6 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -31,11 +31,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -77,7 +79,10 @@ Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map<string, tensorflow::AttrValue>& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map<string, string>* canonicalized_name_to_new_name) { + std::map<string, absl::optional<string>>* canonicalized_name_to_new_name, + bool* modified) { + *modified = false; + // Convert the function to Graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); @@ -89,7 +94,20 @@ Status FunctionalizeControlFlowForFunction( } }); const FunctionBody* body = flr->GetFunctionBody(handle); - const FunctionDef& fdef = body->fdef; + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // We cannot return here directly if the graph has no Switch/Merge. + // It might contain function call nodes, or If/While nodes with Switch/Merge + // in function body. We still need to rewrite those functions and modify + // corresponding nodes. // If any node has associated functions, functionalize them first. // Gather nodes with associated functions first, because rewriting those nodes @@ -97,7 +115,7 @@ Status FunctionalizeControlFlowForFunction( // it. std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>> nodes_to_associated_functions; - for (auto* n : body->graph->nodes()) { + for (auto* n : g->nodes()) { auto associated_functions = GetAssociatedFunctions(*n, flr); if (!associated_functions.empty()) { nodes_to_associated_functions.push_back({n, associated_functions}); @@ -108,57 +126,86 @@ Status FunctionalizeControlFlowForFunction( auto associated_functions = iter.second; for (auto& associated_function : associated_functions) { string name = associated_function.func_name(); - string canonicalized_name = Canonicalize(name, AttrSlice(&attrs)); + string canonicalized_name = + Canonicalize(name, AttrSlice(&associated_function.attrs())); auto iter = canonicalized_name_to_new_name->find(canonicalized_name); string new_name; + bool function_modified; if (iter != canonicalized_name_to_new_name->end()) { - // If we already functionalized this function, skip functionalization - // but still rewrite the node. - new_name = iter->second; + // If we already processed this function, check if it was rewritten. If + // the function was rewritten, the entry will be non-empty. Otherwise + // the entry will be empty. + function_modified = iter->second.has_value(); + if (function_modified) { + new_name = iter->second.value(); + } } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + if (associated_function.type() == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { + // For SymbolicGradient, `name` is always "SymbolicGradient", + // which is not very informative. Use node name instead. + new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); + } else { + new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); + } TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, attrs, fld, flr, canonicalized_name_to_new_name)); - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + name, new_name, associated_function.attrs(), fld, flr, + canonicalized_name_to_new_name, &function_modified)); + if (function_modified) { + // If the function was rewritten, add an non-empty entry. So later we + // know we have processed this function, and it was rewritten into + // another function. + (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; + } else { + // If the function was not rewritten, add an empty entry. So later + // we know we have processed this function, and it does not need to be + // rewritten. + (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; + } + } + if (function_modified) { + *modified = true; + + // Notice that if "n" is a function call, RewriteAssociatedFunction() + // will delete it and create a new node instead, making "n" an invalid + // pointer. That's fine because in that case, associated_functions will + // only have one member and the loop will only run once. + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + g, n, fld, associated_function, new_name)); } - // Notice that if "n" is a function call, RewriteAssociatedFunction() will - // delete it and create a new node instead, making "n" an invalid pointer. - // That's fine because in that case, associated_functions will only have - // one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - body->graph, n, fld, associated_function, new_name)); } } - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *body->graph, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(body->graph, fld)); - if (VLOG_IS_ON(4)) { - dump_graph::DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), - *body->graph, fld); + if (has_switch_or_merge) { + *modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); + if (VLOG_IS_ON(4)) { + dump_graph::DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } } - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*body->graph, new_func_name, &functionalized_fdef)); - - // Copy signature and ret from original FunctionDef. - *functionalized_fdef.mutable_signature() = fdef.signature(); - *functionalized_fdef.mutable_ret() = fdef.ret(); - functionalized_fdef.mutable_signature()->set_name(new_func_name); - - // Add rewritten FunctionDef into library. - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; + + if (*modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } } return ret_status; @@ -184,7 +231,7 @@ Status FunctionalizeControlFlowPass::Run( {"TPUCompile", "function"}, {"XlaLaunch", "function"}, }; - std::map<string, string> canonicalized_name_to_new_name; + std::map<string, absl::optional<string>> canonicalized_name_to_new_name; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); if (it == kNodeTypeToFunctionAttrMapping->end()) { @@ -199,12 +246,15 @@ Status FunctionalizeControlFlowPass::Run( << ". Corresponding function: " << func.name(); string new_func_name = options.flib_def->UniqueFunctionName( absl::StrCat(func.name(), "_f15n_")); + bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name)); - n->ClearAttr(func_attr); - func.set_name(new_func_name); - n->AddAttr(func_attr, func); + &canonicalized_name_to_new_name, &modified)); + if (modified) { + n->ClearAttr(func_attr); + func.set_name(new_func_name); + n->AddAttr(func_attr, func); + } } } diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 46794f7b50..224e5ea123 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -62,6 +62,7 @@ tf_kernel_library( "one_hot_op.cc", "pack_op.cc", "pad_op.cc", + "permute_op.cc", "pooling_ops.cc", "qr_op.cc", "quantize_and_dequantize_op.cc", @@ -94,6 +95,7 @@ tf_kernel_library( "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", + "tensor_list_ops.cc", "tile_ops.cc", "topk_op.cc", "training_ops.cc", @@ -113,11 +115,13 @@ tf_kernel_library( "shape_util.h", ], deps = [ + ":conv_op_helpers", ":if_op", ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:cholesky", "//tensorflow/compiler/tf2xla/lib:qr", "//tensorflow/compiler/tf2xla/lib:random", @@ -156,6 +160,7 @@ tf_kernel_library( "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", + "//tensorflow/core/kernels:list_kernels", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", @@ -172,6 +177,27 @@ tf_kernel_library( ], ) +cc_library( + name = "conv_op_helpers", + srcs = ["conv_op_helpers.cc"], + hdrs = ["conv_op_helpers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:numeric", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:bounds_check", + "//tensorflow/core/kernels:conv_ops", + "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/types:span", + ], +) + tf_kernel_library( name = "while_op", srcs = ["while_op.cc"], diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index b3ad0aea84..a267c0c72f 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -34,12 +34,6 @@ class FusedBatchNormOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { @@ -110,12 +104,6 @@ class FusedBatchNormGradOp : public XlaOpKernel { OP_REQUIRES( ctx, FormatFromString(data_format_str, &data_format_), errors::InvalidArgument("Invalid data format: ", data_format_str)); - OP_REQUIRES(ctx, - (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW || - data_format_ == FORMAT_HWNC || data_format_ == FORMAT_HWCN), - errors::InvalidArgument( - "Unsupported data format ", ToString(data_format_), - "; supported formats are NHWC, NCHW, HWNC and HWCN")); } void Compile(XlaOpKernelContext* ctx) override { diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 0d9a768a6f..47e517a657 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -55,6 +56,24 @@ XLA_MAKE_BINARY(Div, xla::Div(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Atan2, xla::Atan2(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); +// Implementation of DivNoNan. Pseudo-code: +// if (y == 0) { +// return 0 +// } else { +// return x / y; +// } +static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto y_equals_0 = xla::Eq(y, zero); + auto zeros = xla::ZerosLike(x); + auto result = xla::Select(y_equals_0, zeros, xla::Div(x, y)); + return result; +} +XLA_MAKE_BINARY(DivNoNan, + DivNoNanImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -65,7 +84,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions)); // } static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); if (DataTypeIsUnsigned(dtype)) { return xla::Div(x, y); } @@ -84,12 +103,30 @@ static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, XLA_MAKE_BINARY(FloorDiv, FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper)); +static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y))); +} +XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + +static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, + xla::XlaOp y, const BCast& broadcast_helper) { + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); + auto zero = XlaHelpers::Zero(b, dtype); + auto is_zero = xla::Eq(x, zero); + return xla::Select(is_zero, zero, xla::Div(x, y)); +} +XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper)); + // Implementation of FloorMod. Pseudo-code: // T trunc_mod = std::fmod(x, y); // return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y); static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x, xla::XlaOp y, const BCast& broadcast_helper) { - std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper); + std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper); auto zero = XlaHelpers::Zero(b, dtype); auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero)); auto trunc_mod = xla::Rem(x, y); diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 4bd7c74dca..9bb11fb67e 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -13,16 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/bcast.h" namespace tensorflow { namespace { @@ -37,60 +32,9 @@ class BroadcastToOp : public XlaOpKernel { TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); - OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(), - errors::InvalidArgument( - "Input rank (", input_shape.dims(), - ") must be less than or equal to the output rank (", - output_shape.dims(), ")")); - - auto input_dims = input_shape.dim_sizes(); - auto output_dims = output_shape.dim_sizes(); - - // Broadcasting is done right-to-left on right-aligned dimensions; reverse - // the two vectors so elements to be broadcast are aligned. - absl::c_reverse(input_dims); - absl::c_reverse(output_dims); - - std::vector<int64> broadcast_dims; - std::vector<int64> broadcast_shape; - for (int i = 0; i < output_shape.dims(); ++i) { - if (i < input_shape.dims()) { - OP_REQUIRES( - context, - (output_dims[i] == 0 && input_dims[i] == 0) || - (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0), - errors::InvalidArgument("invalid shape to broadcast from ", - input_shape.DebugString(), " to ", - output_shape.DebugString())); - - broadcast_dims.push_back(broadcast_shape.size()); - if (output_dims[i] == input_dims[i] || input_dims[i] == 1) { - broadcast_shape.push_back(output_dims[i]); - } - if (output_dims[i] != input_dims[i]) { - // Add dimensions [I, O/I], which we will later flatten to just - // [O]. We must do this in two phases since XLA broadcasting does not - // support tiling. - broadcast_shape.push_back(input_dims[i]); - broadcast_shape.push_back(output_dims[i] / input_dims[i]); - } - } else { - broadcast_shape.push_back(output_dims[i]); - } - } - absl::c_reverse(broadcast_dims); - int broadcast_shape_size = broadcast_shape.size(); - for (int64& broadcast_dim : broadcast_dims) { - broadcast_dim = broadcast_shape_size - broadcast_dim - 1; - } - absl::c_reverse(broadcast_shape); - xla::XlaOp output = xla::Reshape( - xla::BroadcastInDim(context->Input(0), - xla::ShapeUtil::MakeShape( - context->input_xla_type(0), broadcast_shape), - broadcast_dims), - output_shape.dim_sizes()); - context->SetOutput(0, output); + auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output.status()); + context->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index da8cf3fc6f..2628ef8e24 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" namespace tensorflow { namespace { @@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel { return; } break; + case DT_COMPLEX64: + if (proto_.scomplex_val_size() == 2) { + ctx->SetOutput( + 0, + xla::Broadcast(xla::ConstantR0<xla::complex64>( + b, xla::complex64(proto_.scomplex_val(0), + proto_.scomplex_val(1))), + shape.dim_sizes())); + return; + } + break; case DT_INT32: if (proto_.int_val_size() == 1) { ctx->SetOutput( diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc new file mode 100644 index 0000000000..c9a1be4940 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -0,0 +1,509 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for 2D convolution. + +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +// Returns the expanded size of a filter used for depthwise convolution. +// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. +xla::Shape ExpandedFilterShapeForDepthwiseConvolution(const xla::Shape& shape) { + int num_dims = shape.dimensions_size(); + CHECK_GE(num_dims, 2); // Crash OK + xla::Shape expanded_shape = shape; + expanded_shape.set_dimensions( + num_dims - 1, + shape.dimensions(num_dims - 2) * shape.dimensions(num_dims - 1)); + return expanded_shape; +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tensor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::XlaOp CreateExpandedFilterMask(const xla::Shape& filter_shape, + xla::XlaBuilder* builder) { + xla::Shape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = + filter_shape.dimensions(filter_shape.dimensions_size() - 1); + int64 input_feature = + filter_shape.dimensions(filter_shape.dimensions_size() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); + xla::XlaOp expanded_feature_iota = + xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + xla::Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + std::vector<int64> expanded_feature_broadcast_dims( + expanded_filter_shape.dimensions().begin(), + expanded_filter_shape.dimensions().end()); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = + xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dimensions_size() - 2}); +} + +// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to +// build a depthwise convolution. +xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, + const xla::XlaOp& filter) { + int64 input_feature_dim = filter_shape.dimensions_size() - 2; + int64 output_feature_dim = filter_shape.dimensions_size() - 1; + int64 depthwise_multiplier = filter_shape.dimensions(output_feature_dim); + int64 input_feature = filter_shape.dimensions(input_feature_dim); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + xla::Shape implicit_broadcast_filter_shape = filter_shape; + implicit_broadcast_filter_shape.set_dimensions(input_feature_dim, 1); + implicit_broadcast_filter_shape.set_dimensions( + output_feature_dim, depthwise_multiplier * input_feature); + return xla::Reshape( + filter, xla::AsInt64Slice(implicit_broadcast_filter_shape.dimensions())); +} + +// Reduces the results of the convolution with an expanded filter to the +// non-expanded filter. +xla::XlaOp ContractFilterForDepthwiseBackprop(const xla::Shape& filter_shape, + const xla::XlaOp& filter_backprop, + xla::XlaBuilder* builder) { + auto masked_expanded_filter = + xla::Select(CreateExpandedFilterMask(filter_shape, builder), + filter_backprop, xla::ZerosLike(filter_backprop)); + + auto elem_type = filter_shape.element_type(); + return xla::Reshape( + // This reduce does not need inputs to be converted with + // XlaHelpers::SumAccumulationType() since the select above guarantees + // that only one element is non zero, so there cannot be accumulated + // precision error. + xla::Reduce(masked_expanded_filter, xla::Zero(builder, elem_type), + CreateScalarAddComputation(elem_type, builder), + {filter_shape.dimensions_size() - 2}), + xla::AsInt64Slice(filter_shape.dimensions())); +} + +// Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA +// convolutions (as currently implemented). +Status CheckConvAttrs(const ConvOpAttrs& attrs) { + const int num_dims = attrs.num_spatial_dims + 2; + if (attrs.strides.size() != num_dims) { + return errors::InvalidArgument("Sliding window strides field must specify ", + num_dims, " dimensions"); + } + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + if (attrs.strides[batch_dim] != 1 || attrs.strides[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not yet support strides in the batch and " + "depth dimensions."); + } + if (attrs.dilations.size() != num_dims) { + return errors::InvalidArgument("Dilations field must specify ", num_dims, + " dimensions"); + } + if (attrs.dilations[batch_dim] != 1 || attrs.dilations[feature_dim] != 1) { + return errors::Unimplemented( + "Current implementation does not support dilations in the batch and " + "depth dimensions."); + } + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + if (attrs.dilations[input_dim] < 1) { + return errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + attrs.dilations[input_dim]); + } + } + return Status::OK(); +} + +// Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes +// to TensorShapes. +Status ConvBackpropComputeDimensionsV2XlaShapes( + StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, + const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, + absl::Span<const int32> dilations, const std::vector<int32>& strides, + Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) { + TensorShape input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(filter_shape, &filter_tensor_shape)); + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape)); + return ConvBackpropComputeDimensionsV2( + label, num_spatial_dims, input_tensor_shape, filter_tensor_shape, + out_backprop_tensor_shape, dilations, strides, padding, data_format, + dims); +} + +} // anonymous namespace + +xla::StatusOr<ConvOpAttrs> ConvOpAttrs::Create(int num_spatial_dims, + bool depthwise, + OpKernelConstruction* ctx) { + ConvOpAttrs attrs; + attrs.num_spatial_dims = num_spatial_dims; + attrs.depthwise = depthwise; + TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations)); + TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides)); + TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding)); + + string data_format; + TF_RETURN_IF_ERROR(ctx->GetAttr("data_format", &data_format)); + if (!FormatFromString(data_format, &attrs.data_format)) { + return errors::InvalidArgument("Invalid data format: ", data_format); + } + + return attrs; +} + +xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece /*type_string*/, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = conv_input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(conv_input)); + // Filter has the form [filter_rows, filter_cols, ..., in_depth, out_depth] + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + + // For 2D convolution, there should be 4 dimensions. + int num_dims = attrs.num_spatial_dims + 2; + if (input_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString()); + } + if (filter_shape.dimensions_size() != num_dims) { + return errors::InvalidArgument( + "filter must be ", num_dims, + "-dimensional: ", filter_shape.DebugString()); + } + + // The last two dimensions of the filter are the input and output shapes. + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + int64 in_depth = filter_shape.dimensions(attrs.num_spatial_dims); + // The 'C' dimension for input is in_depth. It must be the same as + // the filter's in_depth. + if (in_depth != input_shape.dimensions(feature_dim)) { + return errors::InvalidArgument( + "input and filter must have the same depth: ", in_depth, " vs ", + input_shape.dimensions(feature_dim)); + } + + if (attrs.depthwise) { + filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); + } + + xla::ConvolutionDimensionNumbers dims; + std::vector<int64> window_strides(attrs.num_spatial_dims); + std::vector<int64> lhs_dilation(attrs.num_spatial_dims, 1); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(attrs.num_spatial_dims); + dims.set_kernel_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = attrs.strides.at(dim); + rhs_dilation[i] = attrs.dilations.at(dim); + + int64 unused_output_size; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2( + input_shape.dimensions(dim), filter_shape.dimensions(i), + rhs_dilation[i], window_strides[i], attrs.padding, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + return xla::ConvGeneralDilated( + conv_input, filter, window_strides, padding, lhs_dilation, rhs_dilation, + dims, /*feature_group_count=*/attrs.depthwise ? in_depth : 1); +} + +xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + int num_dims = attrs.num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + auto* builder = filter.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape filter_shape, builder->GetShape(filter)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(out_backprop)); + + xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape, + out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding, + attrs.data_format, &dims)); + + // The input gradients are computed by a convolution of the output + // gradients and the filter, with some appropriate padding. See the + // comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); + + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the gradient. + dnums.set_kernel_input_feature_dimension(attrs.num_spatial_dims + 1); + dnums.set_kernel_output_feature_dimension(attrs.num_spatial_dims); + + std::vector<int64> kernel_spatial_dims(attrs.num_spatial_dims); + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + std::vector<int64> lhs_dilation(attrs.num_spatial_dims); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<int64> ones(attrs.num_spatial_dims, 1); + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); + + kernel_spatial_dims[i] = i; + padding[i] = {dims.spatial_dims[i].pad_before, + dims.spatial_dims[i].pad_after}; + lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = attrs.dilations[dim]; + } + + // Mirror the filter in the spatial dimensions. + xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); + + // activation gradients + // = gradients (with padding and dilation) <conv> mirrored_weights + return xla::ConvGeneralDilated( + out_backprop, mirrored_weights, /*window_strides=*/ones, padding, + lhs_dilation, rhs_dilation, dnums, + /*feature_group_count=*/ + attrs.depthwise ? out_backprop_shape.dimensions(feature_dim) / + filter_shape.dimensions(attrs.num_spatial_dims + 1) + : 1); +} + +xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs) { + TF_RETURN_IF_ERROR(CheckConvAttrs(attrs)); + + auto* builder = activations.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape activations_shape, + builder->GetShape(activations)); + TF_ASSIGN_OR_RETURN(xla::Shape out_backprop_shape, + builder->GetShape(gradients)); + const xla::Shape expanded_filter_shape = + attrs.depthwise ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) + : filter_shape; + + // Reuse dimension computation logic from conv_grad_ops.cc. + ConvBackpropDimensions dims; + TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes( + type_string, attrs.num_spatial_dims, activations_shape, + expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides, + attrs.padding, attrs.data_format, &dims)); + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + xla::ConvolutionDimensionNumbers dnums; + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we flip the roles of the batch and + // feature dimensions. + // Each spatial entry has size in_depth * batch + + // The last two dimensions of the filter are the input and output shapes. + int num_dims = attrs.num_spatial_dims + 2; + int n_dim = GetTensorBatchDimIndex(num_dims, attrs.data_format); + int c_dim = GetTensorFeatureDimIndex(num_dims, attrs.data_format); + + // Swap n_dim and c_dim in the activations. + dnums.set_input_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] + // where the batch becomes the input feature for the convolution. + dnums.set_kernel_input_feature_dimension(n_dim); + dnums.set_kernel_output_feature_dimension(c_dim); + + std::vector<std::pair<int64, int64>> padding(attrs.num_spatial_dims); + std::vector<int64> rhs_dilation(attrs.num_spatial_dims); + std::vector<int64> window_strides(attrs.num_spatial_dims); + std::vector<int64> ones(attrs.num_spatial_dims, 1); + + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(attrs.num_spatial_dims); + dnums.set_output_feature_dimension(attrs.num_spatial_dims + 1); + + for (int i = 0; i < attrs.num_spatial_dims; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i); + dnums.add_input_spatial_dimensions(dim); + dnums.add_kernel_spatial_dimensions(dim); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + // + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * attrs.dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; + + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64 pad_before = + attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0; + + padding[i] = {pad_before, pad_total - pad_before}; + rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = attrs.dilations[dim]; + } + + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + auto filter_backprop = + xla::ConvGeneralDilated(activations, gradients, window_strides, padding, + /*lhs_dilation=*/ones, rhs_dilation, dnums); + + if (attrs.depthwise) { + filter_backprop = ContractFilterForDepthwiseBackprop( + filter_shape, filter_backprop, activations.builder()); + } + + return filter_backprop; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h new file mode 100644 index 0000000000..6e1b70a478 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ + +#include <vector> + +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +// This header exposes utilities for translating TensorFlow convolution ops into +// XLA ops. +// +// conv_ops.cc contains lowerings for many of these TF convolution ops (e.g. +// Conv2D, Conv3DBackpropFilterV2), but you might want to use the utilities in +// this header to implement a new and exciting convolution op, for example a +// fused TensorFlow op that contains a convolution and other things. + +namespace tensorflow { + +// ConvOpAttrs contains all of the metadata necessary to specify a TF or XLA +// convolution. +struct ConvOpAttrs { + // Constructs a ConvOpAttrs, reading most of the attributes from `ctx`. + static xla::StatusOr<ConvOpAttrs> Create(int num_spatial_dims, bool depthwise, + OpKernelConstruction* ctx); + + bool depthwise; + int num_spatial_dims; + std::vector<int32> dilations; + std::vector<int32> strides; + Padding padding; + TensorFormat data_format; +}; + +// Creates a new XLA forward or backward convolution with the given inputs and +// attributes. +xla::StatusOr<xla::XlaOp> MakeXlaForwardConvOp(StringPiece type_string, + xla::XlaOp conv_input, + xla::XlaOp filter, + const ConvOpAttrs& attrs); +xla::StatusOr<xla::XlaOp> MakeXlaBackpropInputConvOp( + StringPiece type_string, const xla::Shape& input_shape, xla::XlaOp filter, + xla::XlaOp out_backprop, const ConvOpAttrs& attrs); +xla::StatusOr<xla::XlaOp> MakeXlaBackpropFilterConvOp( + StringPiece type_string, xla::XlaOp activations, + const xla::Shape& filter_shape, xla::XlaOp gradients, + const ConvOpAttrs& attrs); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CONV_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 674720e22f..cd7c820be0 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -15,12 +15,17 @@ limitations under the License. // XLA-specific Ops for 2D convolution. +#include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -33,250 +38,28 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { - namespace { -// Returns the expanded size of a filter used for depthwise convolution. -// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N]. -TensorShape ExpandedFilterShapeForDepthwiseConvolution( - const TensorShape& shape) { - int num_dims = shape.dims(); - CHECK_GE(num_dims, 2); - TensorShape expanded_shape = shape; - expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) * - shape.dim_size(num_dims - 1)); - return expanded_shape; -} - -// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. -xla::XlaOp CreateExpandedZero(const TensorShape& filter_shape, DataType dtype, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return xla::Broadcast(XlaHelpers::Zero(builder, dtype), - expanded_filter_shape.dim_sizes()); -} - -// Create a mask for depthwise convolution that will make a normal convolution -// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] -// depthwise filter this returns a [2, 2, 3, 6] tensor -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// 1 1 0 0 0 0 1 1 0 0 0 0 -// 0 0 1 1 0 0 0 0 1 1 0 0 -// 0 0 0 0 1 1 0 0 0 0 1 1 -// -// The first step is to create a one tensor, A, that is [3] -// 0 1 2 -// -// and another tensor, B, that is [3 * 2] -// 0 1 2 3 4 5 -// -// and divide B it by 2 to get -// 0 0 1 1 2 2 -// -// then we broadcast the B to [2, 2, 3, 3 * 2] -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// 0 0 1 1 2 2 0 0 1 1 2 2 -// -// Finally compare A and broadcasted B in dimension 2 amd return the result at -// the beginning of the comment. -xla::XlaOp CreateExpandedFilterMask(const TensorShape& filter_shape, - xla::XlaBuilder* builder) { - TensorShape expanded_filter_shape = - ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); - int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); - - // Create a M sized linspace and an M*N sized linspace that will be - // broadcasted into perpendicular dimensions and compared. - xla::XlaOp input_feature_iota = xla::Iota(builder, xla::S32, input_feature); - xla::XlaOp expanded_feature_iota = - xla::Iota(builder, xla::S32, input_feature * depthwise_multiplier); - - // Divide the M*N sized linspace by the depthwise_multiplier to create - // [0 0 1 1 2 2] in the example in the function comment. - expanded_feature_iota = - xla::Div(expanded_feature_iota, - XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, - depthwise_multiplier)); - - // Broadcast the N*M linspace to [H, W, ..., M, M*N]. - auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); - expanded_feature_broadcast_dims.pop_back(); - auto broadcasted_expanded_feature_iota = - xla::Broadcast(expanded_feature_iota, expanded_feature_broadcast_dims); - - // Compare the broadcasted linspace to the input feature linspace in the - // input feature dimension to create a diagonal predicate. - return xla::Eq(broadcasted_expanded_feature_iota, input_feature_iota, - {expanded_filter_shape.dims() - 2}); -} - -// Reshapes a filter of shape [H, W, ..., M, N] to [H, W, ..., 1, M*N]. Used to -// build a depthwise convolution. -xla::XlaOp ReshapeFilterForDepthwiseConvolution(const TensorShape& filter_shape, - const xla::XlaOp& filter) { - int64 input_feature_dim = filter_shape.dims() - 2; - int64 output_feature_dim = filter_shape.dims() - 1; - int64 depthwise_multiplier = filter_shape.dim_size(output_feature_dim); - int64 input_feature = filter_shape.dim_size(input_feature_dim); - - // Create a [H, W, ..., 1, N*M] reshape of the filter. - TensorShape implicit_broadcast_filter_shape = filter_shape; - implicit_broadcast_filter_shape.set_dim(input_feature_dim, 1); - implicit_broadcast_filter_shape.set_dim(output_feature_dim, - depthwise_multiplier * input_feature); - return xla::Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); -} - -// Reduces the results of the convolution with an expanded filter to the -// non-expanded filter. -xla::XlaOp ContractFilterForDepthwiseBackprop(XlaOpKernelContext* ctx, - const TensorShape& filter_shape, - DataType dtype, - const xla::XlaOp& filter_backprop, - xla::XlaBuilder* builder) { - auto masked_expanded_filter = xla::Select( - CreateExpandedFilterMask(filter_shape, builder), filter_backprop, - CreateExpandedZero(filter_shape, dtype, builder)); - return xla::Reshape( - // This reduce does not need inputs to be converted with - // XlaHelpers::SumAccumulationType() since the ExpandedFilterMask with - // ExpandedZero guarantees that only one element is non zero, so there - // cannot be accumulated precision error. - xla::Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), - *ctx->GetOrCreateAdd(dtype), {filter_shape.dims() - 2}), - filter_shape.dim_sizes()); -} - class ConvOp : public XlaOpKernel { public: explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape input_shape = ctx->InputShape(0); - // Input filter is of the following dimensions: - // [ filter_rows, filter_cols, ..., in_depth, out_depth] - const TensorShape filter_shape = ctx->InputShape(1); - - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES( - ctx, input_shape.dims() == num_dims(), - errors::InvalidArgument("input must be ", num_dims(), "-dimensional", - input_shape.DebugString())); - OP_REQUIRES( - ctx, filter_shape.dims() == num_dims(), - errors::InvalidArgument("filter must be ", num_dims(), - "-dimensional: ", filter_shape.DebugString())); - - // The last two dimension of the filter are the input and output shapes. - const int64 in_depth = filter_shape.dim_size(num_spatial_dims_); - - // The 'C' dimension for input is in_depth. It must be the same as - // the filter's in_depth. - OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim), - errors::InvalidArgument( - "input and filter must have the same depth: ", in_depth, - " vs ", input_shape.dim_size(feature_dim))); - - xla::XlaOp filter = ctx->Input(1); - if (depthwise_) { - filter = ReshapeFilterForDepthwiseConvolution(filter_shape, filter); - } - - xla::ConvolutionDimensionNumbers dims; - std::vector<int64> window_strides(num_spatial_dims_); - std::vector<int64> lhs_dilation(num_spatial_dims_, 1); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - - dims.set_input_batch_dimension(batch_dim); - dims.set_output_batch_dimension(batch_dim); - dims.set_input_feature_dimension(feature_dim); - dims.set_output_feature_dimension(feature_dim); - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_input_spatial_dimensions(dim); - dims.add_kernel_spatial_dimensions(i); - dims.add_output_spatial_dimensions(dim); - window_strides[i] = strides_.at(dim); - rhs_dilation[i] = dilations_.at(dim); - - int64 unused_output_size; - OP_REQUIRES_OK( - ctx, GetWindowedOutputSizeVerboseV2( - input_shape.dim_size(dim), filter_shape.dim_size(i), - rhs_dilation[i], window_strides[i], padding_, - &unused_output_size, &padding[i].first, &padding[i].second)); - } - - xla::XlaOp conv = xla::ConvGeneralDilated( - ctx->Input(0), filter, window_strides, padding, lhs_dilation, - rhs_dilation, dims, - /*feature_group_count=*/depthwise_ ? in_depth : 1); - ctx->SetOutput(0, conv); + xla::StatusOr<xla::XlaOp> conv = MakeXlaForwardConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), ctx->Input(1), attrs_); + OP_REQUIRES_OK(ctx, conv.status()); + ctx->SetOutput(0, conv.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvOp); @@ -308,124 +91,28 @@ class ConvBackpropInputOp : public XlaOpKernel { public: explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, strides_.size() == num_dims(), - errors::InvalidArgument("Sliding window strides field must " - "specify ", - num_dims(), " dimensions")); - int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - int feature_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - OP_REQUIRES( - ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - TensorShape input_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); - - const TensorShape filter_shape = ctx->InputShape(1); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - auto filter = ctx->Input(1); - auto out_backprop = ctx->Input(2); - - // The input gradients are computed by a convolution of the output - // gradients and the filter, with some appropriate padding. See the - // comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(batch_dim); - dnums.set_output_batch_dimension(batch_dim); - dnums.set_input_feature_dimension(feature_dim); - dnums.set_output_feature_dimension(feature_dim); - - // TF filter shape is [ H, W, ..., inC, outC ] - // Transpose the input and output features for computing the gradient. - dnums.set_kernel_input_feature_dimension(num_spatial_dims_ + 1); - dnums.set_kernel_output_feature_dimension(num_spatial_dims_); - - std::vector<int64> kernel_spatial_dims(num_spatial_dims_); - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - std::vector<int64> lhs_dilation(num_spatial_dims_); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<int64> ones(num_spatial_dims_, 1); - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(i); - dnums.add_output_spatial_dimensions(dim); - - kernel_spatial_dims[i] = i; - padding[i] = {dims.spatial_dims[i].pad_before, - dims.spatial_dims[i].pad_after}; - lhs_dilation[i] = dims.spatial_dims[i].stride; - rhs_dilation[i] = dilations_[dim]; - } - - // Mirror the filter in the spatial dimensions. - xla::XlaOp mirrored_weights = xla::Rev(filter, kernel_spatial_dims); - - // activation gradients - // = gradients (with padding and dilation) <conv> mirrored_weights - xla::XlaOp in_backprop = xla::ConvGeneralDilated( - out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, rhs_dilation, dnums, - /*feature_group_count=*/ - depthwise_ ? out_backprop_shape.dim_size(feature_dim) / - filter_shape.dim_size(num_spatial_dims_ + 1) - : 1); - - ctx->SetOutput(0, in_backprop); + TensorShape input_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_tensor_shape)); + xla::Shape input_shape = + TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); + + xla::StatusOr<xla::XlaOp> in_backprop = + MakeXlaBackpropInputConvOp(ctx->op_kernel().type_string(), input_shape, + ctx->Input(1), ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, in_backprop.status()); + ctx->SetOutput(0, in_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropInputOp); @@ -462,172 +149,28 @@ class ConvBackpropFilterOp : public XlaOpKernel { public: explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims, bool depthwise) - : XlaOpKernel(ctx), - num_spatial_dims_(num_spatial_dims), - depthwise_(depthwise) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - string data_format; - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + : XlaOpKernel(ctx) { + xla::StatusOr<ConvOpAttrs> attrs = + ConvOpAttrs::Create(num_spatial_dims, depthwise, ctx); + OP_REQUIRES_OK(ctx, attrs.status()); + attrs_ = attrs.ValueOrDie(); } - int num_dims() const { return num_spatial_dims_ + 2; } - void Compile(XlaOpKernelContext* ctx) override { - const int n_dim = GetTensorBatchDimIndex(num_dims(), data_format_); - const int c_dim = GetTensorFeatureDimIndex(num_dims(), data_format_); - - OP_REQUIRES( - ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1), - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - - OP_REQUIRES(ctx, dilations_.size() == num_dims(), - errors::InvalidArgument("Dilations field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES( - ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, - errors::Unimplemented("Current implementation does not support " - "dilations in the batch and depth dimensions.")); - for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - OP_REQUIRES(ctx, dilations_[input_dim] >= 1, - errors::Unimplemented("Dilation values must be positive; ", i, - "th spatial dimension had dilation ", - dilations_[input_dim])); - } - - const TensorShape activations_shape = ctx->InputShape(0); - TensorShape filter_shape; - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); - const TensorShape out_backprop_shape = ctx->InputShape(2); - - const TensorShape expanded_filter_shape = - depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape) - : filter_shape; - - // Reuse dimension computation logic from conv_grad_ops.cc. - ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, - ConvBackpropComputeDimensionsV2( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, dilations_, - strides_, padding_, data_format_, &dims)); - - xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp activations = ctx->Input(0); - xla::XlaOp gradients = ctx->Input(2); - - // The filter gradients are computed by a convolution of the input - // activations and the output gradients, with some appropriate padding. - // See the comment at the top of conv_grad_ops.h for details. - - xla::ConvolutionDimensionNumbers dnums; - - // The activations (inputs) form the LHS of the convolution. - // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] - // For the gradient computation, we flip the roles of the batch and - // feature dimensions. - // Each spatial entry has size in_depth * batch - - // Swap n_dim and c_dim in the activations. - dnums.set_input_batch_dimension(c_dim); - dnums.set_input_feature_dimension(n_dim); - - // The gradients become the RHS of the convolution. - // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] - // where the batch becomes the input feature for the convolution. - dnums.set_kernel_input_feature_dimension(n_dim); - dnums.set_kernel_output_feature_dimension(c_dim); - - std::vector<std::pair<int64, int64>> padding(num_spatial_dims_); - std::vector<int64> rhs_dilation(num_spatial_dims_); - std::vector<int64> window_strides(num_spatial_dims_); - std::vector<int64> ones(num_spatial_dims_, 1); - - // Tensorflow filter shape is [ H, W, ..., inC, outC ]. - for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_output_spatial_dimensions(i); - } - dnums.set_output_batch_dimension(num_spatial_dims_); - dnums.set_output_feature_dimension(num_spatial_dims_ + 1); - - for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_input_spatial_dimensions(dim); - dnums.add_kernel_spatial_dimensions(dim); - - // We will also need to pad the input with zeros such that after the - // convolution, we get the right size for the filter. - // The padded_in_rows should be such that when we convolve this with the - // expanded_out_rows as a filter, we should get filter_rows back. - // - const int64 padded_in_size = - dims.spatial_dims[i].expanded_output_size + - (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; - - // However it can be smaller than input_rows: in this - // case it means some of the inputs are not used. - // - // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: - // - // INPUT = [ A B C ] - // - // FILTER = [ x y ] - // - // and the output will only have one column: a = A * x + B * y - // - // and input "C" is not used at all. - // - // We apply negative padding in this case. - const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; - - // + For the VALID padding, we don't pad anything on the top/left side - // and pad the bottom/right side with the remaining space. - // + For the SAME padding, we pad top/left side the same as bottom/right - // side. - // - // In addition, if the padded input size is smaller than the input size, - // we need to ignore some training elements of the input. We do this by - // applying negative padding on the right/bottom. - const int64 pad_before = - padding_ == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0; - - padding[i] = {pad_before, pad_total - pad_before}; - rhs_dilation[i] = dims.spatial_dims[i].stride; - window_strides[i] = dilations_[dim]; - } - - // Besides padding the input, we will also expand output_rows to - // expanded_out_rows = (output_rows - 1) * stride + 1 - // with zeros in between: - // - // a . . . b . . . c . . . d . . . e - // - // This is done by specifying the window dilation factors in the - // convolution HLO below. - auto filter_backprop = - xla::ConvGeneralDilated(activations, gradients, window_strides, padding, - /*lhs_dilation=*/ones, rhs_dilation, dnums); - - if (depthwise_) { - filter_backprop = ContractFilterForDepthwiseBackprop( - ctx, filter_shape, ctx->input_type(0), filter_backprop, b); - } - ctx->SetOutput(0, filter_backprop); + TensorShape filter_tensor_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_tensor_shape)); + xla::Shape filter_shape = + TensorShapeToXLAShape(ctx->input_xla_type(0), filter_tensor_shape); + + xla::StatusOr<xla::XlaOp> filter_backprop = MakeXlaBackpropFilterConvOp( + ctx->op_kernel().type_string(), ctx->Input(0), filter_shape, + ctx->Input(2), attrs_); + OP_REQUIRES_OK(ctx, filter_backprop.status()); + ctx->SetOutput(0, filter_backprop.ValueOrDie()); } protected: - const int num_spatial_dims_; - const bool depthwise_; - std::vector<int32> dilations_; - std::vector<int32> strides_; - Padding padding_; - TensorFormat data_format_ = FORMAT_NHWC; + ConvOpAttrs attrs_; private: TF_DISALLOW_COPY_AND_ASSIGN(ConvBackpropFilterOp); diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index ef1015552d..234f7b4a01 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // compute valid broadcast shapes, but rely below on XLA to // automatically perform the broadcast assuming its valid shapes are // a superset of TensorFlow's valid shapes. - BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape)); + BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape), + /*fewer_dims_optimization=*/false); if (!bcast.IsValid()) { ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ", lhs_shape.DebugString(), " vs. ", @@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { } /* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper) { - // Manually construct the broadcasting since MapN does not do - // automatic broadcasting. The bcast helper ensures that - // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and - // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have - // the same shape, so can be operated on by MapN. - - // First reshape the inputs, which should be a metadata-only - // operation since we are flattening the dimensions in order. - auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape()); - auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape()); - - // Next broadcast the necessary input dimensions. We rely on the - // XLA optimizer to be smart about the fact that we are asking - // it to broadcast size 1 on some of these dimensions, to avoid - // adding complexity to this code. - auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast()); - int lhs_size = broadcast_helper.x_bcast().size(); - auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast()); - int rhs_size = broadcast_helper.y_bcast().size(); - - // Now reshape them to the correct output shape. After the - // broadcast each side is twice as wide as it should be, since the - // broadcast dimensions were prepended to the shape. Reshape - // flattening each original dimension with the prepended broadcast - // dimension. E.g. if we started out with lhs_shaped with shape - // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have - // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21]. - std::vector<int64> lhs_reorder; - for (int i = 0; i < lhs_size; ++i) { - lhs_reorder.push_back(i); - lhs_reorder.push_back(i + lhs_size); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) { + auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape()); + if (!lhs_output.ok()) { + xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status()); + return {error, error}; } - auto lhs_output = - xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape()); - std::vector<int64> rhs_reorder; - for (int i = 0; i < rhs_size; ++i) { - rhs_reorder.push_back(i); - rhs_reorder.push_back(i + rhs_size); + auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape()); + if (!rhs_output.ok()) { + xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status()); + return {error, error}; } - auto rhs_output = - xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape()); - - return {lhs_output, rhs_output}; + return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()}; } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index 6653944a91..516ead4bfe 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel { // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same // shape. static std::pair<xla::XlaOp, xla::XlaOp> Broadcast( - xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs, - const BCast& broadcast_helper); + xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index 33a73fe5fd..921b4340c0 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -355,6 +355,9 @@ class NonMaxSuppressionOp : public XlaOpKernel { OP_REQUIRES( context, output_size >= 0, errors::InvalidArgument("Need output_size >= 0, got ", output_size)); + OP_REQUIRES(context, output_size <= kint32max, + errors::InvalidArgument("Need output_size <= kint32Max, got ", + output_size)); xla::XlaOp score_thresh = context->Input("score_threshold"); xla::XlaOp iou_thresh = context->Input("iou_threshold"); @@ -439,12 +442,14 @@ class NonMaxSuppressionOp : public XlaOpKernel { xla::Broadcast(xla::ConstantR0<int32>(builder, 1), {num_boxes}), xla::Broadcast(xla::ConstantR0<int32>(builder, 0), {num_boxes})); - // num_valid is scalar. - xla::XlaOp num_valid = xla::Reduce( + // num_valid is scalar. Value should be bound by output_size. + xla::XlaOp num_valid_total = xla::Reduce( ones_included, /*init_value=*/xla::ConstantR0<int>(builder, 0), /*computation=*/CreateScalarAddComputation(xla::S32, builder), /*dimensions_to_reduce=*/{0}); + xla::XlaOp num_valid = + xla::Min(num_valid_total, xla::ConstantR0<int32>(builder, output_size)); xla::XlaOp output_tuple = TopK(scores_included, output_size); xla::XlaOp selected_indices = xla::GetTupleElement(output_tuple, 1); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index d9a0257b70..7b2bb4a7c5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" @@ -132,14 +133,14 @@ int64 CalculateUpperPadding(int64 in_size, int64 out_size, int64 kernel_size, // If the 2D kernel would be very large, the 1D kernel can be applied once in // each dimension due to the symmetry of the kernel along all axis to reduce the // computational intensity. -std::vector<float> Make1DKernel(int64 n) { +xla::XlaOp Make1DKernel(xla::XlaBuilder* builder, int64 n) { std::vector<float> kernel(n * 2 - 1); for (int64 i = 0; i < n; ++i) { float v = (i + 1.0f) / n; kernel[i] = v; kernel[n * 2 - 2 - i] = v; } - return kernel; + return xla::ConstantR1<float>(builder, kernel); } // Kernels with more than 16 spatial elements are considered intense and the @@ -149,41 +150,26 @@ const int64 kMax2DKernelSize = 16; xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, absl::Span<const int64> kernel_size, int64 channels) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); + auto depthwise_kernel = xla::Broadcast( + xla::Zero(builder, xla::F32), + {(2 * kernel_size[0] - 1), (2 * kernel_size[1] - 1), channels, 1}); - auto diag = xla::ConvertElementType( - xla::Eq(xla::Broadcast(channels_iota, {2 * kernel_size[0] - 1, - 2 * kernel_size[1] - 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); return xla::Mul( - xla::Mul(diag, - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), + xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[1]), /*broadcast_dimensions=*/{1}), - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), + Make1DKernel(builder, kernel_size[0]), /*broadcast_dimensions=*/{0}); } xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, absl::Span<const int64> kernel_size, int64 channels, int64 dim) { - xla::XlaOp channels_iota = xla::Iota(builder, xla::S32, channels); - - auto diag = xla::ConvertElementType( - xla::Eq( - xla::Broadcast(channels_iota, - {dim == 0 ? (2 * kernel_size[0] - 1) : 1, - dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), - channels_iota, /*broadcast_dimensions=*/{2}), - xla::PrimitiveType::F32); - if (dim == 1) { - return xla::Mul( - diag, xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[1])), - /*broadcast_dimensions=*/{1}); - } - return xla::Mul(diag, - xla::ConstantR1<float>(builder, Make1DKernel(kernel_size[0])), - /*broadcast_dimensions=*/{0}); + auto depthwise_kernel = + xla::Broadcast(xla::Zero(builder, xla::F32), + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels, 1}); + return xla::Add(depthwise_kernel, Make1DKernel(builder, kernel_size[dim]), + /*broadcast_dimensions=*/{dim}); } xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, @@ -206,8 +192,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { dimension_numbers.add_input_spatial_dimensions(1 + i); dimension_numbers.add_output_spatial_dimensions(1 + i); @@ -285,7 +271,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, upper_padding[0]}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -294,7 +281,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, upper_padding[0]}, {0, 0}}, /*lhs_dilation=*/{dims.kernel_size[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); xla::XlaOp kernel1 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); output = xla::ConvGeneralDilated( @@ -302,7 +290,8 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, upper_padding[1]}}, /*lhs_dilation=*/{1, dims.kernel_size[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // Add broadcasts to handle expanding from a size == 1 dimension to a @@ -331,15 +320,15 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, xla::ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(0); dimension_numbers.set_output_batch_dimension(0); - dimension_numbers.set_input_feature_dimension(3); - dimension_numbers.set_output_feature_dimension(3); + dimension_numbers.set_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_output_feature_dimension(num_spatial_dims + 1); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_input_spatial_dimensions(1 + i); - dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_input_spatial_dimensions(i + 1); + dimension_numbers.add_output_spatial_dimensions(i + 1); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); xla::XlaOp output; if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { xla::XlaOp kernel = @@ -362,7 +351,8 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } else { xla::XlaOp kernel0 = MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); @@ -388,14 +378,16 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, /*padding=*/ {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, /*lhs_dilation=*/{dims.stride[0], 1}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); output = xla::ConvGeneralDilated( output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, /*padding=*/ {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, /*lhs_dilation=*/{1, dims.stride[1]}, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + /*rhs_dilation=*/{1, 1}, dimension_numbers, + /*feature_group_count=*/channels); } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc new file mode 100644 index 0000000000..0764e5503d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> +#include <vector> + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { +namespace { + +class DataFormatVecPermuteOp : public XlaOpKernel { + public: + explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_)); + OP_REQUIRES( + ctx, src_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + TensorFormat data_format; + OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_)); + OP_REQUIRES( + ctx, dst_format_.size() == 4, + errors::InvalidArgument("Data format should have 4 characters")); + OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format), + errors::InvalidArgument("Invalid data format")); + } + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + const TensorShape input_tensor_shape = ctx->InputShape(0); + int input_rank = input_tensor_shape.dims(); + OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2, + errors::InvalidArgument( + "Input must be a vector or matrix, but got shape ", + input_tensor_shape.DebugString())); + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(0) == 4, + errors::InvalidArgument( + "First dimension of input must be of size 4, but got shape ", + input_tensor_shape.DebugString())); + if (input_rank == 2) { + OP_REQUIRES( + ctx, input_tensor_shape.dim_size(1) == 2, + errors::InvalidArgument( + "Second dimension of 2D input must be of size 2, but got shape ", + input_tensor_shape.DebugString())); + } + std::vector<int32> dst_indices(4, 0); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (src_format_[i] == dst_format_[j]) { + dst_indices[i] = j; + break; + } + } + } + auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices)); + if (input_rank == 2) { + keys = xla::BroadcastInDim( + keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0}); + } + auto sorted = xla::Sort(keys, ctx->Input(0), 0); + auto output = xla::GetTupleElement(sorted, 1); + ctx->SetOutput(0, output); + } + + private: + string src_format_; + string dst_format_; + + TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp); +}; + +// TODO(b/115384656): Support DT_INT64. +REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32), + DataFormatVecPermuteOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc index 8102faad28..8eee5b1299 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc @@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel { std::vector<int64> window_dimensions; std::vector<int64> window_strides; + std::vector<int64> base_dilations; + std::vector<int64> window_dilations; OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( "window_dimensions", &window_dimensions)); OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides", &window_strides)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations", + &base_dilations)); + OP_REQUIRES_OK(context, context->ConstantInputAsIntVector( + "window_dilations", &window_dilations)); const int rank = input_shape.dims(); OP_REQUIRES(context, rank == window_dimensions.size(), @@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel { "The size of window_strides must be equal to the input " "rank (", window_strides.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == base_dilations.size(), + errors::InvalidArgument( + "The size of base_dilations must be equal to the input " + "rank (", + base_dilations.size(), " vs. ", rank, ")")); + OP_REQUIRES(context, rank == window_dilations.size(), + errors::InvalidArgument( + "The size of window_dilations must be equal to the input " + "rank (", + window_dilations.size(), " vs. ", rank, ")")); // Build the reducer function. XlaCompiler::Argument reducer_arg; @@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel { xla::XlaOp output = xla::ReduceWindowWithGeneralPadding( context->Input(0), context->Input(1), *reducer.computation, - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); context->SetOutput(0, output); } @@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel { REGISTER_XLA_OP(Name("XlaReduceWindow") .CompileTimeConstInput("window_dimensions") .CompileTimeConstInput("window_strides") + .CompileTimeConstInput("base_dilations") + .CompileTimeConstInput("window_dilations") .CompileTimeConstInput("padding"), ReduceWindowOp); diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index ab094d7dd1..57afd608de 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel { } auto output = xla::ReduceWindowWithGeneralPadding( XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init, - *reducer, window_dims, window_strides, padding); + *reducer, window_dims, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding); output = XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0)); diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 25a5bcbe1d..0c32b8def0 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/numeric.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -55,10 +57,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { // The type-specific part of the implementation of Range. template <typename T> -Status CreateRangeTensor(const xla::LiteralSlice& start_literal, - const xla::LiteralSlice& limit_literal, - const xla::LiteralSlice& delta_literal, - Tensor* output) { +xla::StatusOr<xla::XlaOp> CreateRangeTensor( + const xla::LiteralSlice& start_literal, + const xla::LiteralSlice& limit_literal, + const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) { T start = start_literal.Get<T>({}); T limit = limit_literal.Get<T>({}); T delta = delta_literal.Get<T>({}); @@ -82,14 +84,10 @@ Status CreateRangeTensor(const xla::LiteralSlice& start_literal, ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta)) : std::ceil(std::abs((limit - start) / delta))); - *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size})); - auto flat = output->flat<T>(); - T val = start; - for (int64 i = 0; i < size; ++i) { - flat(i) = val; - val += delta; - } - return Status::OK(); + return xla::ConstantR0(builder, start) + + xla::ConstantR0(builder, delta) * + xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(), + size); } class RangeOp : public XlaOpKernel { @@ -115,27 +113,26 @@ class RangeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta)); DataType type = input_type(0); - Tensor output; - Status status; + xla::StatusOr<xla::XlaOp> output; switch (type) { case DT_INT32: - status = CreateRangeTensor<int32>(start, limit, delta, &output); + output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder()); break; case DT_INT64: - status = CreateRangeTensor<int64>(start, limit, delta, &output); + output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder()); break; case DT_FLOAT: - status = CreateRangeTensor<float>(start, limit, delta, &output); + output = CreateRangeTensor<float>(start, limit, delta, ctx->builder()); break; case DT_DOUBLE: - status = CreateRangeTensor<double>(start, limit, delta, &output); + output = CreateRangeTensor<double>(start, limit, delta, ctx->builder()); break; default: - status = errors::InvalidArgument("Invalid type for Range ", + output = errors::InvalidArgument("Invalid type for Range ", DataTypeString(type)); } - OP_REQUIRES_OK(ctx, status); - ctx->SetConstantOutput(0, output); + OP_REQUIRES_OK(ctx, output.status()); + ctx->SetOutput(0, output.ValueOrDie()); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 2e0a69b70e..c8a0f31a03 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index aaeeae01cc..45f03d8c21 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel { explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - context->SetOutput(0, xla::Sort(context->Input(0))); + context->SetOutput(0, xla::Sort(context->Input("input"))); } }; REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp); +class XlaKeyValueSortOp : public XlaOpKernel { + public: + explicit XlaKeyValueSortOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::XlaOp result = + xla::Sort(context->Input("keys"), context->Input("values")); + context->SetOutput(0, xla::GetTupleElement(result, 0)); + context->SetOutput(1, xla::GetTupleElement(result, 1)); + } +}; + +REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp); + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc new file mode 100644 index 0000000000..74d4fcc425 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -0,0 +1,226 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA TensorList operators. + +#include <limits> +#include <vector> + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op, + TensorShape* tensor_list_shape) { + auto shape_or_status = builder->GetShape(op); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + xla::Shape shape = shape_or_status.ValueOrDie(); + TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape)); + return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0), + tensor_list_shape); +} + +class TensorListReserveOp : public XlaOpKernel { + public: + explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape element_shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape)); + int64 num_elements; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements)); + + TensorShape tensor_shape; + tensor_shape.AddDim(num_elements); + tensor_shape.AppendShape(element_shape); + + xla::XlaBuilder* b = ctx->builder(); + ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_), + tensor_shape.dim_sizes()), + xla::ConstantR0<int32>(b, 0)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp); +}; + +REGISTER_XLA_OP(Name("TensorListReserve") + .CompileTimeConstInput("element_shape") + .CompileTimeConstInput("num_elements"), + TensorListReserveOp); + +class EmptyTensorListOp : public XlaOpKernel { + public: + explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + ctx->CtxFailure( + errors::InvalidArgument("XLA compilation requires a fixed tensor list " + "size. Use TensorListReserve instead.")); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp); +}; + +REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp); + +class TensorListElementShapeOp : public XlaOpKernel { + public: + explicit TensorListElementShapeOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape)); + shape.RemoveDim(0); + + switch (shape_type_) { + case DT_INT64: + ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes())); + break; + case DT_INT32: { + std::vector<int32> size; + for (int64 s : shape.dim_sizes()) { + size.push_back(s); + } + ctx->SetOutput(0, xla::ConstantR1<int32>(b, size)); + break; + } + default: + ctx->CtxFailure( + errors::InvalidArgument("Unsupported shape type requested")); + return; + } + } + + private: + DataType shape_type_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp); +}; + +REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp); + +class TensorListPushBackOp : public XlaOpKernel { + public: + explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp list = ctx->Input(0); + TensorShape elem_shape = ctx->InputShape(1); + + xla::XlaOp ta = xla::GetTupleElement(list, 0); + xla::XlaOp index = xla::GetTupleElement(list, 1); + xla::XlaOp value = ctx->Input(1); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + + TensorShape slice_shape = elem_shape; + slice_shape.InsertDim(0, 1LL); + auto update = xla::Reshape(value, slice_shape.dim_sizes()); + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + ctx->SetOutput( + 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices), + index + xla::ConstantR0<int32>(b, 1)})); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp); + +class TensorListPopBackOp : public XlaOpKernel { + public: + explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::XlaBuilder* b = ctx->builder(); + xla::XlaOp state = ctx->Input(0); + + TensorShape shape; + OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape)); + + xla::XlaOp ta = xla::GetTupleElement(state, 0); + xla::XlaOp index = xla::GetTupleElement(state, 1); + + index = index - xla::ConstantR0<int32>(b, 1); + + // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. + auto start_indices = + xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0), + xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}})); + + auto slice_shape = shape.dim_sizes(); + slice_shape[0] = 1LL; + + // TODO(phawkins): We don't check the index is in bounds --- there is no + // error mechanism in XLA. + xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape); + // Remove the leading '1' dimension. + std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end()); + + ctx->SetOutput(0, xla::Tuple(b, {ta, index})); + ctx->SetOutput(1, xla::Reshape(read, value_shape)); + } + + private: + DataType dtype_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp); +}; + +REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 8597e7f139..1ce3930fd1 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -32,6 +32,22 @@ cc_library( ) cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( name = "cholesky", srcs = ["cholesky.cc"], hdrs = ["cholesky.h"], diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc new file mode 100644 index 0000000000..3e402ef855 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/broadcast.h" + +#include <vector> + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { + +xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input, + absl::Span<int64 const> output_dims) { + xla::XlaBuilder* builder = input.builder(); + TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input)); + absl::Span<int64 const> input_dims = + xla::AsInt64Slice(input_shape.dimensions()); + + if (input_dims == output_dims) { + return input; + } + + if (input_dims.size() > output_dims.size()) { + return errors::InvalidArgument( + "Input shape (", xla::ShapeUtil::HumanString(input_shape), + ") must have rank less than or equal to the output shape [", + absl::StrJoin(output_dims, ","), "]"); + } + + std::vector<int64> broadcast_dims; + std::vector<int64> broadcast_shape; + auto input_it = input_dims.rbegin(); + for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend(); + ++output_it) { + if (input_it != input_dims.rend()) { + if (!(*output_it == 0 && *input_it == 0) && + !(*input_it != 0 && *output_it % *input_it == 0)) { + return errors::InvalidArgument("Invalid shape broadcast from ", + xla::ShapeUtil::HumanString(input_shape), + " to [", absl::StrJoin(output_dims, ","), + "]"); + } + + broadcast_dims.push_back(broadcast_shape.size()); + if (*output_it == *input_it) { + broadcast_shape.push_back(*output_it); + } else if (*output_it != *input_it) { + // Add dimensions [I, O/I], which we will later flatten to just + // [O]. We must do this in two phases since XLA broadcasting does not + // support tiling. + broadcast_shape.push_back(*input_it); + broadcast_shape.push_back(*output_it / *input_it); + } + ++input_it; + } else { + broadcast_shape.push_back(*output_it); + } + } + TF_RET_CHECK(input_it == input_dims.rend()); + + absl::c_reverse(broadcast_dims); + int broadcast_shape_size = broadcast_shape.size(); + for (int64& broadcast_dim : broadcast_dims) { + broadcast_dim = broadcast_shape_size - broadcast_dim - 1; + } + absl::c_reverse(broadcast_shape); + xla::XlaOp output = xla::BroadcastInDim( + input, + xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape), + broadcast_dims); + if (broadcast_shape != output_dims) { + output = xla::Reshape(output, output_dims); + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/tf2xla/lib/broadcast.h index 35b4b4e20b..591e696f06 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -13,16 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" -#include "tensorflow/core/lib/gtl/map_util.h" +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ -namespace xla { -namespace gpu { +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" -bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { - return !config.debug_options().xla_backend_extra_options().count( - "xla_gpu_experimental_conv_disable_layout_heuristic"); -} +namespace tensorflow { -} // namespace gpu -} // namespace xla +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input, + absl::Span<int64 const> output_dims); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 38dfde165d..2b1c2ced92 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -38,12 +38,10 @@ xla::StatusOr<xla::XlaOp> XlaScatter( combiner, xla::XlaBuilder* builder) { TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer)); - TF_RETURN_IF_ERROR(builder->GetShape(updates).status()); + TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates)); TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices)); absl::Span<const int64> indices_dims = xla::AsInt64Slice(indices_shape.dimensions()); - absl::Span<const int64> buffer_dims = - xla::AsInt64Slice(buffer_shape.dimensions()); // If the indices are N-dimensional, the minor dimension of indices contains // the indices to update. Otherwise the indices are all scalars. @@ -81,104 +79,129 @@ xla::StatusOr<xla::XlaOp> XlaScatter( } } - // Shape of the non-indexed dimensions of the buffer. - std::vector<int64> buffer_shape_post_axes( - buffer_dims.begin() + num_index_dims, buffer_dims.end()); - - // Flatten the major dimensions of indices and updates into a single dimension - // for ease of iteration. - std::vector<int64> flat_indices_shape({num_indices}); - if (indices_are_vectors) { - flat_indices_shape.push_back(num_index_dims); + // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of + // shape [3,3]: + // NOTE: ***This case will not be generated by any of the tf.scatter ops.*** + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[3,2] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={0}, + // inserted_window_dims={1}, + // scatter_dims_to_operand_dims={1}, + // index_vector_dim=1 + // + // + // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of + // shape [3,3]: + // + // operand = s32[3,3] parameter(0) + // indices = s32[2] parameter(1) + // updates = s32[2,3] parameter(2) + // scatter = s32[3,3] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // + // + // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of + // shape [3,3,2] + // + // operand = s32[3,3,2] parameter(0) + // indices = s32[2,2] parameter(1) + // updates = s32[2,2] parameter(2) + // scatter = s32[3,3,2] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={1}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0,1}, + // index_vector_dim=1 + // + // + // Example of a scatter updating slices of shape [] in a tensor of shape [1,1] + // + // operand = s32[1,1] parameter(0) + // indices = s32[1] parameter(1) + // updates = s32[1] parameter(2) + // scatter = s32[1,1] scatter(operand, indices, updates), + // to_apply=update_computation, + // update_window_dims={}, + // inserted_window_dims={0,1}, + // scatter_dims_to_operand_dims={0}, + // index_vector_dim=1 + // Note that updates operand would be broadcasted into [1] in this case. + // + + xla::ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(indices_are_vectors + ? indices_shape.dimensions_size() - 1 + : indices_shape.dimensions_size()); + + int64 updates_rank = xla::ShapeUtil::Rank(updates_shape); + int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape); + int64 num_window_dims_in_updates = buffer_rank - num_index_dims; + + // If the rank of `updates` is 0 and does not match the expected rank of + // updates, broadcast `updates` to the expected shape of updates. + auto new_updates = updates; + std::vector<int64> expected_updates_dims(indices_dims.begin(), + indices_dims.end()); + for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) { + expected_updates_dims.push_back(buffer_shape.dimensions(dim)); + } + int64 expected_updates_rank = expected_updates_dims.size(); + if (updates_rank == 0 && expected_updates_rank != 0) { + new_updates = xla::Broadcast(updates, expected_updates_dims); + TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); + updates_rank = xla::ShapeUtil::Rank(updates_shape); } - std::vector<int64> flat_updates_shape({num_indices}); - flat_updates_shape.insert(flat_updates_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - - // Construct the initial values of the loop-carried Tensors. - auto flat_indices = xla::Reshape(indices, flat_indices_shape); - auto flat_updates = xla::Reshape(updates, flat_updates_shape); - auto init = {flat_indices, flat_updates, buffer}; - - // Constructs the loop body. The implementation of scatter is essentially: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // update = dynamic-slice(updates, i) - // buffer = dynamic-update-slice(buffer, update, index) - auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars, - xla::XlaBuilder* body_builder) { - auto indices = loop_vars[0]; - auto updates = loop_vars[1]; - auto buffer = loop_vars[2]; - - auto zero_index = xla::ConstantLiteral( - body_builder, xla::LiteralUtil::Zero(indices_shape.element_type())); - - // Slice the i-th index from the indices array. - xla::XlaOp index; - auto indices_offset = xla::Reshape(i, {1}); - if (indices_are_vectors) { - indices_offset = xla::Pad(indices_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, 1}})); - - index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims}); - index = xla::Collapse(index, {0, 1}); - } else { - index = xla::DynamicSlice(indices, indices_offset, {1}); + if (updates_rank > 0) { + for (int64 i = (updates_rank - num_window_dims_in_updates); + i < updates_rank; ++i) { + dim_numbers.add_update_window_dims(i); } + } - // Discard updates with negative indices, since some users expect this. - auto index_in_range = xla::ReduceAll( - xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true), - xla::CreateScalarAndComputation(xla::PRED, body_builder)); - - // Make the index in bounds to prevent implementation defined behavior. - index = xla::Max(index, zero_index); - index = xla::Pad( - index, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - - // Slice the i-th index from the updates array. - auto updates_offset = xla::Reshape(i, {1}); - updates_offset = xla::Pad( - updates_offset, zero_index, - xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); - std::vector<int64> flat_updates_slice_shape({1}); - flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - auto update = - xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape); - - // Unflatten the major (iteration) dimensions of the slice to their - // original shape. - std::vector<int64> updates_slice_shape(num_index_dims, 1); - updates_slice_shape.insert(updates_slice_shape.end(), - buffer_shape_post_axes.begin(), - buffer_shape_post_axes.end()); - update = xla::Reshape(update, updates_slice_shape); - - // Apply the update to the buffer. If there is a combiner, use it to merge - // the current values with the update. - auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape); + for (int64 i = 0; i < num_index_dims; ++i) { + dim_numbers.add_inserted_window_dims(i); + dim_numbers.add_scatter_dims_to_operand_dims(i); + } + + // Build the combiner computation. + xla::XlaComputation combiner_computation; + { + xla::XlaBuilder cb("scatter-combiner"); + auto xla_scalar_shape = + xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {}); + auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0"); + auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1"); if (combiner) { - update = combiner(current_value, update, body_builder); + combiner(p0, p1, &cb); } - // Use the current value instead of the update if the index is out of - // bounds. - update = xla::Select(index_in_range, update, current_value); - // Apply the update. - buffer = xla::DynamicUpdateSlice(buffer, update, index); - - return std::vector<xla::XlaOp>{indices, updates, buffer}; - }; - - TF_ASSIGN_OR_RETURN(auto outputs, - XlaForEachIndex(num_indices, indices_shape.element_type(), - body_fn, init, "scatter", builder)); - return outputs[2]; + combiner_computation = cb.Build().ConsumeValueOrDie(); + } + + VLOG(3) << "Scatter op:"; + VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape); + VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape); + VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape); + VLOG(3) << " Scatter Dimension Numbers: "; + VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim(); + VLOG(3) << " update_window_dims: [" + << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]"; + VLOG(3) << " inserted_window_dims: [" + << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]"; + VLOG(3) << " scatter_dims_to_operand_dims: [" + << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",") + << "]"; + + return xla::Scatter(buffer, indices, new_updates, combiner_computation, + dim_numbers); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 13a5f1b850..4cf478c4b9 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -34,7 +34,11 @@ namespace tensorflow { // Otherwise, `indices_are_vectors`, then indices are multidimensional and the // minor dimension of `indices` represents a vector of indices. // -// If any indices are negative, the corresponding update is discarded. +// If `updates` is a scalar, then it will be broadcasted into the expected shape +// of updates. +// +// If any part of the update region is out-of-bounds, the corresponding update +// is discarded. // // If a `combiner` is provided, updates are combined with the existing values in // the buffer using the combiner function. Otherwise, the updates replace the diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 02363500ef..bd2c0a5ee8 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -121,8 +121,8 @@ Wraps the XLA DynamicSlice operator, documented at DynamicSlice extracts a sub-array from the input array at dynamic start_indices. The size of the slice in each dimension is passed in size_indices, which specify the end point of exclusive slice intervals in each -dimension -- [start, start + size). The shape of start_indices must be rank == -1, with dimension size equal to the rank of operand. +dimension -- [start, start + size). The shape of start_indices must have rank 1, +with dimension size equal to the rank of operand. input: A `Tensor` of type T. @@ -131,7 +131,8 @@ start_indices: Rank 1 tensor of N integers containing the starting indices of start_indices: List of N integers containing the slice size for each dimension. Each value must be strictly greater than zero, and start + size - must be less + must be less than or equal to the size of the dimension to avoid + implementation defined behavior. )doc"); REGISTER_OP("XlaDynamicUpdateSlice") @@ -282,6 +283,8 @@ REGISTER_OP("XlaReduceWindow") .Input("init_value: T") .Input("window_dimensions: Tindices") .Input("window_strides: Tindices") + .Input("base_dilations: Tindices") + .Input("window_dilations: Tindices") .Input("padding: Tindices") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") @@ -353,12 +356,33 @@ Wraps the XLA Sort operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#sort . -Sorts a tensor. Currently only rank 1 sorts in ascending order are supported. +Sorts a tensor. Currently only sorts in ascending order are supported. input: A `Tensor` of type T. output: A `Tensor` of type T. )doc"); +REGISTER_OP("XlaKeyValueSort") + .Input("keys: K") + .Input("values: V") + .Output("sorted_keys: K") + .Output("sorted_values: V") + .Attr("K: realnumbertype") + .Attr("V: type") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Wraps the XLA Sort operator, documented at + https://www.tensorflow.org/performance/xla/operation_semantics#sort +. + +Sorts a tensor. Currently only sorts in ascending order are supported. + +keys: A `Tensor` of type K. +values: A `Tensor` of type V. +sorted_keys: A `Tensor` of type K. +sorted_values: A `Tensor` of type V. +)doc"); + // TODO(b/37549631) setting the While Op to always be stateful is too // conservative. REGISTER_OP("XlaWhile") diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index 27dd18a9bb..5e86b5d8ec 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) - shape = array_ops.concat( - [constant_op.constant(dims), - array_ops.shape(x)], axis=0) + shape = array_ops.concat([constant_op.constant(dims), + array_ops.shape(x)], + axis=0) return array_ops.broadcast_to(x, shape, name=name) @@ -320,6 +320,8 @@ def reduce_window(operand, reducer, window_dimensions, window_strides=None, + base_dilations=None, + window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. @@ -332,22 +334,27 @@ def reduce_window(operand, init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers - window_strides: inter-window strides, as a list of integers. Optional; - if omitted, defaults to strides of 1. + window_strides: inter-window strides, as a list of integers. Optional; if + omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. + Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) + base_dilations = base_dilations or [1] * len(window_dimensions) + window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, + base_dilations=base_dilations, + window_dilations=window_dilations, padding=padding, computation=reducer, name=name) @@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides): sort = gen_xla_ops.xla_sort +key_value_sort = gen_xla_ops.xla_key_value_sort while_loop = gen_xla_ops.xla_while diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 20f2ce2919..72b240996f 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" -#include "tensorflow/core/lib/gtl/flatmap.h" +#include "absl/container/flat_hash_map.h" namespace tensorflow { /*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( @@ -30,9 +30,9 @@ namespace tensorflow { } } -static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* +static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>* CreateResourceOpInfoMap() { - auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>; + auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>; auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { @@ -103,15 +103,15 @@ CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& +static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& GetStaticResourceOpInfoMap() { - static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map = - CreateResourceOpInfoMap(); + static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>* + op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { - const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos = + const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index a85ef040a7..956f597301 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { } TEST(ResourceOperationTableTest, HaveAllResourceOps) { - gtl::FlatMap<string, bool> known_resource_ops; + absl::flat_hash_map<string, bool> known_resource_ops; for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 9d1992205b..b589512dcd 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -41,6 +41,14 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, // Convert a TensorShape into the equivalent XLA Shape proto. Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape) { + xla::PrimitiveType type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); + *shape = TensorShapeToXLAShape(type, tensor_shape); + return Status::OK(); +} + +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape) { int rank = tensor_shape.dims(); std::vector<int64> dimensions(rank); std::vector<int64> layout(rank); @@ -50,11 +58,7 @@ Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, // XLA uses minor-to-major; Tensorflow uses major-to-minor. std::iota(layout.rbegin(), layout.rend(), 0); - xla::PrimitiveType type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); - - *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); - return Status::OK(); + return xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 58240b9c96..f7e34a5b40 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -35,6 +35,11 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, xla::Shape* shape); +// Converts a TensorShape into the equivalent XLA Shape proto, taking an +// xla::PrimitiveType to specify the element type. This never fails. +xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, + const TensorShape& tensor_shape); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index d6f42bac86..01dd3ba10f 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def, } if (node_def.op() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. - return false; + // Gradient op has "f" attr, which is set to the function we are getting + // gradient for. We need to functionalize the gradient function. + return true; } for (const auto& iter : node_def.attr()) { @@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( if (flr->GetFunctionLibraryDefinition()->Contains(op)) { // This is a function call node. AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); - results.emplace_back(AssociatedFunctionInfo(op, attrs)); + results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs)); } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) { - // Skip gradient op. Gradient op has "f" attr, which is set to the function - // we are getting gradient for. That function is not associated with the op. + // This is a SymbolicGradient op. + AttrValueMap attrs(node.attrs().begin(), node.attrs().end()); + results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs)); } else { // Collect all function attrs for the node. for (auto& iter : node.attrs()) { if (iter.second.has_func()) { VLOG(2) << "Found function attr for node " << node.name() << ": " << iter.first << " = " << iter.second.func().name(); - results.emplace_back(AssociatedFunctionInfo( + results.emplace_back(AssociatedFunctionInfo::FunctionAttr( iter.second.func().name(), iter.second.func().attr(), iter.first)); } } @@ -410,6 +411,21 @@ Status RewriteAssociatedFunction( graph->RemoveNode(node); break; } + case AssociatedFunctionInfo::kSymbolicGradient: { + NameAttrList func; + TF_RETURN_IF_ERROR(GetNodeAttr( + node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func)); + GradientDef gradient_def; + gradient_def.set_function_name(func.name()); + gradient_def.set_gradient_func(rewritten_function_name); + string original_grad_func = fld->FindGradient(func.name()); + if (original_grad_func.empty()) { + TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def)); + } else if (original_grad_func != rewritten_function_name) { + TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def)); + } + break; + } case AssociatedFunctionInfo::kFunctionAttr: { // Change function attr to rewritten functions. NameAttrList func; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 6065d0bb9a..53eab8b63e 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -65,21 +65,33 @@ uint32 GetXLARandomSeed(); class AssociatedFunctionInfo { public: enum AssociatedFunctionType { - kFunctionCallNode = 0, - kFunctionAttr = 1, + kFunctionAttr = 0, + kFunctionCallNode = 1, + kSymbolicGradient = 2, }; - // The node is a function call. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs) - : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {} - // The function is an attr of the node. - AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs, - const string& attr_name) - : type_(kFunctionAttr), - func_name_(func_name), - attrs_(attrs), - attr_name_(attr_name) {} + static AssociatedFunctionInfo FunctionAttr(const string& func_name, + const AttrValueMap& attrs, + const string& attr_name) { + return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name); + } + + // The node is a function call. + static AssociatedFunctionInfo FunctionCall(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs, + /*attr_name=*/""); + } + + // The node is a SymbolicGradient op. + static AssociatedFunctionInfo SymbolicGradient(const string& func_name, + const AttrValueMap& attrs) { + // attr_name will not be used in this case. + return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs, + /*attr_name=*/""); + } AssociatedFunctionType type() const { return type_; } @@ -90,6 +102,13 @@ class AssociatedFunctionInfo { const AttrValueMap& attrs() const { return attrs_; } private: + AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name, + const AttrValueMap& attrs, const string& attr_name) + : type_(type), + func_name_(func_name), + attrs_(attrs), + attr_name_(attr_name) {} + // Available for all instances. AssociatedFunctionType type_; string func_name_; @@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def, // Gets functions associated with the node. Current cases: // 1. For function call node, its function name; -// 2. For nodes like XlaWhile/XlaIf, all their function attributes. +// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient", +// and returned attrs will be this node's attributes; +// 3. For nodes like XlaWhile/XlaIf, all their function attributes. std::vector<AssociatedFunctionInfo> GetAssociatedFunctions( const Node& node, FunctionLibraryRuntime* flr); // Changes associated functions for the node. Current cases: // 1. For function call node, creates a new node with the new function name and // remove the old node; -// 2. For nodes like XlaWhile/XlaIf, modify their function attributes. +// 2. For SymbolicGradient op, add or replace GradientDef in +// FunctionLibraryDefinition; +// 3. For nodes like XlaWhile/XlaIf, modify their function attributes. Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h index bda667eb1f..6354216eee 100644 --- a/tensorflow/compiler/tf2xla/type_util.h +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -25,6 +25,14 @@ namespace tensorflow { // Converts a Tensorflow DataType to an XLA PrimitiveType. Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); +// N.B.: there is intentionally no function to convert an XLA PrimitiveType to +// a TensorFlow DataType. The mapping from TF types to XLA types is not +// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the +// inverse would not be a well-defined function. If you find that you want the +// inverse mapping, then most likely you should be preserving the original +// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow +// type. + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 739e47778a..b2c57e8880 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, std::unique_ptr<Graph> graph = GetGraph(fbody); + // Clear the "_kernel" attribute if it is set to "host". This is used to + // indicate that a computation should happen on the host instead of the + // accelerator, but doesn't make sense in XLA. + const char* const kKernelAttr = "_kernel"; + for (Node* n : graph->nodes()) { + string value; + if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") { + n->ClearAttr(kKernelAttr); + } + } + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have // core assignments here. @@ -333,10 +344,8 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, } // Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. +// `args` is the list of input arguments, `retvals` is the list of retvals +// produced by _Retval operators, in index order. // If `return_updated_values_for_all_resources` is true, all resources will be // included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. diff --git a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc index 23d04d43b3..bc44301d40 100644 --- a/tensorflow/compiler/tf2xla/xla_cpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_cpu_backend.cc @@ -20,21 +20,6 @@ limitations under the License. namespace tensorflow { bool CpuOpFilter(KernelDef* kdef) { - // TODO(b/34339814): implement inverse erf for double types and remove this - // workaround. - if (kdef->op() == "RandomStandardNormal") { - kdef->clear_constraint(); - // Change the type constraint to permit only DTD_FLOAT. - KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); - attr_constraint->set_name("dtype"); - attr_constraint->mutable_allowed_values()->mutable_list()->add_type( - DT_FLOAT); - return true; - } - // TODO(b/26783907): The CPU backend currently does not implement sort. - if (kdef->op() == "XlaSort" || kdef->op() == "TopKV2") { - return false; - } if (kdef->op() == "Const") { AddDtypeToKernalDefConstraint("dtype", DT_STRING, kdef); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2a9eaeee14..dd3498ef7a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a3a0d10cc0..aa00a45496 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -255,6 +255,11 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); + // Wraps OpKernelContext's allocate_output method while providing special + // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the + // type to allow mapping for variant to more generic types. + Status allocate_output(int index, const xla::Shape& shape, Tensor** output); + OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index b0eeee3174..91d48125f1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible compile time constant inputs."; return false; } + if (x.is_metadata_op != y.is_metadata_op) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible values for is_metadata_op."; + return false; + } return true; } @@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { return &it->second.front()->compile_time_constant_inputs; } +/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return false; + } + + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // is_metadata_op, so only the first match is returned. + return it->second.front()->is_metadata_op; +} + std::vector<string> XlaOpRegistry::BackendNames() { std::vector<string> names; XlaOpRegistry& registry = Instance(); @@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { + registration_->is_metadata_op = true; + return *this; +} + std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 74a4885f1f..4b2c2bacd6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -47,17 +47,18 @@ extern const char* const DEVICE_XLA_GPU; constexpr std::array<DataType, 4> kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; -constexpr std::array<DataType, 9> kNumericTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BFLOAT16}}; +constexpr std::array<DataType, 11> kNumericTypes = { + {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, + DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array<DataType, 9> kCpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 14> kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; -constexpr std::array<DataType, 10> kGpuAllTypes = { - {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array<DataType, 15> kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -136,6 +137,10 @@ class XlaOpRegistry { static const std::unordered_set<string>* CompileTimeConstantInputs( const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -192,6 +197,10 @@ class XlaOpRegistry { // Names of arguments that must be compile-time constants. std::unordered_set<string> compile_time_constant_inputs; + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -256,6 +265,10 @@ class XlaOpRegistrationBuilder { // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); + std::unique_ptr<XlaOpRegistry::OpRegistration> Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index ef70c1f8ac..cc7390c6e6 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -245,6 +245,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f825f67b44..dc097f3696 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -220,6 +220,8 @@ cc_library( "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 25cc37edc4..ff0ec76a7f 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -97,13 +97,11 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie( << "Computation should have progran shape."; auto program_shape = computation.proto().program_shape(); - // Create and run a program which produces a tuple with one element per - // parameter, then return the tuple's constituent buffers. - std::vector<Shape> param_shapes(program_shape.parameters().begin(), - program_shape.parameters().end()); - auto fake_input_tuple = - MakeFakeDataOrDie(ShapeUtil::MakeTupleShape(param_shapes), client); - return client->DeconstructTuple(*fake_input_tuple).ValueOrDie(); + std::vector<std::unique_ptr<GlobalData>> results; + for (const Shape& shape : program_shape.parameters()) { + results.push_back(MakeFakeDataOrDie(shape, client)); + } + return results; } } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 95ff6432a5..6b31831010 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" namespace xla { @@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle, case HloOpcode::kWhile: // TODO(b/32495713): We aren't checking the condition and body // computations themselves. + case HloOpcode::kScatter: + // TODO(b/32495713): We aren't checking the embedded computation in + // Scatter. case HloOpcode::kSend: case HloOpcode::kRecv: case HloOpcode::kParameter: @@ -1278,7 +1281,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) { XlaOp XlaBuilder::CustomCall(const string& call_target_name, absl::Span<const XlaOp> operands, - const Shape& shape) { + const Shape& shape, const string& opaque) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1289,6 +1292,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); + instr.set_custom_call_opaque(opaque); return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -1785,9 +1789,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value, std::vector<std::pair<int64, int64>> padding_values = MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions, window_strides, padding); - return ReduceWindowWithGeneralPadding(operand, init_value, computation, - window_dimensions, window_strides, - padding_values); + return ReduceWindowWithGeneralPadding( + operand, init_value, computation, window_dimensions, window_strides, + /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values); }); } @@ -1796,6 +1800,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; @@ -1806,7 +1812,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding( computation.GetProgramShape()); TF_ASSIGN_OR_RETURN(*instr.mutable_window(), MakeWindow(window_dimensions, window_strides, padding, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{})); + /*lhs_dilation=*/base_dilations, + /*rhs_dilation=*/window_dilations)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferReduceWindowShape(operand_shape, init_shape, @@ -2289,7 +2296,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph( // also a valid dependency order). The related ops will be added to the // subgraph in the same order. std::set<int64> related_ops; - tensorflow::gtl::FlatSet<int64> related_calls; // Related computations. + absl::flat_hash_set<int64> related_calls; // Related computations. std::queue<int64> worklist; worklist.push(root->id()); related_ops.insert(root->id()); @@ -2681,8 +2688,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape) { - return builder->CustomCall(call_target_name, operands, shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, @@ -2795,10 +2803,12 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding) { return operand.builder()->ReduceWindowWithGeneralPadding( operand, init_value, computation, window_dimensions, window_strides, - padding); + base_dilations, window_dilations, padding); } XlaOp CrossReplicaSum(const XlaOp& operand, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d0c59fa6f2..2e14e47a35 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -21,6 +21,8 @@ limitations under the License. #include <type_traits> #include <utility> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stacktrace.h" #include "tensorflow/core/platform/types.h" @@ -577,11 +577,9 @@ class XlaBuilder { absl::Span<const XlaOp> operands); // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. XlaOp CustomCall(const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -673,6 +671,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding); // Returns the sum of the operand value within each subgroup of replicas. All @@ -1029,7 +1029,7 @@ class XlaBuilder { // A map from XlaOp::Handle to the index in the instructions_ vector where the // instruction is held. - tensorflow::gtl::FlatMap<int64, int64> handle_to_index_; + absl::flat_hash_map<int64, int64> handle_to_index_; // The embedded computations used by this computation. Each computation was // the entry computation of some XlaComputation, the key is the unique id of @@ -1037,7 +1037,7 @@ class XlaBuilder { std::map<int64, HloComputationProto> embedded_; // The unique parameter numbers. - tensorflow::gtl::FlatSet<int64> parameter_numbers_; + absl::flat_hash_set<int64> parameter_numbers_; // The metadata to attach to each op. This is structured as a "modal"-like // operation, in order to simplify client code (and not sprinkle this metadata @@ -1195,7 +1195,8 @@ class XlaBuilder { friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span<const XlaOp> operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span<const int64> broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1246,6 +1247,8 @@ class XlaBuilder { const XlaComputation& computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding); friend XlaOp CrossReplicaSum(const XlaOp& operand, absl::Span<const ReplicaGroup> replica_groups); @@ -1717,12 +1720,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span<const XlaOp> operands); -// Enqueues a custom call instruction onto the computation. -// During code generation, a call instruction is emitted which targets a -// symbol with the name |call_target_name|. The |operands| are passed to the -// call instruction. |shape| is the resultant shape. +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque = ""); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1814,6 +1822,8 @@ XlaOp ReduceWindowWithGeneralPadding( const XlaComputation& computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding); // Returns the sum of the operand value within each subgroup of replicas. All diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index a472747bd1..0f9b591c70 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -45,6 +45,16 @@ stream_executor::Stream* ExecutableRunOptions::stream() const { return stream_; } +ExecutableRunOptions& ExecutableRunOptions::set_host_to_device_stream( + stream_executor::Stream* stream) { + host_to_device_stream_ = stream; + return *this; +} + +stream_executor::Stream* ExecutableRunOptions::host_to_device_stream() const { + return host_to_device_stream_; +} + ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool( const Eigen::ThreadPoolDevice* intra_op_thread_pool) { intra_op_thread_pool_ = intra_op_thread_pool; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 416131be00..ba3217f31b 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -65,6 +65,13 @@ class ExecutableRunOptions { ExecutableRunOptions& set_stream(stream_executor::Stream* stream); stream_executor::Stream* stream() const; + // If set, this is the stream to perform any pre-computation transfers on. + // The platform of the stream must match the platform the executable was + // built for. A value of nullptr indicates the option has not been set. + ExecutableRunOptions& set_host_to_device_stream( + stream_executor::Stream* stream); + stream_executor::Stream* host_to_device_stream() const; + // Sets the thread pool device on which to run Eigen subcomputations. // Does not take ownership. ExecutableRunOptions& set_intra_op_thread_pool( @@ -90,6 +97,7 @@ class ExecutableRunOptions { const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; ExecutionProfile* execution_profile_ = nullptr; int rng_seed_ = 0; + stream_executor::Stream* host_to_device_stream_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 0d3136b0cc..3ed3afcfce 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -57,6 +57,8 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // regression. flags->set_xla_cpu_enable_fast_math(true); flags->set_xla_gpu_enable_fast_math(true); + + flags->set_xla_force_host_platform_device_count(1); } // Allocates flag_values and flag_objects; this function must not be called more @@ -323,6 +325,17 @@ void AllocateFlags() { flag_values->xla_gpu_crash_on_verification_failures(), "Crashes the program on extra verification failures, e.g. cuDNN " "cross checking failures"), + tensorflow::Flag( + "xla_force_host_platform_device_count", + int32_setter_for( + &DebugOptions::set_xla_force_host_platform_device_count), + flag_values->xla_force_host_platform_device_count(), + "Force the host platform to pretend that there are these many " + "host \"devices\". All of these host devices are backed by the same" + "threadpool. Setting this to anything other than 1 can increase " + "overhead from context switching but we let the user override this " + "behavior to help run tests on the host that run models in parallel " + "across multiple devices."), }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc index 5035f41988..656ce720a1 100644 --- a/tensorflow/compiler/xla/literal.cc +++ b/tensorflow/compiler/xla/literal.cc @@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal, return InvalidArgument("LiteralProto has no layout"); } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + Literal literal(proto.shape()); TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus( @@ -725,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices, ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); switch (result_shape.element_type()) { - case F32: - return SliceInternal<float>(result_shape, start_indices); + case PRED: + return SliceInternal<bool>(result_shape, start_indices); + case U8: + return SliceInternal<uint8>(result_shape, start_indices); + case U16: + return SliceInternal<uint16>(result_shape, start_indices); + case U32: + return SliceInternal<uint32>(result_shape, start_indices); + case U64: + return SliceInternal<uint64>(result_shape, start_indices); + case S8: + return SliceInternal<int8>(result_shape, start_indices); + case S16: + return SliceInternal<int16>(result_shape, start_indices); + case S32: + return SliceInternal<int32>(result_shape, start_indices); + case S64: + return SliceInternal<int64>(result_shape, start_indices); + case F16: + return SliceInternal<half>(result_shape, start_indices); case BF16: return SliceInternal<bfloat16>(result_shape, start_indices); + case F32: + return SliceInternal<float>(result_shape, start_indices); + case F64: + return SliceInternal<double>(result_shape, start_indices); case C64: return SliceInternal<complex64>(result_shape, start_indices); - case S32: - return SliceInternal<int32>(result_shape, start_indices); - case U32: - return SliceInternal<uint32>(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -1850,6 +1870,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + if (LayoutUtil::IsSparseArray(subshape())) { + // Compute the number of elements (indices) in the sparse shape and reserve + // the necessary space in spare_indices. + TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0) + << "Scalar shapes cannot be sparse"; + TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0) + << "Unexpected number of indices in proto (" + << proto.sparse_indices_size() << ") for shape of rank " + << ShapeUtil::Rank(subshape()); + const int64 index_count = + proto.sparse_indices_size() / ShapeUtil::Rank(subshape()); + sparse_indices()->Resize(index_count); + + // Copy the indices from the proto into the SparseIndexArray object. + TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(), + proto.sparse_indices())); + } + switch (subshape().element_type()) { case PRED: TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds())); @@ -1907,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } } break; case TUPLE: - LOG(FATAL) << "Should not be called on tuple shapes: " - << ShapeUtil::HumanString(subshape()); - break; + return InvalidArgument("Should not be called on tuple shapes: %s", + ShapeUtil::HumanString(subshape())); default: - LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + return InvalidArgument("Is called on unsupported shape: %s", + ShapeUtil::HumanString(subshape())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h index 1e0a2ad0dd..3cd3541fe1 100644 --- a/tensorflow/compiler/xla/literal.h +++ b/tensorflow/compiler/xla/literal.h @@ -203,6 +203,10 @@ class LiteralBase { // Returns the count of the elements in the array at the given shape index in // this literal. int64 element_count(const ShapeIndex& index = {}) const { + if (index.empty()) { + // Common case, avoid GetSubshape(). + return ShapeUtil::ElementsIn(shape()); + } return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); } @@ -852,9 +856,9 @@ class BorrowingLiteral : public LiteralBase { template <typename NativeT> absl::Span<const NativeT> LiteralBase::Piece::data() const { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType<NativeT>()) + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType<NativeT>()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) << " type, but literal element type is " @@ -865,9 +869,9 @@ absl::Span<const NativeT> LiteralBase::Piece::data() const { template <typename NativeT> absl::Span<NativeT> LiteralBase::Piece::data() { - CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); - CHECK_EQ(subshape().element_type(), - primitive_util::NativeToPrimitiveType<NativeT>()) + DCHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + DCHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType<NativeT>()) << "Attempting to access " << PrimitiveType_Name(primitive_util::NativeToPrimitiveType<NativeT>()) << " type, but literal element type is " diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc index 7ad287c897..dd5b54e4c9 100644 --- a/tensorflow/compiler/xla/literal_test.cc +++ b/tensorflow/compiler/xla/literal_test.cc @@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) { absl::Span<const int64>(expected_indices.data(), expected_indices.num_elements())); EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values)); + + // Serialize then deserialize and verify the resulting literal. + TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto, + Literal::CreateFromProto(literal.ToProto())); + + EXPECT_EQ(literal_from_proto.sparse_indices()->data(), + absl::Span<const int64>(expected_indices.data(), + expected_indices.num_elements())); + EXPECT_EQ(literal_from_proto.data<int64>(), + absl::Span<const int64>(expected_values)); } TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 9da5dc0d2d..ffa336f304 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -469,9 +469,11 @@ LocalOp LocalComputationBuilder::ConvGeneralDilated( absl::Span<const int64> window_strides, absl::Span<const std::pair<int64, int64>> padding, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers) { + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count) { return xla::ConvGeneralDilated(lhs.op(), rhs.op(), window_strides, padding, - lhs_dilation, rhs_dilation, dimension_numbers); + lhs_dilation, rhs_dilation, dimension_numbers, + feature_group_count); } LocalOp LocalComputationBuilder::ConvertElementType( @@ -530,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding( const LocalComputation& local_computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64>> padding) { return xla::ReduceWindowWithGeneralPadding( operand.op(), init_value.op(), local_computation.computation(), - window_dimensions, window_strides, padding); + window_dimensions, window_strides, base_dilations, window_dilations, + padding); } LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 1d5dfe5911..43332e0abd 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -248,7 +248,8 @@ class LocalComputationBuilder { absl::Span<const std::pair<int64, int64> > padding, absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers); + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count); LocalOp ConvertElementType(const LocalOp& operand, PrimitiveType new_element_type); @@ -277,6 +278,8 @@ class LocalComputationBuilder { const LocalComputation& local_computation, absl::Span<const int64> window_dimensions, absl::Span<const int64> window_strides, + absl::Span<const int64> base_dilations, + absl::Span<const int64> window_dilations, absl::Span<const std::pair<int64, int64> > padding); LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma, diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index fa4366ff07..f8197488fb 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -995,7 +995,30 @@ class ComputationBuilder(object): window_strides) return self._client.ReduceWindowWithGeneralPadding( operand, init_value, computation_to_apply.c_local_computation, - window_dimensions, window_strides, pads) + window_dimensions, window_strides, (), (), pads) + + def ReduceWindowWithGeneralPadding( + self, operand, init_value, computation_to_apply, window_dimensions, + window_strides, base_dilations, window_dilations, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (LocalOp). + init_value: reduction initial value (LocalOp). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + base_dilations: dilations for the base (sequence of integers). + window_dilations: dilations for window (sequence of integers). + padding: length-N array-like of pairs of integers of (low, high) padding. + + Returns: + A LocalOp representing the added ReduceWindow op. + """ + return self._client.ReduceWindowWithGeneralPadding( + operand, init_value, computation_to_apply.c_local_computation, + window_dimensions, window_strides, base_dilations, window_dilations, + padding) def RngNormal(self, mu, sigma, dims): """Enqueues an RngNormal operation onto the computation. @@ -1109,7 +1132,7 @@ class ComputationBuilder(object): dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) return self._client.DotGeneral(lhs, rhs, dimension_numbers) - def Conv(self, lhs, rhs, window_strides, padding): + def Conv(self, lhs, rhs, window_strides, padding, feature_group_count=1): """Enqueues a Conv operation onto the computation. Args: @@ -1117,6 +1140,7 @@ class ComputationBuilder(object): rhs: LocalOp for the rank N+2 array of kernel weights. window_strides: length-N array-like of integer kernel strides. padding: PaddingType representing either 'SAME' or 'VALID' padding. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the Conv operation. """ @@ -1125,10 +1149,11 @@ class ComputationBuilder(object): self.GetShape(rhs).dimensions()[2:], window_strides) dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, pads, (), - (), dimension_numbers) + (), dimension_numbers, + feature_group_count) def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, - lhs_dilation, rhs_dilation): + lhs_dilation, rhs_dilation, feature_group_count=1): """Enqueues a ConvWithGeneralPadding operation onto the computation. Args: @@ -1138,6 +1163,7 @@ class ComputationBuilder(object): padding: length-N array-like of pairs of integers of (low, high) padding. lhs_dilation: length-N array-like of dilation factors. rhs_dilation: length-N array-like of dilation factors. + feature_group_count: number of feature groups for grouped convolution. Returns: A ComputationdataHandle representing the added ConvWithGeneralPadding op. @@ -1145,7 +1171,8 @@ class ComputationBuilder(object): dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def _GetConvDimensionNumbers(self, num_spatial_dims): """Create ConvolutionDimensionNumbers proto for convolutions.""" @@ -1163,7 +1190,8 @@ class ComputationBuilder(object): return dimension_numbers def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, - rhs_dilation, dimension_numbers): + rhs_dilation, dimension_numbers, + feature_group_count=1): """Enqueues a ConvGeneralDilated operation onto the computation. Args: @@ -1190,6 +1218,7 @@ class ComputationBuilder(object): labels appear in the rhs_spec string, so that window_strides[0] is matched with the dimension corresponding to the first character appearing in rhs_spec that is not 'I' or 'O'. + feature_group_count: number of feature groups for grouped convolution. Returns: a LocalOp representing the ConvGenralDilated operation. """ @@ -1215,7 +1244,8 @@ class ComputationBuilder(object): key=lambda i: rhs_spec.index(out_spec[i]))) return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers) + dimension_numbers, + feature_group_count) def Sort(self, operand, dimension=-1): """Enqueues a sort operation onto the computation.""" diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index fd98e19457..82103f0313 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -661,6 +661,30 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testConvGeneralDilatedGroupedConvolutionF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 2, 3) + rhs = a(2, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + feature_group_count = 2 + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers, feature_group_count) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]], + [[0., 0., 0.], + [330., 380., 160.], + [0., 0., 0.], + [480., 530., 220.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 97fcd37f6b..3abb3855a4 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -34,19 +34,28 @@ cc_library( ], ) -tf_cc_binary( - name = "grpc_service_main_cpu", +cc_library( + name = "grpc_service_main_library", srcs = ["grpc_service_main.cc"], deps = [ ":grpc_service", "//tensorflow:grpc++", "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "@com_google_absl//absl/strings:str_format", ], ) +tf_cc_binary( + name = "grpc_service_main_cpu", + deps = [ + ":grpc_service_main_library", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + tf_cc_test( name = "grpc_client_test", srcs = ["grpc_client_test.cc"], diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index d6b5149a24..522ab99fb1 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "grpcpp/server_builder.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,8 +30,15 @@ namespace { int RealMain(int argc, char** argv) { int32 port = 1685; + bool any_address = false; + string platform_str; std::vector<tensorflow::Flag> flag_list = { - tensorflow::Flag("port", &port, "port to listen on"), + tensorflow::Flag("platform", &platform_str, + "The XLA platform this service should be bound to"), + tensorflow::Flag("port", &port, "The TCP port to listen on"), + tensorflow::Flag( + "any", &any_address, + "Whether to listen to any host address or simply localhost"), }; string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -40,19 +48,24 @@ int RealMain(int argc, char** argv) { } tensorflow::port::InitMain(argv[0], &argc, &argv); + se::Platform* platform = nullptr; + if (!platform_str.empty()) { + platform = PlatformUtil::GetPlatform(platform_str).ValueOrDie(); + } std::unique_ptr<xla::GRPCService> service = - xla::GRPCService::NewService().ConsumeValueOrDie(); + xla::GRPCService::NewService(platform).ConsumeValueOrDie(); ::grpc::ServerBuilder builder; - string server_address(absl::StrFormat("localhost:%d", port)); + string server_address( + absl::StrFormat("%s:%d", any_address ? "[::]" : "localhost", port)); + builder.SetMaxReceiveMessageSize(INT_MAX); builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); LOG(INFO) << "Server listening on " << server_address; server->Wait(); - return 0; } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index fb80c78f68..2b292ed053 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -146,6 +146,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -182,6 +184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -251,6 +254,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -296,6 +300,7 @@ cc_library( "hlo_opcode.cc", "hlo_schedule.cc", "hlo_sharding.cc", + "hlo_sharding_metadata.cc", ], hdrs = [ "dfs_hlo_visitor.h", @@ -309,6 +314,7 @@ cc_library( "hlo_opcode.h", "hlo_schedule.h", "hlo_sharding.h", + "hlo_sharding_metadata.h", ], deps = [ ":hlo_casting_utils", @@ -333,6 +339,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -365,8 +373,11 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "@com_google_absl//absl/strings", + "@com_google_absl//absl/utility", ], ) @@ -392,6 +403,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:span", ], ) @@ -482,6 +494,8 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -590,6 +604,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", @@ -772,6 +787,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -899,6 +915,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -948,6 +965,8 @@ cc_library( deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -983,6 +1002,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1030,6 +1051,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1083,6 +1106,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -1121,6 +1145,8 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) @@ -1142,6 +1168,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -1166,6 +1193,7 @@ tf_cc_test( ":hlo", ":hlo_matchers", ":hlo_module_group", + ":hlo_module_group_metadata", ":hlo_parser", ":hlo_proto", "//tensorflow/compiler/xla:test", @@ -1191,6 +1219,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:optional", ], @@ -1211,6 +1240,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -1255,6 +1286,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -1275,6 +1308,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1290,15 +1324,25 @@ cc_library( ) cc_library( + name = "fusion_queue", + hdrs = ["fusion_queue.h"], + deps = [ + ":hlo", + ], +) + +cc_library( name = "instruction_fusion", srcs = ["instruction_fusion.cc"], hdrs = ["instruction_fusion.h"], deps = [ + ":fusion_queue", ":hlo", ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -1325,6 +1369,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -1380,6 +1426,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], @@ -1635,6 +1682,8 @@ cc_library( ":while_loop_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -1666,6 +1715,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -1792,42 +1842,6 @@ tf_cc_test( ) cc_library( - name = "inliner", - srcs = ["inliner.cc"], - hdrs = ["inliner.h"], - deps = [ - ":hlo", - ":hlo_pass", - ":hlo_query", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "inliner_test", - srcs = ["inliner_test.cc"], - deps = [ - ":cpu_plugin", - ":hlo", - ":hlo_matchers", - ":inliner", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "@com_google_absl//absl/memory", - ], -) - -cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], hdrs = ["computation_placer.h"], @@ -2038,6 +2052,7 @@ cc_library( ":logical_buffer", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2073,6 +2088,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -2094,6 +2110,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2177,6 +2194,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2198,6 +2216,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -2258,6 +2278,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -2314,6 +2336,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2340,6 +2364,8 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], ) @@ -2411,6 +2437,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], @@ -2423,6 +2450,7 @@ tf_cc_test( ":hlo", ":hlo_parser", ":hlo_verifier", + ":layout_assignment", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", @@ -2455,6 +2483,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2557,6 +2587,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_module_group", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -2582,12 +2613,34 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], ) +tf_cc_test( + name = "hlo_pass_pipeline_test", + srcs = ["hlo_pass_pipeline_test.cc"], + deps = [ + ":hlo", + ":hlo_parser", + ":hlo_pass_pipeline", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_cse", srcs = ["hlo_cse.cc"], @@ -2601,6 +2654,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -2675,27 +2729,13 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", ], ) cc_library( - name = "hlo_sharding_metadata", - srcs = ["hlo_sharding_metadata.cc"], - hdrs = [ - "hlo_sharding_metadata.h", - ], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:shape_tree", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:span", - ], -) - -cc_library( name = "hlo_domain_verifier", srcs = ["hlo_domain_verifier.cc"], hdrs = ["hlo_domain_verifier.h"], @@ -2745,7 +2785,6 @@ tf_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_parser", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -3121,6 +3160,7 @@ cc_library( ":hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) @@ -3243,6 +3283,8 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3272,6 +3314,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3328,6 +3371,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:ptr_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -3355,7 +3400,6 @@ cc_library( deps = [ ":hlo", ":hlo_lexer", - ":hlo_sharding_metadata", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -3413,6 +3457,39 @@ cc_library( deps = ["//tensorflow/core:lib"], ) +cc_library( + name = "map_inliner", + srcs = ["map_inliner.cc"], + hdrs = ["map_inliner.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "map_inliner_test", + srcs = ["map_inliner_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":map_inliner", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/memory", + ], +) + tf_cc_test( name = "hlo_casting_utils_test", srcs = ["hlo_casting_utils_test.cc"], diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4ef1dffa73..86d9dbea90 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -754,11 +754,12 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto reshape_if_necessary = [&](HloInstruction* hlo) { + hlo = as_type(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); } - return as_type(hlo, dot->shape().element_type()); + return hlo; }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { @@ -2056,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return Status::OK(); } + // Bail on dilation. + if (window_util::HasDilation(window)) { + VLOG(10) << "Not folding pad into reduce-window as there is dilation."; + return Status::OK(); + } + VLOG(10) << "Considering folding Pad: " << pad->ToString() << "\ninto reduce-window: " << reduce_window->ToString() << (convert != nullptr diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index b864c372fa..9f8d0ee88b 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -24,7 +24,7 @@ limitations under the License. namespace xla { // A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloPassInterface { +class AlgebraicSimplifier : public HloModulePass { public: // Given shapes 'from_shape' and 'to_shape', determines if it is valid to // bitcast from 'from_shape' to 'to_shape' after considering platform diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3fc1ba2427..2047f894b4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P( class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< - ::testing::tuple<int, int, int, bool, bool>> {}; + ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { int m, k, n; bool transpose_lhs, transpose_rhs; - std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + PrimitiveType element_type; + std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam(); - Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); - Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); - Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool())); + ::testing::Bool(), ::testing::Values(F32, BF16))); struct DotOfConcatTestSpec { int64 m; diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index a7d8927cf7..43feccee3c 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -22,6 +22,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -110,7 +111,7 @@ class AllocationTracker { // A map from device memory opaque value to allocation. One such map is // maintained per device ordinal. - using AllocationMap = tensorflow::gtl::FlatMap<const void*, Allocation>; + using AllocationMap = absl::flat_hash_map<const void*, Allocation>; tensorflow::mutex mutex_; @@ -123,10 +124,7 @@ class AllocationTracker { int64 next_handle_ GUARDED_BY(mutex_); // A map from device ordinal to AllocationMap. - // - // This is not a TF FlatMap because (currently) FlatMap (and therefore - // AllocationMap) is not movable. - std::unordered_map<int, AllocationMap> opaque_to_allocation_map_ + absl::flat_hash_map<int, AllocationMap> opaque_to_allocation_map_ GUARDED_BY(mutex_); // A map from data handle to a vector of shaped buffers that represent the @@ -146,7 +144,7 @@ class AllocationTracker { // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then // free'd when both the view *and* the original tuple are Unregistered. This // refcounting is managed in opaque_to_allocation_map_. - tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>> + absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>> handle_to_shaped_buffers_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.h b/tensorflow/compiler/xla/service/batch_dot_simplification.h index 79d37f08d3..5b625bf3b9 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.h +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.h @@ -25,7 +25,7 @@ namespace xla { // Normally these would live in the algebraic simplifier, but we want to run // this to fixpoint (this pass reaches fixed point in one execution) before we // run the DotDecomposer. -class BatchDotSimplification : public HloPassInterface { +class BatchDotSimplification : public HloModulePass { public: StatusOr<bool> Run(HloModule* module) override; absl::string_view name() const override; diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 30d33e0d35..f70f6ddfec 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -35,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 76e32174f3..147f3ae7b6 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -26,7 +26,7 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormExpander : public HloPassInterface { +class BatchNormExpander : public HloModulePass { public: // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h index 5dcd31b83d..cb3d12f0bf 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -31,7 +31,7 @@ namespace xla { // optimization pipeline followed by a DCE pass. If other passes are needed // after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the // changed made by this pass. -class BFloat16ConversionFolding : public HloPassInterface { +class BFloat16ConversionFolding : public HloModulePass { public: explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h index 30b6346312..f48e925823 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.h +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -25,7 +25,7 @@ namespace xla { // A pass which adds F32 <-> BF16 conversions for HLO instructions that do not // support BF16 input/output or mixed precision, according to the passed-in // backend-specific BF16 support rules. -class BFloat16Normalization : public HloPassInterface { +class BFloat16Normalization : public HloModulePass { public: explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) : bfloat16_support_(bfloat16_support) {} @@ -48,7 +48,7 @@ class BFloat16Normalization : public HloPassInterface { // use mixed precision; it removes mixed precision even if the backend supports // it. This pass is used to make the HLO module valid for other HLO passes which // do not support mixed precision. -class BFloat16MixedPrecisionRemoval : public HloPassInterface { +class BFloat16MixedPrecisionRemoval : public HloModulePass { public: BFloat16MixedPrecisionRemoval() {} diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index 58f78f8e24..002be9c970 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes( }; auto root = fusion->fused_instructions_computation()->root_instruction(); - tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers; + absl::flat_hash_set<const HloValue*> changed_root_buffers; auto root_changes_it = changes_to_bf16_.find(root); if (root_changes_it != changes_to_bf16_.end()) { @@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) { bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) { + absl::flat_hash_set<const HloComputation*>* visited_computations) { bool parameter_changed = false; auto insts = computation->MakeInstructionPostOrder(); // Do the adjustment on each instruction in the computation in reverse @@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( // another input parameter. A fixed point will be reached because the // parameters can only be changed from BF16 to F32, not the other way // around. - tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while; + absl::flat_hash_set<const HloComputation*> visited_in_while; while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(), &visited_in_while) || ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(), @@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( HloModule* module) { const auto& computations_topological_order = module->MakeComputationPostOrder(); - tensorflow::gtl::FlatSet<const HloComputation*> resolved; + absl::flat_hash_set<const HloComputation*> resolved; for (auto comp_it = computations_topological_order.rbegin(); comp_it != computations_topological_order.rend(); ++comp_it) { if (ContainsKey(resolved, *comp_it)) { diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h index 1ee64971ab..5fcaa15c83 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.h +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -21,6 +21,8 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/bfloat16_support.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -58,7 +60,7 @@ namespace xla { // BFloat16ConversionFolding. If other passes are needed after this pass, run // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this // pass. -class BFloat16Propagation : public HloPassInterface { +class BFloat16Propagation : public HloModulePass { public: explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); @@ -81,7 +83,7 @@ class BFloat16Propagation : public HloPassInterface { // The set of instructions to consider using bfloat16, computed in the forward // pass. - tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_; + absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_; // *************************** // Functions called and state produced by the backward pass (from root to @@ -110,12 +112,12 @@ class BFloat16Propagation : public HloPassInterface { // The set of HloInstructions that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet<const HloInstruction*> + absl::flat_hash_set<const HloInstruction*> instructions_visited_in_backward_pass_; // The set of HloComputations that have been visited in the // opportunity-finding pass. - tensorflow::gtl::FlatSet<const HloComputation*> + absl::flat_hash_set<const HloComputation*> computations_visited_in_backward_pass_; // *************************** @@ -131,7 +133,7 @@ class BFloat16Propagation : public HloPassInterface { // point is reached. bool ResolveInconsistencyOfAliasingBuffersHelper( HloComputation* computation, - tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations); + absl::flat_hash_set<const HloComputation*>* visited_computations); // Makes the parameters of called computations match how they are called by // the given HLO. @@ -182,11 +184,11 @@ class BFloat16Propagation : public HloPassInterface { PrimitiveType target_type); // The set of F32 HLO values that must be kept in F32. - tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_; + absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_; // Mapping from each HloComputation to the number of callers to it in the // module. Populated at the beginning of this pass. - tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_; + absl::flat_hash_map<const HloComputation*, int64> caller_counts_; // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which // are subject to further adjustment, then finally applied to the HLOs. This @@ -195,8 +197,7 @@ class BFloat16Propagation : public HloPassInterface { // // For each HloInstruction, changes_to_bf16_ stores the affected buffers in // the output as a map from in-place pointers to subshapes to shape indices. - tensorflow::gtl::FlatMap<HloInstruction*, - tensorflow::gtl::FlatMap<Shape*, ShapeIndex>> + absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>> changes_to_bf16_; // Whether the last processed HLO module has been changed by this pass. diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc index 23645346e6..5b48f10505 100644 --- a/tensorflow/compiler/xla/service/bfloat16_support.cc +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( const HloInstruction& hlo, int64 operand_index) { switch (hlo.opcode()) { case HloOpcode::kAbs: + case HloOpcode::kAllToAll: case HloOpcode::kBroadcast: case HloOpcode::kClamp: + case HloOpcode::kCollectivePermute: case HloOpcode::kConcatenate: case HloOpcode::kConvert: case HloOpcode::kCopy: diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 65fa951afe..2c2d1626c2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -22,6 +22,8 @@ limitations under the License. #include <ostream> #include <utility> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -41,10 +43,10 @@ limitations under the License. namespace xla { namespace { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::StrAppend; using absl::StrAppendFormat; -using ::tensorflow::gtl::FlatMap; -using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::HumanReadableNumBytes; template <typename T> @@ -128,8 +130,8 @@ Status GatherComputationsByAllocationType( // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet<const HloComputation*> thread_local_set; - FlatSet<const HloComputation*> global_set; + flat_hash_set<const HloComputation*> thread_local_set; + flat_hash_set<const HloComputation*> global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); @@ -444,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex( bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, const HloInstruction* hlo_b) const { using SliceSet = - FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>; + flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>; // Gets the slices all of instr's subshapes. If any subshape doesn't have an // assigned slice, returns the empty set. auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { @@ -519,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation, // BufferAllocation. void BufferAssignment::CombineTempAllocations() { VLOG(1) << "CombineTempAllocations()"; - FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher> + flat_hash_map<LogicalBuffer::Color, BufferAllocation, + LogicalBuffer::Color::Hasher> combined_allocation_map; // Move all temp allocations into a single run at the end of the allocations @@ -582,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() { } // Update allocation indices to their new positions. - allocation_index_for_buffer_.clear_no_resize(); + allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(), + allocation_index_for_buffer_.end()); for (size_t index = 0; index < allocations_.size(); ++index) { BufferAllocation* allocation = &allocations_[index]; allocation->set_index(index); @@ -812,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const FlatSet<const LogicalBuffer*>& colocated_buffers, - const FlatSet<BufferAllocation::Index>& colocated_allocations, - FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>* + const flat_hash_set<const LogicalBuffer*>& colocated_buffers, + const flat_hash_set<BufferAllocation::Index>& colocated_allocations, + flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>* buffers_to_assign_sequentially, BufferAssignment* assignment) { // Buffers are sorted and assigned to BufferAllocations in decreasing order of @@ -833,7 +837,7 @@ Status BufferAssigner::AssignBuffersForComputation( // Generate a post order sort of instructions for sorting of the // LogicalBuffers. - FlatMap<const HloInstruction*, int> post_order_position; + flat_hash_map<const HloInstruction*, int> post_order_position; int position = 0; for (auto* instruction : computation->MakeInstructionPostOrder()) { post_order_position.emplace(instruction, position); @@ -850,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation( // buffers_to_assign_sequentially map, even if we end up with an empty set // of buffers. This ensures we can correctly determine whether to run // whole-module heap simulation. - buffers_to_assign_sequentially->emplace(computation, - FlatSet<const LogicalBuffer*>()); + buffers_to_assign_sequentially->emplace( + computation, flat_hash_set<const LogicalBuffer*>()); } // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers @@ -1043,12 +1047,12 @@ Status BufferAssigner::AssignBuffersForComputation( return Status::OK(); } -FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>, - LogicalBuffer::Color::Hasher> +flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>, + LogicalBuffer::Color::Hasher> BufferAssigner::SplitBuffersByColor( - const FlatSet<const LogicalBuffer*>& buffers) { - FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>, - LogicalBuffer::Color::Hasher> + const flat_hash_set<const LogicalBuffer*>& buffers) { + flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>, + LogicalBuffer::Color::Hasher> color_map; for (auto buffer : buffers) { color_map[buffer->color()].insert(buffer); @@ -1057,23 +1061,38 @@ BufferAssigner::SplitBuffersByColor( } Status BufferAssigner::AssignBuffersWithSequentialOrdering( - const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>& + const flat_hash_map<const HloComputation*, + flat_hash_set<const LogicalBuffer*>>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment) { // Run the sequence of instructions through the heap simulator. The heuristic // that seems to give the best results is lazy-best-fit, with all runs of // alloc / free calls sorted in decreasing size order. const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); + + // Returns a heap algorithm that chooses the best result from several + // algorithms. + auto get_heap_algorithm = [&](int64 alignment) { + auto algorithms = + absl::make_unique<std::vector<std::unique_ptr<HeapAlgorithm>>>(); + algorithms->push_back(absl::make_unique<DecreasingSizeRunsHeap>( + absl::make_unique<LazyBestFitHeap>(alignment))); + algorithms->push_back( + absl::make_unique<GlobalDecreasingSizeBestFitHeap>(alignment)); + return absl::make_unique<ChooseBestHeapAlgorithm>(std::move(algorithms)); + }; + if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; HloSchedule schedule(&assignment->module()); - FlatSet<const LogicalBuffer*> all_buffers_to_assign; + flat_hash_set<const LogicalBuffer*> all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const flat_hash_set<const LogicalBuffer*>& buffers_to_assign = + pair.second; const std::vector<const HloInstruction*>* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1093,8 +1112,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique<DecreasingSizeRunsHeap>( - absl::make_unique<LazyBestFitHeap>(alignment)), + HeapSimulator::Run(get_heap_algorithm(alignment), assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); @@ -1108,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( VLOG(1) << "Running per-computation heap simulation"; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; - const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second; + const flat_hash_set<const LogicalBuffer*>& buffers_to_assign = + pair.second; const std::vector<const HloInstruction*>* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); @@ -1123,12 +1142,10 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &buffer_value_set; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, - HeapSimulator::Run( - absl::make_unique<DecreasingSizeRunsHeap>( - absl::make_unique<LazyBestFitHeap>(alignment)), - *computation, HloInstructionSequence(*instruction_sequence), - assignment->points_to_analysis(), assignment->buffer_size_, - options)); + HeapSimulator::Run(get_heap_algorithm(alignment), *computation, + HloInstructionSequence(*instruction_sequence), + assignment->points_to_analysis(), + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1145,9 +1162,8 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers( const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) { // Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical // buffers in this allocation. - tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*> - id_to_buffer; - tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes; + absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer; + absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes; for (const auto& pair : allocation.assigned_buffers()) { const LogicalBuffer* buffer = pair.first; const BufferAllocation::OffsetSize& offset_size = pair.second; @@ -1186,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers( // Next gather the set of logical buffers live at the earliest point of // maximal live set size. - tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers; + absl::flat_hash_set<const LogicalBuffer*> live_buffers; live_size = 0; for (const auto& event : heap_trace.events()) { const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id()); @@ -1576,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets( void BufferAssigner::AssignColocatedBufferSets( const std::vector<ColocatedBufferSet>& colocated_buffer_sets, BufferAssignment* assignment, - FlatSet<const LogicalBuffer*>* colocated_buffers, - FlatSet<BufferAllocation::Index>* colocated_allocations) { + flat_hash_set<const LogicalBuffer*>* colocated_buffers, + flat_hash_set<BufferAllocation::Index>* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry @@ -1650,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( // Once b/32491382 enables module-level liveness analysis, we may be able // to assign colocated buffers (or at least reuse their allocation for // buffers outside of the set) in AssignBuffersForComputation. - FlatSet<const LogicalBuffer*> colocated_buffers; - FlatSet<BufferAllocation::Index> colocated_allocations; + flat_hash_set<const LogicalBuffer*> colocated_buffers; + flat_hash_set<BufferAllocation::Index> colocated_allocations; std::vector<ColocatedBufferSet> colocated_buffer_sets; BuildColocatedBufferSets(module, assignment->liveness(), assignment->buffer_size_, &colocated_buffer_sets); @@ -1669,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( // First assign buffers for global computatations. Temporary buffers for // sequential computations are collected in 'buffers_to_assign_sequentially'. - FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>> + flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>> buffers_to_assign_sequentially; for (auto* computation : global_computations) { TF_RETURN_IF_ERROR(AssignBuffersForComputation( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 24ba7c16f5..899cd36e1f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -22,6 +22,8 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" @@ -33,8 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -148,7 +148,7 @@ class BufferAllocation { // Access to the logical buffers assigned to this allocation, and their // associated logical offsets and sizes. - const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>& + const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>& assigned_buffers() const { return assigned_buffers_; } @@ -323,7 +323,7 @@ class BufferAllocation { // Mapping from the set of buffers assigned to this allocation to their // logical offsets and sizes. - tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_; + absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_; int64 fragmentation_bytes_ = 0; std::vector<HeapSimulatorTrace> heap_traces_; @@ -500,7 +500,7 @@ class BufferAssignment { int64 temp_allocation_total_size_ = 0; // Maps Buffers to the index of the BufferAllocation which holds the buffer. - tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index> + absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index> allocation_index_for_buffer_; const HloModule* module_; @@ -554,11 +554,10 @@ class BufferAssigner { // true. Status AssignBuffersForComputation( const HloComputation* computation, bool is_thread_local, - const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers, - const tensorflow::gtl::FlatSet<BufferAllocation::Index>& - colocated_allocations, - tensorflow::gtl::FlatMap<const HloComputation*, - tensorflow::gtl::FlatSet<const LogicalBuffer*>>* + const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers, + const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations, + absl::flat_hash_map<const HloComputation*, + absl::flat_hash_set<const LogicalBuffer*>>* buffers_to_assign_sequentially, BufferAssignment* assignment); @@ -568,9 +567,8 @@ class BufferAssigner { // 'run_whole_module_heap_simulation' is true, the heap simulation will be run // assuming all global computations are sequentially ordered. Status AssignBuffersWithSequentialOrdering( - const tensorflow::gtl::FlatMap< - const HloComputation*, - tensorflow::gtl::FlatSet<const LogicalBuffer*>>& + const absl::flat_hash_map<const HloComputation*, + absl::flat_hash_set<const LogicalBuffer*>>& buffers_to_assign_sequentially, bool run_whole_module_heap_simulation, BufferAssignment* assignment); @@ -590,7 +588,7 @@ class BufferAssigner { // alias. Explicitly handling these colocated buffers is necessary because // points-to analysis is computation level scope and does not recognize // aliasing across computations (b/32491382). - using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>; + using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>; // Returns a vector of ColocatedBufferSet objects, where each // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module' @@ -605,8 +603,8 @@ class BufferAssigner { void AssignColocatedBufferSets( const std::vector<ColocatedBufferSet>& colocated_buffer_sets, BufferAssignment* assignment, - tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers, - tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations); + absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers, + absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations); // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining // the invariant that all sets in 'colocated_buffer_sets' are disjoint. @@ -624,11 +622,10 @@ class BufferAssigner { // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. - tensorflow::gtl::FlatMap<LogicalBuffer::Color, - tensorflow::gtl::FlatSet<const LogicalBuffer*>, - LogicalBuffer::Color::Hasher> - SplitBuffersByColor( - const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers); + absl::flat_hash_map<LogicalBuffer::Color, + absl::flat_hash_set<const LogicalBuffer*>, + LogicalBuffer::Color::Hasher> + SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers); // If true, buffer assignments assumes that input parameter buffers and output // buffers can be shared if their sizes match. diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index cdd3cf4032..f939a426ea 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -27,8 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -102,7 +101,7 @@ class BufferLiveness { // Set of LogicalBuffers which are aliased in the output of other // instructions. For example, a LogicalBuffer which is inserted into a tuple // is considered to be aliased and will be in this set. - tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_; + absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_; // LogicalBuffers that may be live out of the entry computation. PointsToSet::BufferSet maybe_live_out_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h index 305914fca8..cc46af5eee 100644 --- a/tensorflow/compiler/xla/service/buffer_value_containers.h +++ b/tensorflow/compiler/xla/service/buffer_value_containers.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet( return output; } -using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>; +using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>; template <class LogicalBufferContainerT> BufferValueFlatSet ToBufferValueFlatSet( const LogicalBufferContainerT& logical_buffer_container) { diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 23b2a32709..bdd5069632 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -17,6 +17,7 @@ limitations under the License. #include <queue> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { bool CallGraph::DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet<const HloComputation*>* visited) const { + absl::flat_hash_set<const HloComputation*>* visited) const { if (a == b || ContainsKey(*visited, b)) { // The call graph is guaranteed to be acyclic so any previously visited node // we encounter was already determined to be dominated. @@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper( bool CallGraph::Dominates(const HloComputation* a, const HloComputation* b) const { - tensorflow::gtl::FlatSet<const HloComputation*> visited; + absl::flat_hash_set<const HloComputation*> visited; return DominatesHelper(a, b, &visited); } @@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) { Status CallGraph::VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const { + absl::flat_hash_set<const CallGraphNode*>* visited) const { auto pair = visited->insert(&node); if (!pair.second) { // Node was not inserted. Node has already been visited. @@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal( Status CallGraph::VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes) const { - tensorflow::gtl::FlatSet<const CallGraphNode*> visited; + absl::flat_hash_set<const CallGraphNode*> visited; if (visit_unreachable_nodes) { // Traverse from all roots in the call graph. for (const CallGraphNode& node : nodes()) { diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 3af2ab5edf..cb56f4789d 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -20,11 +20,11 @@ limitations under the License. #include <ostream> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -145,19 +145,19 @@ class CallGraphNode { // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector<HloComputation*> callees_; - tensorflow::gtl::FlatSet<HloComputation*> callee_set_; + absl::flat_hash_set<HloComputation*> callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. std::vector<HloComputation*> callers_; - tensorflow::gtl::FlatSet<HloComputation*> caller_set_; + absl::flat_hash_set<HloComputation*> caller_set_; // The call sites in this computation std::vector<CallSite> callsites_; // The map from instruction to index in callsites_ for looking up the callsite // (if any) associated with a particular instruction in this computation. - tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_; + absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_; // The call sites in other computations which call this computation. std::vector<CallSite> caller_callsites_; @@ -250,14 +250,14 @@ class CallGraph { // 'visited'. Status VisitNodesInternal( const VisitorFunction& visitor_func, const CallGraphNode& node, - tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const; + absl::flat_hash_set<const CallGraphNode*>* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, - tensorflow::gtl::FlatSet<const HloComputation*>* visited) const; + absl::flat_hash_set<const HloComputation*>* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; @@ -267,7 +267,7 @@ class CallGraph { // Map from HLO computation to the index of the corresponding call graph node // in nodes_. - tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_; + absl::flat_hash_map<const HloComputation*, int64> node_indices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index c5cd88b9ea..08c4aff4f7 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -25,7 +25,7 @@ namespace xla { // For every kCall operation in the main computation, we inline the body of the // called function, and proceed recursively. -class CallInliner : public HloPassInterface { +class CallInliner : public HloModulePass { public: using InlinedInstructionMap = std::unordered_map<HloInstruction*, HloInstruction*>; diff --git a/tensorflow/compiler/xla/service/conditional_simplifier.h b/tensorflow/compiler/xla/service/conditional_simplifier.h index 3de50cbd7f..2223ad6753 100644 --- a/tensorflow/compiler/xla/service/conditional_simplifier.h +++ b/tensorflow/compiler/xla/service/conditional_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that removes kConditional with a constant predicate, replacing them // with their true or false computation as appropriate. -class ConditionalSimplifier : public HloPassInterface { +class ConditionalSimplifier : public HloModulePass { public: absl::string_view name() const override { return "simplify-conditional"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h index 498894737f..ce0138e56f 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.h +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.h @@ -25,7 +25,7 @@ namespace xla { // A pass which rewrites convolutions with feature_group_count > 1 into // convolutions with feature_group_count = 1. -class ConvolutionFeatureGroupConverter : public HloPassInterface { +class ConvolutionFeatureGroupConverter : public HloModulePass { public: ConvolutionFeatureGroupConverter() {} diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b65dfef9c9..f35324aa35 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" @@ -31,8 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -432,7 +432,7 @@ class CopyRemover { // Construct a list for each HLO buffer in the alias analysis. Maintain a // map from HloValue to the respective list element representing that // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node; + absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate @@ -480,7 +480,7 @@ class CopyRemover { // respective ValueNode representing that value. void AddValueList( absl::Span<const HloValue* const> values, - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) { + absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) { ValueNode* tail = nullptr; ValueNode* head = nullptr; for (const HloValue* value : values) { @@ -516,8 +516,7 @@ class CopyRemover { // respective ValueNode. void CreateCopyMap( const HloModule& module, - const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>& - value_to_node) { + const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) { for (HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with @@ -905,7 +904,7 @@ class CopyRemover { // The heads of all the value lists. Each value list represents the HLO // values contained in a particular HLO buffer. The values in the list are // in dependency order. - tensorflow::gtl::FlatSet<const ValueNode*> value_lists_; + absl::flat_hash_set<const ValueNode*> value_lists_; // Copy removal requires fast access to the value list elements // corresponding to the source and destination values of the kCopy @@ -916,7 +915,7 @@ class CopyRemover { ValueNode* src = nullptr; ValueNode* dest = nullptr; }; - tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_; + absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_; }; HloModule* module_; @@ -1010,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph, HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet<const HloBuffer*> seen; + absl::flat_hash_set<const HloBuffer*> seen; ShapeUtil::ForEachSubshape( root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { std::vector<const HloBuffer*> buffers_at_index = diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index d308f6bc84..c097089e30 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -43,7 +43,7 @@ namespace xla { // (3) The buffer set of the root instruction of the entry computation must be // unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and // InstructionAliasSet::IsDistinct return true. -class CopyInsertion : public HloPassInterface { +class CopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8cc522a59e..58abb330a6 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/compiler/xla/service/cpu:cpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/stream_executor", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -93,6 +94,7 @@ cc_library( ":target_machine_features", "@com_google_absl//absl/types:span", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", @@ -126,7 +128,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:indexed_array_analysis", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", @@ -180,6 +181,7 @@ cc_library( ":runtime_conv2d_mkl", ":runtime_fft", ":runtime_fork_join", + ":runtime_key_value_sort", ":runtime_matmul", ":runtime_matmul_mkl", ":runtime_single_threaded_conv2d", @@ -288,6 +290,8 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -307,6 +311,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@llvm//:analysis", "@llvm//:target", ], @@ -461,12 +466,16 @@ cc_library( ], copts = runtime_copts(), deps = [ + "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", + "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", ], ) @@ -624,6 +633,18 @@ cc_library( ) cc_library( + name = "runtime_key_value_sort", + srcs = ["runtime_key_value_sort.cc"], + hdrs = ["runtime_key_value_sort.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + +cc_library( name = "runtime_fork_join", srcs = ["runtime_fork_join.cc"], hdrs = ["runtime_fork_join.h"], @@ -745,6 +766,7 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h index 59437e88af..becee3f81f 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h @@ -31,7 +31,7 @@ namespace cpu { // called canonical convolutions). This pass expands non-canonical convolutions // into reshapes and canonical convolutions, so that these non-canonical // convolutions can run faster. -class ConvCanonicalization : public HloPassInterface { +class ConvCanonicalization : public HloModulePass { public: explicit ConvCanonicalization( const TargetMachineFeatures* target_machine_features) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 18fc144efe..68c715a086 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -86,8 +86,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/scatter_expander.h" @@ -249,9 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); - // TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding - // where we will take this pass in future. - // pipeline.AddPass<Inliner>(); + pipeline.AddPass<MapInliner>(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. @@ -308,7 +306,8 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( ReducePrecisionInsertion::PassTiming::AFTER_FUSION); pipeline.AddPass<CpuLayoutAssignment>( - module->mutable_entry_computation_layout(), target_machine_features); + module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, target_machine_features); return pipeline.Run(module).status(); } @@ -328,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn( { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>( "simplification after layout assignement"); - pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + // TODO(b/117156505): When the bug is fixed, the CPU backend should not + // produce layout changing elementwise operations. We will then pass + // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to + // enable stricter verification. + pass.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); pass.AddPass<HloPassFix<AlgebraicSimplifier>>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h index d49f7d7cc2..076235f887 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -30,7 +30,7 @@ namespace xla { // // TODO(b/62548313): Remove this when buffer assignment is smarter // (module-scoped). -class CpuCopyInsertion : public HloPassInterface { +class CpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h index 6af724b2a5..a39a9d4724 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // This pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the CPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class CpuHloSupportChecker : public HloPassInterface { +class CpuHloSupportChecker : public HloModulePass { public: CpuHloSupportChecker() = default; ~CpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index bfecbd6e01..c291bf2d1b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -17,6 +17,7 @@ limitations under the License. #include <numeric> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" @@ -38,7 +39,7 @@ using absl::nullopt; using absl::optional; using ShouldMakeOperandColMajorCache = - tensorflow::gtl::FlatMap<const HloInstruction*, bool>; + absl::flat_hash_map<const HloInstruction*, bool>; } // namespace static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 3c4fe68b83..f4da35dd37 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -30,8 +30,11 @@ class CpuLayoutAssignment : public LayoutAssignment { public: explicit CpuLayoutAssignment( ComputationLayout* entry_computation_layout, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func, const TargetMachineFeatures* target_machine_features) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), target_machine_features_(*target_machine_features) {} ~CpuLayoutAssignment() override {} diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 4668f3872d..97659b88a7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -54,8 +54,9 @@ class CpuLayoutAssignmentTest : public HloTestBase { [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(entry_computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); EXPECT_IS_OK(layout_assignment.Run(module).status()); } }; @@ -321,8 +322,9 @@ static StatusOr<DotOutputFusionLayoutAssignmentResult> RunDotOutputFusion( [](int64 shape_size) { return cpu::TargetMachineFeatures::kEigenExpectedTensorAlignment; }); - cpu::CpuLayoutAssignment layout_assignment(&computation_layout, - &target_machine_features); + cpu::CpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + &target_machine_features); TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, layout_assignment.Run(module)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 8a44c384bb..a9febe891b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -17,19 +17,29 @@ limitations under the License. #include <functional> +#include "absl/container/flat_hash_map.h" +#include "absl/synchronization/mutex.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace cpu { namespace runtime { -XfeedManager* GetXfeedManager() { - static XfeedManager* manager = new XfeedManager; - return manager; +XfeedManager* GetXfeedManager(int device_ordinal) { + static auto* managers = new absl::flat_hash_map<int, XfeedManager*>(); + static absl::Mutex* mutex = new absl::Mutex(); + + absl::MutexLock lock(mutex); + auto it = managers->find(device_ordinal); + if (it == managers->end()) { + it = managers->emplace(device_ordinal, new XfeedManager()).first; + } + return it->second; } extern const char* const kEigenMatMulF16SymbolName = @@ -74,6 +84,30 @@ extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; extern const char* const kParallelForkJoinSymbolName = "__xla_cpu_runtime_ParallelForkJoin"; +extern const char* const kKeyValueSortPREDSymbolName = + "__xla_cpu_runtime_KeyValueSortPRED"; +extern const char* const kKeyValueSortS8SymbolName = + "__xla_cpu_runtime_KeyValueSortS8"; +extern const char* const kKeyValueSortU8SymbolName = + "__xla_cpu_runtime_KeyValueSortU8"; +extern const char* const kKeyValueSortS16SymbolName = + "__xla_cpu_runtime_KeyValueSortS16"; +extern const char* const kKeyValueSortU16SymbolName = + "__xla_cpu_runtime_KeyValueSortU16"; +extern const char* const kKeyValueSortF16SymbolName = + "__xla_cpu_runtime_KeyValueSortF16"; +extern const char* const kKeyValueSortS32SymbolName = + "__xla_cpu_runtime_KeyValueSortS32"; +extern const char* const kKeyValueSortU32SymbolName = + "__xla_cpu_runtime_KeyValueSortU32"; +extern const char* const kKeyValueSortF32SymbolName = + "__xla_cpu_runtime_KeyValueSortF32"; +extern const char* const kKeyValueSortS64SymbolName = + "__xla_cpu_runtime_KeyValueSortS64"; +extern const char* const kKeyValueSortU64SymbolName = + "__xla_cpu_runtime_KeyValueSortU64"; +extern const char* const kKeyValueSortF64SymbolName = + "__xla_cpu_runtime_KeyValueSortF64"; extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime @@ -94,14 +128,18 @@ tensorflow::string ShapeString(const void* shape_ptr, xla::int32 shape_length) { } // namespace TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, - const void* shape, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireInfeedBufferForDequeue: " - << ShapeString(shape, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireInfeedBufferForDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireInfeedBufferForDequeue: " + << ShapeString(shape, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->infeed()->BlockingDequeueBuffer(); @@ -114,15 +152,18 @@ __xla_cpu_runtime_AcquireInfeedBufferForDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseInfeedBufferAfterDeque: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseInfeedBufferAfterDeque: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr<xla::Shape> shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->infeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, @@ -130,14 +171,18 @@ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void* -__xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "AcquireOutfeedBufferForPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_AcquireOutfeedBufferForPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "AcquireOutfeedBufferForPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); // Wait until there's a buffer to dequeue. xla::cpu::runtime::XfeedBuffer* buffer = xfeed->outfeed()->BlockingDequeueBuffer(); @@ -150,15 +195,18 @@ __xla_cpu_runtime_AcquireOutfeedBufferForPopulation(xla::int32 buffer_length, } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void -__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation(xla::int32 buffer_length, - void* buffer_ptr, - const void* shape_ptr, - xla::int32 shape_length) { - if (VLOG_IS_ON(2)) { - LOG(INFO) << "ReleaseOutfeedBufferAfterPopulation: " - << ShapeString(shape_ptr, shape_length); - } - xla::cpu::runtime::XfeedManager* xfeed = xla::cpu::runtime::GetXfeedManager(); +__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length) { + int device_ordinal = + run_options ? run_options->stream()->parent()->device_ordinal() : 0; + + VLOG(2) << "ReleaseOutfeedBufferAfterPopulation: " + << ShapeString(shape_ptr, shape_length) << " on stream executor " + << device_ordinal; + + xla::cpu::runtime::XfeedManager* xfeed = + xla::cpu::runtime::GetXfeedManager(device_ordinal); xla::StatusOr<xla::Shape> shape = xla::llvm_ir::DecodeSelfDescribingShapeConstant(shape_ptr, shape_length); xfeed->outfeed()->ReleaseCurrentBuffer(buffer_length, buffer_ptr, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index aa0e967123..b2e760a224 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -26,6 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_ +#include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h" #include "tensorflow/compiler/xla/types.h" @@ -63,13 +64,26 @@ extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; extern const char* const kParallelForkJoinSymbolName; +extern const char* const kKeyValueSortPREDSymbolName; +extern const char* const kKeyValueSortS8SymbolName; +extern const char* const kKeyValueSortU8SymbolName; +extern const char* const kKeyValueSortS16SymbolName; +extern const char* const kKeyValueSortU16SymbolName; +extern const char* const kKeyValueSortF16SymbolName; +extern const char* const kKeyValueSortS32SymbolName; +extern const char* const kKeyValueSortU32SymbolName; +extern const char* const kKeyValueSortF32SymbolName; +extern const char* const kKeyValueSortS64SymbolName; +extern const char* const kKeyValueSortU64SymbolName; +extern const char* const kKeyValueSortF64SymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. extern const char* const kXlaCpuRuntimeSymbolNamePrefix; -// Returns the infeed manager used by the CPU runtime. -XfeedManager* GetXfeedManager(); +// Returns the infeed manager used by the CPU runtime for the CPU device +// `device_ordinal`. Note the device ordinal does not name a CPU +XfeedManager* GetXfeedManager(int device_ordinal); } // namespace runtime } // namespace cpu @@ -77,6 +91,18 @@ XfeedManager* GetXfeedManager(); extern "C" { +// Some things common to all of the runtime entry points below: +// +// * The shape pointer and shape_length reflect values that can be deserialized +// via llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass +// reified type information from the generated program to the runtime, which +// helps check the type safety and contract for the emitted-code/runtime +// communication. +// +// * run_options is used to look up the device ordinal for the stream executor +// we're executing under. If it is null the device ordinal is assumed to be +// 0 (this behavior helps in writing tests). + // Note: in the runtime entry points below, the shape pointer and shape_length // reflect values that can be deserialized via // llvm_ir::DecodeSelfDescribingShapeConstant. This is the way we pass reified @@ -89,7 +115,8 @@ extern "C" { // the length would be more exact, but the length check is chosen as a // tradeoff between error checking and speed/simplicity. extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - xla::int32 buffer_length, const void* shape, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape, xla::int32 shape_length); // Relinquishes the next infeed buffer that was returned by // __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call @@ -104,13 +131,14 @@ extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue( // implemented we will add support for multiple outstanding buffers // that can be returned out of order. extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); // Blocks until the next outfeed buffer is available to be populated, then // returns it. extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - xla::int32 buffer_length, const void* shape_ptr, xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + const void* shape_ptr, xla::int32 shape_length); // Relinquishes the outfeed buffer after it has been populated. // buffer_ptr must have been previously returned by @@ -122,8 +150,8 @@ extern void* __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( // acquired, i.e., there may only be one outstanding outfeed buffer in // use by the runtime. extern void __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - xla::int32 buffer_length, void* buffer_ptr, const void* shape_ptr, - xla::int32 shape_length); + const xla::ExecutableRunOptions* run_options, xla::int32 buffer_length, + void* buffer_ptr, const void* shape_ptr, xla::int32 shape_length); } // extern "C" diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index 5519a43b2f..1cc2844470 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { @@ -128,7 +129,8 @@ Status CpuTransferManager::TransferLiteralToInfeed( buffers.push_back(buffer); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers); cleanup.release(); @@ -141,7 +143,8 @@ Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, size, source)); - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer}); return Status::OK(); @@ -265,7 +268,8 @@ StatusOr<Shape> CpuTransferManager::TransferBuffersFromOutfeedInternal( buffer_pointers.push_back(b.get()); } - cpu::runtime::XfeedManager* xfeed_manager = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed_manager = + cpu::runtime::GetXfeedManager(executor->device_ordinal()); xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers); VLOG(2) << "Waiting for buffer to be notified as populated."; std::vector<Shape> outfed_shapes; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index df8c2a636b..b2abdb39a5 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" @@ -67,8 +69,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -404,13 +404,12 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value * shape_ptr, llvm_ir::EncodeSelfDescribingShapeConstant(shape, &shape_length, &b_)); - // The signature of the acquire infeed buffer function is: - // - // (void*)(int32 length); llvm::Type* int32_type = b_.getInt32Ty(); llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); llvm::FunctionType* acquire_type = llvm::FunctionType::get( - i8_ptr_type, {int32_type, i8_ptr_type, int32_type}, + i8_ptr_type, + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*shape_ptr*/ i8_ptr_type, /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* acquire_func; @@ -423,11 +422,11 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, } acquire_func->setCallingConv(llvm::CallingConv::C); - // The signature of the release infeed buffer function is: - // - // (void)(int32 length, void* buffer); llvm::FunctionType* release_type = llvm::FunctionType::get( - b_.getVoidTy(), {int32_type, i8_ptr_type, i8_ptr_type, int32_type}, + b_.getVoidTy(), + {/*run_options*/ i8_ptr_type, /*buffer_length*/ int32_type, + /*buffer_ptr*/ i8_ptr_type, /*shape_ptr*/ i8_ptr_type, + /*shape_length*/ int32_type}, /*isVarArg=*/false); llvm::Function* release_func; @@ -444,9 +443,9 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, // of size exactly 'length_32', and the runtime is responsible for // check-failing the process if there is a mismatch, versus passing us back a // buffer that we might overrun. - llvm::Value* acquired_pointer = - Call(acquire_func, - {b_.getInt32(length_32), shape_ptr, b_.getInt32(shape_length)}); + llvm::Value* acquired_pointer = Call( + acquire_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + shape_ptr, b_.getInt32(shape_length)}); if (kind == XfeedKind::kInfeed) { // Copy to the program buffer address from the acquired buffer. @@ -458,8 +457,8 @@ Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape, /*SrcAlign=*/1, length_32); } - Call(release_func, {b_.getInt32(length_32), acquired_pointer, shape_ptr, - b_.getInt32(shape_length)}); + Call(release_func, {GetExecutableRunOptionsArgument(), b_.getInt32(length_32), + acquired_pointer, shape_ptr, b_.getInt32(shape_length)}); return Status::OK(); } @@ -495,8 +494,150 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { } Status IrEmitter::HandleSort(HloInstruction* sort) { - // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not implemented on CPU."); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(sort)); + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + ShapeIndex keys_shape_index({}); + ShapeIndex values_shape_index({}); + if (values != nullptr) { + keys_shape_index = ShapeIndex({0}); + values_shape_index = ShapeIndex({1}); + } + auto keys_destination = GetAllocationSlice(*sort, keys_shape_index); + auto keys_destination_address = + EmitBufferPointer(keys_destination, keys->shape()); + auto values_destination = GetAllocationSlice(*sort, values_shape_index); + llvm::Value* values_destination_address = nullptr; + + // The sort is implemented in-place, therefore we first copy the operand + // buffer to the output buffer if they are not the same. + if (keys_destination != GetAllocationSlice(*keys)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(keys->shape().element_type()); + auto source_buffer = GetEmittedValueFor(keys); + int64 keys_size = ByteSizeOf(keys->shape()); + MemCpy(keys_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, keys_size); + } + if (values != nullptr) { + values_destination_address = + EmitBufferPointer(values_destination, values->shape()); + if (values_destination != GetAllocationSlice(*values)) { + int64 primitive_type_size = + ShapeUtil::ByteSizeOfPrimitiveType(values->shape().element_type()); + auto source_buffer = GetEmittedValueFor(values); + int64 values_size = ByteSizeOf(values->shape()); + MemCpy(values_destination_address, /*DstAlign=*/primitive_type_size, + source_buffer, + /*SrcAlign=*/primitive_type_size, values_size); + } + } + + // Normalize the shape and the dimension to sort. + Shape normalized_keys_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + keys->shape()); + int64 physical_dimension_to_sort = LayoutUtil::MakeLogicalToPhysical( + keys->shape().layout())[sort->dimensions(0)]; + + int64 sort_dimension_elements = + normalized_keys_shape.dimensions(physical_dimension_to_sort); + int64 higher_dimensions = 1; + for (int64 i = 0; i < physical_dimension_to_sort; ++i) { + higher_dimensions *= normalized_keys_shape.dimensions(i); + } + int64 lower_dimensions = 1; + for (int64 i = ShapeUtil::Rank(normalized_keys_shape) - 1; + i > physical_dimension_to_sort; --i) { + lower_dimensions *= normalized_keys_shape.dimensions(i); + } + + PrimitiveType keys_type = keys->shape().element_type(); + const char* fn_name = nullptr; + llvm::Type* keys_native_type = nullptr; + switch (keys_type) { + case PRED: + fn_name = runtime::kKeyValueSortPREDSymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S8: + fn_name = runtime::kKeyValueSortS8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case U8: + fn_name = runtime::kKeyValueSortU8SymbolName; + keys_native_type = b_.getInt8PtrTy(); + break; + case S16: + fn_name = runtime::kKeyValueSortS16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case U16: + fn_name = runtime::kKeyValueSortU16SymbolName; + keys_native_type = b_.getInt16Ty()->getPointerTo(); + break; + case F16: + fn_name = runtime::kKeyValueSortF16SymbolName; + keys_native_type = b_.getHalfTy()->getPointerTo(); + break; + case S32: + fn_name = runtime::kKeyValueSortS32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case U32: + fn_name = runtime::kKeyValueSortU32SymbolName; + keys_native_type = b_.getInt32Ty()->getPointerTo(); + break; + case F32: + fn_name = runtime::kKeyValueSortF32SymbolName; + keys_native_type = b_.getFloatTy()->getPointerTo(); + break; + case S64: + fn_name = runtime::kKeyValueSortS64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case U64: + fn_name = runtime::kKeyValueSortU64SymbolName; + keys_native_type = b_.getInt64Ty()->getPointerTo(); + break; + case F64: + fn_name = runtime::kKeyValueSortF64SymbolName; + keys_native_type = b_.getDoubleTy()->getPointerTo(); + break; + default: + return Unimplemented( + "Element type %s not supported in the Sort op on CPU.", + PrimitiveType_Name(keys_type)); + } + + llvm::FunctionType* key_value_sort_type = llvm::FunctionType::get( + b_.getVoidTy(), + {keys_native_type, b_.getInt64Ty(), b_.getInt64Ty(), b_.getInt64Ty(), + b_.getInt8PtrTy(), b_.getInt32Ty()}, + /*isVarArg=*/false); + auto* key_value_sort_func = llvm::cast<llvm::Function>( + module_->getOrInsertFunction(fn_name, key_value_sort_type)); + key_value_sort_func->setCallingConv(llvm::CallingConv::C); + key_value_sort_func->setDoesNotThrow(); + key_value_sort_func->setOnlyAccessesArgMemory(); + Call(key_value_sort_func, + {PointerCast(keys_destination_address, keys_native_type), + b_.getInt64(higher_dimensions), b_.getInt64(sort_dimension_elements), + b_.getInt64(lower_dimensions), + values != nullptr + ? PointerCast(values_destination_address, b_.getInt8PtrTy()) + : llvm::Constant::getNullValue(b_.getInt8PtrTy()), + b_.getInt32(values != nullptr ? ShapeUtil::ByteSizeOfPrimitiveType( + values->shape().element_type()) + : 0)}); + + if (values != nullptr) { + llvm_ir::EmitTuple(GetIrArrayFor(sort), + {keys_destination_address, values_destination_address}, + &b_, module_); + } + return Status::OK(); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -547,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* strided_index = NSWMul(index[i], b_.getInt64(window.dimensions(i).stride())); - input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]), - b_.getInt64(window.dimensions(i).padding_low())); + input_index[i] = NSWSub( + NSWAdd(strided_index, + NSWMul(window_index[i], + b_.getInt64(window.dimensions(i).window_dilation()))), + b_.getInt64(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())), + b_.getInt64(0)); + if (in_bounds_condition == nullptr) { + in_bounds_condition = dilation_condition; + } else { + in_bounds_condition = And(in_bounds_condition, dilation_condition); + } + + // Apply base dilation to the index. + input_index[i] = + SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())); // We need to check if 0 <= input_index[i] < bound, as otherwise we are in // the padding so that we can skip the computation. That is equivalent to @@ -587,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { /*operands=*/{reduce_window->operand(0)}, /*supported_types=*/{F32, BF16, S32, F16})); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(reduce_window->window())) { - return Unimplemented( - "Dilation for ReduceWindow is not implemented on CPU."); - } - // Pseudo code for reduce window: // // for (coordinates O in the output) @@ -1257,10 +1409,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) { // // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains // [0->0, 3->1]. - gtl::FlatMap<int64, int64> unreduced_dim_map; + absl::flat_hash_map<int64, int64> unreduced_dim_map; - gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(), - reduce.dimensions().end()); + absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); const Shape& operand_shape = reduce.operand(0)->shape(); const Shape& result_shape = reduce.shape(); @@ -1836,7 +1988,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - gtl::FlatSet<int64> inner_dims; + absl::flat_hash_set<int64> inner_dims; for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3df99464ba..586f27b104 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -23,6 +23,7 @@ limitations under the License. #include <unordered_map> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/Triple.h" @@ -47,7 +48,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -163,6 +163,12 @@ class IrEmitter : public DfsHloVisitorWithDefault, Status Preprocess(HloInstruction* hlo) override; Status Postprocess(HloInstruction* hlo) override; + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { + return assignment_.GetUniqueSlice(&hlo, index).ConsumeValueOrDie(); + } + private: // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); @@ -421,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, // Maps the buffer allocation slices for the parameters to the computation // being compiled to their parameter numbers. Only relevant for thread local // computations. - tensorflow::gtl::FlatMap<BufferAllocation::Index, int64> + absl::flat_hash_map<BufferAllocation::Index, int64> computation_parameter_allocations_; // Maps HLO instructions to their index into the profile counter array. @@ -561,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault, } }; - tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*, - LiteralPtrHashFunctor, LiteralPtrEqualityFunctor> + absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor, + LiteralPtrEqualityFunctor> emitted_literals_; - tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*> + absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*> constant_buffer_to_global_; std::vector<const HloComputation*> thread_local_computations_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index b4c0c09ec0..ede7f433ca 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,6 +142,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || + opcode == HloOpcode::kSort || (opcode == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction, target_machine_features_)) || diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index a99cd99c14..3822d5300e 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -60,7 +60,7 @@ class ParallelTaskAssignment { // own embedded computation, which is compiled as a parallel compute function, // and which is invoked from a kCall instruction that is lowered in codegen to // a runtime parallel fork/join call. -class ParallelTaskAssigner : public HloPassInterface { +class ParallelTaskAssigner : public HloModulePass { public: // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc new file mode 100644 index 0000000000..e0e7deb98e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.cc @@ -0,0 +1,236 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" + +#include <algorithm> +#include <cmath> +#include <cstring> +#include <memory> +#include <string> +#include <utility> + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace { +using tensorflow::int16; +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::int8; +using tensorflow::uint16; +using tensorflow::uint32; +using tensorflow::uint64; +using tensorflow::uint8; + +template <typename KeyType> +void KeyValueSort(std::pair<KeyType, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements); +} + +// For floating point numbers, we want a total order comparator. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. Also we want to have a stable sort, so if the keys are the +// same, we compare the index values. +template <typename KeyType> +bool LessThan(KeyType lhs, int64 lhs_index, KeyType rhs, int64 rhs_index) { + bool lhs_is_negative = std::signbit(lhs); + bool rhs_is_negative = std::signbit(rhs); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(lhs); + bool rhs_nan = std::isnan(rhs); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; + } + if (lhs != rhs) { + return lhs < rhs; + } + return lhs_index < rhs_index; +} + +template <> +void KeyValueSort(std::pair<double, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<double, int64>& lhs, + const std::pair<double, int64>& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair<float, int64>* row_to_sort, int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<float, int64>& lhs, + const std::pair<float, int64>& rhs) -> bool { + return LessThan(lhs.first, lhs.second, rhs.first, rhs.second); + }); +} + +template <> +void KeyValueSort(std::pair<Eigen::half, int64>* row_to_sort, + int64 num_elements) { + std::sort(row_to_sort, row_to_sort + num_elements, + [](const std::pair<Eigen::half, int64>& lhs, + const std::pair<Eigen::half, int64>& rhs) -> bool { + return LessThan( + Eigen::half_impl::half_to_float(lhs.first), lhs.second, + Eigen::half_impl::half_to_float(rhs.first), rhs.second); + }); +} + +template <typename KeyType> +void KeyValueSortImpl(KeyType* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + // High-level idea of the iteration/sorting logic: + // Conceptually we have a 3-dimensional shape [a, b, c]. b corresponds to the + // dimension to sort, c is the product of the more minor dimensions (set to 1 + // if b is the most minor dimension), and a is the product of the more major + // dimensions (set to 1 if b is the most major dimension). There are a * c + // many rows that we need to sort. We iterate through these, calculate a + // 'base_offset' value which points to the first element in that row, and add + // i * c for accessing the 'i'-th element in that row. + + int64 sort_dimension_elements = b; + int64 num_iteration_elements = a * c; + int64 sort_dimension_offset = c; + + std::unique_ptr<std::pair<KeyType, int64>[]> row_to_sort( + new std::pair<KeyType, int64>[sort_dimension_elements]); + std::unique_ptr<std::string[]> reordered_values( + new std::string[sort_dimension_elements]); + for (int64 index = 0; index < num_iteration_elements; ++index) { + // 'index' can be split into two values which index into the 'c' dimension + // and the 'a' dimension, respectively. 'index' % 'c' is the index into the + // 'c' dimension, 'index' / 'c' is the index into the 'a' dimension. When + // calculating the base offset, we need to multiply the index into the 'a' + // dimension with 'b' * 'c'. + // 'index' / 'c' * 'c' * 'b' = ('index' - 'index' % 'c') * 'b'. + int64 base_offset = + index % sort_dimension_offset + + (index - index % sort_dimension_offset) * sort_dimension_elements; + // TODO(b/26783907): We could define a custom iterator class that references + // both arrays. Then we could avoid the intermediate copy. However this + // would become more complicated, and it is not clear if the benefit is high + // enough. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + row_to_sort[i] = + std::make_pair(keys[base_offset + i * sort_dimension_offset], i); + } + KeyValueSort(row_to_sort.get(), sort_dimension_elements); + for (int64 i = 0; i < sort_dimension_elements; ++i) { + keys[base_offset + i * sort_dimension_offset] = row_to_sort[i].first; + } + if (values == nullptr) { + continue; + } + + // Reorder the values according to the order defined by the keys. + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = + (base_offset + row_to_sort[i].second * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + + reordered_values[i] = std::string(values + memory_index, + values_primitive_type_size_in_bytes); + } + for (int64 i = 0; i < sort_dimension_elements; ++i) { + int64 memory_index = (base_offset + i * sort_dimension_offset) * + values_primitive_type_size_in_bytes; + memcpy(values + memory_index, reordered_values[i].c_str(), + values_primitive_type_size_in_bytes); + } + } +} +} // namespace + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS8( + int8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU8( + uint8* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS16( + int16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU16( + uint16* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS32( + int32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU32( + uint32* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF32( + float* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortS64( + int64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortU64( + uint64* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_KeyValueSortF64( + double* keys, int64 a, int64 b, int64 c, char* values, + int32 values_primitive_type_size_in_bytes) { + KeyValueSortImpl(keys, a, b, c, values, values_primitive_type_size_in_bytes); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h new file mode 100644 index 0000000000..28e35e82c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// 'keys' represents a 3-dimensional shape with dimensions [a, b, c]. The 'b' +// dimension of 'keys' is sorted into ascending order. 'values' can be nullptr. +// If 'values' is not nullptr, the elements in 'values' are reordered in such a +// way that if the element at index 'i' in 'keys' was moved to index 'j', the +// element at index 'i' in 'values' is also moved to index 'j' (which means that +// the same elements correspond to each other as before). +extern void __xla_cpu_runtime_KeyValueSortPRED( + bool* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS8( + tensorflow::int8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU8( + tensorflow::uint8* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS16( + tensorflow::int16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU16( + tensorflow::uint16* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF16( + Eigen::half* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS32( + tensorflow::int32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU32( + tensorflow::uint32* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF32( + float* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortS64( + tensorflow::int64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortU64( + tensorflow::uint64* keys, tensorflow::int64 a, tensorflow::int64 b, + tensorflow::int64 c, char* values, + tensorflow::int32 values_primitive_type_size_in_bytes); + +extern void __xla_cpu_runtime_KeyValueSortF64( + double* keys, tensorflow::int64 a, tensorflow::int64 b, tensorflow::int64 c, + char* values, tensorflow::int32 values_primitive_type_size_in_bytes); +} + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_KEY_VALUE_SORT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index bf98064647..9ec0c8f657 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" @@ -202,6 +203,18 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64); + REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64); registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee)); registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee)); diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc index a0cd8ee2d2..5cdac203af 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/core/platform/logging.h" namespace xla { namespace cpu { diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 8b00ae9e47..a383b4a4a0 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#include "absl/container/flat_hash_map.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace cpu { @@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { // This is mutated from within `GetTargetTransformInfoFor` which is // semantically a getter (and thus `const`); and is therefore declared // mutable. Making this mutable is okay because it has cache semantics. - mutable tensorflow::gtl::FlatMap<const llvm::Function*, - llvm::TargetTransformInfo> + mutable absl::flat_hash_map<const llvm::Function*, llvm::TargetTransformInfo> target_transform_info_cache_; llvm::TargetMachine* target_machine_; }; diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index c55206eee7..4b129c95d4 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -180,3 +180,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cpu_key_value_sort_test", + srcs = ["cpu_key_value_sort_test.cc"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc new file mode 100644 index 0000000000..3934c03a04 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" + +namespace xla { +namespace cpu { +namespace { +class CpuKeyValueSortTest : public CpuCodegenTest {}; + +TEST_F(CpuKeyValueSortTest, SortR1) { + const string hlo_text = R"( +HloModule KeyValueSort + +ENTRY main { + a = f32[10] parameter(0) + + ROOT result = f32[10] sort(f32[10] a), dimensions={0} +} +)"; + + string filecheck_pattern = R"( +CHECK: call void @__xla_cpu_runtime_KeyValueSort +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(hlo_text)); + + CpuAotCompilationOptions options{ + /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*entry_point_name=*/"entry", + /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; + + CompileAheadOfTimeAndVerifyIr(std::move(module), options, filecheck_pattern, + /*match_optimized_ir=*/true); +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc index 7af51db55a..b35fd9dad8 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc @@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) { CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]] CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32 CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48 - CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]} + CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]} CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]} CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]} )"; diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc index 8fe65f488a..cc38b81455 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager_test.cc @@ -66,9 +66,9 @@ void ProcessNextBuffer(int32 length) { auto shape = ShapeUtil::MakeShape(U8, {length}); string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue( - length, bytes.data(), bytes.size()); - __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer, - bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); + __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue( + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } // Performs the acquire/release sequence on the outfeed, as the generated CPU @@ -76,16 +76,16 @@ void ProcessNextBuffer(int32 length) { void ProcessNextOutfeedBuffer(int32 length, const Shape& shape) { string bytes = shape.SerializeAsString(); void* buffer = __xla_cpu_runtime_AcquireOutfeedBufferForPopulation( - length, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, bytes.data(), bytes.size()); __xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation( - length, buffer, bytes.data(), bytes.size()); + /*run_options=*/nullptr, length, buffer, bytes.data(), bytes.size()); } TEST_F(InfeedManagerTest, SingleThreadedSequential) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); xfeed->infeed()->EnqueueBuffersAtomically({b}); @@ -97,7 +97,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TestInfeedBuffer* a = new TestInfeedBuffer(64); TestInfeedBuffer* b = new TestInfeedBuffer(32); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->infeed()->EnqueueBuffersAtomically({a}); ProcessNextBuffer(a->length()); @@ -108,7 +108,7 @@ TEST_F(InfeedManagerTest, SingleThreadedInterleaved) { TEST_F(InfeedManagerTest, MultiThreaded) { tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); const int32 length = 64; @@ -130,7 +130,7 @@ TEST_F(InfeedManagerTest, MultiThreaded) { TEST_F(InfeedManagerTest, OutfeedWrongShape) { TestInfeedBuffer* b = new TestInfeedBuffer(32, /*expect_shape_match=*/false); - cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(); + cpu::runtime::XfeedManager* xfeed = cpu::runtime::GetXfeedManager(0); xfeed->outfeed()->EnqueueBuffersAtomically({b}); ProcessNextOutfeedBuffer(32, ShapeUtil::MakeShape(U8, {33})); diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc index d124f74d19..661539cccb 100644 --- a/tensorflow/compiler/xla/service/defuser.cc +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -22,6 +22,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) { fusion_instruction->fused_instructions_computation(); // A map from fused instruction to its defused clone. - tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*> + absl::flat_hash_map<const HloInstruction*, HloInstruction*> defused_instructions; // Initialize map to contain the fusion instruction parameters mapping // to the operands of the fusion instruction. diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h index c326beb899..aaa41fc4fe 100644 --- a/tensorflow/compiler/xla/service/defuser.h +++ b/tensorflow/compiler/xla/service/defuser.h @@ -25,7 +25,7 @@ namespace xla { // A pass which replaces all fusion instructions with the equivalent un-fused // instructions. -class Defuser : public HloPassInterface { +class Defuser : public HloModulePass { public: Defuser() {} ~Defuser() override {} diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index ba2a674d9a..b3549acfc2 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -24,7 +24,7 @@ namespace xla { namespace { // Pass which strips control dependencies from all instructions in the module. -class ControlDepRemover : public HloPassInterface { +class ControlDepRemover : public HloModulePass { public: ControlDepRemover() = default; absl::string_view name() const override { return "control-dep-remover"; } diff --git a/tensorflow/compiler/xla/service/despecializer.h b/tensorflow/compiler/xla/service/despecializer.h index 7be70add2f..46dcc3a438 100644 --- a/tensorflow/compiler/xla/service/despecializer.h +++ b/tensorflow/compiler/xla/service/despecializer.h @@ -30,7 +30,7 @@ namespace xla { // // Current despecialization passes are Defuser, ImplicitBroadcastRemover, // and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloPassInterface { +class Despecializer : public HloModulePass { public: Despecializer(); absl::string_view name() const override { return "despecializer"; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 5761573791..68d01d75a2 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h index fc38e31700..40e7a3b4c2 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.h +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -23,7 +23,7 @@ namespace xla { // DotDecomposer is a pass which decomposes batch Dot operations into a // sequence of smaller (R2) Dot operations. -class DotDecomposer : public HloPassInterface { +class DotDecomposer : public HloModulePass { public: // Decomposes batch Dot operations when 'decompose_batch_dot' is true. DotDecomposer(bool decompose_batch_dot = true) diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4bb1e071d8..515267edd7 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -847,29 +847,34 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) { - if (prim_type != F32) { - // TODO(b/34339814): Implement inverse erf for F64. + if (prim_type != F16 && prim_type != F32 && prim_type != F64) { return Unimplemented( "Inverse erf is only implemented for element " - "type F32."); + "types F16, F32 and F64."); } - auto getFloat = [&](const float f) { - return llvm::ConstantFP::get(b_->getFloatTy(), f); + + // Upcast half to float. + if (prim_type == F16) { + x = b_->CreateFPExt(x, b_->getFloatTy()); + } + + auto get_float = [&](const double f) { + return llvm::ConstantFP::get(x->getType(), f); }; - auto multiply_add = [&](absl::Span<const float> coefficients, + auto multiply_add = [&](absl::Span<const double> coefficients, llvm::Value* w) { - llvm::Value* p = getFloat(coefficients.front()); + llvm::Value* p = get_float(coefficients.front()); coefficients.remove_prefix(1); for (float coefficient : coefficients) { - p = FAdd(FMul(p, w), getFloat(coefficient)); + p = FAdd(FMul(p, w), get_float(coefficient)); } return p; }; // Approximation for inverse error function from // Giles, M., "Approximating the erfinv function". - // The approximation has the form: - // w = log((1-x)*(1+x)) + // The approximation has the form (float version): + // w = -log((1-x)*(1+x)) // if ( w < 5 ) { // w = w - 2.5 // p = sum_{i=1}^n lq[i]*w^i @@ -879,46 +884,124 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, // } // return p*x llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::log, {b_->getFloatTy()}); + module_, llvm::Intrinsic::log, {x->getType()}); - llvm::Value* w = FNeg( - Call(logf_fn, {FMul(FSub(getFloat(1.0f), x), FAdd(getFloat(1.0f), x))})); + llvm::Value* w = FNeg(Call( + logf_fn, {FMul(FSub(get_float(1.0f), x), FAdd(get_float(1.0f), x))})); llvm::Value* p_addr = - llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_); + llvm_ir::EmitAllocaAtFunctionEntry(x->getType(), "p.addr", b_); + + if (prim_type == F16 || prim_type == F32) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(5.0f)), "w_less_than_five", b_); + // Handle true BB. + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(2.5f)); + absl::Span<const double> lq{ + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + llvm::Value* p = multiply_add(lq, lw); + Store(p, p_addr); + } - llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( - FCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_); - // Handle true BB. - SetToFirstInsertPoint(if_data.true_block, b_); - { - llvm::Value* lw = FSub(w, getFloat(2.5f)); - absl::Span<const float> lq{ - 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, - -4.39150654e-06f, 0.00021858087f, -0.00125372503f, - -0.00417768164f, 0.246640727f, 1.50140941f}; - llvm::Value* p = multiply_add(lq, lw); - Store(p, p_addr); - } + // Handle false BB. + SetToFirstInsertPoint(if_data.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.0f)); + absl::Span<const double> gq{ + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + llvm::Value* p = multiply_add(gq, gw); + Store(p, p_addr); + } - // Handle false BB. - SetToFirstInsertPoint(if_data.false_block, b_); - { - llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()}); - - llvm::Value* gw = FSub(Call(sqrtf_fn, w), getFloat(3.0f)); - absl::Span<const float> gq{ - -0.000200214257f, 0.000100950558f, 0.00134934322f, - -0.00367342844f, 0.00573950773f, -0.0076224613f, - 0.00943887047f, 1.00167406f, 2.83297682f}; - llvm::Value* p = multiply_add(gq, gw); - Store(p, p_addr); - } + SetToFirstInsertPoint(if_data.after_block, b_); + } else { + DCHECK(prim_type == F64); + + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(6.25)), "w_less_than_6.25", b_); + + SetToFirstInsertPoint(if_data.true_block, b_); + { + llvm::Value* lw = FSub(w, get_float(3.125)); + absl::Span<const double> c{ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356}; + llvm::Value* p = multiply_add(c, lw); + Store(p, p_addr); + } - SetToFirstInsertPoint(if_data.after_block, b_); + SetToFirstInsertPoint(if_data.false_block, b_); + llvm_ir::LlvmIfData if_data_second = llvm_ir::EmitIfThenElse( + FCmpOLT(w, get_float(16.0)), "w_less_than_16", b_); + SetToFirstInsertPoint(if_data_second.true_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(3.25)); + absl::Span<const double> t1{ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635}; + llvm::Value* p = multiply_add(t1, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data_second.false_block, b_); + { + llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {b_->getDoubleTy()}); + + llvm::Value* gw = FSub(Call(sqrtf_fn, w), get_float(5.0)); + absl::Span<const double> t2{ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221}; + llvm::Value* p = multiply_add(t2, gw); + Store(p, p_addr); + } + + SetToFirstInsertPoint(if_data.after_block, b_); + } llvm::Value* p = Load(p_addr); - return FMul(p, x); + x = FMul(p, x); + // Trunc back to half if needed. + if (prim_type == F16) { + x = b_->CreateFPTrunc(x, b_->getHalfTy()); + } + return x; } StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(PrimitiveType prim_type, diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.h b/tensorflow/compiler/xla/service/flatten_call_graph.h index 3cccec9862..986970f886 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.h +++ b/tensorflow/compiler/xla/service/flatten_call_graph.h @@ -26,7 +26,7 @@ namespace xla { // Flattening associates each call site with a unique computation (for // sequential calling contexts) This simplifies buffer assignment and // points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloPassInterface { +class FlattenCallGraph : public HloModulePass { public: absl::string_view name() const override { return "flatten-call-graph"; } diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h new file mode 100644 index 0000000000..1208a7dda8 --- /dev/null +++ b/tensorflow/compiler/xla/service/fusion_queue.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ + +#include <utility> + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +// A queue interface that allows implementations to choose fusion candidates in +// custom order. +class FusionQueue { + public: + FusionQueue() = default; + virtual ~FusionQueue() = default; + + // Dequeues the next fusion candidates: a consumer and the list of producers + // as operand indices. + virtual std::pair<HloInstruction*, std::vector<int64>> + DequeueNextInstructionAndOperandsToFuseInOrder() = 0; + + // A callback passed to the queue implementation right before the producer is + // fused into the consumer. + virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} + + // A callback passed to the queue implementation right after the fusion is + // created. Note that original_producer could have been destroyed. + virtual void OnFusingInstruction(HloInstruction* fusion, + HloInstruction* original_producer, + HloInstruction* original_consumer) {} + + // A callback passed to the queue implementation to notify the removal of an + // instruction. + virtual void RemoveInstruction(HloInstruction* instruction) = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_ diff --git a/tensorflow/compiler/xla/service/gather_expander.h b/tensorflow/compiler/xla/service/gather_expander.h index 7bd9ea5984..2b39359aae 100644 --- a/tensorflow/compiler/xla/service/gather_expander.h +++ b/tensorflow/compiler/xla/service/gather_expander.h @@ -23,7 +23,7 @@ namespace xla { // This pass rewrites gather operations into (roughly) while loops of dynamic // slices. This lets backends that don't support gather directly to // nevertheless have a minimum level of support. -class GatherExpander : public HloPassInterface { +class GatherExpander : public HloModulePass { public: absl::string_view name() const override { return "gather_expander"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 64b9683628..350fd32537 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -68,9 +68,7 @@ cc_library( # srcs = [ # "partition_assignment_test.cc", # ], -# tags = [ -# "requires-gpu-sm35", -# ], +# tags = tf_cuda_tests_tags(), # deps = [ # ":partition_assignment", # "//tensorflow/core:stream_executor_no_cuda", @@ -93,6 +91,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_reachability", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", ], ) @@ -359,6 +358,7 @@ cc_library( "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep "//tensorflow/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -373,7 +373,6 @@ cc_library( hdrs = ["ir_emission_utils.h"], deps = [ ":backend_configs", - ":cudnn_convolution_runner", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", @@ -405,6 +404,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", ], ) @@ -414,6 +414,8 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":backend_configs", + ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -422,8 +424,10 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -432,6 +436,7 @@ cc_library( srcs = ["cudnn_convolution_rewriter.cc"], hdrs = ["cudnn_convolution_rewriter.h"], deps = [ + ":backend_configs", ":ir_emission_utils", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:util", @@ -472,6 +477,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:instruction_fusion", "//tensorflow/compiler/xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -504,6 +510,7 @@ cc_library( "//tensorflow/compiler/xla/service:multi_output_fusion", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -537,6 +544,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -596,14 +604,11 @@ cc_library( hdrs = ["pad_for_tensor_cores.h"], deps = [ ":ir_emission_utils", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo_creation_utils", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:shape_inference", ], ) @@ -656,6 +661,7 @@ cc_library( deps = [ ":cudnn_convolution_algorithm_picker", ":cudnn_convolution_rewriter", + ":cudnn_fused_convolution_rewriter", ":fusion_merger", ":gpu_constants", ":gpu_copy_insertion", @@ -713,6 +719,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -774,7 +781,6 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ - ":gpu_options", ":ir_emission_utils", ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", @@ -783,6 +789,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -875,16 +882,6 @@ cc_library( ) cc_library( - name = "gpu_options", - srcs = ["gpu_options.cc"], - hdrs = ["gpu_options.h"], - deps = [ - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:lib_internal", - ], -) - -cc_library( name = "stream_executor_util", srcs = ["stream_executor_util.cc"], hdrs = ["stream_executor_util.h"], @@ -967,3 +964,19 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "cudnn_fused_convolution_rewriter", + srcs = ["cudnn_fused_convolution_rewriter.cc"], + hdrs = ["cudnn_fused_convolution_rewriter.h"], + deps = [ + ":backend_configs", + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/core:stream_executor_no_cuda", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto index 640c6392b8..78e14d860e 100644 --- a/tensorflow/compiler/xla/service/gpu/backend_configs.proto +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -24,4 +24,18 @@ message CudnnConvBackendConfig { // true, cudnn may choose not to use tensor cores, e.g. because the GPU or // selected algorithm doesn't support it. bool tensor_ops_enabled = 2; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // CudnnConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. It is with type + // stream_executor::dnn::ActivationMode. + int64 activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 3a23ac1d63..4effea637d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -29,37 +29,38 @@ limitations under the License. namespace xla { namespace gpu { -using se::dnn::AlgorithmDesc; +ConvolutionThunk::ConvolutionThunk( + const HloCustomCallInstruction* cudnn_call, + std::vector<BufferAllocation::Slice> operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, + BufferAllocation::Slice tuple_result_slice) + : Thunk(Kind::kConvolution, cudnn_call), + cudnn_call_(cudnn_call), + operand_buffers_(std::move(operand_slices)), + result_buffer_(result_slice), + scratch_buffer_(scratch_slice), + tuple_result_buffer_(tuple_result_slice) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { - CudnnConvParams params; + std::vector<se::DeviceMemoryBase> operand_se_buffers; + for (const auto& buffer : operand_buffers_) { + operand_se_buffers.push_back(buffer_allocations.GetDeviceAddress(buffer)); + } + + se::DeviceMemoryBase result_buffer = + buffer_allocations.GetDeviceAddress(result_buffer_); - params.input_buf = buffer_allocations.GetDeviceAddress(input_buffer_); - params.filter_buf = buffer_allocations.GetDeviceAddress(filter_buffer_); - params.output_buf = buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); - TF_RETURN_IF_ERROR(PopulateCudnnConvParams(cudnn_call_, ¶ms)); - auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); - TF_RETURN_IF_ERROR(RunCudnnConvolution(params, scratch, stream)); + TF_RETURN_IF_ERROR(RunCudnnConvolution(cudnn_call_, + absl::MakeSpan(operand_se_buffers), + result_buffer, scratch, stream)); - // Figure out which of output/input/filter is the result produced by - // this op, and write the result tuple. - void* result_ptr = [&] { - switch (params.kind) { - case CudnnConvKind::kForward: - return params.output_buf.opaque(); - case CudnnConvKind::kBackwardInput: - return params.input_buf.opaque(); - case CudnnConvKind::kBackwardFilter: - return params.filter_buf.opaque(); - } - }(); - void* ptrs[] = {result_ptr, scratch.opaque()}; + void* ptrs[] = {result_buffer.opaque(), scratch.opaque()}; se::DeviceMemory<void*> tuple_addr( buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); stream->ThenMemcpyH2D<void*>(ptrs, &tuple_addr); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index d7d1f91fba..f53bc54198 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -42,24 +42,12 @@ class ConvolutionThunk : public Thunk { // Constructs a thunk for launching a DNN convolution. When run, it will // write a tuple (result, scratch_memory) into `tuple_result_buffer`. // - // Note that "output" here doesn't refer to the output from running this - // thunk, but rather to the "output" of a hypothetical forward convolution - // that corresponds to this input+filter+output triple. That is, the result - // generated by this thunk is "output" for forward convs, "input" for - // backward-input convs, and "filter" for backward-filter convs. + // operand_slices should be in the same order as cudnn_call->operands(). ConvolutionThunk(const HloCustomCallInstruction* cudnn_call, - BufferAllocation::Slice input_slice, - BufferAllocation::Slice filter_slice, - BufferAllocation::Slice output_slice, + std::vector<BufferAllocation::Slice> operand_slices, + BufferAllocation::Slice result_slice, BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice tuple_result_slice) - : Thunk(Kind::kConvolution, cudnn_call), - cudnn_call_(cudnn_call), - input_buffer_(std::move(input_slice)), - filter_buffer_(std::move(filter_slice)), - output_buffer_(std::move(output_slice)), - scratch_buffer_(std::move(scratch_slice)), - tuple_result_buffer_(std::move(tuple_result_slice)) {} + BufferAllocation::Slice tuple_result_slice); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -71,9 +59,8 @@ class ConvolutionThunk : public Thunk { private: const HloCustomCallInstruction* cudnn_call_; - BufferAllocation::Slice input_buffer_; - BufferAllocation::Slice filter_buffer_; - BufferAllocation::Slice output_buffer_; + std::vector<BufferAllocation::Slice> operand_buffers_; + BufferAllocation::Slice result_buffer_; BufferAllocation::Slice scratch_buffer_; BufferAllocation::Slice tuple_result_buffer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h index 6e2e330edd..c3f58508dd 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -52,7 +52,7 @@ namespace gpu { // The GPU backend does not implement a lowering for the batchnorm HLOs -- it // expects them to be lowered to cudnn calls via this pass or to HLO soup via // BatchNormRewriter. -class CudnnBatchNormRewriter : public HloPassInterface { +class CudnnBatchNormRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn_batchnorm_rewriter"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index f528e62b17..590c0a7d54 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -76,54 +76,24 @@ StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes( return se::DeviceMemory<uint8>(buffer_addr); } -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output shapes. This works around b/68264959, an integer -// overflow in cuDNNv5 and cuDNNv6. -bool ShouldIncludeWinogradNonfusedAlgo(const Shape& input_shape, - const Shape& output_shape, - const ConvolutionDimensionNumbers& dnums, - se::StreamExecutor* stream_exec) { - // Skip this check for cudnn7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); - if (version.ok() && version.ValueOrDie().major_version() >= 7) { - return true; - } - - int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); - int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); - int64 in_cols = - dnums.input_spatial_dimensions_size() == 1 - ? 1 - : input_shape.dimensions(dnums.input_spatial_dimensions(1)); - int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); - - int64 total_size = CeilOfRatio(batch, int64{16}) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - - const int64 threshold = 1L << 31; - return total_size < threshold; -} - std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind, - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) { std::vector<AlgorithmDesc> algorithms; + bool succ = false; switch (kind) { case CudnnConvKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = + stream_exec->GetConvolveBackwardFilterAlgorithms(true, &algorithms); break; case CudnnConvKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); + succ = stream_exec->GetConvolveBackwardDataAlgorithms(true, &algorithms); break; case CudnnConvKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); + case CudnnConvKind::kForwardActivation: + succ = stream_exec->GetConvolveAlgorithms(true, &algorithms); break; } + DCHECK(succ); return algorithms; } @@ -175,21 +145,13 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -StatusOr<std::tuple<int64, bool, int64>> +StatusOr<CudnnConvolutionAlgorithmPicker::AutotuneResult> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( - const HloCustomCallInstruction* instr) { - CudnnConvParams params; - TF_RETURN_IF_ERROR(PopulateCudnnConvParams(instr, ¶ms)); - - const Shape& input_shape = *params.input_shape; - const Shape& filter_shape = *params.filter_shape; - const Shape& output_shape = *params.output_shape; - - CHECK_EQ(input_shape.element_type(), filter_shape.element_type()); - CHECK_EQ(input_shape.element_type(), output_shape.element_type()); + HloCustomCallInstruction* instr) { // TODO(timshen): for now only check fp16. It can be expanded to other types, // with some work on the HLO routines. - const bool cross_check_enabled = input_shape.element_type() == xla::F16; + const bool cross_check_enabled = + instr->shape().tuple_shapes(0).element_type() == xla::F16; // Don't run this function concurrently on the same GPU. // @@ -257,51 +219,43 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. ScratchAllocator input_output_allocator(device_ordinal, allocator); - TF_ASSIGN_OR_RETURN(params.input_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(input_shape))); - TF_ASSIGN_OR_RETURN(params.filter_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(filter_shape))); - TF_ASSIGN_OR_RETURN(params.output_buf, - input_output_allocator.AllocateBytes( - &stream, ShapeUtil::ByteSizeOf(output_shape))); - - initialize_buffer(params.input_buf); - initialize_buffer(params.filter_buf); - initialize_buffer(params.output_buf); - - DeviceMemoryBase* result_buf = [&] { - switch (params.kind) { - case CudnnConvKind::kBackwardFilter: - return ¶ms.filter_buf; - case CudnnConvKind::kBackwardInput: - return ¶ms.input_buf; - case CudnnConvKind::kForward: - return ¶ms.output_buf; - } - }(); + std::vector<se::DeviceMemoryBase> operand_buffers; + for (const auto* operand : instr->operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); + } + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); - const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( - input_shape, output_shape, *params.dnums, stream_exec_); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + TF_ASSIGN_OR_RETURN(auto backend_config, + instr->backend_config<CudnnConvBackendConfig>()); optional<F16BufferComparator> comparator; // Use the first algorithm that's supported as reference. There isn't a // particular reason to use it, as any algorithm sufficies. It doesn't make // this algorithm considered correct, though. optional<AlgorithmDesc> first_algorithm; - for (const AlgorithmDesc& alg : - GetAlgorithms(params.kind, use_winograd_nonfused, stream_exec_)) { + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + for (const AlgorithmDesc& alg : GetAlgorithms(kind, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - params.algorithm = AlgorithmConfig(alg); - bool launch_ok = RunCudnnConvolution(params, &scratch_allocator, &stream, - &profile_result) + backend_config.set_algorithm(alg.algo_id()); + backend_config.set_tensor_ops_enabled(alg.tensor_ops_enabled()); + TF_RETURN_IF_ERROR(instr->set_backend_config(backend_config)); + bool launch_ok = RunCudnnConvolution(instr, absl::MakeSpan(operand_buffers), + result_buffer, &scratch_allocator, + &stream, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { @@ -312,7 +266,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( .xla_gpu_crash_on_verification_failures(); if (comparator.has_value()) { StatusOr<bool> result = comparator->CompareEqual( - se::DeviceMemory<Eigen::half>(*result_buf)); + se::DeviceMemory<Eigen::half>(result_buffer)); if (!result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(*first_algorithm) << " against " @@ -330,7 +284,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( } } else if (cross_check_enabled) { auto comp = F16BufferComparator::Create( - se::DeviceMemory<Eigen::half>(*result_buf), compiler_, allocator, + se::DeviceMemory<Eigen::half>(result_buffer), compiler_, allocator, &stream); if (comp.ok()) { comparator.emplace(comp.ConsumeValueOrDie()); @@ -362,9 +316,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << AlgorithmToString(best_result.algorithm()) << ", takes " << best_result.elapsed_time_in_ms() << "ms, and uses " << best_result_bytes_used << "B of scratch memory."; - return std::make_tuple(best_result.algorithm().algo_id(), - best_result.algorithm().tensor_ops_enabled(), - best_result_bytes_used); + return AutotuneResult{best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used, + absl::Milliseconds(best_result.elapsed_time_in_ms())}; } return InternalError( @@ -377,40 +332,34 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); - StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc = + StatusOr<AutotuneResult> best_algo_or = PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr)); - - if (!alg_scratch_and_tc.ok()) { - LOG(ERROR) << alg_scratch_and_tc.status(); + if (!best_algo_or.ok()) { + LOG(ERROR) << best_algo_or.status(); return false; } - int64 algorithm; - bool tensor_ops_enabled; - int64 scratch_bytes; - - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = - alg_scratch_and_tc.ConsumeValueOrDie(); - - VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " - << NumBytesToString(scratch_bytes) + auto best_algo = std::move(best_algo_or).ValueOrDie(); + VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm + << " and " << NumBytesToString(best_algo.scratch_bytes) << " of scratch memory: " << instr->ToString() - << " tensor_ops_enabled: " << tensor_ops_enabled; + << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled; // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. HloComputation* computation = instr->parent(); - Shape new_call_shape = - ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), - ShapeUtil::MakeShape(U8, {scratch_bytes})}); + Shape new_call_shape = ShapeUtil::MakeTupleShape( + {instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})}); - CudnnConvBackendConfig backend_config; - backend_config.set_algorithm(algorithm); - backend_config.set_tensor_ops_enabled(tensor_ops_enabled); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + instr->backend_config<CudnnConvBackendConfig>()); + backend_config.set_algorithm(best_algo.algorithm); + backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction( - instr->CloneWithNewOperands(new_call_shape, {instr->mutable_operand(0), - instr->mutable_operand(1)})); + instr->CloneWithNewOperands(new_call_shape, instr->operands())); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index f79b113f8f..136c32210a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -30,7 +31,7 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvolutionAlgorithmPicker : public HloPassInterface { +class CudnnConvolutionAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, @@ -47,10 +48,16 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { StatusOr<bool> Run(HloModule* module) override; private: + struct AutotuneResult { + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + absl::Duration runtime; + }; + StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnInstruction(HloInstruction* instr); - StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( - const HloCustomCallInstruction* instr); + StatusOr<AutotuneResult> PickBestAlgorithm(HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 228379a248..ef29237301 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -35,6 +36,32 @@ namespace gpu { namespace { +HloInstruction* CreateCudnnConv(const char* call_target, const Shape& shape, + HloInstruction* lhs, HloInstruction* rhs, + const Window& window, + const ConvolutionDimensionNumbers& dnums, + int64 feature_group_count) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + custom_call->set_feature_group_count(feature_group_count); + return custom_call; +} + bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { const ConvolutionDimensionNumbers& dnums = conv->convolution_dimension_numbers(); @@ -450,6 +477,12 @@ MatchBackwardInput(HloInstruction* conv) { return std::make_tuple(true, new_window, dnums, rhs); } +CudnnConvBackendConfig GetDefaultBackendConfig() { + CudnnConvBackendConfig config; + config.set_conv_result_scale(1); + return config; +} + // Tries to rewrite a single convolution into a call to cudnn. StatusOr<bool> RunOnInstruction(HloInstruction* conv) { CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); @@ -462,24 +495,24 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - return CreateCudnnConvBackwardFilter( - conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), - window, dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, conv->shape(), + conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums, conv->feature_group_count()); } std::tie(match, window, dnums, rhs) = MatchBackwardInput(conv); if (match) { - return CreateCudnnConvBackwardInput(conv->shape(), - conv->mutable_operand(0), rhs, window, - dnums, conv->feature_group_count()); + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, conv->shape(), + conv->mutable_operand(0), rhs, window, dnums, + conv->feature_group_count()); } // If all else fails, try a forward convolution. if (CanImplementAsCudnnForwardConv(conv)) { - return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), - conv->mutable_operand(1), conv->window(), - conv->convolution_dimension_numbers(), - conv->feature_group_count()); + return CreateCudnnConv( + kCudnnConvForwardCallTarget, conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers(), conv->feature_group_count()); } return nullptr; @@ -489,6 +522,9 @@ StatusOr<bool> RunOnInstruction(HloInstruction* conv) { return false; } + TF_RETURN_IF_ERROR( + custom_call->set_backend_config(GetDefaultBackendConfig())); + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out // the conv result and replace `conv` with it. TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h index fbe7e98494..8d7c6fdab5 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -24,7 +24,7 @@ namespace gpu { // Rewrites plain convolutions, backwards-filter convolutions, and // backwards-input convolutions into CustomCall HLOs that call into cuDNN. -class CudnnConvolutionRewriter : public HloPassInterface { +class CudnnConvolutionRewriter : public HloModulePass { public: absl::string_view name() const override { return "cudnn-convolution-rewriter"; diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 2a86ac265e..89dd1bb272 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -37,6 +39,42 @@ using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; using se::dnn::ProfileResult; +struct CudnnConvParams { + // Here are the fields related to cuDNN's fused convolution. The result thus + // is defined as: + // activation(conv_result_scale * conv(x, w) + + // side_input_scale * side_input + broadcast(bias)) + // + // The most common fused conv is conv forward + relu/identity, for example. + // + // bias_buf is a single-dimensional array, with the length equal to the number + // of output features. It'll be broadcasted to the output shape in order to be + // added to the final results. + // + // side_input_buf, if valid, must have the same shape as the output buffer. + struct FusionParams { + se::dnn::ActivationMode mode; + double side_input_scale; + se::DeviceMemoryBase bias_buf; + se::DeviceMemoryBase side_input_buf; // nullable + }; + + CudnnConvKind kind; + const Shape* input_shape; + const Shape* filter_shape; + const Shape* output_shape; + se::DeviceMemoryBase input_buf; + se::DeviceMemoryBase filter_buf; + se::DeviceMemoryBase output_buf; + const Window* window; + const ConvolutionDimensionNumbers* dnums; + int64 feature_group_count; + se::dnn::AlgorithmConfig algorithm; + double conv_result_scale; + + absl::optional<FusionParams> fusion; +}; + // A StreamExecutor ScratchAllocator that wraps a single XLA allocation, // returning it (in its entirety) the first time Allocate() is called. class ScratchBufAllocator : public se::ScratchAllocator { @@ -92,9 +130,9 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, VLOG(3) << "tensor_ops_enabled: " << algorithm.algorithm().tensor_ops_enabled(); VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); - VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; - VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; - VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "input shape: " << ShapeUtil::HumanStringWithLayout(input_shape); + VLOG(3) << "filter shape: " << ShapeUtil::HumanStringWithLayout(filter_shape); + VLOG(3) << "Output shape: " << ShapeUtil::HumanStringWithLayout(output_shape); VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; @@ -186,23 +224,73 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, switch (kind) { case CudnnConvKind::kForward: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveWithAlgorithm( input_descriptor, input_buf, filter_descriptor, filter_buf, convolution_descriptor, output_descriptor, &output_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardInput: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardDataWithAlgorithm( filter_descriptor, filter_buf, output_descriptor, output_buf, convolution_descriptor, input_descriptor, &input_buf, scratch_allocator, algorithm, profile_result); break; case CudnnConvKind::kBackwardFilter: + if (params.conv_result_scale != 1) { + return InternalError( + "StreamExecutor doesn't support scaled convolution: %lf.", + params.conv_result_scale); + } stream->ThenConvolveBackwardFilterWithAlgorithm( input_descriptor, input_buf, output_descriptor, output_buf, convolution_descriptor, filter_descriptor, &filter_buf, scratch_allocator, algorithm, profile_result); break; + case CudnnConvKind::kForwardActivation: { + BatchDescriptor bias_desc; + bias_desc.set_count(1) + .set_height(1) + .set_width(1) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_layout(output_dl); + + se::DeviceMemory<T> side_input(params.fusion->side_input_buf); + // If there is no side input, use output as the side input. + if (side_input.is_null()) { + if (params.fusion->side_input_scale != 0) { + return InternalError( + "Side input scale is not 0, yet no side input buffer is " + "provided"); + } + // Since side-input scale is 0, the values in the side input don't + // matter. The simplest thing to do would be to pass in a null buffer + // for the side input, but cudnn doesn't allow this. cudnn does promise + // that if side-input-scale is 0 the side input won't be read, so we + // just pass in the output buffer, since it's handy and has the correct + // size. + side_input = output_buf; + } + + stream->ThenFusedConvolveWithAlgorithm( + input_descriptor, input_buf, params.conv_result_scale, + filter_descriptor, filter_buf, convolution_descriptor, side_input, + params.fusion->side_input_scale, bias_desc, + DeviceMemory<T>(params.fusion->bias_buf), params.fusion->mode, + output_descriptor, &output_buf, scratch_allocator, algorithm, + profile_result); + break; + } } if (!stream->ok()) { @@ -214,32 +302,104 @@ Status RunCudnnConvolutionImpl(CudnnConvParams params, return Status::OK(); } -} // anonymous namespace +// Returns the cudnn convolution parameters generated from conv, which must be a +// custom-call to a cudnn convolution. +StatusOr<CudnnConvParams> GetCudnnConvParams( + const HloCustomCallInstruction* conv, + absl::Span<se::DeviceMemoryBase> operand_buffers, + se::DeviceMemoryBase result_buffer) { + CudnnConvParams params; -string CudnnConvKindToString(CudnnConvKind kind) { - switch (kind) { - case CudnnConvKind::kForward: - return "forward"; - case CudnnConvKind::kBackwardFilter: - return "backward_filter"; - case CudnnConvKind::kBackwardInput: - return "backward_input"; + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + conv->backend_config<CudnnConvBackendConfig>()); + const auto& target = conv->custom_call_target(); + const auto& lhs_shape = conv->operand(0)->shape(); + const auto& rhs_shape = conv->operand(1)->shape(); + const auto& conv_result_shape = conv->shape().tuple_shapes(0); + + params.window = &conv->window(); + params.dnums = &conv->convolution_dimension_numbers(); + params.feature_group_count = conv->feature_group_count(); + params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( + backend_config.algorithm(), backend_config.tensor_ops_enabled())); + params.conv_result_scale = backend_config.conv_result_scale(); + + if (target == kCudnnConvForwardCallTarget) { + params.kind = CudnnConvKind::kForward; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + } else if (target == kCudnnConvBackwardInputCallTarget) { + params.kind = CudnnConvKind::kBackwardInput; + params.input_shape = &conv_result_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &lhs_shape; + params.input_buf = result_buffer; + params.filter_buf = operand_buffers[1]; + params.output_buf = operand_buffers[0]; + } else if (target == kCudnnConvBackwardFilterCallTarget) { + params.kind = CudnnConvKind::kBackwardFilter; + params.input_shape = &lhs_shape; + params.filter_shape = &conv_result_shape; + params.output_shape = &rhs_shape; + params.input_buf = operand_buffers[0]; + params.filter_buf = result_buffer; + params.output_buf = operand_buffers[1]; + } else if (target == kCudnnConvBiasActivationForwardCallTarget) { + params.kind = CudnnConvKind::kForwardActivation; + params.input_shape = &lhs_shape; + params.filter_shape = &rhs_shape; + params.output_shape = &conv_result_shape; + params.fusion.emplace(); + auto& fusion = *params.fusion; + if (backend_config.activation_mode() < + static_cast<int64>(se::dnn::ActivationMode::kNumActivationModes)) { + fusion.mode = static_cast<se::dnn::ActivationMode>( + backend_config.activation_mode()); + } else { + return InternalError("Bad activation mode: %s", + backend_config.ShortDebugString()); + } + fusion.side_input_scale = backend_config.side_input_scale(); + params.input_buf = operand_buffers[0]; + params.filter_buf = operand_buffers[1]; + params.output_buf = result_buffer; + params.fusion->bias_buf = operand_buffers[2]; + if (operand_buffers.size() >= 4) { + params.fusion->side_input_buf = operand_buffers[3]; + } + } else { + return InternalError("Unexpected custom call target: %s", target); } + return params; } -Status RunCudnnConvolution(CudnnConvParams params, +} // anonymous namespace + +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span<se::DeviceMemoryBase> operand_buffers, + se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, se::dnn::ProfileResult* profile_result) { ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(params, &scratch_allocator, stream, - profile_result); + return RunCudnnConvolution(conv, operand_buffers, result_buffer, + &scratch_allocator, stream, profile_result); } -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span<se::DeviceMemoryBase> operand_buffers, + se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, se::dnn::ProfileResult* profile_result) { - PrimitiveType output_primitive_type = params.output_shape->element_type(); + TF_ASSIGN_OR_RETURN(CudnnConvParams params, + GetCudnnConvParams(conv, operand_buffers, result_buffer)); + + PrimitiveType output_primitive_type = + conv->shape().tuple_shapes(0).element_type(); switch (output_primitive_type) { case F16: return RunCudnnConvolutionImpl<Eigen::half>(params, scratch_allocator, diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index 381aa37a1b..61aec1cecc 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -27,52 +30,8 @@ namespace gpu { // This file contains low-level routines for running cudnn convolutions. -// Different types of convolutions supported by cudnn. -// -// A way to think about these is that a convolution is defined by three arrays -// -- the "input", the "filter", and the "output" -- and given any two of these, -// we can compute the third. For example, a backward-input convolution takes as -// input a filter and an "output" and produces an "input" such that if one were -// to do a forward convolution of "input" using filter, the result would be -// something with the same shape as "output". -// -// This way of thinking is not correct if you look at the values produced. For -// example, a backward-input convolution is not actually the mathematical -// inverse of a forward convolution. But it's right as far as the shapes and -// "connectivity" (i.e. which elements of the input affect which elements of -// the output) are concerned. -enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter -}; - -struct CudnnConvParams { - CudnnConvKind kind; - const Shape* input_shape; - const Shape* filter_shape; - const Shape* output_shape; - se::DeviceMemoryBase input_buf; - se::DeviceMemoryBase filter_buf; - se::DeviceMemoryBase output_buf; - const Window* window; - const ConvolutionDimensionNumbers* dnums; - int64 feature_group_count; - se::dnn::AlgorithmConfig algorithm; -}; - -// Converts a CudnnConvKind value to a string. -string CudnnConvKindToString(CudnnConvKind kind); - // Calls into cudnn to run the specified convolution. // -// Note that depending on the value of CudnnConvKind, the result of this call -// may be written into input_buf, filter_buf, or output_buf! -// -// At the moment convolution with half data type is implemented with cudnn -// PSEUDO_HALF configuration, that is, the input values are half and the -// internal computation type is float. -// // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In // theory the second one shouldn't be necessary -- users of this function could @@ -83,11 +42,15 @@ string CudnnConvKindToString(CudnnConvKind kind); // allocator and take note of how much memory is used. The next time you call // the same conv, you can provide an explicitly preallocated scratch buffer of // that size, if you like. -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span<se::DeviceMemoryBase> operand_buffers, + se::DeviceMemoryBase result_buffer, se::DeviceMemoryBase scratch_buf, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); -Status RunCudnnConvolution(CudnnConvParams params, +Status RunCudnnConvolution(const HloCustomCallInstruction* conv, + absl::Span<se::DeviceMemoryBase> operand_buffers, + se::DeviceMemoryBase result_buffer, se::ScratchAllocator* scratch_allocator, se::Stream* stream, se::dnn::ProfileResult* profile_result = nullptr); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc new file mode 100644 index 0000000000..3761c19cfc --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.cc @@ -0,0 +1,278 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { +namespace { + +// Describes a matched pattern: +// max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); +// Where side_input has the shape of output buffer, and bias is a 1D array with +// the dimension of number of output features. +struct ConvWithRelu { + HloInstruction* maximum; + HloCustomCallInstruction* conv; + HloInstruction* bias; + HloInstruction* side_input; + HloConstantInstruction* alpha_conv; + HloConstantInstruction* alpha_side_input; +}; + +absl::optional<ConvWithRelu> FindConvWithRelu(HloInstruction* instr) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Broadcast; + using match::Constant; + using match::GetTupleElement; + using match::Maximum; + using match::MultiplyAnyOrder; + using match::Op; + + // The pattern we want to match: + // max(0, alpha1 * conv(x, w) + alpha2 * side_input + broadcast(bias)); + // + // With its variants involving commute/reassociation of adds, multiplies, and + // max, and omission of alpha1, side_input, alpha2, or bias. + + HloInstruction* relu_input; + + // Match max(0, relu_input). + auto zero_pattern = Broadcast(match::ConstantScalar(0)); + if (!Match(instr, Maximum(zero_pattern, Op(&relu_input))) && + !Match(instr, Maximum(Op(&relu_input), zero_pattern))) { + return absl::nullopt; + } + HloInstruction* conv_instr = nullptr; + HloInstruction* alpha_conv_instr = nullptr; + HloInstruction* alpha_side_input_instr = nullptr; + HloInstruction* bias_broadcast_instr = nullptr; + HloInstruction* bias = nullptr; + HloInstruction* side_input = nullptr; + + // These nodes will not be in the returned value, but we need to check them + // for single use. + HloInstruction *gte = nullptr, *add1 = nullptr, *add2 = nullptr, + *mul1 = nullptr, *mul2 = nullptr; + + const auto bias_pattern = Broadcast(&bias_broadcast_instr, Op(&bias)); + const auto conv_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_conv_instr)); + auto conv_pattern = GetTupleElement( + >e, Op(&conv_instr).WithOpcode(HloOpcode::kCustomCall), 0); + return AnyOf<HloInstruction>( + MultiplyAnyOrder(&mul1, alpha_pattern, conv_pattern), conv_pattern); + }(); + const auto side_input_pattern = [&] { + auto alpha_pattern = Broadcast(Constant(&alpha_side_input_instr)); + // If bias is already matched, match arbitrary additional input as side + // input. Note this may force a cheap operation (e.g. broadcast) to be + // materialized into a large buffer, as large as the output buffer. + // + // TODO(timshen): If in practice there are significant false positives, we + // should fix it. + auto side_input_pattern = Op(&side_input); + return AnyOf<HloInstruction>( + MultiplyAnyOrder(&mul2, alpha_pattern, side_input_pattern), + side_input_pattern); + }(); + + { + // Try to match any of the following form of add, in any association: + // addends[0] + // addends[0] + addends[1] + // addends[0] + addends[1] + addends[2] + // + // Then try to match each addend with one of the three patterns: bias, conv, + // or side_input. Notice that side_input matching must go last, as it + // also matches a conv or a bias. + HloInstruction* addends[3] = {nullptr, nullptr, nullptr}; + auto add3_pattern = [&] { + auto add2_pattern = Add(&add1, Op(&addends[0]), Op(&addends[1])); + return AnyOf<HloInstruction>( + AddAnyOrder(&add2, add2_pattern, Op(&addends[2])), add2_pattern, + Op(&addends[0])); + }(); + CHECK(Match(relu_input, add3_pattern)); + for (auto addend : addends) { + if (addend) { + if (bias == nullptr && Match(addend, bias_pattern)) { + CHECK(bias); + } else if (conv_instr == nullptr && Match(addend, conv_pattern)) { + CHECK(conv_instr); + } else if (side_input == nullptr && Match(addend, side_input_pattern)) { + CHECK(side_input); + } else { + return absl::nullopt; + } + } + } + } + + if (conv_instr == nullptr) { + return absl::nullopt; + } + + for (HloInstruction* instr : + {conv_instr, bias_broadcast_instr, gte, add1, add2, mul1, mul2}) { + if (instr && instr->user_count() > 1) { + return absl::nullopt; + } + } + + auto conv = Cast<HloCustomCallInstruction>(conv_instr); + auto bias_broadcast = + CastOrNull<HloBroadcastInstruction>(bias_broadcast_instr); + + if (conv->custom_call_target() != kCudnnConvForwardCallTarget) { + return absl::nullopt; + } + + if (bias_broadcast) { + // TODO(timshen): handle bias_broadcast_instr->dimensions() == {}. + if (bias_broadcast_instr->dimensions().size() != 1) { + return absl::nullopt; + } + if (bias_broadcast_instr->dimensions(0) != + conv->convolution_dimension_numbers().output_feature_dimension()) { + return absl::nullopt; + } + } + + return ConvWithRelu{ + instr, + conv, + bias, + side_input, + CastOrNull<HloConstantInstruction>(alpha_conv_instr), + CastOrNull<HloConstantInstruction>(alpha_side_input_instr)}; +} + +StatusOr<std::unique_ptr<HloInstruction>> TryRewriteToCudnnForwardRelu( + ConvWithRelu match) { + auto conv = match.conv; + + HloComputation* computation = conv->parent(); + PrimitiveType element_type = conv->operand(0)->shape().element_type(); + + const auto get_alpha_value = + [](HloConstantInstruction* instr) -> StatusOr<double> { + TF_ASSIGN_OR_RETURN( + auto alpha, + Cast<HloConstantInstruction>(instr)->literal().Convert(F64)); + return alpha.GetFirstElement<double>(); + }; + + double alpha_conv = 1; + if (match.alpha_conv) { + TF_ASSIGN_OR_RETURN(alpha_conv, get_alpha_value(match.alpha_conv)); + } + + double alpha_side_input; + if (match.side_input) { + if (match.alpha_side_input) { + TF_ASSIGN_OR_RETURN(alpha_side_input, + get_alpha_value(match.alpha_side_input)); + } else { + alpha_side_input = 1; + } + } else { + CHECK(match.alpha_side_input == nullptr); + alpha_side_input = 0; + } + + auto bias = match.bias; + if (!bias) { + auto zero = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(element_type))); + + int64 num_output_feature = conv->shape().tuple_shapes(0).dimensions( + conv->convolution_dimension_numbers().output_feature_dimension()); + bias = computation->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout(element_type, + {num_output_feature}), + zero, {})); + } + + CHECK(bias); + std::vector<HloInstruction*> args = {conv->mutable_operand(0), + conv->mutable_operand(1), bias}; + if (match.side_input) { + args.push_back(match.side_input); + } + auto new_conv = computation->AddInstruction(HloInstruction::CreateCustomCall( + conv->shape(), args, kCudnnConvBiasActivationForwardCallTarget)); + new_conv->set_window(conv->window()); + new_conv->set_convolution_dimension_numbers( + conv->convolution_dimension_numbers()); + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig config, + conv->backend_config<CudnnConvBackendConfig>()); + config.set_activation_mode( + static_cast<int64>(se::dnn::ActivationMode::kRelu)); + config.set_conv_result_scale(alpha_conv); + config.set_side_input_scale(alpha_side_input); + TF_RETURN_IF_ERROR(new_conv->set_backend_config(config)); + + VLOG(1) << "Rewriting " << conv->name() << " to " << new_conv->name(); + return HloInstruction::CreateGetTupleElement(conv->shape().tuple_shapes(0), + new_conv, 0); +} + +} // namespace + +StatusOr<bool> CudnnFusedConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + std::vector<ConvWithRelu> matches; + int num_forward_convs = 0; + for (auto instr : computation->instructions()) { + auto match = FindConvWithRelu(instr); + if (match.has_value()) { + matches.push_back(*match); + } + if (auto call = DynCast<HloCustomCallInstruction>(instr)) { + if (call->custom_call_target() == kCudnnConvForwardCallTarget) { + num_forward_convs++; + } + } + } + VLOG(1) << "Identified cuDNN forward conv + relu: " << matches.size() + << " out of " << num_forward_convs << " forward convs."; + std::vector<std::pair<HloInstruction*, std::unique_ptr<HloInstruction>>> + replacements; + for (const ConvWithRelu& match : matches) { + TF_ASSIGN_OR_RETURN(auto new_instr, TryRewriteToCudnnForwardRelu(match)); + replacements.push_back({match.maximum, std::move(new_instr)}); + changed = true; + } + for (auto& replacement : replacements) { + TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( + replacement.first, std::move(replacement.second))); + } + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h index 498d4a9495..bd12aadded 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_options.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h @@ -13,21 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ -#include "tensorflow/compiler/xla/service/hlo_module_config.h" - -// Helper functions for querying options that are specific to the GPU backend. +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { -// Returns true if we should use heuristics to assign convolution layouts, as -// opposed to always assigning NCHW. -bool ConvUseLayoutHeuristic(const HloModuleConfig& config); +class CudnnFusedConvolutionRewriter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-fused-convolution-rewriter"; + } + + StatusOr<bool> Run(HloModule* module) override; +}; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_FUSED_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index c1aaa4bf04..6dcdaf1cfe 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); const Window& window = hlo->window(); - // TODO(b/31410564): Implement dilation for reduce-window. - if (window_util::HasDilation(window)) { - return Unimplemented( - "Dilation for reduce-window not implemented on GPU. " - "See b/31410564."); - } - PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), @@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( for (size_t i = 0; i < index.size(); ++i) { llvm::Value* stridden_index = NSWMul( index[i], index_typed_const(window.dimensions(i).stride())); + input_index[i] = NSWSub( + NSWAdd(stridden_index, + NSWMul(window_index[i], + index_typed_const( + window.dimensions(i).window_dilation()))), + index_typed_const(window.dimensions(i).padding_low())); + + // We need to verify that we are not in the dilated base area. + llvm::Value* dilation_condition = ICmpEQ( + SRem(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())), + index_typed_const(0)); + in_bounds = And(in_bounds, dilation_condition); + + // Apply base dilation to the index. input_index[i] = - NSWSub(NSWAdd(stridden_index, window_index[i]), - index_typed_const(window.dimensions(i).padding_low())); + SDiv(input_index[i], + index_typed_const(window.dimensions(i).base_dilation())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise // we are in the pad and so can skip the computation. This diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index 7e3f5775b8..f19996edfe 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -32,7 +32,7 @@ namespace gpu { // 2) The result of merging the fusion instruction into its users would not // increase bytes transferred. // -class FusionMerger : public HloPassInterface { +class FusionMerger : public HloModulePass { public: absl::string_view name() const override { return "fusion merger"; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 75f414e47f..e2ab00ce41 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -19,6 +19,7 @@ limitations under the License. #include <set> #include <vector> +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" @@ -27,22 +28,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - HloInstruction*& copy = hlo_to_copy_map_[hlo]; - if (copy == nullptr) { - TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); - } - return copy; -} - StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) { CopyInsertion generic_copy_insertion; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 8ffae18fe8..4c7e38ffeb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -25,20 +25,11 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public HloModulePass { public: absl::string_view name() const override { return "copy-insertion"; } StatusOr<bool> Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making - // duplicate copies. - StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 31a9f9b1be..5742632782 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" @@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) { } module_spec.AddCudaPtxInMemory(ptx().c_str()); - tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals; + absl::flat_hash_map<int64, se::DeviceMemoryBase> globals; se::ModuleHandle module_handle; executor->LoadModule(module_spec, &module_handle); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 38b0f8f15b..0e276282e4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -19,6 +19,7 @@ limitations under the License. #include <memory> #include <string> +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -101,7 +101,7 @@ class GpuExecutable : public Executable { const PointsToSet& GetRootPointsToSet() const; using BufferAllocToDeviceMemoryMap = - tensorflow::gtl::FlatMap<BufferAllocation::Index, se::DeviceMemoryBase>; + absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>; // Loads the PTX or CUBIN for this executable into `executor` and resolves the // globals corresponding to constant buffers. Returns a map mapping buffer diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h index bbb3340760..9c64b4d10c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -23,7 +23,7 @@ namespace xla { // his pass should run early in the HLO pipeline and checks for HLO constructs // which are not supported by the GPU backend and cannot be removed via HLO // transformations (eg, sparse layouts). -class GpuHloSupportChecker : public HloPassInterface { +class GpuHloSupportChecker : public HloModulePass { public: GpuHloSupportChecker() = default; ~GpuHloSupportChecker() override = default; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index d033faee8d..1ffe855750 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,11 +18,12 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -90,45 +91,46 @@ HeuristicLayoutAssignment(const HloInstruction* instr, // operands and the output shape. Depending on the underlying algorithm, one of // { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints) { - CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); - Shape input_shape; - Shape filter_shape; - Shape output_shape; - const auto& target = instr->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->shape().tuple_shapes(0); - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_shape = instr->shape().tuple_shapes(0); - filter_shape = instr->operand(1)->shape(); - output_shape = instr->operand(0)->shape(); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_shape = instr->operand(0)->shape(); - filter_shape = instr->shape().tuple_shapes(0); - output_shape = instr->operand(1)->shape(); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + HloCustomCallInstruction* instr, LayoutConstraints* constraints) { + Shape lhs_shape = instr->operand(0)->shape(); + Shape rhs_shape = instr->operand(1)->shape(); + Shape result_shape = instr->shape().tuple_shapes(0); + + Shape* input_shape; + Shape* filter_shape; + Shape* output_shape; + + TF_ASSIGN_OR_RETURN(auto kind, GetCudnnConvKind(instr)); + switch (kind) { + case CudnnConvKind::kForward: + case CudnnConvKind::kForwardActivation: + input_shape = &lhs_shape; + filter_shape = &rhs_shape; + output_shape = &result_shape; + break; + case CudnnConvKind::kBackwardInput: + input_shape = &result_shape; + filter_shape = &rhs_shape; + output_shape = &lhs_shape; + break; + case CudnnConvKind::kBackwardFilter: + input_shape = &lhs_shape; + filter_shape = &result_shape; + output_shape = &rhs_shape; + break; } { DataLayout input; FilterLayout filter; DataLayout output; - if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { - std::tie(input, filter, output) = - HeuristicLayoutAssignment(instr, stream_executor_); - } else { - input = DataLayout::kBatchDepthYX; - filter = FilterLayout::kOutputInputYX; - output = DataLayout::kBatchDepthYX; - } + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); TF_ASSIGN_OR_RETURN( - std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), - *output_shape.mutable_layout()), + std::tie(*input_shape->mutable_layout(), + *filter_shape->mutable_layout(), + *output_shape->mutable_layout()), StreamExecutorConvLayoutsToXlaLayouts( instr->convolution_dimension_numbers(), input, filter, output)); } @@ -141,24 +143,23 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( instr, /*index=*/{0})); // Set layouts of the instructions' shapes. - if (target == kCudnnConvForwardCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardInputCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); - } else if (target == kCudnnConvBackwardFilterCallTarget) { - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); - TF_RETURN_IF_ERROR( - constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); - } else { - LOG(FATAL) << "Unexpected custom call target: " - << instr->custom_call_target(); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(result_shape.layout(), *call_result_buf)); + // instr->operand(2), if exists, is the bias buffer. There is no need to + // assign layout to it, as it has only one dimension. + + // instr->opernad(3), if exists, is the side input buffer. + if (instr->operand_count() == 4) { + if (kind != CudnnConvKind::kForwardActivation) { + return InternalError( + "Invalid convolution. Conv has a side input, but kind is not fused " + "conv forward: %s", + instr->ToString()); + } + // The side input layout must match the output layout. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(*output_shape, instr, 3)); } return Status::OK(); } @@ -173,8 +174,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( ++iterator) { HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { - TF_RETURN_IF_ERROR( - AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); + TF_RETURN_IF_ERROR(AddBackendConstraintsToDnnConvCustomCall( + Cast<HloCustomCallInstruction>(instruction), constraints)); } // For batched dot we require the default layout. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index ce24af1cf8..4ba7989e9c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -29,8 +30,11 @@ namespace gpu { class GpuLayoutAssignment : public LayoutAssignment { public: explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func, se::StreamExecutor* stream_executor) - : LayoutAssignment(entry_computation_layout), + : LayoutAssignment(entry_computation_layout, + std::move(instruction_can_change_layout_func)), stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} @@ -47,7 +51,7 @@ class GpuLayoutAssignment : public LayoutAssignment { private: Status AddBackendConstraintsToDnnConvCustomCall( - HloInstruction* instr, LayoutConstraints* constraints); + HloCustomCallInstruction* instr, LayoutConstraints* constraints); se::StreamExecutor* stream_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index fbc8ddf599..04681cfcec 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -75,7 +75,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -163,7 +164,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -233,7 +235,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -314,7 +317,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } GpuLayoutAssignment layout_assignment( - &computation_layout, backend().default_stream_executor()); + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the @@ -348,8 +352,9 @@ TEST_F(LayoutAssignmentTest, DotLayout) { ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); - GpuLayoutAssignment layout_assignment(&computation_layout, - backend().default_stream_executor()); + GpuLayoutAssignment layout_assignment( + &computation_layout, LayoutAssignment::InstructionCanChangeLayout, + backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); Shape expected_shape = diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 4d5d8e99f8..b61f038739 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) { } // Compute the precise number of operands to the new fusion. - tensorflow::gtl::FlatSet<const HloInstruction*> operands( - a->operands().begin(), a->operands().end()); + absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(), + a->operands().end()); operands.insert(b->operands().begin(), b->operands().end()); // If there's an edge between `a` and `b`, don't count it: We're fusing that // producer -> consumer relationship. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 22f43bc08b..ec3d8f9405 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -129,6 +129,8 @@ const char* const kCudnnConvBackwardInputCallTarget = "__cudnn$convBackwardInput"; const char* const kCudnnConvBackwardFilterCallTarget = "__cudnn$convBackwardFilter"; +const char* const kCudnnConvBiasActivationForwardCallTarget = + "__cudnn$convBiasActivationForward"; bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() != HloOpcode::kCustomCall) { @@ -137,7 +139,8 @@ bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { const auto& target = hlo.custom_call_target(); return target == kCudnnConvForwardCallTarget || target == kCudnnConvBackwardInputCallTarget || - target == kCudnnConvBackwardFilterCallTarget; + target == kCudnnConvBackwardFilterCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget; } bool ImplementedAsLibraryCall(const HloInstruction& hlo) { @@ -145,59 +148,6 @@ bool ImplementedAsLibraryCall(const HloInstruction& hlo) { IsCustomCallToDnnConvolution(hlo); } -static HloInstruction* CreateCudnnConv(const char* call_target, - const Shape& shape, HloInstruction* lhs, - HloInstruction* rhs, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - HloComputation* computation = lhs->parent(); - - // This call returns a tuple of (conv_result, scratch_memory), where - // conv_result is the actual result of the convolution, and scratch_memory is - // temporary memory used by cudnn. - // - // At the moment, we don't know how much scratch memory this conv is going to - // use, so we put u8[0] in this place. Later on another pass will choose - // which conv algorithm to use, and at that point we'll modify the shape of - // this second tuple element. - Shape call_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - - HloInstruction* custom_call = computation->AddInstruction( - HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); - custom_call->set_window(window); - custom_call->set_convolution_dimension_numbers(dnums); - custom_call->set_feature_group_count(feature_group_count); - return custom_call; -} - -HloInstruction* CreateCudnnConvForward(const Shape& shape, - HloInstruction* input, - HloInstruction* kernel, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, - window, dnums, feature_group_count); -} - -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, - reverse_filter, window, dnums, feature_group_count); -} - -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count) { - return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, - output, window, dnums, feature_group_count); -} - bool IsReductionToVector(const HloInstruction& reduce) { if (HloOpcode::kReduce != reduce.opcode()) { return false; @@ -288,41 +238,35 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } -Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, - CudnnConvParams* params) { - TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, - custom_call->backend_config<CudnnConvBackendConfig>()); - const auto& target = custom_call->custom_call_target(); - const auto& lhs_shape = custom_call->operand(0)->shape(); - const auto& rhs_shape = custom_call->operand(1)->shape(); - const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); - - params->window = &custom_call->window(); - params->dnums = &custom_call->convolution_dimension_numbers(); - params->feature_group_count = custom_call->feature_group_count(); - params->algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( - backend_config.algorithm(), backend_config.tensor_ops_enabled())); - +StatusOr<CudnnConvKind> GetCudnnConvKind( + const HloCustomCallInstruction* instr) { + absl::string_view target = instr->custom_call_target(); if (target == kCudnnConvForwardCallTarget) { - params->kind = CudnnConvKind::kForward; - params->input_shape = &lhs_shape; - params->filter_shape = &rhs_shape; - params->output_shape = &conv_result_shape; - } else if (target == kCudnnConvBackwardInputCallTarget) { - params->kind = CudnnConvKind::kBackwardInput; - params->input_shape = &conv_result_shape; - params->filter_shape = &rhs_shape; - params->output_shape = &lhs_shape; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - params->kind = CudnnConvKind::kBackwardFilter; - params->input_shape = &lhs_shape; - params->filter_shape = &conv_result_shape; - params->output_shape = &rhs_shape; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); + return CudnnConvKind::kForward; + } + if (target == kCudnnConvBackwardInputCallTarget) { + return CudnnConvKind::kBackwardInput; + } + if (target == kCudnnConvBackwardFilterCallTarget) { + return CudnnConvKind::kBackwardFilter; + } + if (target == kCudnnConvBiasActivationForwardCallTarget) { + return CudnnConvKind::kForwardActivation; + } + return InternalError("Unexpected call target: %s", target); +} + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + case CudnnConvKind::kForwardActivation: + return "forward with activation"; } - return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 09c455cc1e..a64a616ab1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" -#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -30,6 +29,33 @@ limitations under the License. namespace xla { namespace gpu { +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output +}; + +StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr); + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. @@ -95,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. @@ -104,28 +131,6 @@ extern const char* const kCudnnConvBackwardFilterCallTarget; // kConvolution opcode. bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); -// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. -// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If -// you want just the conv result, you'll need to get-tuple-element the value -// returned by this function. -// -// The created cudnn call will use the default cudnn algorithm and no scratch -// space. -HloInstruction* CreateCudnnConvForward(const Shape& shape, - HloInstruction* input, - HloInstruction* kernel, - const Window& window, - const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); -HloInstruction* CreateCudnnConvBackwardInput( - const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); -HloInstruction* CreateCudnnConvBackwardFilter( - const Shape& shape, HloInstruction* input, HloInstruction* output, - const Window& window, const ConvolutionDimensionNumbers& dnums, - int64 feature_group_count); - // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. bool ImplementedAsLibraryCall(const HloInstruction& hlo); @@ -150,11 +155,6 @@ llvm::Value* EmitPrintf(absl::string_view fmt, llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); -// Populates params using conv, which must be a custom-call to a cudnn -// convolution. Does not modify any buffers in the params. -Status PopulateCudnnConvParams(const HloCustomCallInstruction* custom_call, - CudnnConvParams* params); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index b669881026..c792dd2ddb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -465,35 +465,18 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { if (IsCustomCallToDnnConvolution(*custom_call)) { const auto& assn = ir_emitter_context_->buffer_assignment(); - auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); - auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + std::vector<BufferAllocation::Slice> operand_slices; + operand_slices.reserve(custom_call->operand_count()); + for (const auto* operand : custom_call->operands()) { + operand_slices.push_back(GetAllocationSlice(*operand)); + } auto tuple_result_slice = GetAllocationSlice(*custom_call); auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const auto& target = custom_call->custom_call_target(); - BufferAllocation::Slice input_slice, filter_slice, output_slice; - - if (target == kCudnnConvForwardCallTarget) { - input_slice = lhs_slice; - filter_slice = rhs_slice; - output_slice = conv_result_slice; - } else if (target == kCudnnConvBackwardInputCallTarget) { - input_slice = conv_result_slice; - filter_slice = rhs_slice; - output_slice = lhs_slice; - } else if (target == kCudnnConvBackwardFilterCallTarget) { - input_slice = lhs_slice; - filter_slice = conv_result_slice; - output_slice = rhs_slice; - } else { - LOG(FATAL) << "Unexpected custom call target: " - << custom_call->custom_call_target(); - } - thunk_sequence_->emplace_back(absl::make_unique<ConvolutionThunk>( - Cast<HloCustomCallInstruction>(custom_call), input_slice, filter_slice, - output_slice, scratch_slice, tuple_result_slice)); + Cast<HloCustomCallInstruction>(custom_call), std::move(operand_slices), + conv_result_slice, scratch_slice, tuple_result_slice)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc index c21f76f6eb..835924024b 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -24,6 +24,7 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, HloInstruction* instr2) { - tensorflow::gtl::FlatSet<HloInstruction*> in_list; + absl::flat_hash_set<HloInstruction*> in_list; for (auto instr : instr1->operands()) { if (!IsProfitableOperand(instr)) { continue; @@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() { bool changed = false; RecomputeReachability(); - tensorflow::gtl::FlatSet<HloInstruction*> to_fuse; + absl::flat_hash_set<HloInstruction*> to_fuse; // Keep a list of the instructions to fuse after making all the fusion // decisions. We first aggressively add instructions to potential_fusion_list, // then filter out instructions that will be no longer fusible because of diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index dfdcf1875d..ac6c2c5565 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass<CudnnConvolutionRewriter>(); + pipeline.AddPass<CudnnFusedConvolutionRewriter>(); pipeline.AddPass<PadInsertion>(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<PadForTensorCores>(); @@ -230,14 +232,17 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // a layout-sensitive verifier! HloPassPipeline pipeline("layout assignment"); pipeline.AddPass<GpuLayoutAssignment>( - hlo_module->mutable_entry_computation_layout(), stream_exec); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout, stream_exec); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } { HloPassPipeline pipeline("post-layout_assignment"); - pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. @@ -283,8 +288,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, { HloPassFix<HloPassPipeline> fusion("fusion"); - fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + fusion.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false); fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true); fusion.AddPass<FusionMerger>(); @@ -296,7 +303,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, HloPassPipeline reduce_pipeline("reduce-precision"); reduce_pipeline.AddInvariantChecker<HloVerifier>( - /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false); + /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -322,8 +330,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true, - /*allow_mixed_precision=*/false); + pipeline.AddInvariantChecker<HloVerifier>( + /*layout_sensitive=*/true, + /*allow_mixed_precision=*/false, + LayoutAssignment::InstructionCanChangeLayout); // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -398,11 +408,11 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { "prefers >= 9.2.88). Compilation of XLA kernels below will likely " "fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas " "binary is sufficient."; - } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) { + } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " "miscompile XLA code, leading to incorrect results or " "invalid-address errors.\n\nYou do not need to update to CUDA " "9.2.88; cherry-picking the ptxas binary is sufficient."; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h index 8e97774750..c4a0b727cd 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/node_hash_map.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/executable.h" @@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler { tensorflow::condition_variable compilation_done_cv_; }; - // Don't even think about switching this to FlatMap; iterator stability is - // critical here. - std::unordered_map<CompilationCacheKey, CompilationCacheValue, - CompilationCacheHash, CompilationCacheEq> + // Don't even think about switching this to flat_hash_map; iterator stability + // is critical here. + absl::node_hash_map<CompilationCacheKey, CompilationCacheValue, + CompilationCacheHash, CompilationCacheEq> compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler); diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc index b0061fa655..e3869b5c36 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -36,15 +37,32 @@ static constexpr int64 kDesiredNumFeaturesFactor = 8; // there's additional room for speedups. Achieving those speedups without also // slowing other things down will likely require a more sophisticated heuristic, // possibly some form of auto-tuning. -static constexpr double kMaxBytesTouchedIncrease = 1.2; +// +// This value should be >= 4/3, otherwise the "dims of size 3 padded up to 4" +// special case inside PadShape won't fire. +static constexpr double kMaxBytesTouchedIncrease = 1.35; // Pads the given dimensions in the given shape up to a multiple of // kDesiredNumFeaturesFactor. static Shape PadShape(Shape s, absl::Span<const int64> dims) { for (int64 dim : dims) { int64 dim_to_pad_size = s.dimensions(dim); - int64 new_dim_to_pad_size = - RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + + // Round dim_to_pad_size up to the next multiple of + // kDesiredNumFeaturesFactor. + // + // Special case: dims of size 3 are rounded up to 4, not + // kDesiredNumFeaturesFactor. Empirically (and on the advice of nvidia), + // this helps, but as of writing, it's not supported by anything in the + // cudnn docs. + int64 new_dim_to_pad_size; + if (dim_to_pad_size == 3) { + new_dim_to_pad_size = 4; + } else { + new_dim_to_pad_size = + RoundUpToNearest(dim_to_pad_size, kDesiredNumFeaturesFactor); + } + s.set_dimensions(dim, new_dim_to_pad_size); } return s; @@ -209,7 +227,11 @@ static std::vector<HloInstruction*> GetRelevantConvs(HloComputation* comp) { std::vector<HloInstruction*> convs; for (HloInstruction* instr : comp->instructions()) { if (IsCustomCallToDnnConvolution(*instr) && - instr->operand(0)->shape().element_type() == F16) { + instr->operand(0)->shape().element_type() == F16 && + // TODO(timshen): Disable for fused conv for now. Implement it if it's + // needed. + Cast<HloCustomCallInstruction>(instr)->custom_call_target() != + kCudnnConvBiasActivationForwardCallTarget) { convs.push_back(instr); } } diff --git a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h index 11dc56a64f..e592a3774e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h +++ b/tensorflow/compiler/xla/service/gpu/pad_for_tensor_cores.h @@ -30,7 +30,7 @@ namespace gpu { // targeting before running this pass. // // TODO(jlebar): Also pad dots. -class PadForTensorCores : public HloPassInterface { +class PadForTensorCores : public HloModulePass { public: absl::string_view name() const override { return "pad for tensor cores"; } diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 2a6415d0b6..b42a19e3a2 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -30,7 +30,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + CHECK(conv.custom_call_target() == kCudnnConvForwardCallTarget || + conv.custom_call_target() == kCudnnConvBiasActivationForwardCallTarget); return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); @@ -161,12 +162,14 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract // out the shape of conv_result. - Shape old_conv_shape = conv->shape().tuple_shapes(0); - VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = CreateCudnnConvForward( - old_conv_shape, new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers(), conv->feature_group_count()); + std::vector<HloInstruction*> operands(conv->operands().begin(), + conv->operands().end()); + operands[0] = new_input; + operands[1] = new_kernel; + auto new_conv = conv->parent()->AddInstruction( + conv->CloneWithNewOperands(conv->shape(), operands)); + new_conv->set_window(new_conv_window); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); @@ -242,10 +245,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // The shape of the backward_conv CustomCall is a tuple (conv_result, // scratch_buffer). Extract out the shape of conv_result. - Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); - HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( - backward_conv_shape, padded_input, output, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + backward_conv->shape(), {padded_input, output})); + new_backward_conv->set_window(new_backward_conv_window); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -308,9 +311,12 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( - new_backward_conv_shape, output, filter, new_backward_conv_window, - backward_conv_dnums, backward_conv->feature_group_count()); + HloInstruction* new_backward_conv_call = + computation->AddInstruction(backward_conv->CloneWithNewOperands( + ShapeUtil::MakeTupleShape( + {new_backward_conv_shape, ShapeUtil::MakeShape(U8, {0})}), + {output, filter})); + new_backward_conv_call->set_window(new_backward_conv_window); // The CustomCall created above returns a tuple (conv_result, scratch_memory). // Extract out the two elements. @@ -380,7 +386,8 @@ StatusOr<bool> PadInsertion::RunOnComputation(HloComputation* computation) { } for (HloInstruction* instruction : convs) { const auto& target = instruction->custom_call_target(); - if (target == kCudnnConvForwardCallTarget) { + if (target == kCudnnConvForwardCallTarget || + target == kCudnnConvBiasActivationForwardCallTarget) { changed |= CanonicalizeForwardConvolution(instruction); } else if (target == kCudnnConvBackwardFilterCallTarget) { changed |= CanonicalizeBackwardFilterConvolution(instruction); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h index a622e894ed..25cdf64c4c 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h @@ -24,7 +24,7 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to cuDNN convolution. -class PadInsertion : public HloPassInterface { +class PadInsertion : public HloModulePass { public: absl::string_view name() const override { return "pad insertion"; } diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index cf9f102d31..375f68a159 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -62,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions( // // <num threads per block> * <max blocks per core> = <max threads per core> - auto threads_per_core = device_desc.threads_per_core_limit(); - auto blocks_per_core = device_desc.blocks_per_core_limit(); - int64 threads_per_block; - if (threads_per_core != 0 && blocks_per_core != 0) { - threads_per_block = device_desc.threads_per_core_limit() / - device_desc.blocks_per_core_limit(); - } else { + int64 threads_per_block = device_desc.threads_per_block_limit(); + if (threads_per_block == 0) { static std::atomic<int64> log_count{0}; if (log_count.fetch_add(1) < 8) { LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h index c2df83aaa4..52d38b6f20 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_ +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace gpu { @@ -34,7 +34,7 @@ class StreamAssignment { private: int stream_count_ = 1; // At least the main stream. - tensorflow::gtl::FlatMap<const HloInstruction*, int> hlo_to_stream_number_; + absl::flat_hash_map<const HloInstruction*, int> hlo_to_stream_number_; }; // Assigns GPU streams to instructions in `module`. diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index db4a33dc56..a725533567 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -25,15 +25,17 @@ filegroup( ) load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "gpu_codegen_test", testonly = True, srcs = ["gpu_codegen_test.cc"], hdrs = ["gpu_codegen_test.h"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -48,9 +50,7 @@ cc_library( tf_cc_test( name = "gpu_copy_test", srcs = ["gpu_copy_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -67,9 +67,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ftz_test", srcs = ["gpu_ftz_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/core:test_main", @@ -79,9 +77,7 @@ tf_cc_test( tf_cc_test( name = "gpu_index_test", srcs = ["gpu_index_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -102,9 +98,7 @@ tf_cc_test( tf_cc_test( name = "gpu_infeed_test", srcs = ["infeed_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -125,9 +119,7 @@ tf_cc_test( tf_cc_test( name = "gpu_kernel_tiling_test", srcs = ["gpu_kernel_tiling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo", @@ -142,7 +134,7 @@ tf_cc_test( tf_cc_test( name = "gpu_ldg_test", srcs = ["gpu_ldg_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -159,9 +151,7 @@ tf_cc_test( tf_cc_test( name = "gpu_noalias_test", srcs = ["gpu_noalias_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla:literal", @@ -178,9 +168,7 @@ tf_cc_test( tf_cc_test( name = "gpu_fusion_test", srcs = ["gpu_fusion_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -194,9 +182,7 @@ tf_cc_test( tf_cc_test( name = "gpu_unrolling_test", srcs = ["gpu_unrolling_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:hlo_module_config", @@ -211,9 +197,7 @@ tf_cc_test( name = "gpu_alignment_test", testonly = True, srcs = ["gpu_alignment_test.cc"], - tags = [ - "requires-gpu-sm35", - ], + tags = tf_cuda_tests_tags(), deps = [ ":gpu_codegen_test", "//tensorflow/compiler/xla/service:gpu_plugin", @@ -225,3 +209,17 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +tf_cc_test( + name = "cudnn_fused_convolution_rewriter_test", + srcs = ["cudnn_fused_convolution_rewriter_test.cc"], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc new file mode 100644 index 0000000000..5632cac186 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/cudnn_fused_convolution_rewriter_test.cc @@ -0,0 +1,283 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/str_replace.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class CudnnFusedConvolutionRewriterTest : public HloTestBase { + protected: + string GetOptimizedHlo(absl::string_view hlo_string) { + return backend() + .compiler() + ->RunHloPasses(ParseHloString(hlo_string, GetModuleConfigForTest()) + .ConsumeValueOrDie(), + backend().default_stream_executor(), + backend().memory_allocator()) + .ConsumeValueOrDie() + ->ToString(); + } + + void TestMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + const string optimized_hlo_string = GetOptimizedHlo(hlo_with_new_type); + EXPECT_EQ(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convForward")) + << optimized_hlo_string; + EXPECT_NE(absl::string_view::npos, + optimized_hlo_string.find("__cudnn$convBiasActivationForward")) + << optimized_hlo_string; + EXPECT_TRUE(RunAndCompare(hlo_with_new_type, ErrorSpec{0.01})) + << optimized_hlo_string; + } + } + + void TestNotMatchWithAllTypes(absl::string_view hlo_string) { + for (absl::string_view type : {"f16", "f32", "f64"}) { + const string hlo_with_new_type = + absl::StrReplaceAll(hlo_string, {{"TYPE", type}}); + string optimized_hlo = GetOptimizedHlo(hlo_with_new_type); + EXPECT_NE(absl::string_view::npos, + optimized_hlo.find("__cudnn$convForward")) + << optimized_hlo; + EXPECT_EQ(absl::string_view::npos, + optimized_hlo.find("__cudnn$convBiasActivationForward")) + << optimized_hlo; + } + } +}; + +TEST_F(CudnnFusedConvolutionRewriterTest, TestConvOnly) { + // max(0, conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(zeros, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBias) { + // max(0, conv(x, w) + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + bias = TYPE[64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestSideInputOnly) { + // max(0, conv(x, w) + side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestBiasAndSideInput) { + // max(0, conv(x, w) + side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConv) { + // max(0, 0.999994934 * conv(x, w)); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,32,9,9] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + alpha_conv = TYPE[1,32,9,9] broadcast(alpha_conv_scalar), dimensions={} + scaled_conv = TYPE[1,32,9,9] multiply(conv, alpha_conv) + ROOT relu = TYPE[1,32,9,9] maximum(zeros, scaled_conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndSideInput) { + // max(0, conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestScaledConvAndScaledSideInput) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + add1 = TYPE[1,3,3,64] add(scaled_conv, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add1) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, + TestScaledConvAndScaledSideInputWithBias) { + // max(0, 0.999994934 * conv(x, w) + 0.899994934 * side_input + bias); + TestMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + alpha_conv_scalar = TYPE[] constant(0.999994934) + alpha_conv = TYPE[1,3,3,64] broadcast(alpha_conv_scalar), dimensions={} + alpha_side_input_scalar = TYPE[] constant(0.899994934) + alpha_side_input = TYPE[1,3,3,64] broadcast(alpha_side_input_scalar), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input = TYPE[1,3,3,64] parameter(2) + bias = TYPE[64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + scaled_conv = TYPE[1,3,3,64] multiply(conv, alpha_conv) + scaled_side_input = TYPE[1,3,3,64] multiply(side_input, alpha_side_input) + broadcasted_bias = TYPE[1,3,3,64] broadcast(bias), dimensions={3} + add1 = TYPE[1,3,3,64] add(scaled_conv, broadcasted_bias) + add2 = TYPE[1,3,3,64] add(add1, scaled_side_input) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchMaxZeroOnly) { + // max(0.1, conv(x, w)) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + point_one = TYPE[] constant(0.1) + point_ones = TYPE[1,32,9,9] broadcast(point_one), dimensions={} + + input = TYPE[1,17,9,9] parameter(0) + filter = TYPE[3,3,17,32] parameter(1) + + conv = TYPE[1,32,9,9] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, feature_group_count=1 + ROOT relu = TYPE[1,32,9,9] maximum(point_ones, conv) + })"); +} + +TEST_F(CudnnFusedConvolutionRewriterTest, TestMatchBroadcastedBiasOnly) { + // max(0, conv(x, w) + side_input1 + side_input2) shouldn't match. + TestNotMatchWithAllTypes(R"( + HloModule Test + + ENTRY Test { + zero = TYPE[] constant(0) + zeros = TYPE[1,3,3,64] broadcast(zero), dimensions={} + + input = TYPE[1,3,3,64] parameter(0) + filter = TYPE[3,3,64,64] parameter(1) + side_input1 = TYPE[1,3,3,64] parameter(2) + side_input2 = TYPE[1,3,3,64] parameter(3) + + conv = TYPE[1,3,3,64] convolution(input, filter), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, feature_group_count=1 + add1 = TYPE[1,3,3,64] add(conv, side_input2) + add2 = TYPE[1,3,3,64] add(add1, side_input1) + ROOT relu = TYPE[1,3,3,64] maximum(zeros, add2) + })"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index e0f3a7e0e2..9220865867 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -18,14 +18,16 @@ limitations under the License. #include <algorithm> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/util.h" namespace xla { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +using absl::flat_hash_map; +using absl::flat_hash_set; /*static*/ StatusOr<int64> HeapSimulator::MinimumMemoryForModule( @@ -56,7 +58,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, @@ -88,7 +90,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run( const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, /*schedule=*/nullptr, memory_by_computation); @@ -115,8 +117,10 @@ Status HeapSimulator::RunComputation( // 'used_buffers' is the reverse map - it tracks which buffers were used by an // instruction, so that we can remove the instructions from a buffer's live // set after they are visited. - FlatMap<const BufferValue*, FlatSet<const HloInstruction*>> live_buffers; - FlatMap<const HloInstruction*, FlatSet<const BufferValue*>> used_buffers; + flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>> + live_buffers; + flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>> + used_buffers; auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( const HloInstruction* user, const BufferValue* buffer) { @@ -213,7 +217,7 @@ Status HeapSimulator::RunComputation( VLOG(4) << " Removing user " << instruction->name() << " from buffer " << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); - FlatSet<const HloInstruction*>* live_set = &it->second; + flat_hash_set<const HloInstruction*>* live_set = &it->second; live_set->erase(instruction); if (live_set->empty()) { live_buffers.erase(it); @@ -235,7 +239,8 @@ Status HeapSimulator::RunComputation( // that we should assign. // Make sure each buffer get reused at most once. - FlatSet<const BufferValue*> reused_buffers; + flat_hash_set<const BufferValue*> reused_buffers; + int64 alloc_size_by_instruction = 0; for (const BufferValue* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; @@ -268,14 +273,15 @@ Status HeapSimulator::RunComputation( if (!shared) { VLOG(3) << " Allocating: " << buffer->ToString(); + alloc_size_by_instruction += size_fn_(*buffer); Alloc(buffer, instruction); } } // Account for the memory used by subcomputations when estimating the // current heap size. if (memory_by_computation_ != nullptr) { - algorithm_->AccountForSubcomputationMemory(instruction, - *memory_by_computation_); + algorithm_->AccountForSubcomputationMemory( + instruction, alloc_size_by_instruction, *memory_by_computation_); } // If all computations in the module have been scheduled, we can save memory @@ -323,7 +329,7 @@ Status HeapSimulator::RunComputation( to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const BufferValue* buffer = buffer_pending.first; - const FlatSet<const HloInstruction*>& pending = buffer_pending.second; + const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; to_free.push_back(buffer); @@ -345,7 +351,7 @@ HeapSimulator::HeapSimulator( std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation) : no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()), algorithm_(std::move(algorithm)), @@ -381,10 +387,8 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - const HloInstruction* instruction_to_calc_aliasing = - memory_by_computation_ == nullptr ? nullptr : instruction; - algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); - no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); + algorithm_->Alloc(buffer, size); + no_fragmentation_stats_->Alloc(buffer, size); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -522,21 +526,9 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } -void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - if (instruction == nullptr || - (instruction->opcode() != HloOpcode::kWhile && - instruction->opcode() != HloOpcode::kCall && - instruction->opcode() != HloOpcode::kConditional)) { - Alloc(buffer, size); - } -} - void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. @@ -550,6 +542,14 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( } } } + if (max_subcomputation_bytes > 0 && + (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + max_subcomputation_bytes -= alloc_size_by_instruction; + } max_heap_size_ = std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } @@ -736,4 +736,209 @@ HeapSimulator::Result LazyBestFitHeap::Finish() { return result_; } +void GlobalDecreasingSizeBestFitHeap::Alloc(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + result_.chunk_map.emplace(buffer, Chunk{0, 0}); + return; + } + auto emplace_result = buffer_intervals_.emplace( + buffer, BufferInterval{buffer, size, current_time_, -1}); + DCHECK(emplace_result.second); + ++current_time_; +} + +void GlobalDecreasingSizeBestFitHeap::Free(const BufferValue* buffer, + int64 size) { + // Degenerate case: 0-sized buffers are always allocated at offset 0. + if (size == 0) { + return; + } + BufferInterval& buffer_interval = FindOrDie(buffer_intervals_, buffer); + DCHECK_EQ(buffer_interval.buffer, buffer); + DCHECK_EQ(buffer_interval.size, size); + DCHECK_EQ(buffer_interval.end, -1); + buffer_interval.end = current_time_; + ++current_time_; +} + +namespace { + +// Node in BufferIntervalTree that stores the alloc and free times of a buffer, +// and the chunk assigned to it. +struct BufferIntervalTreeNode { + // Alloc time. + int64 start; + // Free time. + int64 end; + // Maximum free time of all nodes in the subtree where this node is the root. + int64 subtree_end; + // Allocated chunk for the buffer. + HeapSimulator::Chunk chunk; + // Left child. + BufferIntervalTreeNode* left; + // Right child. + BufferIntervalTreeNode* right; +}; + +// An interval tree that can query buffers overlapping in time. +class BufferIntervalTree { + public: + explicit BufferIntervalTree(int capacity) : node_storage_(capacity) {} + + using Chunk = HeapSimulator::Chunk; + + // Adds a buffer to the interval tree, with the time interval and allocated + // chunk specified. + void Add(int64 start, int64 end, const Chunk& chunk) { + int index = node_count_; + DCHECK_LT(index, node_storage_.size()); + ++node_count_; + + node_storage_[index] = + BufferIntervalTreeNode{start, end, end, chunk, nullptr, nullptr}; + + if (index == 0) { + // This is root. + return; + } + + BufferIntervalTreeNode* parent = &node_storage_[0]; + while (true) { + parent->subtree_end = std::max(parent->subtree_end, end); + if (parent->start > start) { + if (parent->left == nullptr) { + parent->left = &node_storage_[index]; + return; + } + parent = parent->left; + } else { + if (parent->right == nullptr) { + parent->right = &node_storage_[index]; + return; + } + parent = parent->right; + } + } + } + + // Returns vector of allocated chunks that overlap with the given time + // interval. + std::vector<Chunk> ChunksOverlappingInTime(int64 start, int64 end) { + std::vector<Chunk> result; + if (node_count_ == 0) { + return result; + } + std::vector<BufferIntervalTreeNode*> visiting_stack; + visiting_stack.push_back(&node_storage_[0]); + while (!visiting_stack.empty()) { + BufferIntervalTreeNode* top = visiting_stack.back(); + visiting_stack.pop_back(); + if (start > top->subtree_end) { + continue; + } + if (top->left != nullptr) { + visiting_stack.push_back(top->left); + } + if (top->start <= end && top->end >= start) { + result.push_back(top->chunk); + } + if (end < top->start) { + continue; + } + if (top->right != nullptr) { + visiting_stack.push_back(top->right); + } + } + return result; + } + + private: + int64 node_count_ = 0; + std::vector<BufferIntervalTreeNode> node_storage_; +}; + +} // namespace + +HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { + std::vector<BufferInterval> sorted_buffer_intervals; + for (auto& entry : buffer_intervals_) { + sorted_buffer_intervals.push_back(entry.second); + } + std::sort(sorted_buffer_intervals.begin(), sorted_buffer_intervals.end(), + [](const BufferInterval& x, const BufferInterval& y) { + if (x.size != y.size) { + return x.size > y.size; + } + if (x.end - x.start != y.end - y.start) { + return x.end - x.start > y.end - y.start; + } + return x.buffer->id() < y.buffer->id(); + }); + + BufferIntervalTree interval_tree(sorted_buffer_intervals.size()); + for (auto& buffer_interval : sorted_buffer_intervals) { + auto chunks_overlapping_in_time = interval_tree.ChunksOverlappingInTime( + buffer_interval.start, buffer_interval.end); + std::sort( + chunks_overlapping_in_time.begin(), chunks_overlapping_in_time.end(), + [](const Chunk& x, const Chunk& y) { return x.offset < y.offset; }); + + // Find the minimum free chunk that can hold this buffer. + Chunk min_fit_chunk{-1, INT64_MAX}; + auto use_free_chunk_if_smaller = [&](int64 free_offset, int64 free_size) { + if (free_size < buffer_interval.size) { + return; + } + + if (free_size < min_fit_chunk.size) { + min_fit_chunk = {free_offset, free_size}; + } + }; + + int64 offset = 0; + for (auto& chunk : chunks_overlapping_in_time) { + if (offset < chunk.offset) { + use_free_chunk_if_smaller(offset, chunk.offset - offset); + } + offset = + std::max(offset, RoundUpToNearest(chunk.chunk_end(), alignment_)); + } + use_free_chunk_if_smaller(offset, result_.heap_size - offset); + + if (min_fit_chunk.offset == -1) { + // Increase the heap size to fit in the last free chunk. + result_.heap_size = offset + buffer_interval.size; + min_fit_chunk = {offset, buffer_interval.size}; + } + + min_fit_chunk.size = buffer_interval.size; + const auto emplace_result = + result_.chunk_map.emplace(buffer_interval.buffer, min_fit_chunk); + DCHECK(emplace_result.second); + + interval_tree.Add(buffer_interval.start, buffer_interval.end, + min_fit_chunk); + } + return result_; +} + +HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { + DCHECK(!algorithms_.empty()); + std::vector<Result> results(algorithms_.size()); + int64 min_size = INT64_MAX; + int min_size_index = -1; + for (int i = 0; i < algorithms_.size(); ++i) { + results[i] = algorithms_[i]->Finish(); + if (results[i].heap_size < min_size) { + min_size = results[i].heap_size; + min_size_index = i; + } + } + + DCHECK_GE(min_size_index, 0); + return results[min_size_index]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index ffbf947d5a..dbbf43082f 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -21,6 +21,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/buffer_value_containers.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,7 +58,7 @@ class HeapSimulator { // Result represents the result of the heap simulation. struct Result { // The assignment of buffers to chunks. - tensorflow::gtl::FlatMap<const BufferValue*, Chunk> chunk_map; + absl::flat_hash_map<const BufferValue*, Chunk> chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -100,7 +100,7 @@ class HeapSimulator { const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given @@ -130,7 +130,7 @@ class HeapSimulator { const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation = nullptr); private: @@ -140,7 +140,7 @@ class HeapSimulator { HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule = nullptr, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation = nullptr); ~HeapSimulator(); @@ -172,7 +172,7 @@ class HeapSimulator { // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. const HloSchedule* schedule_; - const tensorflow::gtl::FlatMap<const HloComputation*, int64>* + const absl::flat_hash_map<const HloComputation*, int64>* memory_by_computation_; // In addition to Alloc and Free, the heap simulator exposes a concept of @@ -193,12 +193,12 @@ class HeapSimulator { const BufferValue* canonical = nullptr; int64 refcount = 0; }; - tensorflow::gtl::FlatMap<const BufferValue*, std::shared_ptr<SharedGroup>> + absl::flat_hash_map<const BufferValue*, std::shared_ptr<SharedGroup>> shared_buffers_; // Hold some sets for error-checking the sequence of Alloc and Free calls. - tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_; - tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_; + absl::flat_hash_set<const BufferValue*> allocated_buffers_; + absl::flat_hash_set<const BufferValue*> freed_buffers_; // Debugging information filled in while the heap simulator runs. HeapSimulatorTrace debug_trace_; @@ -218,12 +218,6 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; - // NoFragmentationStatsHeap overrides this method. - virtual void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) { - Alloc(buffer, size); - } - // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing // between computations entirely correctly. We are careful to not double count @@ -235,7 +229,9 @@ class HeapAlgorithm { // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + // The total number of bytes allocated by instruction. + int64 alloc_size_by_instruction, + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) {} // Free de-allocates a previously allocated buffer. @@ -257,12 +253,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; - void Alloc(const BufferValue* buffer, int64 size, - const HloInstruction* instruction) override; - void AccountForSubcomputationMemory( - const HloInstruction* instruction, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const HloInstruction* instruction, int64 alloc_size_by_instruction, + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) override; void Free(const BufferValue* buffer, int64 size) override; @@ -351,6 +344,67 @@ class LazyBestFitHeap : public HeapAlgorithm { std::set<Chunk, OrderChunkByIncreasingSize> free_; }; +// GlobalDecreasingSizeBestFitHeap collects the live intervals of all buffers, +// then allocates them in decreasing sizes regardless of the alloc/free time. It +// internally tracks the allocated buffers and their live intervals; when +// allocating a buffer, it finds the best-fit free chunk during its live +// interval. +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { + public: + GlobalDecreasingSizeBestFitHeap(int64 alignment) : alignment_(alignment) {} + ~GlobalDecreasingSizeBestFitHeap() override {} + + void Alloc(const BufferValue* buffer, int64 size) override; + void Free(const BufferValue* buffer, int64 size) override; + Result Finish() override; + + private: + int64 alignment_; + Result result_; + + // The current time represented as an integer. It increments by 1 at each + // Alloc or Free call. + int64 current_time_ = 0; + + // BufferInterval stores a buffer's size and time interval. + struct BufferInterval { + const BufferValue* buffer; + int64 size; + // Alloc time of the buffer. + int64 start; + // Free time of the buffer. + int64 end; + }; + absl::flat_hash_map<const BufferValue*, BufferInterval> buffer_intervals_; +}; + +// A heap algorithm that chooses the best results from other algorithms added to +// it. +class ChooseBestHeapAlgorithm : public HeapAlgorithm { + public: + ChooseBestHeapAlgorithm( + std::unique_ptr<std::vector<std::unique_ptr<HeapAlgorithm>>> algorithms) + : algorithms_(std::move(*algorithms)) {} + ~ChooseBestHeapAlgorithm() override {} + + void Alloc(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Alloc(buffer, size); + } + } + + void Free(const BufferValue* buffer, int64 size) override { + for (auto& algorithm : algorithms_) { + algorithm->Free(buffer, size); + } + } + + Result Finish() override; + + private: + std::vector<std::unique_ptr<HeapAlgorithm>> algorithms_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HEAP_SIMULATOR_H_ diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 957c4a6891..e30e7667f3 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_value.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -98,6 +98,124 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } +TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { + // HloModule SubcomputationAccounting + + // %WhileBody (body_param: f32[4]) -> f32[4] { + // %body_param = f32[4]{0} parameter(0) + // %constant.1 = f32[4]{0} constant({1, 1, 1, 1}) + // ROOT %subtract = f32[4]{0} subtract(f32[4]{0} %body_param, f32[4]{0} + // %constant.1) + // } + + // %WhileCond (cond_param: f32[4]) -> pred[] { + // %cond_param = f32[4]{0} parameter(0) + // %slice = f32[1]{0} slice(f32[4]{0} %cond_param), slice={[0:1]} + // %reshape = f32[] reshape(f32[1]{0} %slice) + // %constant = f32[] constant(0) + // ROOT %not-equal-to = pred[] not-equal-to(f32[] %reshape, f32[] %constant) + // } + + // ENTRY %SubcomputationAccounting () -> f32[2,4] { + // %constant.3 = f32[2,4]{1,0} constant(f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, + // 3, 4 } }) %transpose = f32[2,4]{1,0} transpose(f32[2,4]{1,0} + // %constant.3), dimensions={0,1} %constant.2 = f32[4]{0} constant({1, 1, 1, + // 1}) %while = f32[4]{0} while(f32[4]{0} %constant.2), + // condition=%WhileCond, body=%WhileBody %broadcast = f32[2,4]{1,0} + // broadcast(f32[4]{0} %while), dimensions={1} ROOT %add = f32[2,4]{1,0} + // add(f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) + // } + + auto module = CreateNewVerifiedModule(); + const Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); + const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); + + // reshape(slice(param)) != 0 + // Needs 5 bytes + auto cond_builder = HloComputation::Builder("WhileCond"); + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "cond_param")); + HloInstruction* slice = + cond_builder.AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(F32, {1}), cond_param, {0}, {1}, {1})); + HloInstruction* reshape = + cond_builder.AddInstruction(HloInstruction::CreateReshape(r0f32, slice)); + HloInstruction* zero = cond_builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0))); + HloInstruction* cond_comparison = + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, reshape, zero)); + auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); + + // param - 1 + // Needs 16 bytes + auto body_builder = HloComputation::Builder("WhileBody"); + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "body_param")); + HloInstruction* one_vector = + body_builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1, 1, 1, 1}))); + HloInstruction* subtract = + body_builder.AddInstruction(HloInstruction::CreateBinary( + r1f32, HloOpcode::kSubtract, body_param, one_vector)); + auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); + + // transpose(matrix) + bcast(while) + auto builder = HloComputation::Builder(TestName()); + HloInstruction* while_init = + builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({1, 1, 1, 1}))); + // Creates 16 bytes, ignoring subcomputations + HloInstruction* while_loop = + builder.AddInstruction(HloInstruction::CreateWhile( + r1f32, cond_computation, body_computation, while_init)); + + // Creates 32 bytes and frees 16 + HloInstruction* bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(r2f32, while_loop, {1})); + + HloInstruction* matrix = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>( + {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); + // Creates 32 bytes + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); + + // Creates 32 bytes and frees 64 + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); + + auto entry_computation = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + std::vector<HloInstruction*> cond_vec = {cond_param, slice, reshape, zero, + cond_comparison}; + std::vector<HloInstruction*> while_body_vec = {body_param, one_vector, + subtract}; + std::vector<HloInstruction*> entry_comp_vec = {while_init, while_loop, bcast, + matrix, transpose, add}; + schedule.set_sequence(cond_computation, cond_vec); + schedule.set_sequence(body_computation, while_body_vec); + schedule.set_sequence(entry_computation, entry_comp_vec); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + }; + absl::flat_hash_map<const HloComputation*, int64> memory_by_computation; + memory_by_computation[cond_computation] = 5; + memory_by_computation[body_computation] = 16; + std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = + TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); + + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( + *entry_computation, schedule.sequence(entry_computation), + *points_to_analysis, size_fn, &memory_by_computation) + .ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; @@ -174,7 +292,7 @@ class HeapSimulatorTracker { // Construct the module sequence grouped by computation. HloSchedule schedule(module_.get()); - tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position; + absl::flat_hash_map<const HloInstruction*, int> reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; schedule.GetOrCreateSequence(instruction->parent()) @@ -1021,5 +1139,135 @@ TEST_F(LazyBestFitHeapTest, Alignment) { EXPECT_EQ(128, result.chunk_map.at(buffer_e_).offset); } +class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase {}; + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(0, result.heap_size); + EXPECT_EQ(0, result.chunk_map.size()); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { + // space + // ^ + // | +---a---+ + // | +-------+ + // | +---c---+ + // | +-------+ + // | | b | + // | +-------+ + // | +-------+ + // | | | + // | | d | + // | +-------+ + // -----------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 30); + heap.Alloc(buffer_c_, 20); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_a_, 10); + heap.Free(buffer_b_, 30); + heap.Free(buffer_c_, 20); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(100, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(40, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(70, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | | + // | | d | + // | +---a---+ +-------+ + // | + // | +-------+ + // | | | + // | | c | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 50); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 40); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 50); + heap.Free(buffer_d_, 40); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(120, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_d_).size); + + EXPECT_EQ(60, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(100, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(60, result.chunk_map.at(buffer_d_).offset); +} + +TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { + // space + // ^ + // | +-------+ + // | +---b---+ + // | +-------+ + // | | d | + // | +--a--+ +-------+ + // | +-------+ + // | | | + // | | c | + // | +-------+ + // | +-------+ + // | | | + // | | e | + // | | | + // | +-------+ + // ---------------------> time + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + heap.Alloc(buffer_a_, 10); + heap.Alloc(buffer_b_, 20); + heap.Alloc(buffer_c_, 40); + heap.Free(buffer_a_, 10); + heap.Alloc(buffer_d_, 30); + heap.Alloc(buffer_e_, 50); + heap.Free(buffer_b_, 20); + heap.Free(buffer_c_, 40); + heap.Free(buffer_d_, 30); + heap.Free(buffer_e_, 50); + + const HeapSimulator::Result result = heap.Finish(); + EXPECT_EQ(140, result.heap_size); + EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); + EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); + EXPECT_EQ(40, result.chunk_map.at(buffer_c_).size); + EXPECT_EQ(30, result.chunk_map.at(buffer_d_).size); + EXPECT_EQ(50, result.chunk_map.at(buffer_e_).size); + + EXPECT_EQ(90, result.chunk_map.at(buffer_a_).offset); + EXPECT_EQ(120, result.chunk_map.at(buffer_b_).offset); + EXPECT_EQ(50, result.chunk_map.at(buffer_c_).offset); + EXPECT_EQ(90, result.chunk_map.at(buffer_d_).offset); + EXPECT_EQ(0, result.chunk_map.at(buffer_e_).offset); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index b19ec12638..1ea26ddd5b 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 53 +// Next ID: 56 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -124,9 +124,13 @@ message HloInstructionProto { // The string representation of the infeed configuration. bytes infeed_config = 27; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a external target (eg, global symbol) to call, only present for + // kCustomCall. string custom_call_target = 28; + // Opaque string, only present for kCustomCall. + string custom_call_opaque = 53; + // Shape of outfeed request. xla.Shape outfeed_shape = 29; @@ -176,6 +180,10 @@ message HloInstructionProto { // Collective permute field. repeated SourceTarget source_target_pairs = 52; + + // Sharding for kDomain instructions. + xla.OpSharding domain_entry_sharding = 54; + xla.OpSharding domain_exit_sharding = 55; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0986da65cb..c3da12e273 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -119,7 +121,7 @@ class BufferValueMap { } // Return a set of all the values in the given buffer. - const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer( + const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer( BufferNumber buffer_number) const { return buffers_.at(buffer_number); } @@ -142,7 +144,7 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet<const HloValue*>& old_value_set = + absl::flat_hash_set<const HloValue*>& old_value_set = buffers_.at(old_buffer_number); old_value_set.erase(&value); if (old_value_set.empty()) { @@ -290,13 +292,11 @@ class BufferValueMap { const HloDataflowAnalysis& dataflow_; // A map containing the set of values contained in each buffer. - tensorflow::gtl::FlatMap<BufferNumber, - tensorflow::gtl::FlatSet<const HloValue*>> + absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>> buffers_; // A map indicating which buffer each value is contained in. - tensorflow::gtl::FlatMap<const HloValue*, BufferNumber> - value_to_buffer_number_; + absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_; // The buffer number of the next buffer to be created. BufferNumber next_buffer_number_ = 0; @@ -352,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous( bool HloAliasAnalysis::InstructionBuffersAreDistinct( const HloInstruction* instruction) const { - tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen; + absl::flat_hash_set<const HloBuffer*> buffers_seen; for (const auto& pair : dataflow_analysis_->GetInstructionValueSet(instruction)) { const HloValueSet& value_set = pair.second; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index e345804537..372f99ff01 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_buffer.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -110,7 +111,7 @@ class HloAliasAnalysis { std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_; // A map indicating which buffer a value is contained in. - tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_; + absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_; // A lazily constructed vector containing all HloBuffers sorted by // HloBuffer::Id. diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc index 6c11a073b7..9c3aa0e64d 100644 --- a/tensorflow/compiler/xla/service/hlo_buffer.cc +++ b/tensorflow/compiler/xla/service/hlo_buffer.cc @@ -20,6 +20,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h index 658643b427..24910ca07b 100644 --- a/tensorflow/compiler/xla/service/hlo_clone_context.h +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -18,8 +18,8 @@ limitations under the License. #include <string> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -73,12 +73,12 @@ class HloCloneContext { return FindOrDie(computations_, old_computation); } - const tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>& + const absl::flat_hash_map<const HloInstruction*, HloInstruction*>& cloned_instructions() const { return instructions_; } - const tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>& + const absl::flat_hash_map<const HloComputation*, HloComputation*>& cloned_computations() const { return computations_; } @@ -86,10 +86,8 @@ class HloCloneContext { private: HloModule* module_; string suffix_; - tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*> - instructions_; - tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*> - computations_; + absl::flat_hash_map<const HloInstruction*, HloInstruction*> instructions_; + absl::flat_hash_map<const HloComputation*, HloComputation*> computations_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 601a008d9f..c2041c4667 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -24,6 +24,8 @@ limitations under the License. #include <sstream> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } -namespace { - -// Returns the new name for a fusion parameter when we change its number. -// -// Fusion parameters are named foo.param_1, bar.param_2, etc. We are -// renumbering the parameters, so replace the final number in the name with -// the updated value. -string RenameFusionParameter(const string& original_name, int64 new_param_no) { - const string param_underscore = ".param_"; - size_t index = original_name.rfind(param_underscore); - if (index == string::npos) { - return original_name; - } - string after_param = original_name.substr(index + param_underscore.size()); - int64 numeric_suffix; - if (absl::SimpleAtoi(after_param, &numeric_suffix)) { - return StrCat(original_name.substr(0, index + param_underscore.size()), - new_param_no); - } - return original_name; -} - -} // namespace - Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + param_no, param_instruction->shape(), StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() { if (removed > 0) { const int64 param_no = i - removed; - string param_name = - RenameFusionParameter(param_instruction->name(), param_no); - HloInstruction* new_instr = - AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_name)); + HloInstruction* new_instr = AddInstructionInternal( + HloInstruction::CreateParameter(param_no, param_instruction->shape(), + StrCat("param_", param_no))); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); param_instructions_[param_no] = new_instr; TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); @@ -272,10 +245,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { << "instruction " << instruction->name() << " has control successors and cannot be removed"; - TF_RET_CHECK(instruction_iterators_.count(instruction) != 0); - auto inst_it = instruction_iterators_.at(instruction); - (*inst_it)->set_parent(nullptr); - instructions_.erase(inst_it); + auto inst_it = instruction_iterators_.find(instruction); + TF_RET_CHECK(inst_it != instruction_iterators_.end()); + (*inst_it->second)->set_parent(nullptr); + instructions_.erase(inst_it->second); + instruction_iterators_.erase(inst_it); return Status::OK(); } @@ -304,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, namespace { // Helper which builds a post order of the HLO call graph. -void ComputeComputationPostOrder( - HloComputation* computation, - tensorflow::gtl::FlatSet<HloComputation*>* visited, - std::vector<HloComputation*>* post_order) { +void ComputeComputationPostOrder(HloComputation* computation, + absl::flat_hash_set<HloComputation*>* visited, + std::vector<HloComputation*>* post_order) { if (visited->insert(computation).second) { for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : @@ -324,7 +297,7 @@ void ComputeComputationPostOrder( void HloComputation::ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const { + absl::flat_hash_map<HloInstruction*, VisitState>* visited) const { std::vector<HloInstruction*> dfs_stack; dfs_stack.push_back(root); while (!dfs_stack.empty()) { @@ -421,7 +394,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { std::vector<HloInstruction*> post_order; post_order.reserve(instruction_count()); std::vector<HloInstruction*> trace_instructions; - tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited; + absl::flat_hash_map<HloInstruction*, VisitState> visited; for (auto& instruction : instructions_) { if (instruction->opcode() == HloOpcode::kTrace) { // Trace instructions aren't handled by the DFS visitor. Add trace @@ -442,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const { std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList() const { - tensorflow::gtl::FlatSet<HloComputation*> visited; + absl::flat_hash_set<HloComputation*> visited; std::vector<HloComputation*> post_order; // To avoid special handling of this computation, cast away const of @@ -532,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloComputation>> HloComputation::CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) { - tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map; - tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id; + const absl::flat_hash_map<int64, HloComputation*>& computation_map) { + absl::flat_hash_map<int64, HloInstruction*> instruction_map; + absl::flat_hash_map<HloInstruction*, int64> to_proto_id; std::vector<std::unique_ptr<HloInstruction>> instructions; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { @@ -562,6 +535,28 @@ HloComputation::CreateFromProto( return to_proto_id[a.get()] < to_proto_id[b.get()]; }); + TF_RETURN_IF_ERROR([&]() -> Status { + std::vector<bool> parameters_seen(parameter_count); + int parameters_seen_count = 0; + for (auto& instruction : instructions) { + if (instruction->opcode() == HloOpcode::kParameter) { + int64 param_no = instruction->parameter_number(); + TF_RET_CHECK(param_no >= 0 && param_no < parameter_count) + << "Invalid parameter number. Expected [0, " << parameter_count + << "), got " << param_no; + TF_RET_CHECK(!parameters_seen[param_no]) + << "Parameter number " << param_no + << " already allocated in this computation"; + parameters_seen[param_no] = true; + parameters_seen_count++; + } + } + TF_RET_CHECK(parameters_seen_count == parameter_count) + << "Not all parameters in range [0, " << parameter_count + << ") were referenced"; + return Status::OK(); + }()); + auto computation = absl::WrapUnique( new HloComputation(proto.name(), parameter_count, &instructions, root, /*fusion_instruction=*/nullptr)); @@ -916,13 +911,14 @@ std::unique_ptr<HloComputation> HloComputation::Clone( return CloneWithReplacements( /*replacements=*/std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>(), - context, suffix); + /*extras=*/{}, context, suffix); } std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloCloneContext* context, const string& suffix) { + absl::Span<HloInstruction*> extras, HloCloneContext* context, + const string& suffix) { std::unique_ptr<HloCloneContext> context_ptr; if (context == nullptr) { context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix); @@ -944,6 +940,9 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements( VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; std::vector<HloInstruction*> postorder; + for (HloInstruction* instr : extras) { + postorder.push_back(instr); + } for (HloInstruction* instr : MakeInstructionPostOrder()) { if (HloInstruction* replacement = replace(instr)) { postorder.push_back(replacement); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index a880e9ab30..d87ab4bda1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -25,6 +25,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" @@ -40,8 +42,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -188,7 +188,7 @@ class HloComputation { // calls. static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( const HloComputationProto& proto, - const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map); + const absl::flat_hash_map<int64, HloComputation*>& computation_map); // Gets the instructions in this computation. // @@ -227,7 +227,7 @@ class HloComputation { void UpdateReachabilityThroughInstruction( const HloInstruction* instruction, HloReachabilityMap* reachability_map); - int64 instruction_count() const { return instructions_.size(); } + int64 instruction_count() const { return instruction_iterators_.size(); } // Creates and returns a list of the embedded computations called by this // computation. This includes all embedded computations called directly or @@ -333,10 +333,13 @@ class HloComputation { // // If replacements maps a key to nullptr, we remove that instruction from the // new computation. + // If additional instructions are used by instructions in replacement map, + // they must be passed in post-order in the extras span. std::unique_ptr<HloComputation> CloneWithReplacements( std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> replacements, - HloCloneContext* context = nullptr, const string& suffix = "clone"); + absl::Span<HloInstruction*> extras, HloCloneContext* context = nullptr, + const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of @@ -411,14 +414,14 @@ class HloComputation { // cross-replica-sum the union of the dependencies for all participating // instructions. using ChannelDependencyMap = - tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>; + absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>; ChannelDependencyMap ComputeChannelDependencies() const; enum VisitState { kVisiting, kVisited }; void ComputeInstructionPostOrder( const HloComputation::ChannelDependencyMap& channel_dependency_map, std::vector<HloInstruction*>* post_order, HloInstruction* root, - tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const; + absl::flat_hash_map<HloInstruction*, VisitState>* visited) const; string name_; int64 unique_id_; @@ -436,7 +439,7 @@ class HloComputation { // instruction pointer to location in the list for fast lookup. using InstructionList = std::list<std::unique_ptr<HloInstruction>>; InstructionList instructions_; - std::unordered_map<const HloInstruction*, InstructionList::iterator> + absl::flat_hash_map<const HloInstruction*, InstructionList::iterator> instruction_iterators_; std::vector<HloInstruction*> param_instructions_; diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index f837816cea..4f898ce61c 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -76,6 +76,26 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { continue; } + // Don't constant fold unless it's a net positive or the output is small. + if (ShapeUtil::IsArray(instruction->shape())) { + int64 elements_in_removed_operands = 0; + for (HloInstruction* operand : instruction->operands()) { + if (operand->user_count() == 1 && + ShapeUtil::IsArray(operand->shape())) { + elements_in_removed_operands += + ShapeUtil::ElementsIn(operand->shape()); + } + } + int64 elements_in_constant = + ShapeUtil::ElementsIn(instruction->shape()); + + static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000; + if (elements_in_constant > elements_in_removed_operands && + elements_in_constant > kMaximumConstantSizeElements) { + continue; + } + } + Literal result; // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. @@ -84,6 +104,7 @@ StatusOr<bool> HloConstantFolding::Run(HloModule* module) { << instruction->ToString(); continue; } + VLOG(4) << "Constant folded: " << instruction->ToString(); TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction( instruction, HloInstruction::CreateConstant(std::move(result)))); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.h b/tensorflow/compiler/xla/service/hlo_constant_folding.h index 4557983a9c..4a624cc7b8 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.h +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.h @@ -23,7 +23,7 @@ namespace xla { // A pass which performs constant folding in order to avoid unnecessary // computation on constants. -class HloConstantFolding : public HloPassInterface { +class HloConstantFolding : public HloModulePass { public: absl::string_view name() const override { return "constant_folding"; } diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 3e0def5d26..e45f905f71 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -242,5 +242,25 @@ TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) { EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reduce()); } +const char* const kConstantFoldLargePad = R"( + HloModule ConstantFoldLargePad + + ENTRY r { + a = f32[1,1,1] constant(f32[1,1,1]{{{7}}}) + b = f32[] constant(42) + ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63 + })"; + +TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) { + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kConstantFoldLargePad)); + HloConstantFolding const_folder; + TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get())); + EXPECT_FALSE(result); + + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Pad(op::Constant(), op::Constant())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index b76c50bb5b..b2005d3c21 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/util.h" @@ -201,6 +202,44 @@ StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands, HloInstruction::CreateMap(map_shape, operands, map_computation)); } +StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module) { + DCHECK_NE(nullptr, module); + std::vector<int64> all_dims(ShapeUtil::Rank(operand->shape())); + std::iota(all_dims.begin(), all_dims.end(), 0); + + auto scalar_shape = ShapeUtil::MakeShape(operand->shape().element_type(), {}); + HloComputation* reduce_computation; + { + HloComputation::Builder b(operand->name() + ".reduce_sub_computation"); + auto lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + b.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, binary_opcode, lhs, rhs)); + reduce_computation = module->AddEmbeddedComputation(b.Build()); + } + + return operand->parent()->AddInstruction(HloInstruction::CreateReduce( + scalar_shape, operand, init_value, all_dims, reduce_computation)); +} + +StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false) { + HloComputation* computation = pred->parent(); + DCHECK_EQ(computation, on_true->parent()); + DCHECK_EQ(computation, on_false->parent()); + TF_ASSIGN_OR_RETURN(Shape select_shape, + ShapeInference::InferTernaryOpShape( + HloOpcode::kSelect, pred, on_true, on_false)); + return computation->AddInstruction(HloInstruction::CreateTernary( + select_shape, HloOpcode::kSelect, pred, on_true, on_false)); +} + StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index b22058abb4..8e5ddbbd50 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_ +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -107,6 +108,35 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands, HloComputation* map_computation); +// Creates a Reduce HLO instruction and adds it to the computation containing +// the operand. This will create the sub-computation needed for the reduction in +// the given module. binary_opcode should represent a binary operation. +StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand, + HloInstruction* init_value, + HloOpcode binary_opcode, + HloModule* module); + +// Creates a Select HLO instruction and adds it to the computation containing +// the predicate. The on_true and on_false instructions must also be contained +// in the same computation. +StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred, + HloInstruction* on_true, + HloInstruction* on_false); + +// Creates an R1 Constant HLO instruction of the given PrimitiveType with the +// given values and adds it to the given computation. +template <typename NativeT> +StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation, + PrimitiveType type, + absl::Span<const NativeT> values) { + Literal literal = LiteralUtil::CreateR1<NativeT>(values); + if (literal.shape().element_type() != type) { + TF_ASSIGN_OR_RETURN(literal, literal.Convert(type)); + } + return computation->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); +} + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index b59c9ba3ed..e602107cbe 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { @@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) { // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. - tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash), - decltype(cse_equal)> + absl::flat_hash_set<HloInstruction*, decltype(&CseHash), + decltype(cse_equal)> representatives(/*N=*/computation->instruction_count() + 1, &CseHash, cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h index a28c03599a..e4857fd3fd 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.h +++ b/tensorflow/compiler/xla/service/hlo_cse.h @@ -25,7 +25,7 @@ namespace xla { // and identical instructions with the same operands are commoned. The pass // iterates over the instructions in topological order which enables the pass to // find arbitrarily large common expressions. -class HloCSE : public HloPassInterface { +class HloCSE : public HloModulePass { public: // If is_layout_sensitive is true, then the simplifier preserves layout during // transformation. Otherwise, layout is ignored. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 6a63681996..c22adcdd8d 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis( bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( const HloInstruction* inst) { - tensorflow::gtl::FlatSet<const HloInstruction*> visited; + absl::flat_hash_set<const HloInstruction*> visited; absl::InlinedVector<const HloInstruction*, 4> stack; stack.push_back(inst); while (!stack.empty()) { @@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { void HloDataflowAnalysis::DeleteMarkedValues() { #ifndef NDEBUG // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); + absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); for (const auto& pair : value_sets_) { const HloInstruction* instruction = pair.first; const InstructionValueSet& instruction_value_set = pair.second; @@ -355,23 +356,6 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } -bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { - CHECK_EQ(slice->opcode(), HloOpcode::kSlice); - if (!slice->IsInPlaceSlice()) { - return false; - } - // If this slice is lowered to an in-place version, then it forwards the - // operand value to the output. - const InstructionValueSet& operand_set = - GetInstructionValueSet(slice->operand(0)); - InstructionValueSet& slice_set = GetInstructionValueSet(slice); - if (operand_set != slice_set) { - slice_set = operand_set; - return true; - } - return false; -} - bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -640,8 +624,6 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); - case HloOpcode::kSlice: - return UpdateSliceValueSet(instruction); case HloOpcode::kDomain: return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: @@ -673,7 +655,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue<HloInstruction*> worklist; - tensorflow::gtl::FlatSet<HloInstruction*> workset; + absl::flat_hash_set<HloInstruction*> workset; auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { if (workset.insert(instruction).second) { worklist.push(instruction); @@ -813,11 +795,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; - case HloOpcode::kSlice: - if (!instruction->IsInPlaceSlice()) { - define_all_values(); - } - break; case HloOpcode::kWhile: case HloOpcode::kCall: case HloOpcode::kConditional: diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index e62c1c2ac8..abac398c04 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -182,7 +182,6 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); - bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index 1fe69b1395..4012042672 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -33,7 +33,7 @@ namespace xla { // // This pass does not remove dead parameter instructions, as parameter // instructions cannot be deleted. -class HloDCE : public HloPassInterface { +class HloDCE : public HloModulePass { public: ~HloDCE() override {} absl::string_view name() const override { return "dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index d36631fc2f..c0bf1b9e16 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -30,7 +30,7 @@ namespace xla { // used to break an HLO graph edge connecting two instructions with different // sharding. If a set of connected instructions have all the same sharding, no // kDomain instruction will be placed. -class HloDomainIsolator : public HloPassInterface { +class HloDomainIsolator : public HloModulePass { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 113fd18eae..c6d02f9f67 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -17,6 +17,8 @@ limitations under the License. #include <algorithm> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -40,18 +42,19 @@ namespace xla { return std::move(domain_map); } -bool HloDomainMap::InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const { +bool HloDomainMap::InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const { int64 domain_id1 = GetDomainId(instruction1); int64 domain_id2 = GetDomainId(instruction2); return domain_id1 >= 0 && domain_id1 == domain_id2; } -int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } -int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { +int64 HloDomainMap::GetDomainMetadataId( + const HloInstruction* instruction) const { return FindOrDie(domain_metadata_id_, instruction); } @@ -106,8 +109,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() { auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { return a->Matches(*b); }; - tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash), - decltype(equal)> + absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash), + decltype(equal)> domain_metadata(1024, hash, equal); for (auto& domain : instruction_domains_) { @@ -198,7 +201,8 @@ StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain( return std::move(domain); } -bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { +bool HloDomainMap::IsDomainInstruction( + const HloInstruction* instruction) const { if (instruction->opcode() != HloOpcode::kDomain) { return false; } @@ -216,7 +220,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { /* static */ std::vector<HloInstruction*> HloDomainMap::MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const absl::flat_hash_set<HloInstruction*>& instruction_set, const InstructionOrderMap& instructions_order) { std::vector<HloInstruction*> instructions; instructions.reserve(instruction_set.size()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 56b557d7ce..bce7d1aa7c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -19,14 +19,14 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -58,27 +58,26 @@ class HloDomainMap { } // Checks whether two instructions are within the same domain. - bool InSameDomain(HloInstruction* instruction1, - HloInstruction* instruction2) const; + bool InSameDomain(const HloInstruction* instruction1, + const HloInstruction* instruction2) const; // Checks whether instruction is a kDomain instruction of the kind we are // currently processing. - bool IsDomainInstruction(HloInstruction* instruction) const; + bool IsDomainInstruction(const HloInstruction* instruction) const; // Retrieves the domain identifier of the instruction, or -1 in case // instruction is not found within any domain. - int64 GetDomainId(HloInstruction* instruction) const; + int64 GetDomainId(const HloInstruction* instruction) const; // Returns the unique id of the domain metadata for the domain the given // instruction belongs to. The given instruction must not be a kDomain // instruction since each domain instruction is associated with 2 domains. - int64 GetDomainMetadataId(HloInstruction* instruction) const; + int64 GetDomainMetadataId(const HloInstruction* instruction) const; private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. - using InstructionOrderMap = - tensorflow::gtl::FlatMap<const HloInstruction*, int64>; + using InstructionOrderMap = absl::flat_hash_map<const HloInstruction*, int64>; HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} @@ -111,7 +110,7 @@ class HloDomainMap { // Out of an instruction set, returns a vector of all the ones which are not // a kDomain kind. static std::vector<HloInstruction*> MakeNonDomainInstructions( - const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set, + const absl::flat_hash_set<HloInstruction*>& instruction_set, const InstructionOrderMap& instructions_order); // Populates domain_metadata_id_ that maps each HloInstruction to the unique @@ -120,8 +119,8 @@ class HloDomainMap { string domain_kind_; std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_; - tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_; - tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_; + absl::flat_hash_map<const HloInstruction*, int64> instruction_to_domain_; + absl::flat_hash_map<const HloInstruction*, int64> domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 302807f816..d3c83c15ae 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -20,11 +20,11 @@ limitations under the License. #include <string> #include <vector> +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -42,7 +42,7 @@ class DomainMetadata { // operand/user pathways, without crossing a kDomain instruction of a given // kind. The reach_set can contain kDomain instructions of other kinds, if // two domains of different kind intersect each other. - tensorflow::gtl::FlatSet<HloInstruction*> reach_set; + absl::flat_hash_set<HloInstruction*> reach_set; // The same instructions in reach_set, but purged from kDomain instructions // and ordered according to their computation graph post-order, i.e. @@ -55,8 +55,8 @@ class DomainMetadata { // whose dataflow enters the reach set (domain), while the exit_domains // contains the set of kDomain instructions whose dataflow exit the reach // set. - tensorflow::gtl::FlatSet<HloInstruction*> enter_domains; - tensorflow::gtl::FlatSet<HloInstruction*> exit_domains; + absl::flat_hash_set<HloInstruction*> enter_domains; + absl::flat_hash_set<HloInstruction*> exit_domains; }; virtual ~DomainMetadata() = default; diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h index 97bc8ef604..0fc30fb86c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_remover.h +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -26,7 +26,7 @@ namespace xla { // Removes all the kDomain instructions of a given kind from the input module, // and calls the normalizer to propagate the properties on the possibly new born // instructions. -class HloDomainRemover : public HloPassInterface { +class HloDomainRemover : public HloModulePass { public: // Creates a new HloDomainRemover object tasked at removing all the kDomain // instructions of a given kind. diff --git a/tensorflow/compiler/xla/service/hlo_domain_verifier.h b/tensorflow/compiler/xla/service/hlo_domain_verifier.h index 81d6d69a8c..bea5cba38d 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_domain_verifier.h @@ -29,7 +29,7 @@ namespace xla { // Verifies that the domain instructions are consistent, and the each domain is // surrounded by the same metadata. -class HloDomainVerifier : public HloPassInterface { +class HloDomainVerifier : public HloModulePass { public: HloDomainVerifier(std::vector<string> kinds) : kinds_(std::move(kinds)) {} diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h index 44ded2c2fa..4d2a942925 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.h +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -25,7 +25,7 @@ namespace xla { // inserting Convert ops. This allows a backend to support an element type while // only actually implementing the Convert op for that element type. This is // generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloPassInterface { +class HloElementTypeConverter : public HloModulePass { public: // eliminate_type is the type to eliminate as the input or output of ops, // using Convert ops to replace it with replace_with_type. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 06b6d5b559..eec8d242fa 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -496,6 +496,61 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { return Status::OK(); } +Status HloEvaluator::HandleReal(HloInstruction* real) { + auto operand = real->operand(0); + switch (operand->shape().element_type()) { + case BF16: { + auto result_or = ElementWiseUnaryOpImpl<bfloat16, bfloat16>( + real, [](bfloat16 elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case C64: { + auto result_or = ElementWiseUnaryOpImpl<float, complex64>( + real, [](complex64 elem_operand) { return std::real(elem_operand); }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F16: { + auto result_or = ElementWiseUnaryOpImpl<Eigen::half, Eigen::half>( + real, [](Eigen::half elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F32: { + auto result_or = ElementWiseUnaryOpImpl<float, float>( + real, [](float elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + case F64: { + auto result_or = ElementWiseUnaryOpImpl<double, double>( + real, [](double elem_operand) { return elem_operand; }, + GetEvaluatedLiteralFor(operand)); + TF_ASSIGN_OR_RETURN(evaluated_[real], std::move(result_or)); + break; + } + default: + LOG(FATAL) << "HandleReal: unknown/unhandled primitive type: " + << PrimitiveType_Name(operand->shape().element_type()); + } + + return Status::OK(); +} + +Status HloEvaluator::HandleImag(HloInstruction* imag) { + auto result_or = ElementWiseUnaryOpImpl<float, complex64>( + imag, [](complex64 elem_operand) { return std::imag(elem_operand); }, + GetEvaluatedLiteralFor(imag->operand(0))); + + TF_ASSIGN_OR_RETURN(evaluated_[imag], std::move(result_or)); + return Status::OK(); +} + Status HloEvaluator::HandleCompare(HloInstruction* compare) { HloOpcode opcode = compare->opcode(); auto lhs = compare->operand(0); @@ -1173,80 +1228,85 @@ StatusOr<Literal> EvaluateSortInternal(HloInstruction* sort, TF_RET_CHECK( ShapeUtil::SameDimensions(keys_literal.shape(), values_literal.shape())) << "Sort keys and values must have the same dimensions"; - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for rank-1 and rank-2 shapes, rank is: " - << rank; TF_RET_CHECK(sort->operand_count() == 2) << "Expected key-value sort"; - // We need to sort and array of keys and an array of values, where the + // We need to sort an array of keys and an array of values, where the // sorted order of the values is determined by the keys. The simplest(?) // way to do this is to go to an array-of-pairs representation, sort the // array using the keys, and then go back to pair-of-arrays. VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); VLOG(3) << "HandleSort values_literal: " << values_literal.ToString(); - auto sort_r1 = [](const Literal& keys_literal, - const Literal& values_literal) { - const auto& keys_data = keys_literal.data<KeyType>(); - const auto& values_data = values_literal.data<ValueType>(); - - using kv_pair = std::pair<KeyType, ValueType>; - std::vector<kv_pair> key_value_vector; - CHECK_EQ(keys_data.size(), values_data.size()); - key_value_vector.reserve(keys_data.size()); - for (int i = 0; i < keys_data.size(); ++i) { - key_value_vector.push_back(std::make_pair(keys_data[i], values_data[i])); - } - std::sort(key_value_vector.begin(), key_value_vector.end(), - [](const kv_pair& a, const kv_pair& b) { - return SafeLess<KeyType>(a.first, b.first); - }); - std::vector<KeyType> result_keys; - std::vector<ValueType> result_values; - for (const auto& key_value : key_value_vector) { - result_keys.push_back(key_value.first); - result_values.push_back(key_value.second); - } - Literal result_keys_literal(keys_literal.shape()); - result_keys_literal.PopulateR1(absl::Span<const KeyType>(result_keys)); - Literal result_values_literal(values_literal.shape()); - result_values_literal.PopulateR1( - absl::Span<const ValueType>(result_values)); - return std::make_pair(std::move(result_keys_literal), - std::move(result_values_literal)); - }; - - Literal result_tuple; - if (rank == 1) { - auto result_pair = sort_r1(keys_literal, values_literal); - result_tuple = - LiteralUtil::MakeTuple({&result_pair.first, &result_pair.second}); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal keys_result_literal(keys_literal.shape()); - Literal values_result_literal(values_literal.shape()); - int64 r1_length = keys_literal.shape().dimensions(1); - for (int64 row = 0; row < keys_literal.shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto keys_r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - TF_ASSIGN_OR_RETURN(auto values_r1_slice, - values_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result_pair = sort_r1(keys_r1_slice, values_r1_slice); - TF_ASSIGN_OR_RETURN(auto sorted_keys, - r1_result_pair.first.Reshape({1, r1_length})); - TF_ASSIGN_OR_RETURN(auto sorted_values, - r1_result_pair.second.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( - sorted_keys, {0, 0}, {row, 0}, {1, r1_length})); - TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( - sorted_values, {0, 0}, {row, 0}, {1, r1_length})); - } - result_tuple = - LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); + if (rank == 0) { + // Nothing to sort. + return LiteralUtil::MakeTuple({&keys_literal, &values_literal}); } + Literal keys_result_literal(keys_literal.shape()); + Literal values_result_literal(values_literal.shape()); + std::vector<int64> zero_base(rank, 0); + std::vector<int64> increment(rank, 1); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys_literal.shape().dimensions(sort_dim); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys_literal.shape(), zero_base, + AsInt64Slice(keys_literal.shape().dimensions()), increment, + [&](absl::Span<const int64> indices) -> StatusOr<bool> { + // Extract a slice from the keys and values literals that correspond to + // exactly the row in dimension 'sort_dim'. + std::vector<int64> limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto keys_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& keys_data = keys_to_sort.data<KeyType>(); + TF_ASSIGN_OR_RETURN(auto values_to_sort, + values_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& values_data = values_to_sort.data<ValueType>(); + using kv_pair = std::pair<KeyType, ValueType>; + std::vector<kv_pair> key_value_vector; + key_value_vector.reserve(keys_data.size()); + for (int i = 0; i < keys_data.size(); ++i) { + key_value_vector.push_back( + std::make_pair(keys_data[i], values_data[i])); + } + std::sort(key_value_vector.begin(), key_value_vector.end(), + [](const kv_pair& a, const kv_pair& b) { + return SafeLess<KeyType>(a.first, b.first); + }); + std::vector<KeyType> result_keys; + std::vector<ValueType> result_values; + for (const auto& key_value : key_value_vector) { + result_keys.push_back(key_value.first); + result_values.push_back(key_value.second); + } + Literal sorted_keys(ShapeUtil::MakeShape( + keys_literal.shape().element_type(), {sort_dim_elements})); + sorted_keys.PopulateR1(absl::Span<const KeyType>(result_keys)); + Literal sorted_values(ShapeUtil::MakeShape( + values_literal.shape().element_type(), {sort_dim_elements})); + sorted_values.PopulateR1(absl::Span<const ValueType>(result_values)); + std::vector<int64> slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + std::vector<int64> start_indices(rank, 0); + TF_ASSIGN_OR_RETURN(auto sorted_keys_reshaped, + sorted_keys.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(keys_result_literal.CopySliceFrom( + sorted_keys_reshaped, start_indices, indices, slice_dimensions)); + TF_ASSIGN_OR_RETURN(auto sorted_values_reshaped, + sorted_values.Reshape(slice_dimensions)); + TF_RETURN_IF_ERROR(values_result_literal.CopySliceFrom( + sorted_values_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + + Literal result_tuple; + result_tuple = + LiteralUtil::MakeTuple({&keys_result_literal, &values_result_literal}); VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString(); return std::move(result_tuple); } @@ -1292,15 +1352,6 @@ StatusOr<Literal> EvaluateSort(HloInstruction* sort, } // namespace Status HloEvaluator::HandleSort(HloInstruction* sort) { - const int64 sort_dim = sort->dimensions(0); - const int64 rank = ShapeUtil::Rank(sort->operand(0)->shape()); - if (sort_dim != rank - 1) { - return Unimplemented( - "Trying to sort along dimension %d, which is not the last " - "dimension", - sort_dim); - } - if (!ShapeUtil::IsTuple(sort->shape())) { return DefaultAction(sort); } else { @@ -1327,7 +1378,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) { "unsupported"); } } - return reduce->Visit(typed_visitors_.at(first_element_type).get()); + return reduce->Visit(typed_visitors_[first_element_type].get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 21e676d671..07f8d0aad4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -18,6 +18,7 @@ limitations under the License. #include <memory> +#include "absl/container/node_hash_map.h" #include "absl/memory/memory.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -134,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Wraps around instruction handling to infer types before dispatching to // the corresponding typed Visitor. Status DefaultAction(HloInstruction* hlo) override { - return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get()); + return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get()); } Status Preprocess(HloInstruction* hlo) override; @@ -184,6 +184,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSort(HloInstruction* sort) override; + Status HandleReal(HloInstruction* real) override; + + Status HandleImag(HloInstruction* imag) override; + Status HandleReduce(HloInstruction* reduce) override; // Returns the already-evaluated literal result for the instruction. @@ -206,8 +210,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // post-orderring. // Must be cleared for each evaluation. // Storing Literal in place require the container to have pointer stability so - // we cannot use FlatMap any more. - std::unordered_map<const HloInstruction*, Literal> evaluated_; + // we cannot use flat_hash_map any more. + absl::node_hash_map<const HloInstruction*, Literal> evaluated_; private: template <typename ReturnT, typename NativeT> @@ -237,12 +241,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { } // Map from a primitive type to its associated (templated) DfsHloVisitor. - // Note: the hash function here is only needed because current gcc std::hash - // does not specialize for enum types. This should however be fixed in the - // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5 - tensorflow::gtl::FlatMap<PrimitiveType, std::unique_ptr<DfsHloVisitor>, - std::hash<int>> - typed_visitors_; + std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE]; // Caches pointers to input literals, assuming they are in post-order. // Literals are not owned by this class, and they must outlive the lifetime of diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 01e88566a5..608a42bb60 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -66,6 +66,20 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>, .ConsumeValueOrDie(); } + // Evaluate function that takes in a local module instead of using module_ + // that is in HloVerifiedTestBase. Once module_ in HloVerifiedTestBase is + // removed, this should be the default Evaluate function. + Literal EvaluateWithModule( + HloModule* module, absl::Span<const Literal* const> arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(module).ValueOrDie(); + } + return evaluator_->Evaluate(*module->entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } + std::unique_ptr<HloEvaluator> evaluator_; void TestUnaryOp(HloOpcode opcode, Literal expected, Literal input, @@ -1449,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } +TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) { + HloComputation::Builder b(TestName()); + + // arg: + // f32[3,3] { + // { 1, 2, 3 }, + // { 5, 6, 7 }, + // { 9, 10, 11 }, + // } + auto arg_array = absl::make_unique<Array2D<float>>(3, 3); + arg_array->FillUnique(1.0f); + auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array); + + HloInstruction* arg_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); + + auto init_value = b.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f))); + + HloComputation::Builder max_computation("max"); + Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + auto param_lhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "lhs")); + auto param_rhs = max_computation.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "rhs")); + max_computation.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(2); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1}); + b.AddInstruction(HloInstruction::CreateReduceWindow( + shape, arg_instruction, init_value, window, max_func)); + + module().AddEntryComputation(b.Build()); + + Literal result = Evaluate(); + + auto expected = LiteralUtil::CreateR2<float>({{11}}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); @@ -2530,6 +2596,114 @@ ENTRY main { expected, Evaluate({&operand, &scatter_indices, &updates}))); } +TEST_P(HloEvaluatorTest, EvaluateScatter_NegativeIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_NegativeIndices + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the negative indices. + Literal scatter_indices = LiteralUtil::CreateR1<int32>({-1, 2}); + Literal updates = LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {77, 88, 99}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobIndices) { + const string hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + // No updates should happen for the OOB indices. + Literal scatter_indices = LiteralUtil::CreateR2<int32>( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + Literal updates = LiteralUtil::CreateR3<int32>( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + EXPECT_TRUE(LiteralTestUtil::Equal( + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 30, 60}, {7, 20, 9}}), + EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + +TEST_P(HloEvaluatorTest, EvaluateScatter_OobUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(hlo_text)); + Literal operand = + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}}); + Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}}); + // Given the update window size of 2,2 and the index of 0,2, the update window + // will be OOB. So, nothing should be updated. + Literal expected = operand.Clone(); + EXPECT_TRUE(LiteralTestUtil::Equal( + expected, EvaluateWithModule(module.get(), + {&operand, &scatter_indices, &updates}))); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise comparison with 2 bfloat16 operands. TEST_P(HloEvaluatorTest, DoesCompareBF16) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 8fb17a0033..a450dc6ff5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include <cmath> + #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" @@ -41,7 +43,9 @@ template <typename T> using is_complex64_t = std::is_same<T, complex64>; // It's UB to use std::sort with std::less<float>, because of NaNs. Define -// "safe" less functions which are actually strict weak orders. +// "safe" less functions which are actually strict weak orders. -NaN and NaN +// should appear at the beginning and end of the ordering, and -0.0 should +// appear before 0.0. template < typename NativeT, typename std::enable_if<std::is_integral<NativeT>::value>::type* = nullptr> @@ -49,26 +53,33 @@ bool SafeLess(const NativeT& a, const NativeT& b) { return a < b; } -template <typename NativeT, - typename std::enable_if< - std::is_floating_point<NativeT>::value || - std::is_same<NativeT, bfloat16>::value>::type* = nullptr> +template <typename NativeT, typename std::enable_if<std::is_floating_point< + NativeT>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (std::isnan(b)) { - return !std::isnan(a); - } else { - return a < b; + bool lhs_is_negative = std::signbit(a); + bool rhs_is_negative = std::signbit(b); + // If the signs are different, we can just compare the signs. + if (lhs_is_negative != rhs_is_negative) { + return lhs_is_negative && !rhs_is_negative; + } + bool lhs_nan = std::isnan(a); + bool rhs_nan = std::isnan(b); + // Exactly one number is nan? + if (lhs_nan != rhs_nan) { + if (lhs_nan) { + return lhs_is_negative; + } + return !rhs_is_negative; } + return a < b; } -template <typename NativeT, typename std::enable_if<std::is_same< - NativeT, Eigen::half>::value>::type* = nullptr> +template <typename NativeT, + typename std::enable_if< + std::is_same<NativeT, bfloat16>::value || + std::is_same<NativeT, Eigen::half>::value>::type* = nullptr> bool SafeLess(const NativeT& a, const NativeT& b) { - if (Eigen::half_impl::isnan(b)) { - return !Eigen::half_impl::isnan(a); - } else { - return a < b; - } + return SafeLess(static_cast<float>(a), static_cast<float>(b)); } // Templated DfsHloVisitor for use by HloEvaluator. @@ -78,6 +89,8 @@ bool SafeLess(const NativeT& a, const NativeT& b) { // to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. +// - HandleImag and HandleReal: where the resulting literal type is always float +// and the operand is always complex, or real in the case of HandleReal. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. // @@ -318,14 +331,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleFloor<ReturnT>(floor); } - Status HandleImag(HloInstruction* imag) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[imag], - ElementWiseUnaryOp(imag, [](ElementwiseT elem_operand) { - return std::imag(elem_operand); - })); - return Status::OK(); - } - Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { @@ -673,14 +678,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleReal(HloInstruction* real) override { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[real], - ElementWiseUnaryOp(real, [](ElementwiseT elem_operand) { - return std::real(elem_operand); - })); - return Status::OK(); - } - template <typename NativeT, typename std::enable_if<std::is_floating_point< NativeT>::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { @@ -1527,47 +1524,55 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { !std::is_same<NativeT, bool>::value>::type* = nullptr> Status HandleSort(HloInstruction* sort) { auto keys = sort->operand(0); - auto rank = ShapeUtil::Rank(keys->shape()); - TF_RET_CHECK(rank > 0 && rank <= 2) - << "Sort is only supported for R1 and R2 shapes"; TF_RET_CHECK(sort->operand_count() == 1) << "Typed visitor does not support key-value sort"; const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys); - - auto sort_r1 = [this](const Literal& keys_literal) { - VLOG(3) << "HandleSort keys_literal: " << keys_literal.ToString(); - const auto& keys_data = keys_literal.data<ReturnT>(); - - std::vector<ReturnT> result_data(keys_data.begin(), keys_data.end()); - std::sort(result_data.begin(), result_data.end(), - [](const ReturnT& a, const ReturnT& b) { - return SafeLess<ReturnT>(a, b); - }); - Literal result_literal(keys_literal.shape()); - result_literal.PopulateR1(absl::Span<const ReturnT>(result_data)); - VLOG(3) << "HandleSort result_literal: " << result_literal.ToString(); - return result_literal; - }; - - if (rank == 1) { - parent_->evaluated_[sort] = std::move(sort_r1(keys_literal)); - } else { - // For R2 sort, the desired semantics are to sort each matrix row - // independently. - Literal result_literal(keys_literal.shape()); - int64 r1_length = keys->shape().dimensions(1); - for (int64 row = 0; row < keys->shape().dimensions(0); ++row) { - TF_ASSIGN_OR_RETURN(auto r1_slice, - keys_literal.Slice({row, 0}, {row + 1, r1_length}) - .Reshape({r1_length})); - auto r1_result = sort_r1(r1_slice); - TF_ASSIGN_OR_RETURN(r1_result, r1_result.Reshape({1, r1_length})); - TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( - r1_result, {0, 0}, {row, 0}, {1, r1_length})); - } - parent_->evaluated_[sort] = std::move(result_literal); + int64 sort_dim = sort->dimensions(0); + int64 sort_dim_elements = keys->shape().dimensions(sort_dim); + int64 rank = ShapeUtil::Rank(keys->shape()); + if (rank == 0) { + // Nothing to sort. + parent_->evaluated_[sort] = keys_literal.Clone(); + return Status::OK(); } + Literal result_literal(keys_literal.shape()); + std::vector<int64> zero_base(rank, 0); + std::vector<int64> increment(rank, 1); + increment[sort_dim] = sort_dim_elements; + // Iterate through each dimension except 'sort_dim'. + TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus( + keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()), + increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> { + // Extract a slice from the literal that corresponds to exactly the + // row in dimension 'sort_dim'. + std::vector<int64> limit_indices(indices.begin(), indices.end()); + std::for_each(limit_indices.begin(), limit_indices.end(), + [](int64& index) { ++index; }); + limit_indices[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto row_to_sort, + keys_literal.Slice(indices, limit_indices) + .Reshape({sort_dim_elements})); + const auto& row_data = row_to_sort.data<NativeT>(); + + std::vector<NativeT> result_data(row_data.begin(), row_data.end()); + std::sort(result_data.begin(), result_data.end(), + [](const NativeT& a, const NativeT& b) { + return SafeLess<NativeT>(a, b); + }); + Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(), + {sort_dim_elements})); + sorted_row.PopulateR1(absl::Span<const NativeT>(result_data)); + std::vector<int64> slice_dimensions(rank, 1); + slice_dimensions[sort_dim] = sort_dim_elements; + TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped, + sorted_row.Reshape(slice_dimensions)); + std::vector<int64> start_indices(rank, 0); + TF_RETURN_IF_ERROR(result_literal.CopySliceFrom( + sorted_row_reshaped, start_indices, indices, slice_dimensions)); + return true; + })); + parent_->evaluated_[sort] = std::move(result_literal); return Status::OK(); } @@ -2265,19 +2270,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // be 1. int64 update_dim_size = update_dim == -1 ? 1 : updates_shape.dimensions(update_dim); - // Clamp the scatter index so that the scatter region fits in the - // operand. input_scatter_index_clamped[i] = - // clamp(input_scatter_index[i], 0, - // operand_shape.dimensions(i) - - // update_dim_size); - input_scatter_index_clamped[i] = - std::min(operand_shape.dimensions(i) - update_dim_size, - std::max(0LL, input_scatter_index[i])); + // If any part of the update region is out-of-bounds, then do not + // perform any update on the input. + if ((input_scatter_index[i] < 0) || + (input_scatter_index[i] > + operand_shape.dimensions(i) - update_dim_size)) { + return true; + } } for (int i = 0, e = input_index.size(); i < e; i++) { - input_index[i] = input_scatter_index_clamped[i] + input_window_index[i]; - DCHECK_GE(input_index[i], 0); - DCHECK_LT(input_index[i], operand_shape.dimensions(i)); + input_index[i] = input_scatter_index[i] + input_window_index[i]; } auto result_value_literal = @@ -2611,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { std::vector<int64> base_index(rank); bool out_of_bound = false; for (int64 i = 0; i < rank; ++i) { - base_index[i] = window_count_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); + base_index[i] = + window_count_index[i] * window.dimensions(i).stride() + + window_index[i] * window.dimensions(i).window_dilation() - + window.dimensions(i).padding_low(); + // We are not in the base area if the dilation placed us out of bounds. + if (base_index[i] % window.dimensions(i).base_dilation() != 0) { + out_of_bound = true; + break; + } + // Apply the dilation to the base area. + base_index[i] /= window.dimensions(i).base_dilation(); if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { out_of_bound = true; break; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index de3d7a1677..ce4cad4235 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -90,8 +90,9 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData( HloInstructionInfo* instruction_info = computation_info->add_instruction_infos(); instruction_info->set_long_name(hlo->ToString()); - instruction_info->set_short_name( - hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_short_name(hlo->ToString( + HloPrintOptions().set_compact_operands(true).set_print_operand_names( + false))); instruction_info->set_category(hlo->ToCategory()); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); instruction_info->set_transcendental_count( diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 287ba84b3b..13a74fd8a1 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1110,7 +1110,7 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { instr->metadata().source_line())); } - return StrJoin(lines, "<br/>"); + return StrJoin(lines, "\n"); } string HloDotDumper::GetInstructionNodeBackendConfig( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e905f2983a..2f6db7cd7c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,6 +22,8 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/ascii.h" @@ -37,14 +39,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" @@ -59,8 +60,8 @@ using absl::StrJoin; /* static */ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map, - const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) { + const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, + const absl::flat_hash_map<int64, HloComputation*>& computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -80,6 +81,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( const auto computations = [&computation_map, &proto](int index) { return computation_map.at(proto.called_computation_ids(index)); }; + + TF_RET_CHECK(std::all_of( + proto.operand_ids().begin(), proto.operand_ids().end(), + [&instruction_map](int64 id) { return instruction_map.contains(id); })) + << proto.name() << " instruction contains invalid operand id(s)"; + + TF_RET_CHECK(std::all_of( + proto.called_computation_ids().begin(), + proto.called_computation_ids().end(), + [&computation_map](int64 id) { return computation_map.contains(id); })) + << proto.name() << " instruction references invalid computation id(s)"; + + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape())); + switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: @@ -266,7 +281,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + auto* fused_computation = + tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id); TF_RET_CHECK(fused_computation != nullptr) << "No fusion computation with id " << fusion_id; instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(), @@ -302,6 +318,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } break; case HloOpcode::kOutfeed: TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), operands(1), proto.outfeed_config()); break; @@ -379,7 +397,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; case HloOpcode::kCustomCall: instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target()); + proto.custom_call_target(), + proto.custom_call_opaque()); if (proto.has_window()) { static_cast<HloCustomCallInstruction*>(instruction.get()) ->set_window(proto.window()); @@ -446,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kIota: - TF_RET_CHECK(proto.dimensions_size() <= 1) - << "Iota instruction should have at most 1 dimension but sees " + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Iota instruction should have 1 dimension but sees " << proto.dimensions_size(); instruction = CreateIota(proto.shape(), proto.dimensions(0)); break; @@ -465,31 +484,34 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.dot_dimension_numbers(), precision_config); break; } - case HloOpcode::kDomain: + case HloOpcode::kDomain: { TF_RET_CHECK(proto.operand_ids_size() == 1) << "Domain instruction should have 1 operands but sees " << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_domain_entry_sharding()) + << "Domain instruction must domain_entry_sharding"; + TF_RET_CHECK(proto.has_domain_exit_sharding()) + << "Domain instruction must domain_exit_sharding"; + TF_ASSIGN_OR_RETURN( + HloSharding entry_hlo_sharding, + HloSharding::FromProto(proto.domain_entry_sharding())); + TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding, + HloSharding::FromProto(proto.domain_exit_sharding())); instruction = absl::make_unique<HloDomainInstruction>( - proto.shape(), operands(0), /*operand_side_metadata=*/nullptr, - /*user_side_metadata=*/nullptr); + proto.shape(), operands(0), + absl::make_unique<ShardingMetadata>( + std::make_shared<const HloSharding>(entry_hlo_sharding)), + absl::make_unique<ShardingMetadata>( + std::make_shared<const HloSharding>(exit_hlo_sharding))); break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; instruction->AppendOperand(instruction_map.at(operand_id)); } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } if (instruction->opcode() != HloOpcode::kFusion) { for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; instruction->called_computations_.push_back( computation_map.at(computation_id)); } @@ -501,6 +523,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + TF_RET_CHECK(!proto.name().empty()); instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); @@ -1108,9 +1137,9 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target) { - return absl::make_unique<HloCustomCallInstruction>(shape, operands, - custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque) { + return absl::make_unique<HloCustomCallInstruction>( + shape, operands, custom_call_target, opaque); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple( @@ -1431,7 +1460,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; - tensorflow::gtl::FlatSet<const HloInstruction*> seen; + absl::flat_hash_set<const HloInstruction*> seen; for (HloInstruction* operand : operands()) { if (seen.insert(operand).second) { unique.push_back(operand); @@ -2005,7 +2034,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( options.is_in_nested_computation()) { str.push_back(PrintName( canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { + } else if (options.print_operand_names()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); @@ -2423,7 +2452,7 @@ template <typename Visitor> static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { - visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); + visitor->ReserveVisitStates(root->GetModule()->instruction_count()); // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>. // @@ -2660,14 +2689,14 @@ class HloInstruction::FusionReusesParamElements { // the value of this parameter, which would save stack space but not allow us // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { - tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache; + absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); } private: static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, - tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) { + absl::flat_hash_map<const HloInstruction*, UseKind>* cache) { if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) { if (hlo_param->parameter_number() == i) { return UseKind::kUse; @@ -2910,6 +2939,26 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } +bool HloPtrComparator::operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + if (rhs == nullptr) { + // Nothing compares less than nullptr. + return false; + } + if (lhs == nullptr) { + return true; + } + auto lhs_module = lhs->GetModule(); + auto rhs_module = rhs->GetModule(); + CHECK((lhs_module == nullptr && rhs_module == nullptr) || + (lhs_module != nullptr && rhs_module != nullptr)); + if (lhs_module != nullptr && + lhs_module->unique_id() != rhs_module->unique_id()) { + return lhs_module->unique_id() < rhs_module->unique_id(); + } + return lhs->unique_id() < rhs->unique_id(); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3027,10 +3076,6 @@ const std::vector<int64>& HloInstruction::slice_strides() const { return Cast<HloSliceInstruction>(this)->slice_strides(); } -bool HloInstruction::IsInPlaceSlice() const { - return Cast<HloSliceInstruction>(this)->IsInPlaceSlice(); -} - const Literal& HloInstruction::literal() const { return Cast<HloConstantInstruction>(this)->literal(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 4f6cac1396..374862c4b6 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -32,6 +32,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -80,6 +80,7 @@ class HloPrintOptions { print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), + print_operand_names_(true), print_program_shape_(true), print_percent_(true), print_control_dependencies_(true), @@ -107,6 +108,7 @@ class HloPrintOptions { .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) + .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) @@ -144,6 +146,12 @@ class HloPrintOptions { return *this; } + // If true, the operand names will be printed. + HloPrintOptions& set_print_operand_names(bool value) { + print_operand_names_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -162,8 +170,8 @@ class HloPrintOptions { return *this; } - // If true, only a part of operands will be printed out, and their names will - // be omitted (note that in this case the text will not be parsable). + // If true, only a part of operands will be printed out (note that in this + // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; @@ -197,6 +205,7 @@ class HloPrintOptions { bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } + bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { @@ -215,6 +224,7 @@ class HloPrintOptions { bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; + bool print_operand_names_; bool print_program_shape_; bool print_percent_; bool print_control_dependencies_; @@ -247,7 +257,7 @@ class CanonicalNameMap { private: int64 index; - tensorflow::gtl::FlatMap<string, string> canonical_name_map; + absl::flat_hash_map<string, string> canonical_name_map; }; // HLO instructions are the atomic unit of the high-level compiler's IR. @@ -350,8 +360,8 @@ class HloInstruction { // calls. static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( const HloInstructionProto& proto, - const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map, - const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map); + const absl::flat_hash_map<int64, HloInstruction*>& instruction_map, + const absl::flat_hash_map<int64, HloComputation*>& computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, @@ -718,10 +728,11 @@ class HloInstruction { HloComputation* computation); // Creates a custom call instruction that applies the given custom call target - // to the given operands. "shape" is the resultant shape. + // to the given operands. "opaque" can be an arbitrary string with a + // backend-specific interpretation. "shape" is the resultant shape. static std::unique_ptr<HloInstruction> CreateCustomCall( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. @@ -1319,9 +1330,6 @@ class HloInstruction { int64 slice_strides(int64 dimension) const; const std::vector<int64>& slice_strides() const; - // Delegates to HloSliceInstruction::IsInPlaceSlice. - bool IsInPlaceSlice() const; - // Returns the literal associated with this instruction. const Literal& literal() const; @@ -1616,6 +1624,10 @@ class HloInstruction { InstructionVector operands_; // The set of control predecessors of this instruction. + // Note that the order of the instructions in the vector influences the order + // computed in HloComputation::ComputeInstructionPostOrder, which may + // influence the result of the compilation by changing the scheduling. We are + // not sure if it matters. std::vector<HloInstruction*> control_predecessors_; // The users of this instruction. Users are HLOs where this instruction is an @@ -1689,21 +1701,9 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); // To make the iteration order over the map deterministic, the comparator // should not be using the pointer values, but rather an intrinsic property of // the hlo. Exception: null pointer values compare less than non-null. -// -// Note that this cannot be used for HLO instructions across multiple modules -// since the id of HLO instructions are only unique within each HLO module. struct HloPtrComparator { bool operator()(const HloInstruction* const& lhs, - const HloInstruction* const& rhs) const { - if (rhs == nullptr) { - // Nothing compares less than nullptr. - return false; - } - if (lhs == nullptr) { - return true; - } - return lhs->unique_id() < rhs->unique_id(); - } + const HloInstruction* const& rhs) const; }; template <typename ValueT> diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e92882c22a..152d8eacdb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -18,6 +18,7 @@ limitations under the License. #include <deque> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" @@ -27,8 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/window_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, HloInstructionProto HloSendRecvInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); proto.set_channel_id(channel_id_); + proto.set_is_host_transfer(is_host_transfer_); return proto; } @@ -641,14 +643,6 @@ HloTransposeInstruction::HloTransposeInstruction( absl::Span<const int64> dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -1042,7 +1036,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand( const int64 param_no = operand_count(); // Name the parameter after the instruction it represents in the outer // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); + // string param_name = StrCat(new_operand->name(), ".param_", param_no); + string param_name = StrCat("param_", param_no); HloInstruction* fused_parameter = fused_instructions_computation()->AddParameter( HloInstruction::CreateParameter(param_no, new_operand->shape(), @@ -1098,7 +1093,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( // Note that we add the unfused instructions to this->parent_ computation. // This is necessary because the unique_id needs for an instruction and // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new; + absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new; std::vector<HloInstruction*> unfused_instructions; auto computation_to_merge = instruction_to_merge->fused_instructions_computation(); @@ -1391,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( } Status HloFusionInstruction::DeduplicateFusionOperands() { - tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices; + absl::flat_hash_map<const HloInstruction*, int> operand_indices; std::vector<int> operands_to_remove; for (int i = 0; i < operand_count(); ++i) { auto emplace_result = operand_indices.emplace(operand(i), i); @@ -1488,7 +1483,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloGetTupleElementInstruction::HloGetTupleElementInstruction( const Shape& shape, HloInstruction* operand, int64 index) : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); AppendOperand(operand); } @@ -1610,9 +1604,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); AppendOperand(token_operand); } @@ -1830,9 +1821,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target) + absl::string_view custom_call_target, absl::string_view opaque) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); @@ -1849,6 +1841,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1872,6 +1865,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl( // an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + // If the opaque string becomes enormous we may want to reconsider printing + // this inline and consider other options. + if (!opaque_.empty()) { + extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); + } return extra; } @@ -1897,7 +1895,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } - return custom_call_target_ == casted_other.custom_call_target_; + return custom_call_target_ == casted_other.custom_call_target_ && + opaque_ == casted_other.opaque_; } std::unique_ptr<HloInstruction> @@ -1905,7 +1904,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const { auto cloned = absl::make_unique<HloCustomCallInstruction>( - shape, new_operands, custom_call_target()); + shape, new_operands, custom_call_target(), opaque()); if (window_ != nullptr) { cloned->set_window(*window_); } @@ -2301,4 +2300,23 @@ std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); } + +HloInstructionProto HloDomainInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + auto operand_side_sharding = + dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get()); + if (operand_side_sharding) { + *proto.mutable_domain_entry_sharding() = + operand_side_sharding->sharding()->ToProto(); + } + + auto user_side_sharding = + dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get()); + if (user_side_sharding) { + *proto.mutable_domain_exit_sharding() = + user_side_sharding->sharding()->ToProto(); + } + + return proto; +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 2d7bc83855..e169604072 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -546,17 +546,6 @@ class HloSliceInstruction : public HloInstruction { } const std::vector<int64>& slice_strides() const { return slice_strides_; } - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - private: std::vector<string> ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -573,9 +562,6 @@ class HloSliceInstruction : public HloInstruction { std::vector<int64> slice_starts_; std::vector<int64> slice_limits_; std::vector<int64> slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; }; class HloConstantInstruction : public HloInstruction { @@ -910,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. @@ -1070,7 +1055,8 @@ class HloCustomCallInstruction : public HloInstruction { public: explicit HloCustomCallInstruction(const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target); + absl::string_view custom_call_target, + absl::string_view opaque); const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1090,6 +1076,7 @@ class HloCustomCallInstruction : public HloInstruction { convolution_dimension_numbers_ = absl::make_unique<ConvolutionDimensionNumbers>(dnums); } + const string& opaque() const { return opaque_; } const string& custom_call_target() const { return custom_call_target_; } void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; @@ -1109,8 +1096,10 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const override; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a global symbol to call. string custom_call_target_; + // Opaque string interpreted by the backend. + string opaque_; // Describes the window in a windowed operation such as convolution. std::unique_ptr<Window> window_; // Describes the dimension numbers used for a convolution. @@ -1337,6 +1326,9 @@ class HloDomainInstruction : public HloInstruction { std::unique_ptr<DomainMetadata> operand_side_metadata, std::unique_ptr<DomainMetadata> user_side_metadata); + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + // Retrieves the operand side metadata of a kDomain instruction. const DomainMetadata& operand_side_metadata() const { return *operand_side_metadata_; diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc index 3a1dd471c6..5bf055f3c0 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis.cc @@ -219,6 +219,33 @@ void PropagateLivenessToParameterCallers( } } +// Makes sure that if a live instruction is within a computation used in control +// flow operations, we mark live even other related instructions. +void PropagateLivenessThroughControlFlow( + const HloInstruction* instruction, + HloLivenessAnalysis::HloIndexMap* live_index_map, Worklist* worklist, + Workset* workset, CallGraph* call_graph) { + const CallGraphNode& call_graph_node = + call_graph->GetNode(instruction->parent()); + if (call_graph_node.context() == CallContext::kSequential) { + for (const CallSite& callsite : call_graph_node.caller_callsites()) { + HloInstruction* caller = callsite.instruction(); + if (caller->opcode() == HloOpcode::kWhile) { + // If a live instruction is within the %while body or condition + // computation, mark the predicate value returned by the condition + // computation live as well. + MarkLiveAtIndex(caller->while_condition()->root_instruction(), {}, + live_index_map, worklist, workset); + } else if (caller->opcode() == HloOpcode::kConditional) { + // If a live instruction is within the true or false branches of a + // conditional, we mark the predicate operand live as well. + MarkLiveAtIndex(caller->operand(0), {}, live_index_map, worklist, + workset); + } + } + } +} + } // namespace HloLivenessAnalysis::HloLivenessAnalysis(const HloModule& module) @@ -257,12 +284,10 @@ void HloLivenessAnalysis::RunAnalysis() { } else if (instruction->opcode() == HloOpcode::kGetTupleElement) { PropagateLivenessThroughGTE(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kWhile && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kWhile) { PropagateLivenessThroughWhile(instruction, &live_index_map_, &worklist, &workset); - } else if (instruction->opcode() == HloOpcode::kParameter && - ShapeUtil::IsTuple(instruction->shape())) { + } else if (instruction->opcode() == HloOpcode::kParameter) { PropagateLivenessToParameterCallers(instruction, &live_index_map_, &worklist, &workset, call_graph_.get()); @@ -277,6 +302,8 @@ void HloLivenessAnalysis::RunAnalysis() { MarkLiveAtAllIndices(operand, &live_index_map_, &worklist, &workset); } } + PropagateLivenessThroughControlFlow(instruction, &live_index_map_, + &worklist, &workset, call_graph_.get()); } } diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 01b625c29c..e0ae1173c6 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -398,5 +398,89 @@ TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "loop_var.1"), {2})); } +TEST_F(HloLivenessAnalysisTest, WhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + WhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + WhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=WhileCondition, + body=WhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + +TEST_F(HloLivenessAnalysisTest, NestedWhileWithOutfeed) { + auto module = ParseHloString(R"( + HloModule OutfeedLoop + InnerWhileBody { + body_param = (s32[]) parameter(0) + token = token[] after-all() + constant.2 = s32[] constant(2) + outfeed_tuple = (s32[]) outfeed(constant.2, token) + get-tuple-element.1 = s32[] get-tuple-element(body_param), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]) tuple(add) + } + InnerWhileCondition { + cond_param = (s32[]) parameter(0) + get-tuple-element.3 = s32[] get-tuple-element(cond_param), index=0 + constant.2 = s32[] constant(10) + ROOT less-than = pred[] less-than(get-tuple-element.3, constant.2) + } + OuterWhileCondition { + cond_param.2 = (s32[]) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(cond_param.2), index=0 + constant.5 = s32[] constant(5) + ROOT less-than.2 = pred[] less-than(get-tuple-element.5, constant.5) + } + OuterWhileBody { + body_param.2 = (s32[]) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(body_param.2), index=0 + constant.6 = s32[] constant(0) + tuple.2 = (s32[]) tuple(constant.6) + inner_while = (s32[]) while(tuple.2), condition=InnerWhileCondition, + body=InnerWhileBody + constant.7 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.8, constant.7) + ROOT rtuple = (s32[]) tuple(add.2) + } + ENTRY SimpleLoop { + constant.3 = s32[] constant(0) + tuple.1 = (s32[]) tuple(constant.3) + while = (s32[]) while(tuple.1), condition=OuterWhileCondition, + body=OuterWhileBody + ROOT rtuple = () tuple() + })") + .ValueOrDie(); + + const HloLivenessAnalysis& liveness = RunLiveness(module.get()); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "add.2"), {})); + EXPECT_TRUE(liveness.IsLive(GetInstruction(module.get(), "constant.3"), {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc index c7ec88d450..5cee865b7a 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc @@ -20,6 +20,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -74,7 +76,7 @@ class ListScheduler { const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { ListScheduler scheduler(computation, points_to_analysis, size_function, memory_by_computation); @@ -99,7 +101,7 @@ class ListScheduler { ListScheduler(const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) : computation_(computation), points_to_analysis_(points_to_analysis), @@ -110,7 +112,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses; + absl::flat_hash_set<const LogicalBuffer*> instr_uses; for (auto* operand : instruction->operands()) { points_to_analysis.GetPointsToSet(operand).ForEachElement( [&](const ShapeIndex& /*index*/, @@ -193,13 +195,15 @@ class ListScheduler { return entry; } - // Returns the number of bytes freed if the HLO instruction is scheduled. - // If the instruction calls subcomputations, we count the memory used by the - // subcomputations as memory "defined" by the instruction. This is not - // entirely accurate, because subcomputation memory will be freed after the - // instruction finishes. But it is more accurate than not taking - // subcomputations into account at all. In the future, we may improve - // accounting for subcomputation memory (b/65409243). + // Returns the number of bytes freed *after* the HLO instruction finishes. + // The current List algorithm only considers two states for an instruction: + // right before it runs, and after it finishes. We don't represent memory + // usage during the execution of an instruction. But if the instruction calls + // subcomputations, they are only live during the instruction's execution. + // We end up counting the memory used by subcomputations as memory "defined" + // by the instruction. This is not entirely accurate, but it is more accurate + // than not taking subcomputations into account at all. In the future, we may + // improve accounting for subcomputation memory (b/65409243). int64 BytesFreedIfScheduled(const ReadyListEntry& entry) { int64 freed_bytes = 0; for (const auto& kv : entry.used_buffer_unscheduled_use_counts) { @@ -221,7 +225,18 @@ class ListScheduler { } } } - return freed_bytes - entry.bytes_defined - max_subcomputation_bytes; + int64 bytes_defined; + if (max_subcomputation_bytes > 0 && + (entry.instruction->opcode() == HloOpcode::kWhile || + entry.instruction->opcode() == HloOpcode::kCall || + entry.instruction->opcode() == HloOpcode::kConditional)) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + bytes_defined = max_subcomputation_bytes; + } else { + bytes_defined = entry.bytes_defined + max_subcomputation_bytes; + } + return freed_bytes - bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -234,8 +249,7 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - tensorflow::gtl::FlatMap<const HloInstruction*, int64> - unscheduled_pred_count; + absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -251,8 +265,8 @@ class ListScheduler { std::multimap<Priority, ReadyListEntry> ready_queue; // Map of ready instructions to their iterators in ready_queue. - tensorflow::gtl::FlatMap<const HloInstruction*, - std::multimap<Priority, ReadyListEntry>::iterator> + absl::flat_hash_map<const HloInstruction*, + std::multimap<Priority, ReadyListEntry>::iterator> ready_instructions; auto add_to_ready_queue = [&](HloInstruction* inst) { @@ -262,9 +276,8 @@ class ListScheduler { }; for (auto* instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction) == 0) { + if (instruction->operands().empty() && + instruction->control_predecessors().empty()) { add_to_ready_queue(instruction); } } @@ -347,21 +360,19 @@ class ListScheduler { // Computations are analyzed in post-order. When scheduling an instruction // that includes subcomputations, such as a while loop, we use this map to // look up the memory needed by subcomputations. - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. - tensorflow::gtl::FlatMap<const HloInstruction*, - std::vector<const LogicalBuffer*>> + absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map, and that the map - // entries are std::pair's. - std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; + // LogicalBuffer. + absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_; // Set of instructions which have been scheduled. - tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_; + absl::flat_hash_set<const HloInstruction*> scheduled_instructions_; }; int64 SumLogicalBufferSizes( @@ -379,7 +390,7 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper( const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { VLOG(2) << "Computation: " << computation.name(); if (algorithm) { @@ -396,13 +407,13 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { // These variables are a hack to prevent overflows. int64 cumulative_total_size = 0; - int64 total_hlos = computation.parent()->NumUniqueInstructionIds(); - tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users; - tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes; + int64 total_hlos = computation.parent()->instruction_count(); + absl::flat_hash_map<const HloInstruction*, int64> extra_users; + absl::flat_hash_map<const HloInstruction*, int64> total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { if (ListScheduler::IgnoreInstruction(*hlo)) { extra_users[hlo] = 0; @@ -419,7 +430,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; cumulative_total_size += logical_buffer_size; - tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands( + absl::flat_hash_set<const HloInstruction*> unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; @@ -467,7 +478,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { return ListScheduler::Run(computation, points_to_analysis, size_function, memory_by_computation); @@ -477,7 +488,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { return HloInstructionSequence(computation.MakeInstructionPostOrder()); } @@ -486,7 +497,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -549,7 +560,7 @@ StatusOr<HloSchedule> ScheduleModule( HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(&module)); - tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; + absl::flat_hash_map<const HloComputation*, int64> memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, @@ -577,7 +588,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation( CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); - tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map; + absl::flat_hash_map<const HloComputation*, int64> empty_map; return ScheduleComputationHelper(computation, *points_to_analysis, size_function, nullptr, empty_map); } diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h index 5e02868eba..a4c1d3db81 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h @@ -18,6 +18,7 @@ limitations under the License. #include <vector> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -37,7 +38,7 @@ namespace xla { typedef std::function<StatusOr<HloInstructionSequence>( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)> + const absl::flat_hash_map<const HloComputation*, int64>&)> MemorySchedulerAlgorithm; // List scheduler @@ -45,7 +46,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation); // DFS-order scheduler @@ -53,7 +54,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation); // Naive Post Order scheduler @@ -61,7 +62,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation); // The default scheduling algorithm. Runs both the list scheduler @@ -71,7 +72,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, - const tensorflow::gtl::FlatMap<const HloComputation*, int64>& + const absl::flat_hash_map<const HloComputation*, int64>& memory_by_computation); // Returns an HloSchedule which seeks to minimize the memory required for @@ -90,7 +91,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation( // A pass which schedules the HLO instructions in a module. The HloModule's // schedule field is set to the resulting HloSchedule using // HloModule::set_schedule. -class HloMemoryScheduler : public HloPassInterface { +class HloMemoryScheduler : public HloModulePass { public: // size_function is the function returning the number of bytes required for a // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not @@ -109,7 +110,7 @@ class HloMemoryScheduler : public HloPassInterface { // A trivial pass which clears the schedule currently set on the // HloModule. After this pass runs HloModudle::has_schedule will return false. -class HloDescheduler : public HloPassInterface { +class HloDescheduler : public HloModulePass { public: HloDescheduler() = default; ~HloDescheduler() override = default; diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 1b9e9bfc77..214119fba8 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include <string> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -146,126 +147,6 @@ ENTRY root { instructions_by_name.at("e"))); } -TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { - // %WhileCond (cond_param: f32[4]) -> pred[] { - // %cond_param = f32[4]{0} parameter(0) - // %constant = f32[1,4]{1,0} constant(f32[1,4] { { 0, 0, 0, 0 } }) - // ROOT %not-equal-to = pred[] not-equal-to( - // f32[4]{0} %cond_param, f32[1,4]{1,0} %constant) - // } - // %WhileBody (body_param: f32[4]) -> f32[4] { - // %body_param = f32[4]{0} parameter(0) - // %constant.1 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // ROOT %subtract = f32[4]{0} subtract( - // f32[4]{0} %body_param, f32[1,4]{1,0} %constant.1) - // } - // %ListAccountsForSubcomputations () -> f32[2,4] { - // %constant.3 = f32[2,4]{1,0} constant( - // f32[2,4] { { 1, 2, 3, 4 }, { 1, 2, 3, 4 } }) - // %transpose = f32[2,4]{1,0} transpose( - // f32[2,4]{1,0} %constant.3), dimensions={0,1} - // %constant.2 = f32[1,4]{1,0} constant(f32[1,4] { { 1, 1, 1, 1 } }) - // %while = f32[4]{0} while(f32[1,4]{1,0} %constant.2), - // condition=%WhileCond, - // body=%WhileBody - // %broadcast = f32[2,4]{1,0} broadcast(f32[4]{0} %while), dimensions={0} - // ROOT %add = f32[2,4]{1,0} add( - // f32[2,4]{1,0} %transpose, f32[2,4]{1,0} %broadcast) - // } - - auto module = CreateNewModule(); - const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); - - // param != 0 - // Needs 17 bytes - auto cond_builder = HloComputation::Builder("WhileCond"); - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "cond_param")); - HloInstruction* zero_vector = - cond_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2<float>({{0, 0, 0, 0}}))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kNe, cond_param, zero_vector)); - auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build()); - - // param - 1 - // Needs 16 bytes - auto body_builder = HloComputation::Builder("WhileBody"); - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, r1f32, "body_param")); - HloInstruction* one_vector = - body_builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2<float>({{1, 1, 1, 1}}))); - body_builder.AddInstruction(HloInstruction::CreateBinary( - r1f32, HloOpcode::kSubtract, body_param, one_vector)); - auto body_computation = module->AddEmbeddedComputation(body_builder.Build()); - - // transpose(matrix) + bcast(while) - auto builder = HloComputation::Builder(TestName()); - HloInstruction* while_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2<float>({{1, 1, 1, 1}}))); - // Creates 16 bytes, ignoring subcomputations - HloInstruction* while_loop = - builder.AddInstruction(HloInstruction::CreateWhile( - r1f32, cond_computation, body_computation, while_init)); - - // Creates 32 bytes and frees 16 - HloInstruction* bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(r2f32, while_loop, {0})); - - HloInstruction* matrix = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>( - {{1.0, 2.0, 3.0, 4.0}, {1.0, 2.0, 3.0, 4.0}}))); - // Creates 32 bytes - HloInstruction* transpose = builder.AddInstruction( - HloInstruction::CreateTranspose(r2f32, matrix, {0, 1})); - - // Creates 32 bytes and frees 64 - HloInstruction* add = builder.AddInstruction( - HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, transpose, bcast)); - - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - }; - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); - // Verify that all instructions are in the sequence. - auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - schedule.sequence(entry_computation).size()); - SequentialHloOrdering ordering(schedule); - // This schedule is an example of List's greedy heuristics being suboptimal. - // The while_loop is more expensive than transpose, so it would have been - // better to schedule it first, instead of during the busy time. - EXPECT_TRUE(ordering.ExecutesBefore(transpose, while_loop)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, bcast)); - EXPECT_TRUE(ordering.ExecutesBefore(bcast, add)); - EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); - - tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; - memory_by_computation[cond_computation] = 17; - memory_by_computation[body_computation] = 16; - std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = - TuplePointsToAnalysis::Run(module.get()).ValueOrDie(); - - // HeapSimulator doesn't account for subcomputations - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn) - .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The output buffer is aliased, - // so we don't double count. - EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), - *points_to_analysis, size_fn, &memory_by_computation) - .ValueOrDie()); -} - TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto builder = HloComputation::Builder(TestName()); const auto TUPLE_SIZE = 1; @@ -409,7 +290,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { EXPECT_EQ(module->entry_computation()->instruction_count(), schedule.sequence(module->entry_computation()).size()); - tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation; + absl::flat_hash_map<const HloComputation*, int64> memory_by_computation; memory_by_computation[cond_computation] = 17; memory_by_computation[body_computation] = 16; std::unique_ptr<TuplePointsToAnalysis> points_to_analysis = diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index b3949f3a6d..93e04eb3db 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -23,6 +23,8 @@ limitations under the License. #include <utility> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" @@ -144,7 +146,8 @@ void HloModule::ReplaceComputations( case HloOpcode::kCall: case HloOpcode::kMap: case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: { + case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: { HloComputation* new_arg = tensorflow::gtl::FindWithDefault( replacements, instruction->to_apply(), nullptr); if (new_arg != nullptr) { @@ -285,8 +288,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); - tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map; - tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id; + absl::flat_hash_map<int64, HloComputation*> computation_map; + absl::flat_hash_map<HloComputation*, int64> to_proto_id; std::vector<std::unique_ptr<HloComputation>> computations; HloComputation* entry = nullptr; for (const HloComputationProto& computation_proto : proto.computations()) { @@ -327,10 +330,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( // Because we didn't uniquify the names or the ids, double-check that the // instruction and computation names and ids are unique from the proto. - tensorflow::gtl::FlatSet<string> computation_names; - tensorflow::gtl::FlatSet<string> instruction_names; - tensorflow::gtl::FlatSet<int> computation_ids; - tensorflow::gtl::FlatSet<int> instruction_ids; + absl::flat_hash_set<string> computation_names; + absl::flat_hash_set<string> instruction_names; + absl::flat_hash_set<int> computation_ids; + absl::flat_hash_set<int> instruction_ids; for (HloComputation* computation : module->computations()) { TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) << "Computation name is not unique: " << computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3bc2d13781..735804e827 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -63,6 +63,7 @@ class HloModule { // tests). The versioned handle is used by the service in the compilation // cache. A default configuration is created for this module. explicit HloModule(const string& name, const HloModuleConfig& config); + virtual ~HloModule() {} // Adds an entry computation to the module. A module can only have one entry // computation. Returns a pointer to the newly added computation. @@ -87,6 +88,7 @@ class HloModule { const std::unordered_map<HloComputation*, HloComputation*>& replacements); const string& name() const { return name_; } + void set_name(string name) { name_ = std::move(name); } // Returns a deep copy of this module including all computations. std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const; @@ -255,7 +257,7 @@ class HloModule { std::unique_ptr<HloComputation> computation, bool is_entry, bool uniquify_identifiers); - const string name_; + string name_; HloModuleConfig config_; HloComputation* entry_computation_ = nullptr; std::vector<std::unique_ptr<HloComputation>> computations_; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.cc b/tensorflow/compiler/xla/service/hlo_module_dce.cc index f7be5cae22..31d26cc51e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce.cc @@ -50,9 +50,7 @@ StatusOr<bool> RunWhileDCE(HloModule* module, HloLivenessAnalysis* liveness) { auto* while_body_root = while_body_comp->root_instruction(); if (!ShapeUtil::IsTuple(xla_while->shape()) || - while_body_root->opcode() != HloOpcode::kTuple || - while_body_comp->HasSideEffect() || - xla_while->while_condition()->HasSideEffect()) { + while_body_root->opcode() != HloOpcode::kTuple) { // Only run DCE on tuple-shaped while loops where body root is Tuple, // with no I/O instructions. VLOG(1) << "WhileDCE SKIP while: " << xla_while->ToString(); diff --git a/tensorflow/compiler/xla/service/hlo_module_dce.h b/tensorflow/compiler/xla/service/hlo_module_dce.h index 12ca2340a6..d472211d2a 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce.h +++ b/tensorflow/compiler/xla/service/hlo_module_dce.h @@ -28,7 +28,7 @@ namespace xla { // Sweeps through live instructions which cross computation boundaries (kWhile), // and removes code at dead shape indices. // -class HloModuleDCE : public HloPassInterface { +class HloModuleDCE : public HloModulePass { public: ~HloModuleDCE() override {} absl::string_view name() const override { return "hlo-module-dce"; } diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index 9c01862a4b..b4aac4c807 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const { } /* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>> -HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) { +HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) { auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules); TF_RETURN_IF_ERROR(metadata->Build()); return std::move(metadata); @@ -392,22 +392,28 @@ Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1, if (!ContainsKey(companion_set_index_, instruction1) && !ContainsKey(companion_set_index_, instruction2)) { companion_sets_.push_back( - absl::make_unique<std::unordered_set<HloInstruction*>>()); + absl::make_unique<std::vector<HloInstruction*>>()); auto companion_set = companion_sets_.back().get(); - companion_set->insert(instruction1); - companion_set->insert(instruction2); + companion_set->push_back(instruction1); + companion_set->push_back(instruction2); companion_set_index_[instruction1] = companion_sets_.size() - 1; companion_set_index_[instruction2] = companion_sets_.size() - 1; } else if (!ContainsKey(companion_set_index_, instruction1)) { - companion_sets_[companion_set_index_[instruction2]]->insert(instruction1); + companion_sets_[companion_set_index_[instruction2]]->push_back( + instruction1); companion_set_index_[instruction1] = companion_set_index_[instruction2]; } else if (!ContainsKey(companion_set_index_, instruction2)) { - companion_sets_[companion_set_index_[instruction1]]->insert(instruction2); + companion_sets_[companion_set_index_[instruction1]]->push_back( + instruction2); companion_set_index_[instruction2] = companion_set_index_[instruction1]; } else if (companion_set_index_[instruction1] != companion_set_index_[instruction2]) { - companion_sets_[companion_set_index_[instruction1]]->insert( - Companions(instruction2).begin(), Companions(instruction2).end()); + // At any point while building the companion sets, each instruction belongs + // to at most 1 companion set, so the union of two companion sets is + // concatenating two disjoint sets. + absl::c_copy(Companions(instruction2), + std::back_inserter( + *companion_sets_[companion_set_index_[instruction1]])); int64 index_to_remove = companion_set_index_[instruction2]; for (HloInstruction* hlo : Companions(instruction2)) { companion_set_index_[hlo] = companion_set_index_[instruction1]; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 768b0c7eb3..928df0f5a7 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -22,6 +22,7 @@ limitations under the License. #include <unordered_set> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -102,14 +102,14 @@ class HloModuleGroupMetadata { HloInstruction* recv_done = nullptr; }; - explicit HloModuleGroupMetadata(const std::vector<HloModule*>& modules) - : modules_(modules) {} + explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules) + : modules_(modules.begin(), modules.end()) {} ~HloModuleGroupMetadata() = default; // Build and return the metadata for the given modules. static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build( - const std::vector<HloModule*>& modules); + absl::Span<HloModule* const> modules); // Returns true if the instruction is one of the 4 channel instructions (Send, // Recv, SendDone, RecvDone). @@ -169,14 +169,14 @@ class HloModuleGroupMetadata { // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. - const std::unordered_set<HloInstruction*>& Companions( + const std::vector<HloInstruction*>& Companions( const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } // Returns the companion set at the given index. - const std::unordered_set<HloInstruction*>& companion_set(int64 index) const { + const std::vector<HloInstruction*>& companion_set(int64 index) const { CHECK_LT(index, companion_sets_.size()); return *companion_sets_[index]; } @@ -187,7 +187,7 @@ class HloModuleGroupMetadata { } // Returns the list of all companion sets in the HLO module group. - const std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>>& + const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>& companion_sets() const { return companion_sets_; } @@ -247,37 +247,36 @@ class HloModuleGroupMetadata { void DumpCollectedStats() const; // List of all companion instructions sets in the module. - std::vector<std::unique_ptr<std::unordered_set<HloInstruction*>>> - companion_sets_; + std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_; + absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). - tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction> + absl::flat_hash_map<const HloComputation*, TrackedInstruction> tracked_instructions_; // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of // communicating instructions within the proper called computation(s). - tensorflow::gtl::FlatMap<HloInstruction*, std::vector<HloInstruction*>> + absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> tracked_instructions_comms_; // All channels in the module. std::vector<Channel> channels_; // Map from channel ids to the index in channels_. - tensorflow::gtl::FlatMap<int64, int64> channel_id_map_; + absl::flat_hash_map<int64, int64> channel_id_map_; // Map from all-reduce ids to the all reduce instructions. - tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_; + absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_; // The maximum channel id used in the module group. int64 max_channel_id_ = -1; // The modules that this metadata was built from. - const std::vector<HloModule*>& modules_; + const std::vector<HloModule*> modules_; - tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>> + absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>> points_to_analyses_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_test.cc b/tensorflow/compiler/xla/service/hlo_module_group_test.cc index ebf790ba6f..b7b12cb72b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -137,6 +138,69 @@ ENTRY %entry (a: f32[]) -> f32[] { ::testing::ElementsAre(op::Parameter())); } +// Tests that the order of companion instructions in the companion set doesn't +// change across runs. +TEST_F(HloModuleGroupTest, ModuleGroupCompanionOrder) { + // A simple while loop template for core i sending to core i+1. + constexpr char text[] = R"( +HloModule module_%d + +while_cond { + ROOT p = pred[] constant(true) +} + +while_body { + param = s32[] parameter(0) + token.s = token[] after-all() + token.r = token[] after-all() + send = (s32[], u32[], token[]) send(param, token.s), channel_id=%d + send-done = token[] send-done(send), channel_id=%d + recv = (s32[], u32[], token[]) recv(token.r), channel_id=%d + ROOT recv-done = (s32[], token[]) recv-done(recv), channel_id=%d +} + +ENTRY entry { + while_init = s32[] constant(1) + ROOT while = s32[] while(while_init), condition=while_cond, body=while_body +} +)"; + + // Try creating the module and the metadata kTrialCount times and check the + // companion instructions remain in the same order. + const int64 kTrialCount = 5; + const int64 kDeviceCount = 10; + std::vector<int64> companion_order; + + for (int64 t = 0; t < kTrialCount; ++t) { + HloModuleGroup group(TestName()); + for (int64 i = 0; i < kDeviceCount; ++i) { + const int64 send_channel = i; + const int64 recv_channel = i == 0 ? kDeviceCount - 1 : i - 1; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module, + ParseHloString(absl::StrFormat(text, i, send_channel, send_channel, + recv_channel, recv_channel))); + group.push_back(std::move(module)); + } + ASSERT_EQ(group.modules().size(), kDeviceCount); + + TF_ASSERT_OK_AND_ASSIGN(auto metadata, + HloModuleGroupMetadata::Build(group.modules())); + ASSERT_EQ(metadata->companion_sets().size(), 1); + + std::vector<int64> module_ids; + for (HloInstruction* companion : *metadata->companion_sets()[0]) { + module_ids.push_back(metadata->GetModuleId(companion->GetModule())); + } + + if (t == 0) { + companion_order = module_ids; + } else { + EXPECT_TRUE(absl::c_equal(companion_order, module_ids)); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index d83ee71490..fddeb5f0a2 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { std::vector<HloInstruction*> predecessors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet<HloInstruction*> unique; + absl::flat_hash_set<HloInstruction*> unique; // Adds to the unique predecessors list; if the predecessors is a companion // instruction, also add companion instructions; if the predecessors is a @@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { std::vector<HloInstruction*> successors; // Use a vector to avoid non-determinism. - tensorflow::gtl::FlatSet<HloInstruction*> unique; + absl::flat_hash_set<HloInstruction*> unique; // Adds to the unique successors list; if the successor is a companion // instruction, also add companion instructions; if the successor is a diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h index 309c23045d..f21b44bcd9 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h @@ -20,6 +20,7 @@ limitations under the License. #include <memory> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -87,7 +87,7 @@ class HloModuleGroupUtil { // * visit_state: map from each instruction to its visit state. // * visit_function: function called when each instruction group. // * root: the root instruction of the traversal. - using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>; + using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>; Status VisitTopologicalOrder(VisitStates* visit_state, const VisitFunction& visit_function, HloInstruction* root); diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 2d4e38589f..4551a1c2e2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) { } StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({ + static auto* opcode_map = new absl::flat_hash_map<string, HloOpcode>({ #define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ {opcode_name, HloOpcode::enum_name}, HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index f1dc08bafa..23d41d91d6 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { - // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' - // is live into the module. + // Entry parameter should always be defined before other instructions. const HloModule* module = b.defining_instruction()->parent()->parent(); if (b.defining_instruction()->parent() == module->entry_computation() && b.defining_instruction()->opcode() == HloOpcode::kParameter) { return false; } + if (a.defining_instruction()->parent() == module->entry_computation() && + a.defining_instruction()->opcode() == HloOpcode::kParameter) { + return true; + } + // Phi values require special handling. Because XLA does not have a phi // instruction, the definition instruction of the phis values are // placeholders: either the subcomputation parameter (body or condition) or @@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { for (auto predecessor : all) { if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { - pieces.push_back(absl::StrFormat(" %s", predecessor->name())); + pieces.push_back(absl::StrFormat(" %s", predecessor->name())); } } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b0361c3f02..66313492eb 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering { // predecessors. An instruction is an element of its own predecessor set. // // Subclasses should fill this in to define the desired ordering. - tensorflow::gtl::FlatMap<const HloComputation*, - std::unique_ptr<HloReachabilityMap>> + absl::flat_hash_map<const HloComputation*, + std::unique_ptr<HloReachabilityMap>> predecessors_; }; @@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering { // this map so more than one instruction may have the same position // value. This is not a problem because ExecutesBefore also verifies // instructions are in the same computation. - tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_; + absl::flat_hash_map<const HloInstruction*, int> order_position_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 00970bcda3..b045adc964 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } +TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) { + // Entry parameter should always be defined before other instruction. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0))); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param), + dataflow->GetValueDefinedAt(constant))); + EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(param))); +} + TEST_F(HloOrderingTest, ValuesInWhileComputations) { // Tests the ordering of values (defined by dataflow analysis) in the body and // condition of a while instruction. HLO code: diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 11caa89c54..dd62988bcc 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -64,14 +64,11 @@ class HloParser { public: using LocTy = HloLexer::LocTy; - explicit HloParser(absl::string_view str, const HloModuleConfig& config) - : lexer_(str), config_(config) {} + explicit HloParser(absl::string_view str) : lexer_(str) {} - // Runs the parser. Returns false if an error occurred. - bool Run(); - - // Returns the parsed HloModule. - std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); } + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns false if an error occurred. + Status Run(HloModule* module); // Returns the error information. string GetError() const { return StrJoin(error_, "\n"); } @@ -82,28 +79,37 @@ class HloParser { StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); StatusOr<PaddingConfig> ParsePaddingConfigOnly(); - // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); - private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map<string, std::pair<HloInstruction*, LocTy>>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair<HloInstruction*, LocTy>* FindInstruction( const string& name, const optional<Shape>& shape = nullopt); + // Parse a single instruction worth of text. + bool ParseSingleInstruction(HloModule* module); + // ParseXXX returns false if an error occurred. - bool ParseHloModule(); - bool ParseComputations(); + bool ParseHloModule(HloModule* module); + + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -284,25 +290,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map<string, std::pair<HloInstruction*, LocTy>> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector<InstrNameTable> scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector<InstrNameTable>* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector<InstrNameTable>* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map<string, std::pair<HloComputation*, LocTy>> computation_pool_; - HloLexer lexer_; - std::unique_ptr<HloModule> module_; std::vector<std::unique_ptr<HloComputation>> computations_; - const HloModuleConfig config_; std::vector<string> error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function<std::pair<HloInstruction*, LocTy>*(string, - const optional<Shape>&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function<std::pair<HloInstruction*, LocTy>*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { @@ -349,24 +377,50 @@ bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run() { +Status HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(); + if (lexer_.GetKind() == TokKind::kw_HloModule) { + // This means that the text contains a full HLO module. + if (!ParseHloModule(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a HloModule:\n%s", + GetError()); + } + return Status::OK(); + } + // This means that the text is a single HLO instruction. + if (!ParseSingleInstruction(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a single " + "HloInstruction:\n%s", + GetError()); + } + return Status::OK(); } std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( const string& name, const optional<Shape>& shape) { - std::pair<HloInstruction*, LocTy>* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair<HloInstruction*, LocTy>* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } // ::= 'HloModule' name computations -bool HloParser::ParseHloModule() { +bool HloParser::ParseHloModule(HloModule* module) { if (lexer_.GetKind() != TokKind::kw_HloModule) { return TokenError("expects HloModule"); } @@ -385,22 +439,20 @@ bool HloParser::ParseHloModule() { return false; } - module_ = absl::make_unique<HloModule>(name, config_); - - if (!ParseComputations()) { + module->set_name(name); + if (!ParseComputations(module)) { return false; } if (is_scheduled.has_value() && *is_scheduled) { - TF_CHECK_OK( - module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + TF_CHECK_OK(module->set_schedule(ScheduleFromInstructionOrder(module))); } return true; } // computations ::= (computation)+ -bool HloParser::ParseComputations() { +bool HloParser::ParseComputations(HloModule* module) { HloComputation* entry_computation = nullptr; do { if (!ParseComputation(&entry_computation)) { @@ -416,21 +468,20 @@ bool HloParser::ParseComputations() { if ((entry_computation != nullptr && computations_[i].get() != entry_computation) || (entry_computation == nullptr && i != computations_.size() - 1)) { - module_->AddEmbeddedComputation(std::move(computations_[i])); + module->AddEmbeddedComputation(std::move(computations_[i])); continue; } - auto computation = - module_->AddEntryComputation(std::move(computations_[i])); + auto computation = module->AddEntryComputation(std::move(computations_[i])); // The parameters and result layouts were set to default layout. Here we // set the layouts to what the hlo text says. for (int p = 0; p < computation->num_parameters(); p++) { const Shape& param_shape = computation->parameter_instruction(p)->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_parameter_layout(p) ->CopyLayoutFromShape(param_shape)); } const Shape& result_shape = computation->root_instruction()->shape(); - TF_CHECK_OK(module_->mutable_entry_computation_layout() + TF_CHECK_OK(module->mutable_entry_computation_layout() ->mutable_result_layout() ->CopyLayoutFromShape(result_shape)); } @@ -447,7 +498,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique<HloComputation::Builder>(name); LocTy shape_loc = nullptr; Shape shape; @@ -455,40 +505,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -497,43 +528,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair<HloInstruction*, LocTy>* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector<HloInstruction*> operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -544,6 +594,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector<HloInstruction*> operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map<string, AttrConfig> attrs; optional<OpSharding> sharding; @@ -1274,11 +1337,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional<string> custom_call_target; + optional<string> opaque; optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; optional<int64> feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; @@ -1287,8 +1352,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( - shape, operands, *custom_call_target)); + instruction = builder->AddInstruction( + HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); if (window.has_value()) { instruction->set_window(*window); } @@ -2151,7 +2217,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair<HloInstruction*, LocTy>* instruction = FindInstruction(name, shape); @@ -2304,9 +2383,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); return true; @@ -3139,7 +3226,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3209,91 +3296,96 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +bool HloParser::ParseSingleInstruction(HloModule* module) { + if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { + LOG(FATAL) << "Parser state is not clean. Please do not call any other " + "methods before calling ParseSingleInstruction."; + } + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair<HloInstruction*, LocTy>* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; - // Prime the lexer. - lexer_.Lex(); - // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return false; + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return false; + } } - return Status::OK(); + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); + } + return true; } } // namespace StatusOr<std::unique_ptr<HloModule>> ParseHloString( absl::string_view str, const HloModuleConfig& config) { - HloParser parser(str, config); - if (!parser.Run()) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } - return parser.ConsumeHloModule(); + auto module = absl::make_unique<HloModule>(/*name=*/"_", config); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) { - HloModuleConfig config; - return ParseHloString(str, config); + auto module = absl::make_unique<HloModule>(/*name=*/"_", HloModuleConfig()); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module.get())); + return std::move(module); } -StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - absl::string_view str, absl::string_view name) { - HloModuleConfig config; - HloParser parser(str, config); - auto builder = absl::make_unique<HloComputation::Builder>(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr<HloComputation> computation = builder->Build(); - auto module = absl::make_unique<HloModule>(string(name), config); - module->AddEntryComputation(std::move(computation)); - return std::move(module); +Status ParseHloString(absl::string_view str, HloModule* module) { + TF_RET_CHECK(module->computation_count() == 0); + HloParser parser(str); + TF_RETURN_IF_ERROR(parser.Run(module)); + return Status::OK(); } StatusOr<HloSharding> ParseSharding(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseShardingOnly(); } StatusOr<Window> ParseWindow(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseWindowOnly(); } StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbers( absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParseConvolutionDimensionNumbersOnly(); } StatusOr<PaddingConfig> ParsePaddingConfig(absl::string_view str) { - HloModuleConfig config; - HloParser parser(str, config); + HloParser parser(str); return parser.ParsePaddingConfigOnly(); } diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 1882a184da..81eeb9f13b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -30,18 +30,18 @@ namespace xla { // For details about the syntax accepted by this parser, see // g3doc/hlo_parser.md. -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with the given config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. StatusOr<std::unique_ptr<HloModule>> ParseHloString( absl::string_view str, const HloModuleConfig& config); -// Parses the text for a single HLO operation into an HLO module with a function -// that runs that operation (with the same parameters) as its entry computation. -StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - absl::string_view str, absl::string_view name = "single_op"); +// Given a string in the HloModule::ToString() format, parses the string and +// builds the HloModule in place at the given module pointer. 'module' must +// point to an empty module (no computations). +Status ParseHloString(absl::string_view str, HloModule* module); -// The api of the hlo parser. Given a string in the HloModule::ToString() -// format, parses the string and creates a HloModule with default config. +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with default config. StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str); // Parses the result of HloSharding::ToString(), e.g. "{replicated}". diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index cca50fab54..255123d331 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1004,6 +1004,18 @@ ENTRY CustomCall { )" }, +// CustomCall with opaque value. +{ +"CustomCallWithOpaque", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque" +} + +)" +}, // Variables with non-default names { "NonDefaultNames", @@ -1151,49 +1163,80 @@ ENTRY Sort { // clang-format on } -class HloParserTest : public ::testing::Test, - public ::testing::WithParamInterface<TestData> { +// The test class for those tests defined above which round-trip through the +// parser and ToString is templatized on two bool parameters: +// +// short_form : used for the "short" test cases which use the ShortParsable +// output form. +// proto_round_trip : whether the module should also be round-tripped through +// HloProto form. This provides much better coverage for the proto +// serialization/deserialization. +// +// The proto_round_trip=true case also technically covers the Parser->ToString +// roundtrip as well, but separating out the Parser->ToString roundtrip as its +// own test provides better isolation and could conceivably catch weirdo bugs +// which are hidden by interaction between the textual and proto roundtripping. +template <bool short_form, bool proto_round_trip> +class HloParameterizedParserTest + : public ::testing::Test, + public ::testing::WithParamInterface<TestData> { protected: - static void ExpectHasSubstr(string_view s, string_view expected) { - EXPECT_TRUE(absl::StrContains(s, expected)) - << "'" << s << "' does not contain '" << expected << "'"; - } - // Expects "ToString(ParseHloString(string)) == string", that is, parses the // string, asserts that it succeeded, stringifies the parsed module, and // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString( - HloPrintOptions().set_print_large_constants(true))); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(original)); + if (proto_round_trip) { + TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto( + module->ToProto(), module->config())); + } + if (short_form) { + EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable())); + } else { + EXPECT_EQ( + original, + module->ToString(HloPrintOptions().set_print_large_constants(true))); + } } }; -class HloParserShortTest : public HloParserTest { - protected: - void ExpectEqualShort() { - const string& original = GetParam().module_string; - auto result = ParseHloString(original); - TF_ASSERT_OK(result.status()); - EXPECT_EQ(original, - result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); - } -}; +// These using shenanigans are required because the TEST_P macro doesn't like +// template instantiations which contain commas. +using HloParserTestLong = HloParameterizedParserTest<false, false>; +using HloParserTestLongProto = HloParameterizedParserTest<false, true>; +using HloParserTestShort = HloParameterizedParserTest<true, false>; +using HloParserTestShortProto = HloParameterizedParserTest<true, true>; -TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserTestLong, Run) { ExpectEqual(); } +TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); } +TEST_P(HloParserTestShort, Run) { ExpectEqual(); } +TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); } -TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong, ::testing::ValuesIn(CreateTestCases()), TestDataToString); - -INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestLongProto, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, + HloParserTestShortProto, ::testing::ValuesIn(CreateShortTestCases()), TestDataToString); +class HloParserTest : public ::testing::Test { + protected: + static void ExpectHasSubstr(string_view s, string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } +}; + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = ParseHloString(original); @@ -1261,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) { ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) - %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4} %constant = s32[] constant(42) %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) } @@ -1720,6 +1763,25 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, SameNameDiffComputations) { + const string original = R"(HloModule same_names: +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT result = f32[] add(p0, p1) +} + +ENTRY ReduceR3ToR2 { + p0 = f32[8,16,256]{2,1,0} parameter(0) + p1 = f32[] constant(0) + ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); + ASSERT_NE(module->entry_computation(), nullptr); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + TEST_F(HloParserTest, ParseSharding) { const string original = "{maximal device=42}"; TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); @@ -1773,27 +1835,142 @@ TEST(HloParserSingleOpTest, SingleOp) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " "f32[2,4]{1,0} %x)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Parameter(0), op::Parameter(1))); } -TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { +TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { + const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; + StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("expects '=' in instruction")); +} + +TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); + StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); - EXPECT_THAT( - module.status().ToString(), - ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("Operand had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, SingleOpNoNames) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, CanonicalOp) { + const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, CanonicalOpWithNested) { + const string text = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested) { + const string text = + R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= +{ + %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("does not exist: x")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[], f32[]) +})"; + auto status = ParseHloString(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const string text = R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/hlo_pass_interface.h b/tensorflow/compiler/xla/service/hlo_pass_interface.h index f1ad0f9b01..fdaac34386 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_interface.h +++ b/tensorflow/compiler/xla/service/hlo_pass_interface.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_INTERFACE_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_module_group.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -25,15 +26,45 @@ limitations under the License. namespace xla { // Base class for HLO passes. These are used with the HloPassPipeline to -// organize a sequence of passes. +// organize a sequence of passes. An HLO pass should not extend this class +// directly; it should extend HloModulePass or HloModuleGroupPass. class HloPassInterface { public: virtual ~HloPassInterface() = default; virtual absl::string_view name() const = 0; - // Run the pass on the given HLO module. Return whether it modified the + // Run the pass on the given HLO module. Returns whether it modified the // module. virtual StatusOr<bool> Run(HloModule* module) = 0; + + // Run the pass on the given HLO module group. Returns whether it modified the + // module group. Ideally, the module group variant would be named "Run" as + // well, but C++ does not handle overloaded virtual methods well. + virtual StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) = 0; +}; + +// Base class for passes which are module-scoped. +class HloModulePass : public HloPassInterface { + public: + // Runs the pass on a module group by iterating through each module in the + // group. + StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + TF_ASSIGN_OR_RETURN(bool module_changed, Run(module)); + changed |= module_changed; + } + return changed; + }; +}; + +// Base class for passes which are module-group scoped. These passes cannot run +// on an HLO module. +class HloModuleGroupPass : public HloPassInterface { + public: + StatusOr<bool> Run(HloModule* module) override { + return InternalError("Module group pass cannot be run on a module"); + } }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 6e4ed0de62..5e004ce78a 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -17,7 +17,8 @@ limitations under the License. #include <functional> -#include "absl/strings/str_cat.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" @@ -25,112 +26,131 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { -namespace { -using absl::StrAppend; -using absl::StrCat; - -void DumpModuleGraph(const HloModule& module, const string& message) { - hlo_graph_dumper::MaybeDumpHloModule(module, message); - VLOG(3) << "HLO " << message << ":"; - XLA_VLOG_LINES(3, module.ToString()); +template <typename HloT> +Status HloPassPipeline::RunInvariantCheckers( + HloT* hlo, absl::string_view after_pass_name) { + for (auto& invariant_checker : invariant_checkers_) { + VLOG(1) << " Invariant checker " << invariant_checker->name(); + StatusOr<bool> changed_status = RunHelper(invariant_checker.get(), hlo); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); + if (!changed_status.ok()) { + VLOG(2) << "Failed invariant check:"; + XLA_VLOG_LINES(2, hlo->ToString()); + return Status(changed_status.status().code(), + absl::StrCat(changed_status.status().error_message(), + "\n\nFailed after ", after_pass_name)); + } + TF_RET_CHECK(!changed_status.ValueOrDie()) + << "invariant checkers must not change the graph"; + } + return Status::OK(); } -void DumpModuleProto(const HloModule& module, const string& dump_to, - const string& pipeline_name, const string& pass_name) { - static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); - static auto* const module_id_to_pass_number = - new tensorflow::gtl::FlatMap<int64, int64>(); - - tensorflow::mutex_lock lock(mu); - const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; +template <typename HloT> +StatusOr<bool> HloPassPipeline::RunPassesInternal( + HloT* hlo, absl::Span<HloPassInterface* const> passes) { + string last_pass_name = "pipeline-start"; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, last_pass_name)); + bool changed = false; + for (HloPassInterface* pass : passes) { + VLOG(1) << " HLO pass " << pass->name(); + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/pass->name()); + TF_ASSIGN_OR_RETURN(bool pass_changed, RunHelper(pass, hlo)); + changed |= pass_changed; + TF_RETURN_IF_ERROR(RunInvariantCheckers(hlo, pass->name())); + last_pass_name = string(pass->name()); + } + MaybeDumpHlo(*hlo, + /*after_pass_name=*/last_pass_name, + /*before_pass_name=*/"pipeline-end"); + return changed; +} - const string mod_name = SanitizeFileName( - absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), - pass_number, pipeline_name, pass_name)); +std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses( + const DebugOptions& debug_options) { + auto repeated_field = debug_options.xla_disable_hlo_passes(); + absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(), + repeated_field.end()); + if (!disabled_pass_names.empty()) { + VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " + << absl::StrJoin(disabled_pass_names, ", "); + } - TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), - dump_to, mod_name)); + std::vector<HloPassInterface*> enabled_passes; + for (auto& pass : passes_) { + if (disabled_pass_names.count(string(pass->name())) == 0) { + enabled_passes.push_back(pass.get()); + } + } + return enabled_passes; } -} // namespace -StatusOr<bool> HloPassPipeline::Run(HloModule* module) { - run_called_ = true; +void HloPassPipeline::MaybeDumpHlo(const HloModule& module, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + const string& proto_dump_path = + module.config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!proto_dump_path.empty()) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new absl::flat_hash_map<int64, int64>(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string filename = SanitizeFileName( + absl::StrFormat("module_%04d.%04d.%s.after_%s", module.unique_id(), + pass_number, name(), after_pass_name)); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory( + MakeHloProto(module), proto_dump_path, filename)); + } - VLOG(1) << "Running HLO pass pipeline " << name(); + const string message = + StrCat("after ", after_pass_name, ", before ", before_pass_name); + hlo_graph_dumper::MaybeDumpHloModule(module, message); + VLOG(3) << "HLO " << message << ":"; + XLA_VLOG_LINES(3, module.ToString()); +} - auto repeated_field = - module->config().debug_options().xla_disable_hlo_passes(); - tensorflow::gtl::FlatSet<string> disabled_passes(repeated_field.begin(), - repeated_field.end()); - if (!disabled_passes.empty()) { - VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: " - << absl::StrJoin(disabled_passes, ", "); +void HloPassPipeline::MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name) { + for (const HloModule* module : module_group.modules()) { + MaybeDumpHlo(*module, after_pass_name, before_pass_name); } +} - auto run_invariant_checkers = [this, - module](const string& message) -> Status { - for (auto& invariant_checker : invariant_checkers_) { - VLOG(1) << " Invariant checker " << invariant_checker->name(); - StatusOr<bool> changed_status = invariant_checker->Run(module); - VLOG(1) << " Invariant checker done " << invariant_checker->name(); - if (!changed_status.ok()) { - VLOG(2) << "Module failed invariant check:"; - XLA_VLOG_LINES(2, module->ToString()); - return Status(changed_status.status().code(), - StrCat(changed_status.status().error_message(), - "\n\nFailed ", message)); - } - TF_RET_CHECK(!changed_status.ValueOrDie()) - << "invariant checkers must not change the graph"; - } - return Status::OK(); - }; +StatusOr<bool> HloPassPipeline::Run(HloModule* module) { + run_called_ = true; - string prefix = StrCat(name(), ": pipeline start"); - bool changed = false; - string message; - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("before running pipeline: ", name()))); - const string xla_dump_per_pass_hlo_proto_to = - module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - "pipeline_start"); - } + VLOG(1) << "Running HLO pass pipeline on module " << module->name() << ": " + << name(); - for (auto& pass : passes_) { - if (disabled_passes.count(string(pass->name())) > 0) { - VLOG(1) << " Skipping HLO pass " << pass->name() - << ", disabled by --xla_disable_hlo_passes"; - continue; - } + return RunPassesInternal(module, + GetEnabledPasses(module->config().debug_options())); +} - VLOG(1) << " HLO pass " << pass->name(); +StatusOr<bool> HloPassPipeline::RunOnModuleGroup(HloModuleGroup* module_group) { + run_called_ = true; - // Emit label containing: "after foo-pass, before bar-pass". - message.clear(); - StrAppend(&message, prefix, ", before ", pass->name()); - DumpModuleGraph(*module, message); - - TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); - TF_RETURN_IF_ERROR( - run_invariant_checkers(StrCat("after running pass: ", pass->name()))); - if (!xla_dump_per_pass_hlo_proto_to.empty()) { - DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, string(name()), - string(pass->name())); - } + VLOG(1) << "Running HLO pass pipeline on module group " + << module_group->name() << ": " << name(); - changed |= changed_this_pass; - prefix.clear(); - StrAppend(&prefix, name(), ": after ", pass->name()); + if (module_group->modules().empty()) { + VLOG(1) << "Module group is empty. Nothing to do."; + return false; } - DumpModuleGraph(*module, prefix + ", pipeline end"); - return changed; + + return RunPassesInternal( + module_group, + GetEnabledPasses(module_group->module(0).config().debug_options())); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h index 1d41a4dac1..09e7033ea4 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/statusor.h" @@ -61,10 +62,45 @@ class HloPassPipeline : public HloPassInterface { return *pass; } - // Run all passes on the given HLO module. StatusOr<bool> Run(HloModule* module) override; + StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override; private: + // Returns the set of passes which are enabled. DebugOptions can selectively + // disable passes via --xla_disable_hlo_passes flag. + std::vector<HloPassInterface*> GetEnabledPasses( + const DebugOptions& debug_options); + + // Maybe dumps the given module or module group depending on flag values + // contained in DebugOptions of module config. + void MaybeDumpHlo(const HloModuleGroup& module_group, + absl::string_view after_pass_name, + absl::string_view before_pass_name); + void MaybeDumpHlo(const HloModule& module, absl::string_view after_pass_name, + absl::string_view before_pass_name); + + // Runs the invariant checker on the given HLO. HloT can be either HloModule + // or HloModuleGroup. + template <typename HloT> + Status RunInvariantCheckers(HloT* hlo, absl::string_view after_pass_name); + + // Helper which runs the given pass on the given HLO. HloT can be either + // HloModule or HloModuleGroup. + template <typename HloT> + StatusOr<bool> RunPassesInternal(HloT* hlo, + absl::Span<HloPassInterface* const> passes); + + // Helpers which run the given passes on the given HLO construct. These + // helpers enable templating of the core of the pipeline logic by providing + // HloModule and HloModuleGroup specific methods with the same name. + static StatusOr<bool> RunHelper(HloPassInterface* pass, HloModule* module) { + return pass->Run(module); + } + static StatusOr<bool> RunHelper(HloPassInterface* pass, + HloModuleGroup* module_group) { + return pass->RunOnModuleGroup(module_group); + } + const string name_; std::vector<std::unique_ptr<HloPassInterface>> passes_; std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_; diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc new file mode 100644 index 0000000000..ee8cb12b23 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloPassPipelineTest : public HloVerifiedTestBase { + protected: + StatusOr<HloModuleGroup> ParseModuleGroup( + absl::Span<const string> hlo_strings) { + HloModuleGroup group(TestName()); + for (const string& hlo_string : hlo_strings) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(hlo_string)); + group.push_back(std::move(module)); + } + return std::move(group); + } +}; + +// A module pass which renames instructions named 'foo' to 'bar'. +class FooToBarModulePass : public HloModulePass { + absl::string_view name() const override { return "foo2bar"; } + + StatusOr<bool> Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "foo") { + instruction->SetAndSanitizeName("bar"); + changed = true; + } + } + } + return changed; + } +}; + +// A module group pass which renames instructions named 'baz' to 'qux'. +class BazToQuxModuleGroupPass : public HloModuleGroupPass { + absl::string_view name() const override { return "baz2qux"; } + + StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group) override { + bool changed = false; + for (HloModule* module : module_group->modules()) { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "baz") { + instruction->SetAndSanitizeName("qux"); + changed = true; + } + } + } + } + return changed; + } +}; + +// An invariant checker pass which returns an error if there exists an +// instruction named 'bar'. +class BarBlowerUpper : public HloModulePass { + absl::string_view name() const override { return "bar-blower-upper"; } + + StatusOr<bool> Run(HloModule* module) override { + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == "bar") { + return InternalError("Module has instruction named bar"); + } + } + } + return false; + } +}; + +TEST_F(HloPassPipelineTest, ModulePassChanged) { + // Test an HLO module pass which changes a module. + const string module_str = R"( +HloModule ModulePassChanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass<FooToBarModulePass>(); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->name(), "foo"); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_EQ(root->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, ModulePassUnchanged) { + // Test an HLO module pass which does not change a module. + const string module_str = R"( +HloModule ModulePassUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT blahblah = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass<FooToBarModulePass>(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(HloPassPipelineTest, MixedPipeline) { + // Test a pipeline with both a module pass and a module group pass. + const string module_0_str = R"( +HloModule MixedPipeline.1 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT baz = f32[] multiply(a, b) +} +)"; + const string module_1_str = R"( +HloModule MixedPipeline.0 + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group, + ParseModuleGroup({module_0_str, module_1_str})); + + HloPassPipeline pipeline(TestName()); + pipeline.AddPass<BazToQuxModuleGroupPass>(); + pipeline.AddPass<FooToBarModulePass>(); + + HloInstruction* root0 = + module_group.module(0).entry_computation()->root_instruction(); + HloInstruction* root1 = + module_group.module(1).entry_computation()->root_instruction(); + EXPECT_EQ(root0->name(), "baz"); + EXPECT_EQ(root1->name(), "foo"); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pipeline.RunOnModuleGroup(&module_group)); + EXPECT_TRUE(changed); + + EXPECT_EQ(root0->name(), "qux"); + EXPECT_EQ(root1->name(), "bar"); +} + +TEST_F(HloPassPipelineTest, InvariantChecker) { + const string module_str = R"( +HloModule InvariantChecker + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str)); + { + // Run a pipeline with just the invariant checker. It should not fail + // because there is no 'bar' instruction in the module. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker<BarBlowerUpper>(); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_FALSE(changed); + } + + { + // Run a pipeline which renames 'foo' to 'bar' then an invariant checker + // which fails if there is an instruction named 'bar'. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker<BarBlowerUpper>(); + pipeline.AddPass<FooToBarModulePass>(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after foo2bar")); + } + + { + // Run the invariant-checker only pipeline again. It should fail this time. + HloPassPipeline pipeline(TestName()); + pipeline.AddInvariantChecker<BarBlowerUpper>(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Module has instruction named bar")); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Failed after pipeline-start")); + } +} + +TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) { + // Running a module group pass on a module should produce an error. + const string module_str = R"( +HloModule ModuleGroupPassOnModule + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT foo = f32[] multiply(a, b) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module, + ParseAndReturnVerifiedModule(module_str)); + HloPassPipeline pipeline(TestName()); + pipeline.AddPass<BazToQuxModuleGroupPass>(); + + Status status = pipeline.Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Module group pass cannot be run on a module")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index b66a2aa4bd..5a5f01f8fd 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -19,11 +19,11 @@ limitations under the License. #include <list> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -154,7 +154,7 @@ class HloReachabilityMap { // Dense assignment from HloInstruction* to number. These numbers index // into the bit_vectors_ vector and into the bits within a BitVector. - tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_; + absl::flat_hash_map<const HloInstruction*, int> indices_; // Bitvectors holding the reachability to each instruction. The bit vector for // instruction X includes ones for each instruction which X is reachable from. diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bd6dd79b67..5ac43808ee 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -20,6 +20,8 @@ limitations under the License. #include <set> #include <string> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -75,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) { // cache before, and eventually calling the IsRematerializable() API. bool CanBeRematerialized( const HloInstruction* instruction, - tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) { + absl::flat_hash_map<const HloInstruction*, bool>* remat_able) { auto it = remat_able->find(instruction); if (it != remat_able->end()) { return it->second; @@ -268,7 +270,7 @@ class InstructionList { Item* first_; // Item for each instruction. - tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_; + absl::flat_hash_map<const HloInstruction*, Item*> item_map_; }; // Return the items which use the given LogicalBuffer. Sets @@ -503,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker( PointsToSet::BufferSet live_out_set = points_to_analysis.GetPointsToSet(computation_->root_instruction()) .CreateFlattenedSet(); - tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId> + absl::flat_hash_map<const LogicalBuffer*, BufferId> logical_buffer_to_buffer_id; for (auto* item = instruction_list_.first(); item != nullptr; @@ -854,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction, Item* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, int64 memory_limit_bytes, - tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) { + absl::flat_hash_map<const HloInstruction*, bool>* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -980,10 +982,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation( // rematerialization is essentially a move). If the next rematerialization of // the instruction is also a move then the rematerialization is added to the // blacklist. - tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions; + absl::flat_hash_set<const HloInstruction*> remat_move_instructions; // The map from instructions to their rematerializable status. - tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able; + absl::flat_hash_map<const HloInstruction*, bool> remat_able; // The peak memory of the computation at any point in the instruction // sequence. @@ -1198,6 +1200,12 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) { << HumanReadableNumBytes(memory_limit_bytes_); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); + // Initialize pass object state. + computation_peak_memory_.clear(); + rematerialized_computations_.clear(); + instructions_rematerialized_ = 0; + net_instructions_added_ = 0; + TF_RET_CHECK(module->has_schedule()); TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index e2aaf18b3e..70d83c04f0 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -15,6 +15,8 @@ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_ +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -33,7 +35,7 @@ namespace xla { // CSE will undo the effects of this optimization and should not be run after // this pass. In general, this pass should be run very late, immediately before // code generation. -class HloRematerialization : public HloPassInterface { +class HloRematerialization : public HloModulePass { public: using ShapeSizeFunction = std::function<int64(const Shape&)>; @@ -115,14 +117,13 @@ class HloRematerialization : public HloPassInterface { // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization // occurs. - tensorflow::gtl::FlatMap<const HloComputation*, int64> - computation_peak_memory_; + absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_; std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_; // Set of computations which have had rematerialization // applied. Rematerialization is only applied once per computation. - tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_; + absl::flat_hash_set<const HloComputation*> rematerialized_computations_; // Count of the total instructions rematerialized. int64 instructions_rematerialized_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index 3fc5dbeb02..9972eb2077 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -18,6 +18,8 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" @@ -30,7 +32,7 @@ namespace xla { /* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto( const HloModule* module, const HloScheduleProto& proto) { - tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation; + absl::flat_hash_map<int64, const HloComputation*> id_to_computation; for (const HloComputation* computation : module->computations()) { id_to_computation[computation->unique_id()] = computation; } @@ -44,7 +46,7 @@ namespace xla { << "No computation exists in HLO module with id " << computation_id; const HloComputation* computation = comp_it->second; - tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction; + absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { id_to_instruction[instruction->unique_id()] = instruction; } @@ -112,13 +114,13 @@ Status HloSchedule::UpdateComputationSchedule( const HloComputation* computation) { // Map from unique ID to HloInstruction pointer for instructions in the // computation. - tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction; + absl::flat_hash_map<int, const HloInstruction*> id_to_instruction; for (const HloInstruction* instruction : computation->instructions()) { InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); } // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet<int> ids_in_schedule; + absl::flat_hash_set<int> ids_in_schedule; for (int id : sequences_.at(computation->unique_id()).ids()) { InsertOrDie(&ids_in_schedule, id); } @@ -126,15 +128,13 @@ Status HloSchedule::UpdateComputationSchedule( // Map from HloInstruction X to newly added instructions (instruction is in // computation, but not in schedule) which use X. If an instruction is not in // the map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap<const HloInstruction*, - std::vector<const HloInstruction*>> + absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>> new_instruction_uses; // For each newly added instruction, this is the count of the instruction's // operands that have not yet been scheduled. When this value reaches zero, // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap<const HloInstruction*, int> - unscheduled_operand_count; + absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count; // Create a worklist of newly added instructions which are ready to be added // to the schedule. Initialize worklist with those that have zero operands. @@ -211,15 +211,15 @@ Status HloSchedule::Update() { if (sequences_.size() > nonfusion_computations.size()) { // Schedule contains some computations which have been removed from the // HloModule. Remove them from the schedule as well. - tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids; + absl::flat_hash_set<int64> nonfusion_computations_ids; for (const HloComputation* computation : nonfusion_computations) { nonfusion_computations_ids.insert(computation->unique_id()); } for (auto it = sequences_.begin(); it != sequences_.end();) { if (nonfusion_computations_ids.count(it->first) == 0) { - it = sequences_.erase(it); + sequences_.erase(it++); } else { - it++; + ++it; } } } @@ -254,7 +254,7 @@ Status HloSchedule::Verify() const { // For each computation verify the set of instructions is the same and that // each dependency and control edge is honored. for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position; + absl::flat_hash_map<const HloInstruction*, int> instruction_position; int pos = 0; for (const HloInstruction* instruction : sequence(computation).instructions()) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 270fe6039f..0a714101ee 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -18,6 +18,7 @@ limitations under the License. #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -103,8 +104,7 @@ class HloSchedule { // Returns a map from HloComputation unique ID to instruction sequence. The // map contains all sequences in the schedule. - const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences() - const { + const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const { return sequences_; } @@ -148,7 +148,7 @@ class HloSchedule { // A map from computation unique ID to instruction sequence. Unique IDs are // used rather than HloComputation pointers because HLO pointers are not // unique across HLO transformations because pointers may be recycled. - tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_; + absl::flat_hash_map<int64, HloInstructionSequence> sequences_; }; std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index de7e6b53d4..188f4acc79 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/overflow_util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -369,10 +370,28 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, return HloSharding(tuple_shardings); } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || - proto.tile_assignment_devices().size() == 1) { + } else if (proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } + + TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL) + << "Maximal sharding is expected to have single device assignment, but " + << proto.tile_assignment_devices().size() << " has provided."; + + TF_RET_CHECK(proto.tile_assignment_devices().size() > 1); + TF_RET_CHECK(!proto.tile_assignment_dimensions().empty()); + + // RE: the product of tile assignment tensor dimensions must be + // equal to tile_assignment_devices.size(). + int64 product_of_dimensions = 1; + for (auto dimension : proto.tile_assignment_dimensions()) { + TF_RET_CHECK(dimension > 0); + product_of_dimensions = + MultiplyWithoutOverflow(product_of_dimensions, dimension); + TF_RET_CHECK(product_of_dimensions > 0); + } + TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size()); + // Some versions of gcc cannot infer the TileAssignment constructor from a // braced initializer-list, so create one manually. std::vector<int64> devices(proto.tile_assignment_devices().begin(), diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h index d1cf644f82..fa34bddde1 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h @@ -22,7 +22,7 @@ namespace xla { // Unify subcomputations of a `HloModule`: if any computations are equal, choose // one arbitrarily to use and delete the others. -class HloSubcomputationUnification : public HloPassInterface { +class HloSubcomputationUnification : public HloModulePass { public: absl::string_view name() const override { return "subcomputation-unification"; diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 773fc7d225..59594ab2f0 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -18,6 +18,7 @@ limitations under the License. #include <algorithm> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -131,6 +131,7 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); + case HloOpcode::kDomain: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; @@ -166,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses( positions_.insert(positions_.end(), positions.begin(), positions.end()); // Gather the computation roots at which this value appears. - tensorflow::gtl::FlatSet<HloInstruction*> root_positions; + absl::flat_hash_set<HloInstruction*> root_positions; for (const HloPosition& position : positions_) { if (position.instruction == position.instruction->parent()->root_instruction()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 0f6ecd42f6..496fe1795d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -15,6 +15,7 @@ limitations under the License. #include <set> +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -315,7 +315,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { int64 output_dimension = broadcast->dimensions()[operand_dimension]; TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) && (broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension))) + operand_shape.dimensions(operand_dimension))) << broadcast->ToString() << " operand shape " << operand_shape; } return Status::OK(); @@ -549,6 +549,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kTupleSelect: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kSort: case HloOpcode::kTuple: case HloOpcode::kWhile: break; @@ -764,7 +765,136 @@ Status VerifyHloStructure(HloModule* module) { return Status::OK(); } -Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape())); + } + } + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast<const HloSendRecvInstruction>(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast<const HloSendRecvInstruction>(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString(), instr2->ToString()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + absl::flat_hash_map<int64, const HloInstruction*> host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast<const HloSendRecvInstruction>(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %d is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + +// CHECKs various invariants of a fusion instruction. +Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { @@ -867,50 +997,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } } + TF_RET_CHECK(fusion->called_computations() == + absl::Span<HloComputation* const>( + {fusion->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << fusion->ToString() << " fusion->fused_instructions_computation(): " + << fusion->fused_instructions_computation()->ToString() + << " fusion->called_computations(): " + << ComputationsToString(fusion->called_computations()); + + for (const auto& fused : fusion->fused_instructions()) { + TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << fusion->parent(); + } + // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. return Status::OK(); } -Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - if (while_cond->num_parameters() != 1) { - return FailedPrecondition( - "While condition must have exactly 1 parameter; had %d : %s", - while_cond->num_parameters(), while_cond->ToString()); - } - if (while_body->num_parameters() != 1) { - return FailedPrecondition( - "While body must have exactly 1 parameter; had %d : %s", - while_body->num_parameters(), while_body->ToString()); - } - if (instruction->operand_count() != 1) { - return FailedPrecondition( - "While loop must have exactly one operand; had %d : %s", - instruction->operand_count(), instruction->ToString()); - } - return Status::OK(); -} - -Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { - if (instruction->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %d", - instruction->true_computation()->name(), instruction->ToString(), - instruction->true_computation()->num_parameters()); - } - if (instruction->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %d", - instruction->false_computation()->name(), instruction->ToString(), - instruction->false_computation()->num_parameters()); - } - return Status::OK(); -} - -Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { +// Checks that the non-scalar operand shapes are compatible to the output +// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); @@ -927,199 +1039,158 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } -namespace { +// Visitor which verifies various fields on the HLO instruction. This class does +// not check result shape as that is checked in the ShapeVerifier. +class InstructionVerifier : public DfsHloVisitorWithDefault { + public: + explicit InstructionVerifier(std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) + : instruction_can_change_layout_func_( + instruction_can_change_layout_func) {} -// Returns true if the given Shape has a TOKEN shape as any subshape. -bool ShapeContainsToken(const Shape& shape) { - bool contains_token = false; - ShapeUtil::ForEachSubshape( - shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { - contains_token = true; - } - }); - return contains_token; -} + Status DefaultAction(HloInstruction*) override { return Status::OK(); } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape())); - } + Status HandleFusion(HloInstruction* fusion) override { + return CheckFusionInstruction(fusion); } - return Status::OK(); -} -// Checks if the given two instructions share the same channel id. -Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%d), %s (%d)", - instr1->ToString(), instr1->channel_id(), instr2->ToString(), - instr2->channel_id()); + Status HandleBroadcast(HloInstruction* broadcast) override { + // If you see this failure then someone has confused the difference + // between the HLO broadcast op, and the UserComputation broadcast + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // or ComputationLowerer::Visit() + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(broadcast->operand(0)->shape())) + << "Broadcast HLO (" << broadcast->ToShortString() + << ") has invalid number of dimensions: " + << broadcast->dimensions().size() + << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + return Status::OK(); } - return Status::OK(); -} -// Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their -// 'done' variant. -Status CheckSameIsHostTransfer(const HloInstruction* instr1, - const HloInstruction* instr2) { - const HloSendRecvInstruction* send_recv1 = - DynCast<const HloSendRecvInstruction>(instr1); - const HloSendRecvInstruction* send_recv2 = - DynCast<const HloSendRecvInstruction>(instr2); - TF_RET_CHECK(send_recv1 != nullptr); - TF_RET_CHECK(send_recv2 != nullptr); - if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( - "Expected instructions to have the same is-host-transfer property: " - "%s, " - "%s ", - instr1->ToString(), instr2->ToString()); + Status HandleWhile(HloInstruction* xla_while) override { + auto* while_cond = xla_while->while_condition(); + auto* while_body = xla_while->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); + } + if (xla_while->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %d : %s", + xla_while->operand_count(), xla_while->ToString()); + } + return Status::OK(); } - return Status::OK(); -} -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - tensorflow::gtl::FlatMap<int64, const HloInstruction*> host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast<const HloSendRecvInstruction>(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %d is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - sendrecv->channel_id(), sendrecv->ToString(), - it_inserted.first->second->ToString()); - } + Status HandleConditional(HloInstruction* conditional) override { + if (conditional->true_computation()->num_parameters() != 1) { + return FailedPrecondition( + "True computation %s of %s must have 1 parameter insted of %d", + conditional->true_computation()->name(), conditional->ToString(), + conditional->true_computation()->num_parameters()); } + if (conditional->false_computation()->num_parameters() != 1) { + return FailedPrecondition( + "False computation %s of %s must have 1 parameter insted of %d", + conditional->false_computation()->name(), conditional->ToString(), + conditional->false_computation()->num_parameters()); + } + return Status::OK(); + } + + Status HandleElementwiseUnary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleElementwiseBinary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + Status HandleGetTupleElement(HloInstruction* gte) override { + TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); return Status::OK(); - }; + } - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. - for (const HloComputation* computation : module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - switch (instruction->opcode()) { - case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); - break; - } - case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); - break; + Status HandleTranspose(HloInstruction* transpose) override { + const Shape& shape = transpose->shape(); + const HloInstruction* operand = transpose->operand(0); + TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size()); + TF_RET_CHECK(shape.dimensions().size() == + transpose->operand(0)->shape().dimensions().size()); + TF_RET_CHECK(std::equal( + operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(transpose->dimensions(), shape.dimensions()).begin())) + << "shape: " << shape << ", operand->shape(): " << shape + << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ") + << "}"; + return Status::OK(); + } + + Status Preprocess(HloInstruction* instruction) override { + auto previous = instructions_by_name_.find(instruction->name()); + TF_RET_CHECK(previous == instructions_by_name_.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << instruction->parent()->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions_by_name_[instruction->name()] = instruction; + return Status::OK(); + } + + Status Postprocess(HloInstruction* instruction) override { + if (instruction_can_change_layout_func_ && + LayoutUtil::IsDenseArray(instruction->shape()) && + !instruction_can_change_layout_func_(instruction)) { + const Shape& result_shape = instruction->shape(); + const Layout& result_layout = result_shape.layout(); + for (HloInstruction* operand : instruction->operands()) { + const Shape& operand_shape = operand->shape(); + if (LayoutUtil::IsDenseArray(operand_shape) && + ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) { + const Layout& operand_layout = operand_shape.layout(); + TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout)) + << "Instruction shouldn't change layouts " + << instruction->ToString() << " From " + << ShapeUtil::HumanString(result_shape) << " To " + << ShapeUtil::HumanString(operand_shape); } - case HloOpcode::kSendDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); - break; - case HloOpcode::kRecvDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); - break; - default: - break; } } + + return Status::OK(); } - return Status::OK(); -} + + private: + absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_; + // Determines whether an instruction can change layouts. + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func_; +}; } // namespace StatusOr<bool> HloVerifier::Run(HloModule* module) { + TF_RET_CHECK(!module->name().empty()); TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions; - for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation); - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK(instruction->called_computations() == - absl::Span<HloComputation* const>( - {instruction->fused_instructions_computation()})) - << "Fusion HLO calls computations other than the " - "fused_instructions_computation: " - << instruction->ToString() - << " instruction->fused_instructions_computation(): " - << instruction->fused_instructions_computation()->ToString() - << " instruction->called_computations(): " - << ComputationsToString(instruction->called_computations()); - - for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == - instruction->fused_instructions_computation()) - << "Fused HLO was missing a parent: " << fused->ToString() - << " parent: " << fused->parent() - << " computation: " << computation; - } - } else if (instruction->opcode() == HloOpcode::kBroadcast) { - // If you see this failure then someone has confused the difference - // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I - // or ComputationLowerer::Visit() - TF_RET_CHECK(instruction->dimensions().size() == - ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO (" << instruction->ToShortString() - << ") has invalid number of dimensions: " - << instruction->dimensions().size() - << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); - } else if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); - } else if (instruction->opcode() != - HloOpcode::kRng /* Rng operands are always scalar. */ - && instruction->IsElementwise()) { - TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); - } - - auto previous = instructions.find(instruction->name()); - TF_RET_CHECK(previous == instructions.end()) - << "HLO has name that is not unique within module:\n" - << instruction->ToString() - << " in computation: " << computation->name() - << "\nPrevious HLO with same name:\n" - << previous->second->ToString() - << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction; - } - std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + + InstructionVerifier instruction_verifier( + instruction_can_change_layout_func_); + TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 42e3027bf1..cb49cb95ba 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -151,15 +151,21 @@ class ShapeVerifier : public DfsHloVisitor { // HLO pass that verifies invariants of HLO instructions for each computation in // the module. -class HloVerifier : public HloPassInterface { +class HloVerifier : public HloModulePass { public: using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>; - explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision) + explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}) : shape_verifier_factory_([layout_sensitive, allow_mixed_precision] { return absl::make_unique<ShapeVerifier>(layout_sensitive, allow_mixed_precision); - }) {} + }), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { + CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive); + } // Uses custom shape verification. explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) @@ -172,22 +178,15 @@ class HloVerifier : public HloPassInterface { StatusOr<bool> Run(HloModule* module) override; private: - // CHECKs various invariants of a fusion instruction. - Status CheckFusionInstruction(HloInstruction* fusion) const; - - Status CheckWhileInstruction(HloInstruction* instruction); - - Status CheckConditionalInstruction(HloInstruction* instruction); - - // Checks that the non-scalar operand shapes are compatible to the output - // shape, i.e., that there are no implicit broadcasts of size-one dimensions. - Status CheckElementwiseInstruction(HloInstruction* instruction); - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This is a factory function because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; + + // Determines whether an instruction can change layouts. + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc index 8f0423bb1c..afe01e5487 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; +class HloVerifierTestLayoutSensitive : public HloTestBase { + public: + HloVerifierTestLayoutSensitive() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + TEST_F(HloVerifierTest, NullInstructionParent) { HloComputation::Builder builder(TestName()); const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) { HasSubstr("non-positive base area dilation factor")); } +static const char* const kAddWithLayoutChangeHlo = R"( + HloModule AddWithLayoutChange + ENTRY AddWithLayoutChange { + par0 = f32[3,4]{1,0} parameter(0) + par1 = f32[3,4]{0,1} parameter(1) + ROOT add0 = f32[3,4]{1,0} add(par0,par1) + } + )"; + +TEST_F(HloVerifierTest, AddWithLayoutChange) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) { + const char* const kSliceWithLayoutChangeHlo = R"( + HloModule SliceWithLayoutChange + ENTRY SliceWithLayoutChange { + par0 = f32[4,5]{0,1} parameter(0) + par1 = s32[2] parameter(1) + ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1), + dynamic_slice_sizes={3,4} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kSliceWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} + +TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) { + const char* const kConcatWithLayoutChangeHlo = R"( + HloModule ConcatWithLayoutChange + ENTRY ConcatWithLayoutChange { + par0 = f32[3,5]{0,1} parameter(0) + par1 = f32[3,3]{1,0} parameter(1) + ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1), + dimensions={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseHloString(kConcatWithLayoutChangeHlo)); + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("Instruction shouldn't change layouts")); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h index 85bb4a8b24..9c48b7db61 100644 --- a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -25,7 +25,7 @@ namespace xla { // Pass which replaces all implicit broadcasts with their equivalent sequence of // explicit broadcast and reshape instructions. -class ImplicitBroadcastRemover : public HloPassInterface { +class ImplicitBroadcastRemover : public HloModulePass { public: ImplicitBroadcastRemover() {} ~ImplicitBroadcastRemover() override {} diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 06f0e1ed25..1ebb331977 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -23,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace gtl = ::tensorflow::gtl; @@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache( absl::InlinedVector<const HloInstruction*, 4> stack; enum DfsState { kDiscovered, kVisited }; - gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map; + absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map; stack.push_back(root); InsertOrDie(&dfs_state_map, root, kDiscovered); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index df9cbab915..e5aa67fd85 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -18,10 +18,10 @@ limitations under the License. #include <type_traits> +#include "absl/container/flat_hash_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/util/ptr_util.h" namespace xla { @@ -360,13 +360,13 @@ class IndexedArrayAnalysis { std::vector<std::unique_ptr<Array>> owned_tensors_; std::vector<Literal> owned_literals_; - tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_; + absl::flat_hash_map<const HloInstruction*, Array*> cache_; }; // A pass that prints all non-trivial results returned by IndexedArrayAnalysis. // This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to // unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloPassInterface { +class IndexedArrayAnalysisPrinterPass : public HloModulePass { public: absl::string_view name() const override; StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 3fdc2cee9a..69a4c160ee 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -22,11 +22,12 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -188,13 +189,20 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_duplicate) { + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>* + result_cache) { if (consumer == producer) { return true; } if (!consumer->IsFusible()) { return false; } + auto cache_it = result_cache->find(std::make_pair(producer, consumer)); + if (cache_it != result_cache->end()) { + return cache_it->second; + } + bool result = true; for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) { auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter @@ -202,20 +210,23 @@ bool InstructionFusion::CanFuseOnAllPaths( if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_duplicate.count(consumer_operand) > 0 || - !ShouldFuse(consumer, i)) { - return false; + if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + result = false; + break; } // The producer is reachable from consumer_operand which means we need // to be able to fuse consumer_operand into consumer in order for // producer to be fusible into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { - return false; + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse, + result_cache)) { + result = false; + break; } } - return true; + result_cache->emplace(std::make_pair(producer, consumer), result); + return result; } InstructionFusion::HloInstructionSet @@ -231,6 +242,8 @@ InstructionFusion::ComputeGloballyUnfusible( // fusing operations that require duplication later depending on // is_expensive_(). HloInstructionSet do_not_duplicate; + absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool> + can_fuse_on_all_paths_result_cache; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { if (do_not_duplicate.count(producer) > 0) { @@ -286,7 +299,8 @@ InstructionFusion::ComputeGloballyUnfusible( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusible() && - CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate, + &can_fuse_on_all_paths_result_cache)) { continue; } do_not_duplicate.insert(producer); @@ -417,7 +431,7 @@ class ReversePostOrderFusionQueue : public FusionQueue { private: std::vector<HloInstruction*> post_order_; - tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_; + absl::flat_hash_map<HloInstruction*, int> post_order_index_; }; } // namespace diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c1fde8ecfc..f14c667520 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -1,3 +1,4 @@ +#include "absl/container/flat_hash_map.h" /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -24,39 +26,12 @@ limitations under the License. namespace xla { -// A queue interface that allows implementations to choose fusion candidates in -// custom order. -class FusionQueue { - public: - FusionQueue() = default; - virtual ~FusionQueue() = default; - - // Dequeues the next fusion candidates: a consumer and the list of producers - // as operand indices. - virtual std::pair<HloInstruction*, std::vector<int64>> - DequeueNextInstructionAndOperandsToFuseInOrder() = 0; - - // A callback passed to the queue implementation right before the producer is - // fused into the consumer. - virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {} - - // A callback passed to the queue implementation right after the fusion is - // created. Note that original_producer could have been destroyed. - virtual void OnFusingInstruction(HloInstruction* fusion, - HloInstruction* original_producer, - HloInstruction* original_consumer) {} - - // A callback passed to the queue implementation to notify the removal of an - // instruction. - virtual void RemoveInstruction(HloInstruction* instruction) = 0; -}; - // HLO pass which performs instruction fusion. Instructions are fused // "vertically", meaning producing instructions are fused into their consumers // with the intent that the loops which compute their values will be fused in // code generation. Derived classes define ShouldFuse method to select which // instructions to fuse. -class InstructionFusion : public HloPassInterface { +class InstructionFusion : public HloModulePass { public: explicit InstructionFusion( std::function<bool(const HloInstruction& instruction)> is_expensive, @@ -151,8 +126,15 @@ class InstructionFusion : public HloPassInterface { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. - bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloInstructionSet& do_not_fuse); + // + // A map from <producer, consumer> to a bool is required as the result cache + // to store and query the results of calls to this function, in order to avoid + // repeated computations. + bool CanFuseOnAllPaths( + HloInstruction* producer, HloInstruction* consumer, + const HloInstructionSet& do_not_fuse, + absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>* + result_cache); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 146c9052f1..1484e14df1 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -45,8 +45,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", - "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", + "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index bb69cb9c47..7c79eb7d79 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" -#include "tensorflow/compiler/xla/service/inliner.h" #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -44,7 +44,8 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass<LayoutAssignment>( - hlo_module->mutable_entry_computation_layout()); + hlo_module->mutable_entry_computation_layout(), + LayoutAssignment::InstructionCanChangeLayout); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 082bf8bffe..cc4a342e9d 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetBufferLayout(new_shape.layout(), *buffer)); } + } else if (instruction->IsCrossModuleAllReduce()) { + CHECK(get_channel_constraints(instruction)) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 all_reduce_id = instruction->all_reduce_id().value(); + if (!get_channel_constraints(instruction) + ->IsChannelConstrained(all_reduce_id)) { + continue; + } + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape& buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape)); + Shape new_buffer_shape = + get_channel_constraints(instruction) + ->LayoutShapeForChannel(buffer_shape, all_reduce_id); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(new_buffer_shape, instruction)); } } @@ -776,21 +792,27 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout( << " instruction: " << instruction->ToString(); if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. + // Copy tuple elements which have differing layouts. std::vector<HloInstruction*> element_copies; for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); ++i) { + const Shape& target_shape = + ShapeUtil::GetSubshape(shape_with_layout, {i}); + const Shape& instr_shape = + ShapeUtil::GetSubshape(instruction->shape(), {i}); HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - SetupCopiedInstruction(*instruction, gte, {i}); - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); + HloInstruction::CreateGetTupleElement(instr_shape, instruction, i)); + + if (ShapeUtil::Equal(target_shape, instr_shape)) { + // Shapes and layouts are equal, no need to copy. + element_copies.push_back(gte); + } else { + SetupCopiedInstruction(*instruction, gte, {i}); + // Recurse to copy each element. + TF_ASSIGN_OR_RETURN(HloInstruction * element_copy, + CreateCopyWithNewLayout(target_shape, gte)); + element_copies.push_back(element_copy); + } } // Gather element copies into a tuple with a new Tuple instruction. HloInstruction* tuple_copy = instruction->parent()->AddInstruction( @@ -958,10 +980,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { LayoutAssignment::LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func, ChannelLayoutConstraints* channel_constraints) : entry_computation_layout_(entry_computation_layout), + saved_entry_computation_layout_(*entry_computation_layout), - channel_layout_constraints_(channel_constraints) { + channel_layout_constraints_(channel_constraints), + instruction_can_change_layout_func_( + std::move(instruction_can_change_layout_func)) { if (channel_layout_constraints_ != nullptr) { // Save a copy of the input ChannelLayoutConstraints so that we can reset it // if we have to undo previous operations (ClearPreviousPassSideEffects()). @@ -982,7 +1009,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(instruction->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) { + !instruction_can_change_layout_func_(instruction)) { // Propagate the result layout to the operand layout if the instruction // requires the same layout out for the result and the operand. // @@ -1060,7 +1087,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (!ShapeUtil::IsScalar(operand->shape()) && ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape()) && - InstructionRequiresInputLayoutEqualToOutputLayout(user)) { + !instruction_can_change_layout_func_(user)) { // Assign users the same layout as the operand. return absl::make_unique<Layout>(operand_layout); } @@ -1512,19 +1539,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, // Verify all layouts in the shape have been set. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - - // Copy the root instruction's result if its layout does not match the result - // layout constraint. - if (constraints.ResultLayout() != nullptr && - !constraints.ResultLayout()->MatchesLayoutInShape( - computation->root_instruction()->shape())) { - TF_ASSIGN_OR_RETURN( - HloInstruction * new_root, - CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), - computation->root_instruction())); - computation->set_root_instruction(new_root); - } - return Status::OK(); } @@ -1654,6 +1668,18 @@ Status LayoutAssignment::RunOnComputation( TF_RETURN_IF_ERROR( ConstrainChannelLayouts(computation, channel_constraints)); } + + // Copy the root instruction's result if its layout does not match the result + // layout constraint. + if (constraints.ResultLayout() != nullptr && + !constraints.ResultLayout()->MatchesLayoutInShape( + computation->root_instruction()->shape())) { + TF_ASSIGN_OR_RETURN( + HloInstruction * new_root, + CreateCopyWithNewLayout(constraints.ResultLayout()->shape(), + computation->root_instruction())); + computation->set_root_instruction(new_root); + } return Status::OK(); } @@ -1709,6 +1735,30 @@ Status LayoutAssignment::ConstrainChannelLayouts( ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0}); *send_shape = shape; } + } else if (instruction->IsCrossModuleAllReduce()) { + const Layout* layout = + get_channel_constraints(instruction) + ->ConstrainChannel(instruction->all_reduce_id().value(), + instruction->shape().layout()); + if (layout != nullptr) { + // We found an already constrained layout which does not match the one + // the channel wants to impose. Either add a new kCopy, or use the + // existing one to marshal the correct shape. + HloInstruction* operand = instruction->mutable_operand(0); + Shape shape = operand->shape(); + *shape.mutable_layout() = *layout; + if (operand->opcode() != HloOpcode::kCopy) { + HloInstruction* copy = operand->parent()->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand)); + RegisterAddedCopy(copy); + SetupCopiedInstruction(*operand, copy, {}); + TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy)); + operand = copy; + } else { + *operand->mutable_shape() = shape; + } + *instruction->mutable_shape() = shape; + } } } return Status::OK(); @@ -1803,7 +1853,8 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) { return true; } -bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( +/* static */ +bool LayoutAssignment::InstructionCanChangeLayout( const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kAbs: @@ -1869,7 +1920,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kTanh: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: - return true; + return false; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1900,7 +1951,7 @@ bool LayoutAssignment::InstructionRequiresInputLayoutEqualToOutputLayout( case HloOpcode::kTrace: case HloOpcode::kTranspose: case HloOpcode::kTuple: - return false; + return true; } } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index cf545031d3..2d48e12263 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -25,6 +25,8 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -38,8 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -228,8 +228,8 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set<LogicalBuffer::Id> unconstrained_buffer_ids_; - mutable tensorflow::gtl::FlatMap<const HloInstruction*, - std::unique_ptr<PointsToSet::BufferSet>> + mutable absl::flat_hash_map<const HloInstruction*, + std::unique_ptr<PointsToSet::BufferSet>> buffer_sets_cache_; HloComputation* computation_; @@ -281,11 +281,16 @@ class ChannelLayoutConstraints { // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. -class LayoutAssignment : public HloPassInterface { +class LayoutAssignment : public HloModulePass { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. // + // instruction_can_change_layout_func is a function object that determines + // whether an instruction can change layouts. An instruction not being able to + // change layout means that it requires operands with the same rank as the + // output to have the same layout as the output. + // // channel_constraints is both an input and output. Any sends or recvs that // are present in channel_constraints will be laid out as constrained. Any // unconstrained sends or recvs will be laid out as locally optimal and their @@ -295,6 +300,8 @@ class LayoutAssignment : public HloPassInterface { // within any module passed to `Run`. explicit LayoutAssignment( ComputationLayout* entry_computation_layout, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = InstructionCanChangeLayout, ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} absl::string_view name() const override { return "layout-assignment"; } @@ -303,10 +310,10 @@ class LayoutAssignment : public HloPassInterface { // (any layouts were changed). StatusOr<bool> Run(HloModule* module) override; - // Returns true if the instruction requires that operands with the same rank - // as the output have to have the same layout as the output. - virtual bool InstructionRequiresInputLayoutEqualToOutputLayout( - const HloInstruction* instruction); + // Determines whether an instruction can change layouts. An instruction not + // being able to change layout means that it requires operands with the same + // rank as the output to have the same layout as the output. + static bool InstructionCanChangeLayout(const HloInstruction* instruction); protected: // These methods, invoked by PropagateConstraints, propagate a layout @@ -504,7 +511,7 @@ class LayoutAssignment : public HloPassInterface { // Every copy added to the module by the layout assignment pass is registered // here. - tensorflow::gtl::FlatSet<HloInstruction*> added_copies_; + absl::flat_hash_set<HloInstruction*> added_copies_; // The pointer to the channel layout constraints passed in with the // constructor. If not nullptr, this is an input/output argument. @@ -521,8 +528,10 @@ class LayoutAssignment : public HloPassInterface { // The set of HLO instructions which lacked any layout constraint, thus // receiving propagated default layouts. - tensorflow::gtl::FlatSet<const HloInstruction*> - unconstrained_layout_instructions_; + absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_; + + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 752a61476d..2c549cd872 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -55,7 +55,8 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { ComputationLayout* entry_computation_layout, ChannelLayoutConstraints* channel_constraints = nullptr) { LayoutAssignment layout_assignment( - entry_computation_layout, /*channel_constraints=*/channel_constraints); + entry_computation_layout, LayoutAssignment::InstructionCanChangeLayout, + /*channel_constraints=*/channel_constraints); EXPECT_IS_OK(layout_assignment.Run(module).status()); } @@ -860,6 +861,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}))); } +TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) { + // Pin non matching layouts to parameter and root. + const char* module_str = R"( + HloModule test_module + + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY entry_computation { + param = (f32[2,2]) parameter(0) + gte = f32[2,2] get-tuple-element(param), index=0 + ar.0 = f32[2,2] cross-replica-sum(gte), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=0} + const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}}) + ROOT ar.1 = f32[2,2] cross-replica-sum(const), + all_reduce_id=0, replica_groups={{0}}, to_apply=add, + sharding={maximal device=1} + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(module_str)); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({1, 0})); + + ChannelLayoutConstraints channel_constraints; + AssignLayouts(module.get(), &computation_layout, &channel_constraints); + + EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0)); +} + TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { const char* module_str = R"( HloModule CopySliceOperandToAvoidImplicitLayoutChange @@ -998,5 +1043,64 @@ TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { op::ShapeWithLayout(shape_copy)))); } +TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { + // The first infeed uses layout {0,1}, while the second uses layout {1,0}. + // The mismatch forces a copy of the tuple. The tuple contains a token, so + // layout assignment will fail if it tries to copy the whole tuple. + const char* module_str = R"( + HloModule TupleCopyOnLayoutMismatch + + condition.1 (tup: (s32[], token[], f32[512,1024]{0,1})) -> pred[] { + tup.1 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.1 = s32[] get-tuple-element(tup.1), index=0 + five = s32[] constant(5) + ROOT lt = pred[] less-than(counter.1, five) + } + + body.2 (tup: (s32[], token[], f32[512,1024]{0,1})) -> (s32[], token[], f32[512,1024]{0,1}) { + tup.2 = (s32[], token[], f32[512,1024]{0,1}) parameter(0) + counter.2 = s32[] get-tuple-element(tup.2), index=0 + tok.2 = token[] get-tuple-element(tup.2), index=1 + + ifeed.2 = (f32[512,1024]{1,0}, token[]) infeed(tok.2) + next_tok = token[] get-tuple-element(ifeed.2), index=1 + next_buf = f32[512,1024]{1,0} get-tuple-element(ifeed.2), index=0 + + one = s32[] constant(1) + next_counter = s32[] add(counter.2, one) + ROOT tup = (s32[], token[], f32[512,1024]{0,1}) tuple(next_counter, next_tok, next_buf) + } + + ENTRY main () -> f32[512,1024]{0,1} { + start_tok = token[] after-all() + + ifeed.3 = (f32[512,1024]{0,1}, token[]) infeed(start_tok) + itok = token[] get-tuple-element(ifeed.3), index=1 + ibuf = f32[512,1024]{0,1} get-tuple-element(ifeed.3), index=0 + + zero = s32[] constant(0) + itup = (s32[], token[], f32[512,1024]{0,1}) tuple(zero, itok, ibuf) + + loop = (s32[], token[], f32[512,1024]{0,1}) while(itup), condition=condition.1, body=body.2 + ROOT result = f32[512,1024]{0,1} get-tuple-element(loop), index=2 + } + )"; + + ParseAndVerifyModule(module_str); + ComputationLayout computation_layout( + module().entry_computation()->ComputeProgramShape()); + + // Sanity check to verify that there's a layout mismatch. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); + + AssignLayouts(&module(), &computation_layout); + + // Make sure that layout assignment did not magically eliminate the mismatch, + // in which case the test didn't prove anything. + EXPECT_THAT(LayoutOf(&module(), "ibuf"), ElementsAre(0, 1)); + EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 540bbb7c7a..6223a34b12 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -38,6 +38,8 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm//:core", ], diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index e5370eca56..643ecd0fba 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h" -#include <unordered_set> +#include <map> #include "llvm/IR/MDBuilder.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( add_buffers_to_worklist(operand); } - tensorflow::gtl::FlatSet<BufferAllocation::Slice, - BufferAllocation::Slice::Hasher> - buffers; + std::set<BufferAllocation::Slice> buffers; for (const LogicalBuffer* buffer : worklist) { // Skip buffers which cannot be added to the noalias set. if (!assignment.HasAllocation(*buffer) || diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h index 8d9fa99d82..2b46b3c396 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_ +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { namespace llvm_ir { @@ -77,14 +76,14 @@ class AliasAnalysis { // A map from a buffer slice to metadata corresponding to its alias.scope // metadata. The index kParameterAliasSet is used to hold aliasing // information for parameters. - tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*, - BufferAllocation::Slice::Hasher> + absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*, + BufferAllocation::Slice::Hasher> alias_scope_metadata_; // A map from a buffer slice to metadata corresponding to its noalias // metadata. - tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*, - BufferAllocation::Slice::Hasher> + absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*, + BufferAllocation::Slice::Hasher> noalias_metadata_; }; diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index eaa09591b7..ec52a24d78 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -54,7 +54,7 @@ Status LogicalBufferAnalysis::Analyze() { // so reserve 10% more than the number of instructions to avoid frequent // resizes. logical_buffers_.clear(); - logical_buffers_.reserve((module_->NumUniqueInstructionIds() * 11) / 10); + logical_buffers_.reserve((module_->instruction_count() * 11) / 10); // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc index 5fd779ebf9..2200ef054a 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/map_inliner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include <memory> #include <string> @@ -32,10 +32,10 @@ limitations under the License. namespace xla { -// InlinerVisitor traverses the HLO computation and inlines maps. -class InlinerVisitor : public DfsHloVisitorWithDefault { +// MapInlinerVisitor traverses the HLO computation and inlines maps. +class MapInlinerVisitor : public DfsHloVisitorWithDefault { public: - explicit InlinerVisitor(HloComputation* computation) + explicit MapInlinerVisitor(HloComputation* computation) : computation_(computation) {} // Default visitor action is to do nothing and return OK. @@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault { StatusOr<bool> Run(HloComputation* computation); private: - // Current HloComputation instance the InlinerVisitor is traversing. + // Current HloComputation instance the MapInlinerVisitor is traversing. HloComputation* computation_; // Whether algebraic simplification has occurred. bool changed_ = false; }; -StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) { +StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) { changed_ = false; computation_ = computation; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this)); return changed_; } -Status InlinerVisitor::HandleMap(HloInstruction* map) { +Status MapInlinerVisitor::HandleMap(HloInstruction* map) { HloComputation* function = map->to_apply(); HloInstruction& root = *function->root_instruction(); - // TODO(b/29249531): Add DCE pass to remove unused HloComputations. // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { if (root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. return Status::OK(); } VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function " << root.ToShortString(); - // If the input is a constant then the shape of the constant could be - // different than the map shape. Hence, a broadcast is needed, else the - // cloned operand with new shape and operands work. - if (root.opcode() != HloOpcode::kConstant) { - std::vector<HloInstruction*> params; - for (int64 o = 0; o < root.operands().size(); o++) { - params.push_back(map->operands()[root.operand(o)->parameter_number()]); - } - HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), params)); + if (root.opcode() == HloOpcode::kParameter) { + // If the root is a parameter, then use the corresponding operand as the + // result of the computation. TF_RETURN_IF_ERROR( - computation_->ReplaceInstruction(map, placed_instruction)); - } else { + map->ReplaceAllUsesWith(map->operands()[root.parameter_number()])); + TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map)); + } else if (root.opcode() == HloOpcode::kConstant) { + // If the input is a constant then the shape of the constant could be + // different than the map shape. Hence, a broadcast is needed, else the + // cloned operand with new shape and operands work. + // // The constant is in an embedded computation and needs to be recreated // as part of the computation that the broadcast is inserted into. HloInstruction* constant = computation_->AddInstruction(root.Clone()); @@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { HloInstruction::CreateBroadcast(map->shape(), constant, {})); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); + } else { + std::vector<HloInstruction*> params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } + HloInstruction* placed_instruction = computation_->AddInstruction( + root.CloneWithNewOperands(map->shape(), params)); + TF_RETURN_IF_ERROR( + computation_->ReplaceInstruction(map, placed_instruction)); } changed_ = true; return Status::OK(); @@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) { return Status::OK(); } -StatusOr<bool> Inliner::Run(HloModule* module) { - InlinerVisitor visitor(/*computation=*/nullptr); +StatusOr<bool> MapInliner::Run(HloModule* module) { + MapInlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; for (HloComputation* computation : module->computations()) { TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h index efa8ed3abc..b679118118 100644 --- a/tensorflow/compiler/xla/service/inliner.h +++ b/tensorflow/compiler/xla/service/map_inliner.h @@ -13,27 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { -// A pass which performs inlining. Which can result, for example, in functions -// that were previously being mapped by Map instead directly applied to the -// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)). -class Inliner : public HloPassInterface { +// A pass which performs map inlining. This replaces kMap instructions with +// their equivalent sequence of array operations. For example: +// map({X, Y}, add) -> add(X, Y)). +class MapInliner : public HloModulePass { public: - ~Inliner() override = default; - absl::string_view name() const override { return "inline"; } + ~MapInliner() override = default; + absl::string_view name() const override { return "map-inline"; } - // Run inlining on the given computation. Returns whether the computation was - // changed. + // Run map inlining on the given computation. Returns whether the computation + // was changed. StatusOr<bool> Run(HloModule* module) override; }; } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_ diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc index 7e967f035c..84059dd0f7 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/map_inliner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/inliner.h" +#include "tensorflow/compiler/xla/service/map_inliner.h" #include <memory> #include <utility> @@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using InlinerTest = HloVerifiedTestBase; +using MapInlinerTest = HloVerifiedTestBase; // Test that `map` with `max` is transformed to `max` -TEST_F(InlinerTest, MapMax) { +TEST_F(MapInlinerTest, MapMax) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto max_builder = HloComputation::Builder(TestName()); @@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Maximum(lhs, rhs)); @@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) { } // Test that `constant` function is changed to `broadcast`. -TEST_F(InlinerTest, MapConstant) { +TEST_F(MapInlinerTest, MapConstant) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); auto const2_builder = HloComputation::Builder(TestName()); @@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) { hlo_module->AddEmbeddedComputation(std::move(const2_f32)); hlo_module->AddEntryComputation(std::move(computation)); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); root = hlo_module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Broadcast(op::Constant())); @@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } -TEST_F(InlinerTest, MapSubtractOppositeOrder) { +TEST_F(MapInlinerTest, MapSubtractOppositeOrder) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); // Note that the parameter ordinals are in the opposite order to their @@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { hlo_module->AddEmbeddedComputation(std::move(max_f32)); hlo_module->AddEntryComputation(std::move(computation)); - Inliner inliner; + MapInliner inliner; EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie()); EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), op::Subtract(rhs, lhs)); @@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) { EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); } +TEST_F(MapInlinerTest, MapParameter) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + auto param_builder = HloComputation::Builder(TestName()); + param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0")); + param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1")); + auto param_f32 = param_builder.Build(); + + auto builder = HloComputation::Builder("MapParamFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEmbeddedComputation(std::move(param_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + MapInliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(hlo_module->Clone(), {}); + auto expected = LiteralUtil::CreateR0<float>(4); + EXPECT_TRUE(LiteralTestUtil::Equal(result, expected)); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index b9ec31c497..2ca527bc4c 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/multi_output_fusion.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) { all_fusion_candidates_.push_back(instruction); std::vector<HloInstruction*> candidates; - tensorflow::gtl::FlatSet<HloInstruction*> candidates_set; + absl::flat_hash_set<HloInstruction*> candidates_set; VLOG(10) << "Looking at instruction: " << instruction->name(); for (auto operand : instruction->operands()) { // Filter out the non-interesting instructions -- they @@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { // Update the fusible list for fusion. Variable new_fusibles keeps // track of the new or changed entries. std::vector<std::pair<HloInstruction*, int64>> new_fusibles; - tensorflow::gtl::FlatSet<HloInstruction*> in_list; + absl::flat_hash_set<HloInstruction*> in_list; auto it = fusion_node.fusibles.begin(); while (it != fusion_node.fusibles.end()) { HloInstruction* instr = it->first; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h index d2c52651c4..9508ab2ed1 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.h +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -19,6 +19,7 @@ limitations under the License. #include <queue> #include <vector> +#include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -44,7 +45,7 @@ namespace xla { // Note that the reachability map is updated based on the original computation. // This works because the reachability is monotonically increasing with // instruction fusion. -class MultiOutputFusion : public HloPassInterface { +class MultiOutputFusion : public HloModulePass { public: MultiOutputFusion(int64 fuel) : fuel_(fuel) {} @@ -126,7 +127,7 @@ class MultiOutputFusion : public HloPassInterface { std::vector<FusionCandidate> candidates_; // A map that maps an instruction to the index_. - tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_; + absl::flat_hash_map<HloInstruction*, int> candidates_index_; // The reachability map of current computation. std::unique_ptr<HloReachabilityMap> reachability_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index bd8fb17a23..ac2f79674f 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -39,8 +39,10 @@ NameUniquer::NameUniquer(const string& separator) { } /*static*/ string NameUniquer::GetSanitizedName(const string& name) { + if (name.empty()) { + return ""; + } string result = name; - CHECK(!result.empty()) << "name should not be empty"; char c = static_cast<unsigned char>(result[0]); if (!isalpha(c) && c != '_') { result[0] = '_'; diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index 6dd89c240f..8909d0f4fe 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -18,10 +18,10 @@ limitations under the License. #include <string> +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -69,7 +69,7 @@ class NameUniquer { int64 next_ = 0; // Set of all the identifiers which has been used. - tensorflow::gtl::FlatSet<int64> used_; + absl::flat_hash_set<int64> used_; }; // The string to use to separate the prefix of the name from the uniquing @@ -78,7 +78,7 @@ class NameUniquer { // Map from name prefix to the generator data structure which tracks used // identifiers and generates new ones. - tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_; + absl::flat_hash_map<string, SequentialIdGenerator> generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index 4869db79e7..380cde0e6a 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -17,8 +17,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_PATTERN_MATCHER_H_ #include "absl/strings/string_view.h" +#include "absl/utility/utility.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -116,15 +120,82 @@ namespace xla { // .WithOperand(1, Op(&c)) // .WithOperand(2, Op(&d)) // + +struct MatchOption { + // If true, actually capture matched item into the user pointer. + bool capture; +}; + template <typename Value, typename Pattern> -bool Match(Value* value, const Pattern& pattern) { - return pattern.Match(value); +bool Match(Value* value, const Pattern& pattern, + MatchOption option = {/*.capture=*/true}) { + if (option.capture) { + auto new_option = option; + new_option.capture = false; + if (!pattern.Match(value, new_option)) { + return false; + } + } + return pattern.Match(value, option); } namespace match { namespace detail { +template <typename Item, typename... Patterns> +class AllOfPattern { + public: + explicit AllOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + bool Match(Item* item, MatchOption option) const { + bool matched = MatchImpl(item, option, std::integral_constant<size_t, 0>()); + // This invariant is guaranteed by the top-level Match and AnyOf. + DCHECK(matched || !option.capture); + return matched; + } + + private: + template <typename ItemType, size_t index> + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant<size_t, index>) const { + return std::get<index>(patterns_).Match(item, option) && + MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); + } + + template <typename ItemType> + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant<size_t, sizeof...(Patterns)>) const { + return true; + } + + std::tuple<Patterns...> patterns_; +}; + +} // namespace detail + +// Returns a pattern that represents the conjunction of all input patterns. All +// patterns need to match in order to have the AllOf pattern match. +// +// TODO(timshen): Currently AllOf is still nested, e.g. AllOf<AllOf<A>, B> is +// not AllOf<A, B>. We might want to flatten the AllOf type structure if the +// C++ compile error message gets annoying. +template <typename Item, typename... Patterns> +detail::AllOfPattern<typename std::remove_const<Item>::type, Patterns...> AllOf( + const Patterns&... patterns) { + return detail::AllOfPattern<typename std::remove_const<Item>::type, + Patterns...>(patterns...); +} + +namespace detail { + template <typename LayoutType, typename Impl> class LayoutPattern; @@ -132,57 +203,61 @@ class LayoutPattern; // nullptr. class LayoutPatternBaseImpl { public: - bool Match(const ::xla::Layout* layout) const { return layout != nullptr; } + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout != nullptr; + } }; // A LayoutPattern implementation that matches only if the layout equals a // Layout proto. -template <typename Previous> class LayoutPatternEqualImpl { public: - explicit constexpr LayoutPatternEqualImpl(const Previous& previous, - const ::xla::Layout* layout) - : previous_(previous), layout_(layout) {} + explicit constexpr LayoutPatternEqualImpl(const ::xla::Layout* layout) + : layout_(layout) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && LayoutUtil::Equal(*layout_, *layout); + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return LayoutUtil::Equal(*layout_, *layout); } private: - Previous previous_; const ::xla::Layout* layout_; }; // A LayoutPattern implementation that matches only if the layout has a given // format. -template <typename Previous> class LayoutPatternFormatImpl { public: - explicit constexpr LayoutPatternFormatImpl(const Previous& previous, - Format format) - : previous_(previous), format_(format) {} + explicit constexpr LayoutPatternFormatImpl(Format format) : format_(format) {} - bool Match(const ::xla::Layout* layout) const { - return previous_.Match(layout) && layout->format() == format_; + bool Match(const ::xla::Layout* layout, MatchOption option) const { + return layout->format() == format_; } private: - Previous previous_; Format format_; }; // A pattern that matches Layouts. template <typename LayoutType, typename Impl> class LayoutPattern { + private: + template <typename NewImpl> + LayoutPattern<LayoutType, AllOfPattern<::xla::Layout, Impl, NewImpl>> + AppendImpl(NewImpl new_impl) const { + return LayoutPattern<LayoutType, + AllOfPattern<::xla::Layout, Impl, NewImpl>>( + AllOf<Layout>(impl_, std::move(new_impl)), matched_layout_); + } + public: explicit constexpr LayoutPattern(const Impl& impl, LayoutType** matched_layout) : impl_(impl), matched_layout_(matched_layout) {} // Returns true and captures the layout iff it matches the pattern. - bool Match(const ::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(const ::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -191,9 +266,9 @@ class LayoutPattern { } // Returns true and captures the layout iff it matches the pattern. - bool Match(::xla::Layout* layout) const { - if (impl_.Match(layout)) { - if (matched_layout_) { + bool Match(::xla::Layout* layout, MatchOption option) const { + if (impl_.Match(layout, option)) { + if (option.capture && matched_layout_) { *matched_layout_ = layout; } return true; @@ -203,24 +278,21 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>> EqualTo( - const ::xla::Layout* layout) const { - return LayoutPattern<LayoutType, LayoutPatternEqualImpl<Impl>>( - LayoutPatternEqualImpl<Impl>(impl_, layout), matched_layout_); + constexpr auto EqualTo(const ::xla::Layout* layout) const + -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>> - WithDenseFormat() const { - return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>( - LayoutPatternFormatImpl<Impl>(impl_, DENSE), matched_layout_); + constexpr auto WithDenseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + return AppendImpl(LayoutPatternFormatImpl(DENSE)); } // Modifies the pattern to match only if the layout has a sparse format. - constexpr LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>> - WithSparseFormat() const { - return LayoutPattern<LayoutType, LayoutPatternFormatImpl<Impl>>( - LayoutPatternFormatImpl<Impl>(impl_, SPARSE), matched_layout_); + constexpr auto WithSparseFormat() const + -> decltype(this->AppendImpl(LayoutPatternFormatImpl(SPARSE))) { + return AppendImpl(LayoutPatternFormatImpl(SPARSE)); } private: @@ -228,8 +300,72 @@ class LayoutPattern { LayoutType** matched_layout_; }; +template <typename Item, typename... Patterns> +class AnyOfPattern { + public: + explicit AnyOfPattern(const Patterns&... patterns) : patterns_(patterns...) {} + + bool Match(const Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant<size_t, 0>()); + } + + bool Match(Item* item, MatchOption option) const { + return MatchImpl(item, option, std::integral_constant<size_t, 0>()); + } + + private: + template <typename ItemType, size_t index> + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant<size_t, index>) const { + auto new_option = option; + new_option.capture = false; + // Try to match the sub-pattern without capturing behavior. + if (std::get<index>(patterns_).Match(item, new_option)) { + // Capture the branch. + if (option.capture) { + // TODO(timshen): Currently the behavior can be exponential. Optimize it + // with memoization or recording the matched sub-pattern index, if it + // takes too long to run. + // + // Specifically, the "memoization" approach is to create an empty + // container with the key (pattern, instruction), and value as whether + // matched or not. + // + // Alternatively, we may run the pattern matching with captures off, but + // instead record a "trace" somewhere, indicating how exactly the + // pattern matches the input. For example, the trace information for + // AnyOf will be a runtime number indicate which sub-pattern is matched. + // Then we run another pass to do captures only with the help of the + // trace. + bool ret = std::get<index>(patterns_).Match(item, option); + DCHECK(ret); + } + return true; + } + return MatchImpl(item, option, std::integral_constant<size_t, index + 1>()); + } + + template <typename ItemType> + bool MatchImpl(ItemType* item, MatchOption option, + std::integral_constant<size_t, sizeof...(Patterns)>) const { + return false; + } + + std::tuple<Patterns...> patterns_; +}; + } // namespace detail +// Returns a pattern that represents the logical disjunction of the input +// patterns. The returned pattern matches from left to right, and stops on the +// first match. +template <typename Item, typename... Patterns> +detail::AnyOfPattern<typename std::remove_const<Item>::type, Patterns...> AnyOf( + const Patterns&... patterns) { + return detail::AnyOfPattern<typename std::remove_const<Item>::type, + Patterns...>(patterns...); +} + // Creates a layout pattern that will capture the matched layout in the // argument. inline constexpr detail::LayoutPattern<const ::xla::Layout, @@ -258,172 +394,145 @@ class ShapePattern; // nullptr. class ShapePatternBaseImpl { public: - bool Match(const ::xla::Shape* shape) const { return shape != nullptr; } + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return shape != nullptr; + } }; // A ShapePattern implementation that matches only if the shape equals a Shape // proto. -template <typename Previous> class ShapePatternEqualImpl { public: - explicit constexpr ShapePatternEqualImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternEqualImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Equal(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Equal(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape is compatible to // a Shape proto. -template <typename Previous> class ShapePatternCompatibleImpl { public: - explicit constexpr ShapePatternCompatibleImpl(const Previous& previous, - const ::xla::Shape* shape) - : previous_(previous), shape_(shape) {} + explicit constexpr ShapePatternCompatibleImpl(const ::xla::Shape* shape) + : shape_(shape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Compatible(*shape_, *shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Compatible(*shape_, *shape); } private: - Previous previous_; const ::xla::Shape* shape_; }; // A ShapePattern implementation that matches only if the shape has a given // element type. -template <typename Previous> class ShapePatternElementTypeImpl { public: - explicit constexpr ShapePatternElementTypeImpl(const Previous& previous, - PrimitiveType element_type) - : previous_(previous), element_type_(element_type) {} + explicit constexpr ShapePatternElementTypeImpl(PrimitiveType element_type) + : element_type_(element_type) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && shape->element_type() == element_type_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return shape->element_type() == element_type_; } private: - Previous previous_; PrimitiveType element_type_; }; // A ShapePattern implementation that matches only if the shape is scalar. -template <typename Previous> class ShapePatternIsScalarImpl { public: - explicit constexpr ShapePatternIsScalarImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsScalarImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsScalar(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsScalar(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is an array -template <typename Previous> class ShapePatternIsArrayImpl { public: - explicit constexpr ShapePatternIsArrayImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsArrayImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsArray(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsArray(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape is a tuple. -template <typename Previous> class ShapePatternIsTupleImpl { public: - explicit constexpr ShapePatternIsTupleImpl(const Previous& previous) - : previous_(previous) {} + explicit constexpr ShapePatternIsTupleImpl() {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IsTuple(*shape); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IsTuple(*shape); } - - private: - Previous previous_; }; // A ShapePattern implementation that matches only if the shape has a given // rank. -template <typename Previous> class ShapePatternRankImpl { public: - explicit constexpr ShapePatternRankImpl(const Previous& previous, int64 rank) - : previous_(previous), rank_(rank) {} + explicit constexpr ShapePatternRankImpl(int64 rank) : rank_(rank) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::Rank(*shape) == rank_; + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::Rank(*shape) == rank_; } private: - Previous previous_; int64 rank_; }; // A ShapePattern implementation that matches only if the shape has a layout // that matches a given pattern. -template <typename Previous, typename LayoutType, typename LayoutImpl> +template <typename LayoutType, typename LayoutImpl> class ShapePatternLayoutImpl { public: explicit constexpr ShapePatternLayoutImpl( - const Previous& previous, const LayoutPattern<LayoutType, LayoutImpl>& layout) - : previous_(previous), layout_(layout) {} + : layout_(layout) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(&shape->layout()); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(&shape->layout(), option); } - bool Match(Shape* shape) const { - return previous_.Match(shape) && LayoutUtil::HasLayout(*shape) && - layout_.Match(shape->mutable_layout()); + bool Match(Shape* shape, MatchOption option) const { + return LayoutUtil::HasLayout(*shape) && + layout_.Match(shape->mutable_layout(), option); } private: - Previous previous_; LayoutPattern<LayoutType, LayoutImpl> layout_; }; // A ShapePattern implementation that matches only if the shape has a subshape // that matches a given pattern. -template <typename Previous, typename SubshapeType, typename SubshapeImpl> +template <typename SubshapeType, typename SubshapeImpl> class ShapePatternSubshapeImpl { public: explicit ShapePatternSubshapeImpl( - const Previous& previous, ShapeIndexView index, + ShapeIndexView index, const ShapePattern<SubshapeType, SubshapeImpl>& subshape) - : previous_(previous), index_(index), subshape_(subshape) {} + : index_(index), subshape_(subshape) {} - bool Match(const ::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_)); + bool Match(const ::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(&ShapeUtil::GetSubshape(*shape, index_), option); } - bool Match(::xla::Shape* shape) const { - return previous_.Match(shape) && ShapeUtil::IndexIsValid(*shape, index_) && - subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_)); + bool Match(::xla::Shape* shape, MatchOption option) const { + return ShapeUtil::IndexIsValid(*shape, index_) && + subshape_.Match(ShapeUtil::GetMutableSubshape(shape, index_), + option); } private: - Previous previous_; ShapeIndexView index_; ShapePattern<SubshapeType, SubshapeImpl> subshape_; }; @@ -431,14 +540,22 @@ class ShapePatternSubshapeImpl { // A pattern that matches Shapes. template <typename ShapeType, typename Impl> class ShapePattern { + private: + template <typename NewImpl> + ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>> AppendImpl( + NewImpl new_impl) const { + return ShapePattern<ShapeType, AllOfPattern<::xla::Shape, Impl, NewImpl>>( + AllOf<Shape>(impl_, std::move(new_impl)), matched_shape_); + } + public: explicit constexpr ShapePattern(const Impl& impl, ShapeType** matched_shape) : impl_(impl), matched_shape_(matched_shape) {} // Returns true and captures the shape iff it matches the pattern. - bool Match(const ::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(const ::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -447,9 +564,9 @@ class ShapePattern { } // Returns true and captures the shape iff it matches the pattern. - bool Match(::xla::Shape* shape) const { - if (impl_.Match(shape)) { - if (matched_shape_) { + bool Match(::xla::Shape* shape, MatchOption option) const { + if (impl_.Match(shape, option)) { + if (option.capture && matched_shape_) { *matched_shape_ = shape; } return true; @@ -459,108 +576,90 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>> EqualTo( - const ::xla::Shape* shape) const { - return ShapePattern<ShapeType, ShapePatternEqualImpl<Impl>>( - ShapePatternEqualImpl<Impl>(impl_, shape), matched_shape_); + constexpr auto EqualTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>> - CompatibleTo(const ::xla::Shape* shape) const { - return ShapePattern<ShapeType, ShapePatternCompatibleImpl<Impl>>( - ShapePatternCompatibleImpl<Impl>(impl_, shape), matched_shape_); + constexpr auto CompatibleTo(const ::xla::Shape* shape) const + -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>> - WithElementType(PrimitiveType element_type) const { - return ShapePattern<ShapeType, ShapePatternElementTypeImpl<Impl>>( - ShapePatternElementTypeImpl<Impl>(impl_, element_type), matched_shape_); + constexpr auto WithElementType(PrimitiveType element_type) const + -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>> IsScalar() - const { - return ShapePattern<ShapeType, ShapePatternIsScalarImpl<Impl>>( - ShapePatternIsScalarImpl<Impl>(impl_), matched_shape_); + constexpr auto IsScalar() const + -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>> IsArray() - const { - return ShapePattern<ShapeType, ShapePatternIsArrayImpl<Impl>>( - ShapePatternIsArrayImpl<Impl>(impl_), matched_shape_); + constexpr auto IsArray() const + -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>> IsTuple() - const { - return ShapePattern<ShapeType, ShapePatternIsTupleImpl<Impl>>( - ShapePatternIsTupleImpl<Impl>(impl_), matched_shape_); + constexpr auto IsTuple() const + -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + return AppendImpl(ShapePatternIsTupleImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr ShapePattern<ShapeType, ShapePatternRankImpl<Impl>> WithRank( - int64 rank) const { - return ShapePattern<ShapeType, ShapePatternRankImpl<Impl>>( - ShapePatternRankImpl<Impl>(impl_, rank), matched_shape_); + constexpr auto WithRank(int64 rank) const + -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template <typename LayoutType, typename LayoutImpl> - constexpr ShapePattern<ShapeType, - ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>> - WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const { - return ShapePattern<ShapeType, - ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>>( - ShapePatternLayoutImpl<Impl, LayoutType, LayoutImpl>(impl_, layout), - matched_shape_); - } - - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl<Impl, const ::xla::Layout, - LayoutPatternEqualImpl<LayoutPatternBaseImpl>>> - WithLayoutEqualTo(const ::xla::Layout* layout) const { + auto WithLayout(const LayoutPattern<LayoutType, LayoutImpl>& layout) const + -> decltype(this->AppendImpl( + ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout))) { + return AppendImpl(ShapePatternLayoutImpl<LayoutType, LayoutImpl>(layout)); + } + + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const + -> decltype(this->WithLayout(Layout().EqualTo(layout))) { return WithLayout(Layout().EqualTo(layout)); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl<Impl, const ::xla::Layout, - LayoutPatternFormatImpl<LayoutPatternBaseImpl>>> - IsDenseArray() const { + constexpr auto IsDenseArray() const + -> decltype(this->WithLayout(Layout().WithDenseFormat())) { return WithLayout(Layout().WithDenseFormat()); } - constexpr ShapePattern< - ShapeType, - ShapePatternLayoutImpl<Impl, const ::xla::Layout, - LayoutPatternFormatImpl<LayoutPatternBaseImpl>>> - IsSparseArray() const { + constexpr auto IsSparseArray() const + -> decltype(this->WithLayout(Layout().WithSparseFormat())) { return WithLayout(Layout().WithSparseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template <typename SubshapeType, typename SubshapeImpl> + auto WithSubshape(ShapeIndexView index, + const ShapePattern<SubshapeType, SubshapeImpl>& subshape) + const -> decltype(this->AppendImpl( + ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, + subshape))) { + return AppendImpl( + ShapePatternSubshapeImpl<SubshapeType, SubshapeImpl>(index, subshape)); + } + ShapePattern<ShapeType, - ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>> - WithSubshape(ShapeIndexView index, - const ShapePattern<SubshapeType, SubshapeImpl>& subshape) const { - return ShapePattern< - ShapeType, ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>>( - ShapePatternSubshapeImpl<Impl, SubshapeType, SubshapeImpl>(impl_, index, - subshape), - matched_shape_); - } - - ShapePattern<ShapeType, ShapePatternSubshapeImpl< - Impl, const ::xla::Shape, - ShapePatternEqualImpl<ShapePatternBaseImpl>>> + AllOfPattern<Shape, Impl, + ShapePatternSubshapeImpl< + const ::xla::Shape, + AllOfPattern<::xla::Shape, ShapePatternBaseImpl, + ShapePatternEqualImpl>>>> WithSubshapeEqualTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, ShapePattern<const ::xla::Shape, ShapePatternBaseImpl>( @@ -568,9 +667,12 @@ class ShapePattern { .EqualTo(shape)); } - ShapePattern<ShapeType, ShapePatternSubshapeImpl< - Impl, const ::xla::Shape, - ShapePatternCompatibleImpl<ShapePatternBaseImpl>>> + ShapePattern<ShapeType, + AllOfPattern<Shape, Impl, + ShapePatternSubshapeImpl< + const ::xla::Shape, + AllOfPattern<::xla::Shape, ShapePatternBaseImpl, + ShapePatternCompatibleImpl>>>> WithSubshapeCompatibleTo(ShapeIndexView index, const ::xla::Shape* shape) const { return WithSubshape(index, @@ -611,159 +713,169 @@ class HloInstructionPattern; // instruction is not nullptr. class HloInstructionPatternBaseImpl { public: - bool Match(const ::xla::HloInstruction* inst) const { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { return inst != nullptr; } }; // An HloInstructionPattern implementation that matches only if the instruction // has a given name. -template <typename Previous> class HloInstructionPatternNameImpl { public: - explicit HloInstructionPatternNameImpl(const Previous& previous, - absl::string_view name) - : previous_(previous), name_(name) {} + explicit HloInstructionPatternNameImpl(absl::string_view name) + : name_(name) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->name() == name_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->name() == name_; } private: - Previous previous_; absl::string_view name_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a given opcode. -template <typename Previous> class HloInstructionPatternOpcodeImpl { public: - explicit constexpr HloInstructionPatternOpcodeImpl(const Previous& previous, - HloOpcode opcode, + explicit constexpr HloInstructionPatternOpcodeImpl(HloOpcode opcode, bool invert) - : previous_(previous), opcode_(opcode), invert_(invert) {} + : opcode_(opcode), invert_(invert) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && (invert_ ^ (inst->opcode() == opcode_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return (invert_ ^ (inst->opcode() == opcode_)); } private: - Previous previous_; HloOpcode opcode_; bool invert_; }; // An HloInstructionPattern implementation that matches only if the instruction // has a shape that matches a given pattern. -template <typename Previous, typename ShapeType, typename ShapeImpl> +template <typename ShapeType, typename ShapeImpl> class HloInstructionPatternShapeImpl { public: explicit constexpr HloInstructionPatternShapeImpl( - const Previous& previous, const ShapePattern<ShapeType, ShapeImpl>& shape) - : previous_(previous), shape_(shape) {} + const ShapePattern<ShapeType, ShapeImpl>& shape) + : shape_(shape) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(&inst->shape()); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(&inst->shape(), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && shape_.Match(inst->mutable_shape()); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return shape_.Match(inst->mutable_shape(), option); } private: - Previous previous_; ShapePattern<ShapeType, ShapeImpl> shape_; }; // An HloInstructionPattern implementation that matches only if the instruction // has an operand that matches a given pattern. -template <typename Previous, typename OperandType, typename OperandImpl> +template <typename OperandType, typename OperandImpl> class HloInstructionPatternOperandImpl { public: explicit constexpr HloInstructionPatternOperandImpl( - const Previous& previous, int64 operand_index, + int64 operand_index, const HloInstructionPattern<OperandType, OperandImpl>& operand) - : previous_(previous), operand_index_(operand_index), operand_(operand) {} + : operand_index_(operand_index), operand_(operand) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->operand(operand_index_)); + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->operand(operand_index_), option); } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && operand_index_ < inst->operand_count() && - operand_.Match(inst->mutable_operand(operand_index_)); + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return operand_index_ < inst->operand_count() && + operand_.Match(inst->mutable_operand(operand_index_), option); } private: - Previous previous_; int64 operand_index_; HloInstructionPattern<OperandType, OperandImpl> operand_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a fusion node with a particular kind. -template <typename Previous> class HloInstructionPatternFusionKindImpl { public: explicit constexpr HloInstructionPatternFusionKindImpl( - const Previous& previous, ::xla::HloInstruction::FusionKind kind) - : previous_(previous), kind_(kind) {} + ::xla::HloInstruction::FusionKind kind) + : kind_(kind) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && inst->opcode() == HloOpcode::kFusion && - inst->fusion_kind() == kind_; + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kFusion && inst->fusion_kind() == kind_; } private: - Previous previous_; ::xla::HloInstruction::FusionKind kind_; }; // An HloInstructionPattern implementation that matches only if the instruction // is a kGetTupleElement with a particular tuple index. -template <typename Previous> class HloInstructionPatternTupleIndexImpl { public: - explicit constexpr HloInstructionPatternTupleIndexImpl( - const Previous& previous, int64 tuple_index) - : previous_(previous), tuple_index_(tuple_index) {} + explicit constexpr HloInstructionPatternTupleIndexImpl(int64 tuple_index) + : tuple_index_(tuple_index) {} - bool Match(const ::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } - bool Match(::xla::HloInstruction* inst) const { - return previous_.Match(inst) && - inst->opcode() == HloOpcode::kGetTupleElement && + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + return inst->opcode() == HloOpcode::kGetTupleElement && inst->tuple_index() == tuple_index_; } private: - Previous previous_; int64 tuple_index_; }; +template <typename ItemType, typename Predicate> +class HloPredicatePatternImpl { + public: + explicit HloPredicatePatternImpl(Predicate pred) : pred_(std::move(pred)) {} + + bool Match(const ItemType* item, MatchOption option) const { + return pred_(item); + } + + bool Match(ItemType* item, MatchOption option) const { return pred_(item); } + + private: + Predicate pred_; +}; + +struct PatternFriend; + // A pattern that matches HloInstructions. template <typename HloInstructionType, typename Impl> class HloInstructionPattern { + private: + template <typename NewImpl> + HloInstructionPattern<HloInstructionType, + AllOfPattern<::xla::HloInstruction, Impl, NewImpl>> + AppendImpl(NewImpl new_impl) const { + return HloInstructionPattern< + HloInstructionType, AllOfPattern<::xla::HloInstruction, Impl, NewImpl>>( + AllOf<HloInstruction>(impl_, std::move(new_impl)), matched_inst_); + } + public: explicit constexpr HloInstructionPattern(const Impl& impl, HloInstructionType** matched_inst) : impl_(impl), matched_inst_(matched_inst) {} // Returns true and captures the instruction iff it matches the pattern. - bool Match(const ::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(const ::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -772,9 +884,9 @@ class HloInstructionPattern { } // Returns true and captures the instruction iff it matches the pattern. - bool Match(::xla::HloInstruction* inst) const { - if (impl_.Match(inst)) { - if (matched_inst_) { + bool Match(::xla::HloInstruction* inst, MatchOption option) const { + if (impl_.Match(inst, option)) { + if (option.capture && matched_inst_) { *matched_inst_ = inst; } return true; @@ -783,102 +895,87 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - HloInstructionPattern<HloInstructionType, HloInstructionPatternNameImpl<Impl>> - WithName(absl::string_view name) const { - return HloInstructionPattern<HloInstructionType, - HloInstructionPatternNameImpl<Impl>>( - HloInstructionPatternNameImpl<Impl>(impl_, name), matched_inst_); + auto WithName(absl::string_view name) const + -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>> - WithOpcode(HloOpcode opcode) const { - return HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>>( - HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, false), - matched_inst_); + auto WithOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + false))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>> - WithoutOpcode(HloOpcode opcode) const { - return HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>>( - HloInstructionPatternOpcodeImpl<Impl>(impl_, opcode, true), - matched_inst_); + auto WithoutOpcode(HloOpcode opcode) const + -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, + true))) { + return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>> - IsConstant() const { + constexpr auto IsConstant() const + -> decltype(this->WithOpcode(HloOpcode::kConstant)) { return WithOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternOpcodeImpl<Impl>> - IsNonConstant() const { + constexpr auto IsNonConstant() const + -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template <typename ShapeType, typename ShapeImpl> - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>> - WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>>( - HloInstructionPatternShapeImpl<Impl, ShapeType, ShapeImpl>(impl_, - shape), - matched_inst_); + constexpr auto WithShape(const ShapePattern<ShapeType, ShapeImpl>& shape) + const -> decltype(this->AppendImpl( + HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape))) { + return AppendImpl( + HloInstructionPatternShapeImpl<ShapeType, ShapeImpl>(shape)); } // Modifies the pattern to match only if the instruction has an operand that // matches the given pattern. template <typename OperandType, typename OperandImpl> - constexpr HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>> - WithOperand( + constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern<OperandType, OperandImpl>& operand) const { - return HloInstructionPattern< - HloInstructionType, - HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>>( - HloInstructionPatternOperandImpl<Impl, OperandType, OperandImpl>( - impl_, operand_index, operand), - matched_inst_); + const HloInstructionPattern<OperandType, OperandImpl>& operand) const + -> decltype(this->AppendImpl( + HloInstructionPatternOperandImpl<OperandType, OperandImpl>( + operand_index, operand))) { + return AppendImpl( + HloInstructionPatternOperandImpl<OperandType, OperandImpl>( + operand_index, operand)); } // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternFusionKindImpl<Impl>> - WithFusionKind(HloInstruction::FusionKind kind) const { - return HloInstructionPattern<HloInstructionType, - HloInstructionPatternFusionKindImpl<Impl>>( - HloInstructionPatternFusionKindImpl<Impl>(impl_, kind), matched_inst_); + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const + -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr HloInstructionPattern<HloInstructionType, - HloInstructionPatternTupleIndexImpl<Impl>> - WithTupleIndex(int64 tuple_index) const { - return HloInstructionPattern<HloInstructionType, - HloInstructionPatternTupleIndexImpl<Impl>>( - HloInstructionPatternTupleIndexImpl<Impl>(impl_, tuple_index), - matched_inst_); + constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( + this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } private: + template <typename Predicate> + constexpr auto WithPredicate(Predicate pred) const -> decltype( + this->AppendImpl(HloPredicatePatternImpl<HloInstruction, Predicate>( + std::move(pred)))) { + return AppendImpl( + HloPredicatePatternImpl<HloInstruction, Predicate>(std::move(pred))); + } + + friend struct PatternFriend; + Impl impl_; HloInstructionType** matched_inst_; }; @@ -1005,31 +1102,50 @@ XLA_UNOP_PATTERN(Transpose) .WithOperand(0, std::forward<Lhs>(lhs)) \ .WithOperand(1, std::forward<Rhs>(rhs)); \ } -XLA_BINOP_PATTERN(Add) + +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template <typename Lhs, typename Rhs> \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ + ->decltype(AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs))) { \ + return AnyOf<HloInstruction>(NAME(lhs, rhs), NAME(rhs, lhs)); \ + } \ + \ + template <typename HloInstructionType, typename Lhs, typename Rhs> \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) \ + ->decltype(AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs))) { \ + return AnyOf<HloInstructionType>(NAME(matched_inst, lhs, rhs), \ + NAME(matched_inst, rhs, lhs)); \ + } +XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) XLA_BINOP_PATTERN(Divide) XLA_BINOP_PATTERN(Complex) XLA_BINOP_PATTERN(Dot) -XLA_BINOP_PATTERN(Eq) +XLA_COMMUTATIVE_BINOP_PATTERN(Eq) XLA_BINOP_PATTERN(Gather) XLA_BINOP_PATTERN(Ge) XLA_BINOP_PATTERN(Gt) XLA_BINOP_PATTERN(Le) XLA_BINOP_PATTERN(Lt) -XLA_BINOP_PATTERN(Maximum) -XLA_BINOP_PATTERN(Minimum) -XLA_BINOP_PATTERN(Multiply) -XLA_BINOP_PATTERN(Ne) +XLA_COMMUTATIVE_BINOP_PATTERN(Maximum) +XLA_COMMUTATIVE_BINOP_PATTERN(Minimum) +XLA_COMMUTATIVE_BINOP_PATTERN(Multiply) +XLA_COMMUTATIVE_BINOP_PATTERN(Ne) XLA_BINOP_PATTERN(Outfeed) XLA_BINOP_PATTERN(Power) XLA_BINOP_PATTERN(Remainder) XLA_BINOP_PATTERN(Send) XLA_BINOP_PATTERN(Subtract) -XLA_BINOP_PATTERN(And) -XLA_BINOP_PATTERN(Or) +XLA_COMMUTATIVE_BINOP_PATTERN(And) +XLA_COMMUTATIVE_BINOP_PATTERN(Or) XLA_BINOP_PATTERN(ShiftLeft) XLA_BINOP_PATTERN(ShiftRightArithmetic) XLA_BINOP_PATTERN(ShiftRightLogical) +#undef XLA_COMMUTATIVE_BINOP_PATTERN #undef XLA_BINOP_PATTERN // Helpers for ternary instructions. @@ -1070,6 +1186,30 @@ XLA_TERNOP_PATTERN(Clamp); XLA_TERNOP_PATTERN(Select); #undef XLA_TERNOP_PATTERN +namespace detail { +struct PatternFriend { + template <typename T> + static auto ConstantScalar(T constant) -> decltype( + Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate( + std::declval<std::function<bool(const HloInstruction*)>>())) { + std::function<bool(const HloInstruction*)> pred = + [constant](const HloInstruction* instr) { + const auto& literal = Cast<HloConstantInstruction>(instr)->literal(); + auto status_or_const = LiteralUtil::CreateR0(constant).Convert( + literal.shape().element_type()); + return status_or_const.ok() && + literal == status_or_const.ConsumeValueOrDie(); + }; + + return Constant() + .WithShape(match::Shape().IsScalar()) + .WithPredicate(std::move(pred)); + } +}; +} // namespace detail + // Helpers for matching non-constant instructions. inline auto NonConstant() -> decltype(Op().IsNonConstant()) { return Op().IsNonConstant(); @@ -1107,6 +1247,12 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, .WithTupleIndex(tuple_index); } +template <typename T> +inline auto ConstantScalar(T constant) + -> decltype(detail::PatternFriend::ConstantScalar(constant)) { + return detail::PatternFriend::ConstantScalar(constant); +} + } // namespace match } // namespace xla diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index a530581c34..3ab7b7fd71 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -211,5 +211,188 @@ TEST(PatternMatcherTest, GetTupleElement) { EXPECT_TRUE(Match(root, match::GetTupleElement(match::Op(), 1))); } +TEST(PatternMatcherTest, AnyOf) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0), + match::ConstantScalar(1)))); + EXPECT_TRUE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(1), + match::ConstantScalar(0)))); + EXPECT_FALSE( + Match(root, match::AnyOf<HloInstruction>(match::ConstantScalar(0), + match::ConstantScalar(2)))); +} + +TEST(PatternMatcherTest, ConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(42) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_TRUE(Match(root, match::ConstantScalar(42))); + EXPECT_FALSE(Match(root, match::ConstantScalar(41))); + EXPECT_FALSE(Match(root, match::ConstantScalar(0))); +} + +TEST(PatternMatcherTest, NoMatchConstantScalar) { + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT v = f16[] parameter(0) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + EXPECT_FALSE(Match(root, match::ConstantScalar(42))); +} + +TEST(PatternMatcherTest, MultiplyAnyOrder) { + using match::ConstantScalar; + using match::MultiplyAnyOrder; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + const HloInstruction* instr; + + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(42), ConstantScalar(52)))); + EXPECT_TRUE(Match( + root, MultiplyAnyOrder(&instr, ConstantScalar(52), ConstantScalar(42)))); +} + +TEST(PatternMatcherTest, AnyOfShortCircuit) { + using match::AnyOf; + using match::Multiply; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + lhs = f16[] constant(42) + rhs = f16[] constant(52) + ROOT multiply = f16[] multiply(lhs, rhs) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf<HloInstruction>(Multiply(&mul, Op(), Op()), Op(&any)))); + EXPECT_NE(nullptr, mul); + EXPECT_EQ(nullptr, any); + } + { + const HloInstruction* mul = nullptr; + const HloInstruction* any = nullptr; + + ASSERT_TRUE(Match( + root, AnyOf<HloInstruction>(Op(&any), Multiply(&mul, Op(), Op())))); + EXPECT_NE(nullptr, any); + EXPECT_EQ(nullptr, mul); + } +} + +TEST(PatternMatcherTest, AllOf) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module ENTRY test { ROOT constant = f16[] constant(1) })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + auto scalar_pattern = Constant().WithShape(match::Shape().IsScalar()); + auto f16_pattern = Constant().WithShape(match::Shape().WithElementType(F16)); + ASSERT_TRUE(Match(root, scalar_pattern)); + ASSERT_TRUE(Match(root, f16_pattern)); + EXPECT_TRUE(Match(root, AllOf<HloInstruction>(scalar_pattern, f16_pattern))); + EXPECT_TRUE(Match(root, AllOf<HloInstruction>(f16_pattern, scalar_pattern))); + EXPECT_FALSE( + Match(root, AllOf<HloInstruction>(Broadcast(Op()), f16_pattern))); + EXPECT_FALSE( + Match(root, AllOf<HloInstruction>(Broadcast(Op()), scalar_pattern))); +} + +TEST(PatternMatcherTest, AllOfNoCaptureIfNotMatch) { + using match::AllOf; + using match::Broadcast; + using match::Constant; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_FALSE( + Match(root, AllOf<HloInstruction>(Constant(&constant), Broadcast(Op())))); + EXPECT_EQ(nullptr, constant); + ASSERT_TRUE(Match(root, Constant(&constant))); + EXPECT_NE(nullptr, constant); +} + +TEST(PatternMatcherTest, TestNoCapture) { + using match::Constant; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + ROOT v = f16[] constant(42) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* constant = nullptr; + ASSERT_TRUE(Match(root, Constant(&constant), {/*capture=*/false})); + EXPECT_EQ(nullptr, constant); +} + +TEST(PatternMatcherTest, TestCaptureMatchedSubPatternForAnyOf) { + using match::Add; + using match::AddAnyOrder; + using match::AnyOf; + using match::Op; + + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + u = f16[] parameter(0) + v = f16[] parameter(1) + ROOT add = f16[] add(u, v) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + + const HloInstruction* addend0 = nullptr; + const HloInstruction* addend1 = nullptr; + const HloInstruction* addend2 = nullptr; + auto add2_pattern = Add(Op(&addend0), Op(&addend1)); + auto add3_pattern = AnyOf<HloInstruction>( + AddAnyOrder(add2_pattern, Op(&addend2)), add2_pattern, Op(&addend0)); + + ASSERT_TRUE(Match(root, add3_pattern)); + EXPECT_NE(nullptr, addend0); + EXPECT_NE(nullptr, addend1); + EXPECT_EQ(nullptr, addend2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 178a78ede0..c522e7ae23 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -217,9 +218,12 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { if (platform->id() == se::host::kHostPlatformId) { // On host "devices", StreamExecutor exports a device for each hardware // thread. Because we parallelize a single computation across threads, it - // doesn't make sense to expose these as separate devices, so fix the number - // of devices to one. - device_count = 1; + // doesn't make sense to expose these as separate devices, so by default we + // fix the number of devices to one. However we do let the user override + // this behavior to help run tests on the host that run models in parallel + // across multiple devices. + device_count = legacy_flags::GetDebugOptionsFromFlags() + .xla_force_host_platform_device_count(); } std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr); VLOG(1) << "Initializing devices"; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h index 256b231e3a..0b4e82e8d6 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h @@ -22,14 +22,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { // HLO pass which inserts reduce-precision instructions into the HLO graph, for // purposes of experimenting with the effects of reduced-precision storage of // intermediate values. -class ReducePrecisionInsertion : public HloPassInterface { +class ReducePrecisionInsertion : public HloModulePass { using InstructionFilterFunction = std::function<bool(const HloInstruction*)>; public: diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index 1e86a0823a..a3db439e34 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -24,7 +24,7 @@ namespace xla { // This now only moves them outputward across elementwise ops all whose operands // are equivalent Reshapes or Transposes, but in future could potentially move // them inputward also. -class ReshapeMover : public HloPassInterface { +class ReshapeMover : public HloModulePass { public: absl::string_view name() const override { return "reshape-mover"; } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index 2f4b2667c4..de7aee262e 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -155,6 +155,53 @@ static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( return MakeConcatHlo(expanded_index_components, /*dimension=*/0); } +static StatusOr<HloInstruction*> CheckIndexValidity( + HloComputation* computation, HloInstruction* index, + absl::Span<const int64> operand_dims, absl::Span<const int64> window_sizes, + HloModule* module) { + DCHECK_NE(nullptr, module); + DCHECK_EQ(operand_dims.size(), window_sizes.size()); + + // Valid range for the index: [0, operand_dims - window_sizes] + + // Check if the index has any negative values. + TF_ASSIGN_OR_RETURN( + HloInstruction * zero_index, + BroadcastZeros(computation, index->shape().element_type(), + AsInt64Slice(index->shape().dimensions()))); + TF_ASSIGN_OR_RETURN(HloInstruction * negative_index_check, + MakeBinaryHlo(HloOpcode::kLe, zero_index, index)); + + // Check if the index is OOB w.r.t. the operand dimensions and window sizes. + std::vector<int64> max_valid_index(operand_dims.size()); + for (int i = 0; i < operand_dims.size(); ++i) { + max_valid_index[i] = operand_dims[i] - window_sizes[i]; + } + TF_ASSIGN_OR_RETURN( + HloInstruction * max_valid_index_constant, + MakeR1ConstantHlo<int64>(computation, index->shape().element_type(), + max_valid_index)); + TF_ASSIGN_OR_RETURN( + HloInstruction * oob_index_check, + MakeBinaryHlo(HloOpcode::kGe, max_valid_index_constant, index)); + + // Combine the results of the two checks above. + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index, + MakeBinaryHlo(HloOpcode::kAnd, negative_index_check, oob_index_check)); + + // Reduce the index validity check vector into a scalar predicate. + auto reduction_init = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true))); + TF_ASSIGN_OR_RETURN( + HloInstruction * valid_index_reduced, + MakeReduceHlo(valid_index, reduction_init, HloOpcode::kAnd, module)); + + // Return a broadcasted value of the scalar predicate to the same size as the + // window. + return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); +} + // Body of the while loop that performs the scatter operation using other HLOs. static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody( HloInstruction* scatter, HloInstruction* induction_var, @@ -222,7 +269,16 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody( InsertDegenerateDims(update_slice_for_scatter, AsInt64Slice(dim_numbers.inserted_window_dims()))); - // Extact the slice to update from `operand` tensor. + // Note that the following transformation assumes that both DynamicSlice and + // DynamicUpdateSlice follow the same semantics for OOB indices. For example, + // if there are negative indices and DynamicSlice uses "clamping" semantics, + // then the extracted data will be "shifted". Since DynamicUpdateSlice also + // follows the same "clamping" semantics, writing the update will also be + // "shifted" by exactly the same amount. So, this transformation is correct as + // long as the semantics of handling OOB indices remain the same in + // DynamicSlice and DynamicUpdateSlice. + + // Extract the slice to update from `operand` tensor. const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); TF_ASSIGN_OR_RETURN( HloInstruction * operand_slice_to_update, @@ -237,10 +293,24 @@ static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody( MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, scatter->to_apply())); + TF_ASSIGN_OR_RETURN( + HloInstruction * is_index_valid, + CheckIndexValidity( + operand->parent(), scatter_slice_start, + AsInt64Slice(operand->shape().dimensions()), + AsInt64Slice(update_slice_with_dims_inserted->shape().dimensions()), + scatter->GetModule())); + + // Select the updated operand only if the index is valid. If not, select the + // original value. + TF_ASSIGN_OR_RETURN(HloInstruction * update_to_apply, + MakeSelectHlo(is_index_valid, updated_operand_slice, + operand_slice_to_update)); + // Write the updated value of the slice into `operand` tensor. - TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, - MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, - scatter_slice_start)); + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, update_to_apply, scatter_slice_start)); return StatusOr<std::vector<HloInstruction*>>{ {updated_operand, scatter_indices, updates}}; diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 14f062c89c..559a85dccf 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -20,7 +20,7 @@ limitations under the License. namespace xla { -class ScatterExpander : public HloPassInterface { +class ScatterExpander : public HloModulePass { public: absl::string_view name() const override { return "scatter_expander"; } StatusOr<bool> Run(HloModule* module) override; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 74bdf2a2e3..e379911462 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -22,6 +22,7 @@ limitations under the License. #include <string> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers( // Check that dimension numbers are unique. auto dims_unique = [](absl::Span<const int64> contracting_dims, absl::Span<const int64> batch_dims) -> bool { - tensorflow::gtl::FlatSet<int64> dim_set; + absl::flat_hash_set<int64> dim_set; auto is_unique = [&dim_set](int64 i) -> bool { return dim_set.insert(i).second; }; @@ -1665,10 +1665,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (input_features != kernel_input_features * feature_group_count) { return InvalidArgument( "Expected LHS feature dimension (value %d) to match RHS " - "input feature dimension * feature_group_count (value %d); " + "input feature dimension * feature_group_count (value %d * %d = %d); " "got <conv>(%s, %s)\n" "Dimension numbers: {%s}.", - input_features, kernel_input_features * feature_group_count, + input_features, kernel_input_features, feature_group_count, + kernel_input_features * feature_group_count, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } @@ -2379,7 +2380,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Transpose dimensions not a permutation of the operand dimensions."); + "Transpose dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + StrJoin(dimensions, ","), ShapeUtil::HumanString(operand)); } // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However, diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 921a984589..56952e3ada 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -18,6 +18,7 @@ limitations under the License. #include <string> #include <utility> +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. - tensorflow::gtl::FlatSet<void*> deallocated_ptrs; + absl::flat_hash_set<void*> deallocated_ptrs; for (auto& pair : buffers_) { se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index 5d1cd1c442..ec09dff924 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { // Re-use an existing stream from the pool. stream = std::move(streams_.back()); streams_.pop_back(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream"; + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } else { + VLOG(1) << stream->DebugStreamPointers() + << " stream was not ok, StreamPool deleting"; + stream = nullptr; + } } } diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc index aaf5c37b0d..92f47579d3 100644 --- a/tensorflow/compiler/xla/service/stream_pool_test.cc +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) { EXPECT_EQ(stream2_ptr, stream3_ptr); } +TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { + std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Return the stream, but hold a handle to it. + se::Stream* stream1_ptr = stream1.get(); + stream1 = nullptr; + + // Now stream1 is back in the pool, force an error on the stream. Here we call + // a method that requires DNN support, which we know the Host platform doesn't + // support. + stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1_ptr->ok()); + + // Borrow stream2. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h index 3e5aa2db60..f95f982eb8 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.h +++ b/tensorflow/compiler/xla/service/transpose_folding.h @@ -23,7 +23,7 @@ namespace xla { // HLO pass that folds transpose operators into Dot operators, where the Dot // operator is implemented by a GEMM kernel that can transpose its inputs. -class TransposeFolding : public HloPassInterface { +class TransposeFolding : public HloModulePass { public: using OperandIndices = std::vector<int64>; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 6fed7c76d0..811ac55e2d 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -280,16 +280,6 @@ Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { - // A kSlice instruction aliases its operand if the backend lowers it to an - // in-place implementation. - if (slice->IsInPlaceSlice()) { - CreateCopiedPointsToSet(slice, slice->operand(0)); - return Status::OK(); - } - return DefaultAction(slice); -} - Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to element {0} of its // output. The other indices ({} and {1}) define their own buffers. @@ -455,15 +445,10 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - // kSlice ops that are lowered to an in-place version are expected to not - // define their output buffer. - if (buffer.instruction()->opcode() != HloOpcode::kSlice || - !buffer.instruction()->IsInPlaceSlice()) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString(), buffer.instruction()->name()); - } + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString(), buffer.instruction()->name()); } if (buffer.id() < 0 || diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index a9e8a51e09..30c365053c 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -36,8 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/compactptrset.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -249,7 +247,6 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleDomain(HloInstruction* domain) override; - Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index 8c91d6e69d..e126a53023 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -25,7 +25,7 @@ namespace xla { // A pass which simplifies patterns of Tuple and GetTupleElement instructions in // the module. -class TupleSimplifier : public HloPassInterface { +class TupleSimplifier : public HloModulePass { public: TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} explicit TupleSimplifier(bool exclude_entry_computation); diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index 56145822be..067cfcc17d 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h index 2dba7d7f75..577bad6c70 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.h +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.h @@ -50,7 +50,7 @@ namespace xla { // conditions as well. // // TODO(b/79121449): We should also sink broadcasts of constants. -class WhileLoopConstantSinking : public HloPassInterface { +class WhileLoopConstantSinking : public HloModulePass { public: ~WhileLoopConstantSinking() override = default; diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index e8fe33e626..9795b2830b 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -15,18 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { +using absl::flat_hash_map; +using absl::flat_hash_set; using absl::InlinedVector; -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; // Copies `to_hoist` to the computation containing `while_instr`, hoisting its // operands as needed. All of its transitive operands are expected to be either @@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet; // function hoists the operands in `unhoisted_invariant_instructions` and moves // them into `hoisted_instructions`. static void CreateLoopInvariantCopy( - FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions, - FlatSet<HloInstruction*>* unhoisted_invariant_instructions, + flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions, + flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions, HloInstruction* while_instr, HloInstruction* to_hoist) { HloComputation* parent_of_while = while_instr->parent(); HloComputation* while_body = while_instr->while_body(); @@ -147,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( // Maps instructions in the while body to instructions hoisted outside the // while that compute the same value. - FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions; + flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions; // Contains instructions that can be legally hoisted, but were deemed to be // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we // hoist an instruction in this set, we move it from // unhoisted_invariant_instructions to hoisted_instructions. - FlatSet<HloInstruction*> unhoisted_invariant_instructions; + flat_hash_set<HloInstruction*> unhoisted_invariant_instructions; // Invariant GTE's axiomatically satisfy the constraints for // unhoisted_invariant_instructions -- they can be legally hoisted, but there diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 2cdf20ce80..3031899f71 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -25,7 +25,7 @@ namespace xla { // HLO pass that rewrites while loops to hoist loop invariant instructions in // the while body into the computation that contains the while instruction. -class WhileLoopInvariantCodeMotion : public HloPassInterface { +class WhileLoopInvariantCodeMotion : public HloModulePass { public: // If `hoist_constants` is true then constants are always hoisted out of while // loop bodies. Otherwise they are only hoisted out if they enable other diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index 6a7bfe3f12..630d71e5ca 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/while_loop_analysis.h" -#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } - tensorflow::gtl::FlatSet<int64> used_tuple_indices; + absl::flat_hash_set<int64> used_tuple_indices; for (HloComputation* comp : {while_body, while_cond}) { // The HLO verifier ensures that while_input's shape matches while_init's // shape, which we verified above is a tuple. @@ -181,7 +182,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.end()); std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); - tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx; + absl::flat_hash_map<int64, int64> old_to_new_tuple_idx; for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { int64 old_idx = new_to_old_tuple_idx[new_idx]; old_to_new_tuple_idx[old_idx] = new_idx; @@ -252,7 +253,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { // Create the new while condition, body, and init value. std::unique_ptr<HloComputation> new_while_cond = while_cond->CloneWithReplacements( - make_while_computation_replacements(while_cond)); + make_while_computation_replacements(while_cond), /*extras=*/{}); std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>> while_body_replacements = make_while_computation_replacements(while_body); @@ -265,7 +266,8 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { while_body_replacements.emplace( while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); std::unique_ptr<HloComputation> new_while_body = - while_body->CloneWithReplacements(std::move(while_body_replacements)); + while_body->CloneWithReplacements(std::move(while_body_replacements), + /*extras=*/{}); // Add a new while_init instruction that repackages the old while_init // instruction's elements. We rely on the AlgebraicSimplifier and DCE to @@ -404,7 +406,7 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) { // build a map from the tuple element index to the constant value. Limit this // to scalar constant values because propagating array constants can regress // performance by forcing us to copy constants. - tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant; + absl::flat_hash_map<int, const HloInstruction*> index_to_constant; for (int i = 0; i < root_operands.size(); i++) { HloInstruction* instr = root_operands[i]; if (instr->opcode() == HloOpcode::kGetTupleElement && diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 78024f14dc..0bc5a0107b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -30,7 +30,7 @@ namespace xla { // - Elements of a while loop's tuple that the loop doesn't use are removed // from the tuple. // -class WhileLoopSimplifier : public HloPassInterface { +class WhileLoopSimplifier : public HloModulePass { public: ~WhileLoopSimplifier() override {} absl::string_view name() const override { return "simplify-while-loops"; } diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index a7f0e207eb..87294120d5 100644 --- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -21,7 +21,7 @@ limitations under the License. // HLO pass that replaces zero sized Hlos with a zero sized constant literal. namespace xla { -class ZeroSizedHloElimination : public HloPassInterface { +class ZeroSizedHloElimination : public HloModulePass { public: StatusOr<bool> Run(HloModule* module) override; absl::string_view name() const override { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 96c80fd577..d244923532 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -422,8 +422,11 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); - CHECK_EQ(shape.dimensions_size(), Rank(shape)); + DCHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); + DCHECK_EQ(shape.dimensions_size(), Rank(shape)); + if (shape.dimensions().size() == 1) { + return shape.dimensions()[0]; + } return std::accumulate<decltype(shape.dimensions().begin()), int64>( shape.dimensions().begin(), shape.dimensions().end(), 1LL, std::multiplies<int64>()); @@ -828,7 +831,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID || + !PrimitiveType_IsValid(shape.element_type())) { return InvalidArgument("shape has invalid element type: %s", shape.ShortDebugString()); } @@ -865,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) { return Status::OK(); } - if (Rank(shape) != shape.dimensions_size()) { - return InvalidArgument( - "shape's rank is mismatched with dimension count; rank=%d " - "dimensions_size=%d", - Rank(shape), shape.dimensions_size()); + if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) { + return InvalidArgument("sparse arrays must have rank > 0"); } for (int64 i = 0; i < Rank(shape); ++i) { int64 dimension = shape.dimensions(i); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 623ae39de8..d8bb27beae 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,7 @@ limitations under the License. #include <initializer_list> #include <string> +#include "absl/base/macros.h" #include "absl/container/inlined_vector.h" #include "absl/types/optional.h" #include "absl/types/span.h" @@ -479,8 +480,7 @@ class ShapeUtil { // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. - // - // DEPRECATED: Use Equal() instead. + ABSL_DEPRECATED("Use Equal() instead.") static bool ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list<int64> dimensions); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 30e3077edb..8a0ae33042 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -29,6 +29,10 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites" load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) # Generate test_suites for all backends, named "${backend}_tests". generate_backend_suites() @@ -150,11 +154,31 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/core:lib", - "//tensorflow/core:test", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", ], ) +tf_cc_test( + name = "hlo_verified_test_base_test", + srcs = ["hlo_verified_test_base_test.cc"], + deps = [ + ":hlo_test_base", + ":hlo_verified_test_base", + ":test_macros_cpu", + ":test_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + tf_cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], @@ -398,6 +422,7 @@ xla_test( "//tensorflow/core:regexp_internal", "//tensorflow/core:test", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", ], ) @@ -1797,7 +1822,7 @@ xla_test( tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:test_helpers", @@ -2096,7 +2121,7 @@ tf_cc_test( name = "sample_file_test", srcs = ["sample_file_test.cc"], data = ["isolated_convolution.hlo"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":hlo_test_base", "//tensorflow/compiler/xla:test", @@ -2121,11 +2146,11 @@ xla_test( ":test_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", "//tensorflow/core:test", + "@com_google_absl//absl/container:flat_hash_set", ], ) @@ -2144,3 +2169,21 @@ xla_test( "//tensorflow/core:lib", ], ) + +tf_cc_test( + name = "multiple_devices_on_host_test", + srcs = ["multiple_devices_on_host_test.cc"], + args = ["--xla_force_host_platform_device_count=4"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + "@com_google_absl//absl/synchronization", + ], +) diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 53f2c3bfbf..05d4d04034 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -3,256 +3,266 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) all_backends = ["cpu", "gpu"] + plugins.keys() def filter_backends(backends): - """Removes "gpu" from a backend list if CUDA is not enabled. - - This allows us to simply hardcode lists including "gpu" here and in the - BUILD file, without causing failures when CUDA isn't enabled.' - - Args: - backends: A list of backends to filter. - - Returns: - The filtered list of backends. - """ - if cuda_is_configured(): - return backends - else: - return [backend for backend in backends if backend != "gpu"] - - -def xla_test(name, - srcs, - deps, - xla_test_library_deps=[], - backends=[], - blacklisted_backends=[], - args=[], - tags=[], - copts=[], - data=[], - backend_tags={}, - backend_args={}, - **kwargs): - """Generates cc_test targets for the given XLA backends. - - This rule generates a cc_test target for one or more XLA backends and also a - platform-agnostic cc_library rule. The arguments are identical to cc_test with - two additions: 'backends' and 'backend_args'. 'backends' specifies the - backends to generate tests for ("cpu", "gpu"), and - 'backend_args'/'backend_tags' specifies backend-specific args parameters to - use when generating the cc_test. - - The name of the cc_tests are the provided name argument with the backend name - appended, and the cc_library target name is the provided name argument with - "_lib" appended. For example, if name parameter is "foo_test", then the cpu - test target will be "foo_test_cpu" and the cc_library target is "foo_lib". - - The cc_library target can be used to link with other plugins outside of - xla_test. - - The build rule also defines a test suite ${name} which includes the tests for - each of the supported backends. - - Each generated cc_test target has a tag indicating which backend the test is - for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These - tags can be used to gather tests for a particular backend into a test_suite. - - Examples: - - # Generates the targets: foo_test_cpu and foo_test_gpu. - xla_test( - name = "foo_test", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) + """Removes "gpu" from a backend list if CUDA is not enabled. - # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu - # includes the additional arg "--special_cpu_flag". - xla_test( - name = "bar_test", - srcs = ["bar_test.cc"], - backends = ["cpu", "gpu"], - backend_args = {"cpu": ["--special_cpu_flag"]} - deps = [...], - ) + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' - The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} - to the value 1 where ${BACKEND} is the uppercase name of the backend. - - Args: - name: Name of the target. - srcs: Sources for the target. - deps: Dependencies of the target. - xla_test_library_deps: If set, the generated test targets will depend on the - respective cc_libraries generated by the xla_test_library rule. - backends: A list of backends to generate tests for. Supported values: "cpu", - "gpu". If this list is empty, the test will be generated for all supported - backends. - blacklisted_backends: A list of backends to NOT generate tests for. - args: Test arguments for the target. - tags: Tags for the target. - copts: Additional copts to pass to the build. - data: Additional data to pass to the build. - backend_tags: A dict mapping backend name to list of additional tags to - use for that target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. - **kwargs: Additional keyword arguments to pass to native.cc_test. - """ - test_names = [] - if not backends: - backends = all_backends - - backends = [backend for backend in backends - if backend not in blacklisted_backends] - - native.cc_library( - name="%s_lib" % name, - srcs=srcs, - copts=copts, - testonly=True, - deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], - ) - - for backend in filter_backends(backends): - test_name = "%s_%s" % (name, backend) - this_backend_tags = ["xla_%s" % backend] - this_backend_copts = [] - this_backend_args = backend_args.get(backend, []) - this_backend_data = [] - if backend == "cpu": - backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] - elif backend == "gpu": - backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] - backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] - this_backend_tags += ["requires-gpu-sm35"] - elif backend in plugins: - backend_deps = [] - backend_deps += plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] - this_backend_tags += plugins[backend]["tags"] - this_backend_args += plugins[backend]["args"] - this_backend_data += plugins[backend]["data"] - else: - fail("Unknown backend %s" % backend) - - if xla_test_library_deps: - for lib_dep in xla_test_library_deps: - backend_deps += ["%s_%s" % (lib_dep, backend)] - - tf_cc_test( - name=test_name, - srcs=srcs, - tags=tags + backend_tags.get(backend, []) + this_backend_tags, - extra_copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + - this_backend_copts, - args=args + this_backend_args, - deps=deps + backend_deps, - data=data + this_backend_data, - **kwargs) - - test_names.append(test_name) - - native.test_suite(name=name, tests=test_names) - -def xla_test_library(name, - srcs, - hdrs=[], - deps=[], - backends=[]): - """Generates cc_library targets for the given XLA backends. - - This rule forces the sources to be compiled for each backend so that the - backend specific macros could expand correctly. It's useful when test targets - in different directories referring to the same sources but test with different - arguments. - - Examples: - - # Generates the targets: foo_test_library_cpu and foo_test_gpu. - xla_test_library( - name = "foo_test_library", - srcs = ["foo_test.cc"], - backends = ["cpu", "gpu"], - deps = [...], - ) - # Then use the xla_test rule to generate test targets: - xla_test( - name = "foo_test", - srcs = [], - backends = ["cpu", "gpu"], - deps = [...], - xla_test_library_deps = [":foo_test_library"], - ) + Args: + backends: A list of backends to filter. - Args: - name: Name of the target. - srcs: Sources for the target. - hdrs: Headers for the target. - deps: Dependencies of the target. - backends: A list of backends to generate libraries for. - Supported values: "cpu", "gpu". If this list is empty, the - library will be generated for all supported backends. - """ - - if not backends: - backends = all_backends - - for backend in filter_backends(backends): - this_backend_copts = [] - if backend in ["cpu", "gpu"]: - backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] - elif backend in plugins: - backend_deps = plugins[backend]["deps"] - this_backend_copts += plugins[backend]["copts"] + Returns: + The filtered list of backends. + """ + if cuda_is_configured(): + return backends else: - fail("Unknown backend %s" % backend) + return [backend for backend in backends if backend != "gpu"] + +def xla_test( + name, + srcs, + deps, + xla_test_library_deps = [], + backends = [], + blacklisted_backends = [], + args = [], + tags = [], + copts = [], + data = [], + backend_tags = {}, + backend_args = {}, + **kwargs): + """Generates cc_test targets for the given XLA backends. + + This rule generates a cc_test target for one or more XLA backends and also a + platform-agnostic cc_library rule. The arguments are identical to cc_test with + two additions: 'backends' and 'backend_args'. 'backends' specifies the + backends to generate tests for ("cpu", "gpu"), and + 'backend_args'/'backend_tags' specifies backend-specific args parameters to + use when generating the cc_test. + + The name of the cc_tests are the provided name argument with the backend name + appended, and the cc_library target name is the provided name argument with + "_lib" appended. For example, if name parameter is "foo_test", then the cpu + test target will be "foo_test_cpu" and the cc_library target is "foo_lib". + + The cc_library target can be used to link with other plugins outside of + xla_test. + + The build rule also defines a test suite ${name} which includes the tests for + each of the supported backends. + + Each generated cc_test target has a tag indicating which backend the test is + for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These + tags can be used to gather tests for a particular backend into a test_suite. + + Examples: + + # Generates the targets: foo_test_cpu and foo_test_gpu. + xla_test( + name = "foo_test", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + + # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu + # includes the additional arg "--special_cpu_flag". + xla_test( + name = "bar_test", + srcs = ["bar_test.cc"], + backends = ["cpu", "gpu"], + backend_args = {"cpu": ["--special_cpu_flag"]} + deps = [...], + ) + + The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND} + to the value 1 where ${BACKEND} is the uppercase name of the backend. + + Args: + name: Name of the target. + srcs: Sources for the target. + deps: Dependencies of the target. + xla_test_library_deps: If set, the generated test targets will depend on the + respective cc_libraries generated by the xla_test_library rule. + backends: A list of backends to generate tests for. Supported values: "cpu", + "gpu". If this list is empty, the test will be generated for all supported + backends. + blacklisted_backends: A list of backends to NOT generate tests for. + args: Test arguments for the target. + tags: Tags for the target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. + backend_tags: A dict mapping backend name to list of additional tags to + use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. + """ + test_names = [] + if not backends: + backends = all_backends + + backends = [ + backend + for backend in backends + if backend not in blacklisted_backends + ] native.cc_library( - name = "%s_%s" % (name, backend), + name = "%s_lib" % name, srcs = srcs, + copts = copts, testonly = True, - hdrs = hdrs, - copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] - + this_backend_copts, - deps = deps + backend_deps, + deps = deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - -def generate_backend_suites(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - native.test_suite(name="%s_tests" % backend, - tags = ["xla_%s" % backend]) - - -def generate_backend_test_macros(backends=[]): - if not backends: - backends = all_backends - for backend in filter_backends(backends): - manifest = "" - if backend in plugins: - manifest = plugins[backend]["disabled_manifest"] - - native.cc_library( - name="test_macros_%s" % backend, - testonly = True, - srcs = ["test_macros.cc"], - hdrs = ["test_macros.h"], - copts = [ - "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), - "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, - ], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", - ]) + for backend in filter_backends(backends): + test_name = "%s_%s" % (name, backend) + this_backend_tags = ["xla_%s" % backend] + this_backend_copts = [] + this_backend_args = backend_args.get(backend, []) + this_backend_data = [] + if backend == "cpu": + backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"] + elif backend == "gpu": + backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"] + backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"] + this_backend_tags += tf_cuda_tests_tags() + elif backend in plugins: + backend_deps = [] + backend_deps += plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + this_backend_tags += plugins[backend]["tags"] + this_backend_args += plugins[backend]["args"] + this_backend_data += plugins[backend]["data"] + else: + fail("Unknown backend %s" % backend) + + if xla_test_library_deps: + for lib_dep in xla_test_library_deps: + backend_deps += ["%s_%s" % (lib_dep, backend)] + + tf_cc_test( + name = test_name, + srcs = srcs, + tags = tags + backend_tags.get(backend, []) + this_backend_tags, + extra_copts = copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + args = args + this_backend_args, + deps = deps + backend_deps, + data = data + this_backend_data, + **kwargs + ) + + test_names.append(test_name) + + native.test_suite(name = name, tests = test_names) + +def xla_test_library( + name, + srcs, + hdrs = [], + deps = [], + backends = []): + """Generates cc_library targets for the given XLA backends. + + This rule forces the sources to be compiled for each backend so that the + backend specific macros could expand correctly. It's useful when test targets + in different directories referring to the same sources but test with different + arguments. + + Examples: + + # Generates the targets: foo_test_library_cpu and foo_test_gpu. + xla_test_library( + name = "foo_test_library", + srcs = ["foo_test.cc"], + backends = ["cpu", "gpu"], + deps = [...], + ) + # Then use the xla_test rule to generate test targets: + xla_test( + name = "foo_test", + srcs = [], + backends = ["cpu", "gpu"], + deps = [...], + xla_test_library_deps = [":foo_test_library"], + ) + + Args: + name: Name of the target. + srcs: Sources for the target. + hdrs: Headers for the target. + deps: Dependencies of the target. + backends: A list of backends to generate libraries for. + Supported values: "cpu", "gpu". If this list is empty, the + library will be generated for all supported backends. + """ + + if not backends: + backends = all_backends + + for backend in filter_backends(backends): + this_backend_copts = [] + if backend in ["cpu", "gpu"]: + backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend] + elif backend in plugins: + backend_deps = plugins[backend]["deps"] + this_backend_copts += plugins[backend]["copts"] + else: + fail("Unknown backend %s" % backend) + + native.cc_library( + name = "%s_%s" % (name, backend), + srcs = srcs, + testonly = True, + hdrs = hdrs, + copts = ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] + + this_backend_copts, + deps = deps + backend_deps, + ) + +def generate_backend_suites(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + native.test_suite( + name = "%s_tests" % backend, + tags = ["xla_%s" % backend, "-broken", "manual"], + ) + +def generate_backend_test_macros(backends = []): + if not backends: + backends = all_backends + for backend in filter_backends(backends): + manifest = "" + if backend in plugins: + manifest = plugins[backend]["disabled_manifest"] + + native.cc_library( + name = "test_macros_%s" % backend, + testonly = True, + srcs = ["test_macros.cc"], + hdrs = ["test_macros.h"], + copts = [ + "-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper(), + "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, + ], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], + ) diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 070b092d18..b851db14ec 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { XlaBuilder builder(TestName()); auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs); auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs); - Conv(lhs, rhs, {1, 1}, Padding::kValid); + PrecisionConfig precision; + // The left hand side of the convolution is numbers between 0 and 2304 which + // requires at least 11 mantissa bits and the DEFAULT precision config is + // allowed to round to bfloat16 which only has 7 mantissa bits. + precision.add_operand_precision(PrecisionConfig::HIGHEST); + precision.add_operand_precision(PrecisionConfig::DEFAULT); + Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1, + &precision); ComputeAndCompare(&builder, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 0171f51583..6c0847a875 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -394,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { ParametricDotTestWithoutLayoutAssignment() { execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( "layout-assignment"); + // Disable algebraic simplification because the pass may replace a dot + // instruction with a layout-changing multiplication instruction. + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "algsimp"); } }; @@ -404,31 +408,18 @@ std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() { for (bool lhs_row_major : {true, false}) { for (bool rhs_row_major : {true, false}) { for (bool has_addend : {true, false}) { + // The addend needs to be row major to match the result of the dot. params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } if (n != 1) { params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, /*dot_lhs_row_major=*/lhs_row_major, /*dot_rhs_row_major=*/rhs_row_major, /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (has_addend) { - params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/lhs_row_major, - /*dot_rhs_row_major=*/rhs_row_major, - /*has_addend=*/has_addend, - /*addend_row_major=*/false}); - } } } } diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 9c94acb437..4d4b676a53 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -764,8 +764,10 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D<float, 3>(HloOpcode::kClamp); } -// TODO(b/73903144): Enable on interpreter once interpreter supports bitcast. -XLA_TEST_F(FusionTest, DISABLED_ON_INTERPRETER(FusionWithLayout)) { +// TODO(b/117156505): Remove this test when the bug is fixed and the CPU backend +// should not generate layout changing elementwise operations. +#ifdef XLA_TEST_BACKEND_CPU +XLA_TEST_F(FusionTest, LayoutChangingElementWiseOp) { const string hlo_text = R"( HloModule Cluster @@ -794,6 +796,7 @@ ENTRY main { LiteralUtil::CreateR3<float>({{{0.}, {0.76159415595}}, {{0.}, {0.}}}), result)); } +#endif class FusionClientLibraryTest : public ClientLibraryTestBase {}; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index bdd4fd7e3d..7ab2ecda58 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) { } // namespace HloTestBase::HloTestBase(bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) : HloTestBase(GetTestPlatform(), GetReferencePlatform(), verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier) {} + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func) {} HloTestBase::HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier) + bool allow_mixed_precision_in_hlo_verifier, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func) : test_runner_(test_platform), reference_runner_(reference_platform) { hlo_verifier_ = absl::make_unique<HloVerifier>( /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier); + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); } std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 0ae4bdc104..217428befa 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test { // interpreter is the only supported backend, it will be both the test backend // and the reference backend. HloTestBase(bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}); // If your test doesn't use interpreter as the reference backend, you can use // this constructor. Note that your test target is responsible for linking in // both needed backends. HloTestBase(se::Platform* test_platform, se::Platform* reference_platform, bool verifier_layout_sensitive = false, - bool allow_mixed_precision_in_hlo_verifier = true); + bool allow_mixed_precision_in_hlo_verifier = true, + std::function<bool(const HloInstruction*)> + instruction_can_change_layout_func = {}); ~HloTestBase() override {} diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 8f86c528d0..8bd0a729b7 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -21,64 +21,68 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/test.h" namespace xla { -HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, - bool allow_mixed_precision) - : HloTestBase( - /*verifier_layout_sensitive=*/layout_sensitive, - /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision) {} - -HloVerifiedTestBase::~HloVerifiedTestBase() { - // We can't call the ASSERT or EXPECT test macros in destructors, so we - // perform HLO verification in TearDown, and use the CHECK here to ensure - // users don't accidentally override the verification. - CHECK(tear_down_called_) - << "TearDown was never called; subclasses of HloVerifiedTestBase that " - << "override TearDown must call the superclass TearDown."; -} - -void HloVerifiedTestBase::TearDown() { - EXPECT_FALSE(tear_down_called_) - << "TearDown called more than once; it should be called exactly once."; - tear_down_called_ = true; - if (module_) { - VerifyModule(module_.get()); +Status VerifiedHloModule::Verify() { + if (computation_count() == 0) { + // The computation was never built. Nothing to verify. + return Status::OK(); } - for (int i = 0; i < modules_.size(); ++i) { - VerifyModule(modules_.at(i).get()); - } - HloTestBase::TearDown(); + return verifier_.Run(this).status(); } -void HloVerifiedTestBase::VerifyModule(HloModule* module) { - xla::StatusOr<bool> mutated = verifier().Run(module); - if (!mutated.ok()) { - ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); - } else { - EXPECT_FALSE(mutated.ValueOrDie()) - << "HloVerifier should never mutate the HloModule"; +void VerifiedHloModule::VerifyOrAddFailure(const string& message) { + Status status = Verify(); + if (!status.ok()) { + ADD_FAILURE() << "HloVerifier failed on module " << name() + << (message.empty() ? "" : absl::StrCat(" (", message, ")")) + << ": " << status; } } +HloVerifiedTestBase::HloVerifiedTestBase(bool layout_sensitive, + bool allow_mixed_precision) + : HloTestBase( + /*verifier_layout_sensitive=*/layout_sensitive, + /*allow_mixed_precision_in_hlo_verifier=*/allow_mixed_precision), + verifier_layout_sensitive_(layout_sensitive), + allow_mixed_precision_in_hlo_verifier_(allow_mixed_precision) {} + HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = HloTestBase::CreateNewModule(); + module_ = CreateNewVerifiedModule(TestName()); } return *module_; } HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { - modules_.emplace_back(HloTestBase::CreateNewModule()); + modules_.emplace_back(CreateNewVerifiedModule(name)); return modules_.back().get(); } void HloVerifiedTestBase::ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text, config)); - VerifyModule(module_.get()); + module_ = CreateNewVerifiedModule(TestName()); + TF_CHECK_OK(ParseHloString(hlo_text, module_.get())); + module_->VerifyOrAddFailure("after parsing"); } + +StatusOr<std::unique_ptr<VerifiedHloModule>> +HloVerifiedTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) { + auto module = CreateNewVerifiedModule(TestName()); + TF_RETURN_IF_ERROR(ParseHloString(hlo_text, module.get())); + TF_RETURN_IF_ERROR(module->Verify()); + return std::move(module); +} + +std::unique_ptr<VerifiedHloModule> HloVerifiedTestBase::CreateNewVerifiedModule( + const string& name) { + return absl::make_unique<VerifiedHloModule>( + name, GetModuleConfigForTest(), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index 8fbc4fa753..388a99bb36 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -20,53 +20,84 @@ limitations under the License. #include <memory> #include <utility> +#include "absl/base/macros.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" namespace xla { -// A base class for HLO tests that stores a default HloModule, and automatically -// performs verification on that module on tear-down. +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Verifies the module using HloVerifier and returns the status. + Status Verify(); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(const string& message); + + private: + HloVerifier verifier_; +}; + +// A base class for HLO tests that stores a default VerifiedHloModule. class HloVerifiedTestBase : public HloTestBase { protected: - explicit HloVerifiedTestBase(bool layout_sensitive = false, - bool allow_mixed_precision = false); - ~HloVerifiedTestBase() override; + HloVerifiedTestBase(bool layout_sensitive = false, + bool allow_mixed_precision = false); // Constructs a default shape verifier. std::unique_ptr<ShapeVerifier> MakeShapeVerifier(); - // Performs verification on the default HloModule returned by module(). - // Automatically called by the testing framework for each test. - // - // REQUIRED: subclasses that override TearDown() must call this explicitly. - void TearDown() override; - // Returns the default HloModule, lazily creating it if necessary via // HloTestBase::CreateNewModule(). + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule& module(); + + ABSL_DEPRECATED("Use ParseAndReturnVerifiedModule() instead.") void ParseAndVerifyModule(absl::string_view hlo_text, const HloModuleConfig& config = HloModuleConfig()); + // Parses the given string and returns module as a VerifiedHloModule. + StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule( + absl::string_view hlo_text, + const HloModuleConfig& config = HloModuleConfig()); + // Creates a new module for a test, and stores it in modules_ so it can be // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent // creation of unverified modules. + ABSL_DEPRECATED("Use CreateNewVerifiedModule() instead.") HloModule* CreateNewModule(const string& name = TestName()); - private: - void VerifyModule(HloModule* module); + // Creates and returns a verified HLO module with the given name. + std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule( + const string& name = TestName()); + private: // It is confusing to store modules created by module() and CreateNewModule() // in different fields, but it allows us to migrate tests to // HloVerifiedTestBase more easily, so it's a win because we can verify more // modules. See b/80488902. // // Lazily populated. Access via module(). - std::unique_ptr<HloModule> module_; + std::unique_ptr<VerifiedHloModule> module_; + // Populated by calls to CreateNewModule. - std::vector<std::unique_ptr<HloModule>> modules_; + std::vector<std::unique_ptr<VerifiedHloModule>> modules_; - bool tear_down_called_ = false; + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc new file mode 100644 index 0000000000..5c0263e811 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base_test.cc @@ -0,0 +1,158 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +// This class includes unit tests which are expected to fail because invalid HLO +// modules are intentionally built. Unfortunately, Tensorflow doesn't appear to +// include the necessary gunit parts to test this test machinery (needs the +// macro EXPECT_NONFATAL_FAILURE). The disabled tests can be run with the +// disabled tests enabled and failures can be manually compared against +// expectations. +class HloVerifiedTestBaseTest : public HloVerifiedTestBase {}; + +XLA_TEST_F(HloVerifiedTestBaseTest, NoModule) { + // Test shouldn't fail if no module is created at all. +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodLazilyCreatedModule) { + // Use module() to lazily create an empty module, build it up, and verify no + // failures. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadLazilyCreatedModule) { + // Use module() to lazily create an empty module and build up an invalid + // module. + HloModule& hlo_module = module(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + hlo_module.AddEntryComputation(builder.Build()); + + *hlo_module.entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, GoodCreateNewModule) { + // Call CreateNewModule and build up a valid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_BadCreateNewModule) { + // Call CreateNewModule and build up a invalid module. + HloModule* module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0))); + builder.AddInstruction( + HloInstruction::CreateUnary(input->shape(), HloOpcode::kNegate, input)); + module->AddEntryComputation(builder.Build()); + + *module->entry_computation()->root_instruction()->mutable_shape() = + ShapeUtil::MakeShape(PRED, {1, 2, 3}); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndVerifyModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndVerifyModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + ParseAndVerifyModule(hlo_string); + EXPECT_EQ(module().entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleGood) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_EQ(module->entry_computation()->instruction_count(), 3); +} + +XLA_TEST_F(HloVerifiedTestBaseTest, ParseAndReturnVerifiedModuleInvalidText) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleGood + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x,y) +} + +RANDOM GARBAGE +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +// This test is expected to fail. See test class comment. +XLA_TEST_F(HloVerifiedTestBaseTest, DISABLED_ParseAndReturnVerifiedModuleBad) { + const char* const hlo_string = R"( +HloModule ParseAndReturnVerifiedModuleBad + +ENTRY entry { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[1234] add(x,y) +} +)"; + + ASSERT_IS_NOT_OK(ParseAndReturnVerifiedModule(hlo_string).status()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc new file mode 100644 index 0000000000..c530591c6e --- /dev/null +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { +StatusOr<XlaComputation> BuildComputation() { + XlaBuilder b("computation"); + Shape scalar_s32 = ShapeUtil::MakeShape(S32, {}); + XlaOp infeed = InfeedWithToken(CreateToken(&b), scalar_s32); + return b.Build( + OutfeedWithToken(GetTupleElement(infeed, 0) + + ConstantLiteral(&b, LiteralUtil::CreateR0<int32>(1)), + GetTupleElement(infeed, 1), scalar_s32, "")); +} + +void CompileAndExecute( + LocalExecutable* executable, int device_ordinal, LocalClient* client, + absl::Mutex* results_mutex, + std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>>* results) { + xla::ExecutableRunOptions execute_options; + execute_options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + execute_options.set_device_ordinal(device_ordinal); + execute_options.set_allocator( + xla::ClientLibrary::GetXlaService(client->platform()) + ->backend() + .memory_allocator()); + StatusOr<ScopedShapedBuffer> result = executable->Run({}, execute_options); + { + absl::MutexLock lock(results_mutex); + results->emplace_back(device_ordinal, std::move(result)); + } +} + +void TestWithDeviceCount(const int device_count) { + // Run `device_count` copies of the XLA program built by BuildComputation. + TF_ASSERT_OK_AND_ASSIGN( + se::Platform* const platform, + perftools::gputools::MultiPlatformManager::PlatformWithName("Host")); + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + TF_ASSERT_OK_AND_ASSIGN( + LocalClient* const client, + xla::ClientLibrary::GetOrCreateLocalClient(client_options)); + + TF_ASSERT_OK_AND_ASSIGN(XlaComputation xla_computation, BuildComputation()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<LocalExecutable> executable, + client->Compile(xla_computation, {}, xla::ExecutableBuildOptions{})); + std::vector<tensorflow::Thread*> threads; + absl::Mutex results_mutex; + std::vector<std::pair<int, StatusOr<ScopedShapedBuffer>>> results; + tensorflow::Env* env = tensorflow::Env::Default(); + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + tensorflow::Thread* t = env->StartThread( + tensorflow::ThreadOptions{}, absl::StrCat("thread-", device_ordinal), + [&executable, device_ordinal, client, &results_mutex, &results] { + CompileAndExecute(executable.get(), device_ordinal, client, + &results_mutex, &results); + }); + threads.push_back(t); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(client->TransferToInfeedLocal( + LiteralUtil::CreateR0<int32>(device_ordinal * 100), device_ordinal)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, + client->TransferFromOutfeedLocal( + ShapeUtil::MakeShape(S32, {}), device_ordinal)); + EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1)); + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + delete threads[device_ordinal]; + } + + for (int device_ordinal = 0; device_ordinal < device_count; + device_ordinal++) { + TF_ASSERT_OK(results[device_ordinal].second.status()); + } +} + +// NB! This test requires --xla_force_host_platform_device_count=4 + +TEST(MultipleDeviceOnHostTest, OneDevice) { TestWithDeviceCount(1); } + +TEST(MultipleDeviceOnHostTest, TwoDevices) { TestWithDeviceCount(2); } + +TEST(MultipleDeviceOnHostTest, ThreeDevices) { TestWithDeviceCount(3); } + +TEST(MultipleDeviceOnHostTest, FourDevices) { TestWithDeviceCount(4); } +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 63491a90bf..22fe4a2670 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, /*padding=*/padding); CHECK(reducer == kAdd || reducer == kMax); @@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1303,11 +1308,19 @@ struct R1ReduceWindowTestData { /*pad_high=*/{0}, /*reducer=*/Reducer::kAdd}, + // The pattern generated by inclusive scan (cumsum/cumprod). {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, /*strides=*/{1}, /*pad_low=*/{4095}, /*pad_high=*/{0}, /*reducer=*/Reducer::kMax}, + + // The pattern generated by exclusive scan (cumsum/cumprod). + {/*base_bounds=*/{4096}, /*window_bounds=*/{4096}, + /*strides=*/{1}, + /*pad_low=*/{4096}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kMax}, }; string R1ReduceWindowTestDataToString( @@ -1361,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/padding); + /*window_strides=*/param.strides, + /*base_dilations=*/{}, + /*window_dilations=*/{}, + /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc index d20dba028a..b21dd56045 100644 --- a/tensorflow/compiler/xla/tests/scatter_test.cc +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -507,6 +507,36 @@ ENTRY main { RunTest(hlo_text, &operand, &scatter_indices, &updates); } +XLA_TEST_F(ScatterTest, OutOfBoundsUpdateWindow) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd_OobUpdateWindow + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[1,2] parameter(1) + updates = s32[1,2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + Literal operand = + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + Literal scatter_indices = LiteralUtil::CreateR2<int32>({{0, 2}}); + Literal updates = LiteralUtil::CreateR3<int32>({{{-10, 10}, {-40, 40}}}); + RunTest(hlo_text, &operand, &scatter_indices, &updates); +} + XLA_TEST_F(ScatterTest, OneScalarIndex) { const char* hlo_text = R"( HloModule OneScalarIndex diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index a40c2d7de6..2cc33ab096 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -412,6 +412,7 @@ INSTANTIATE_TEST_CASE_P( R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{7, 11}}, {{0, 1}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{1, 0}}}, // R2Spec{511, 513, {{129, 300}}, {{400, 500}}, {{11, 7}}, {{0, 1}}}, // + R2Spec{8672, 512, {{8, 0}}, {{8672, 512}}, {{542, 1}}, {{1, 0}}}, // R2Spec{ 511, 513, {{129, 300}}, {{400, 500}}, {{101, 129}}, {{1, 0}}}, // R2Spec{ diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index 181e5cbe29..bc433eac8f 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet<uint32> key_set; + absl::flat_hash_set<uint32> key_set; for (const float& value : key_arg.data<float>()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second); } @@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> ( ASSERT_EQ(args.size(), 2); const Literal& key_arg = args[0]; - tensorflow::gtl::FlatSet<int32> key_set; + absl::flat_hash_set<int32> key_set; for (const int32& value : key_arg.data<int32>()) { EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second); } diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 7abd8651d5..8b1b9e1519 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -763,9 +763,7 @@ TEST_F(WhileTest, TwoWhileLoopsAndSharedBody) { ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001)); } -// Test while nodes that share the while body computation. -// TODO(b/37245345): Fails on GPU backend. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileLoopsWithSharedBodyAndInit)) { +TEST_F(WhileTest, WhileLoopsWithSharedBodyAndInit) { std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})}; Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc index db5a824de0..a6e70eb6ca 100644 --- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <vector> #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -83,7 +83,7 @@ struct ParsedProfileOutputLine { Status ParseOneProfileOutputLine( const string& line, bool expect_hlo, - gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results, + absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results, absl::Span<const absl::string_view> opcodes_to_ignore = {}) { string separator = "[^:]*:: +"; string match_percentage = R"(\d+\.\d*% +\d+Σ)"; @@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) { std::vector<string> profile_output_lines = absl::StrSplit(profile_output, '\n'); - gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; + absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines; TF_ASSERT_OK(ParseOneProfileOutputLine( profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); @@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) { ASSERT_NE(while_body_profile_end, profile_output_lines.end()); - gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines; + absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines; for (auto while_body_profile_i = while_body_profile_start + 1; while_body_profile_i != while_body_profile_end; while_body_profile_i++) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index b53f89d63b..60d25a6407 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -200,6 +200,15 @@ message DebugOptions { // among different algorithms. bool xla_gpu_crash_on_verification_failures = 101; + // Force the host platform to pretend that there are these many host + // "devices". All these devices are backed by the same threadpool. Defaults + // to 1. + // + // Setting this to anything other than 1 can increase overhead from context + // switching but we let the user override this behavior to help run tests on + // the host that run models in parallel across multiple devices. + int32 xla_force_host_platform_device_count = 102; + // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. map<string, string> xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc index fda4c31298..40ec1b0ba9 100644 --- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { REGISTER_OP("XRTExecute") - .Attr("Ninputs: int") + .Attr("Ninputs: int >= 0") .Input("computation_handle: int64") .Input("execution_config: string") .Input("input_handles: Ninputs * int64") diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index 09ab4ed95f..b6dcfc4eb9 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -8,6 +8,10 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) cc_library( name = "raw_api_test_lib", @@ -57,7 +61,7 @@ tf_cuda_cc_test( size = "medium", srcs = [], args = ["--xla_test_device=XLA_GPU"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_gpu_device", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 2952feb16a..f590fbf0d9 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a, return equal; } +xla::XlaComputation OnePlusTwo() { + xla::XlaBuilder builder("OnePlusTwo"); + auto c0 = xla::ConstantR0(&builder, 1.0f); + auto c1 = xla::ConstantR0(&builder, 2.0f); + xla::Add(c0, c1); + return builder.Build().ValueOrDie(); +} + xla::XlaComputation AddAndScale() { xla::XlaBuilder builder("AddAndScale"); auto p0 = xla::Parameter(&builder, 0, @@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); } +TEST(RawApiTest, CompileAndExecuteZeroArg) { + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto e_config = + ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto result = ops::XRTExecute(root, c_handle, e_config, + std::initializer_list<Input>({})); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector<Tensor> outputs; + TF_EXPECT_OK(session.Run({read_back}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()())); + + auto expected = xla::LiteralUtil::CreateR0<float>(3.0f); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteReturnTuple) { xrt::XLAAllocation p0; p0.set_device_ordinal(0); |