diff options
Diffstat (limited to 'tensorflow/compiler')
68 files changed, 2977 insertions, 400 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 55b98da472..e059f77563 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -314,12 +314,16 @@ cc_library( "deadness_analysis_internal.h", "encapsulate_subgraphs_pass.cc", "mark_for_compilation_pass.cc", + "mark_for_compilation_pass_test_helper.cc", + "partially_decluster_pass.cc", ], hdrs = [ "build_xla_launch_ops_pass.h", "deadness_analysis.h", "encapsulate_subgraphs_pass.h", "mark_for_compilation_pass.h", + "mark_for_compilation_pass_test_helper.h", + "partially_decluster_pass.h", ], deps = [ ":common", @@ -354,6 +358,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", ], @@ -418,10 +423,12 @@ tf_cc_test( srcs = [ "encapsulate_subgraphs_pass_test.cc", "mark_for_compilation_pass_test.cc", + "partially_decluster_pass_test.cc", ], deps = [ ":common", ":compilation_passes", + ":xla_cluster_util", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc index 4d49a14b24..c37b6112cc 100644 --- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc +++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/partially_decluster_pass.h" #include "tensorflow/core/common_runtime/optimization_registry.h" namespace tensorflow { @@ -23,15 +24,18 @@ namespace tensorflow { REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10, MarkForCompilationPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, + PartiallyDeclusterPass); + // The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We // also need to run it after the graph been rewritten to have _Send nodes added // for fetches. Before the _Send nodes are added, fetch nodes are identified by // name, and encapsulation might remove that node from the graph. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, EncapsulateSubgraphsPass); // Must run after EncapsulateSubgraphsPass. -REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30, +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40, BuildXlaLaunchOpsPass); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 37a2f3b5ac..7f4370b5b0 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -210,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie()); + OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs( + ctx, kernel, run_result.ConsumeValueOrDie())); VLOG(1) << "Done"; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 45d422943c..90d5d56998 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave // such nodes out of XLA clusters. if (HasForwardedRefInput(node)) { + VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast."; return false; } @@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def, bool IsCompilableWhile(const Node& while_node, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Loop marking: " << while_node.type_string(); - const NameAttrList* name_attr; NodeDef call; Status status; status = GetNodeAttr(while_node.attrs(), "cond", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'cond' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'cond' attribute on While node."; return false; } const string cond_func = name_attr->name(); @@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(cond_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop condition: " << cond_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop condition: " << cond_func; return false; } status = GetNodeAttr(while_node.attrs(), "body", &name_attr); if (!status.ok()) { - VLOG(2) << "Missing 'body' attribute on While node."; + VLOG(2) << "Rejecting While " << while_node.name() + << ": missing 'body' attribute on While node."; return false; } const string body_func = name_attr->name(); @@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node, call.set_op(body_func); *call.mutable_attr() = name_attr->attr(); if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Can't compile loop body: " << body_func; + VLOG(2) << "Rejecting While " << while_node.name() + << ": can't compile loop body: " << body_func; return false; } - VLOG(2) << "Loop is compilable."; return true; } @@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node, bool IsCompilableCall(const NodeDef& call_def, const DeviceType& jit_device_type, int depth, FunctionLibraryRuntime* lib_runtime) { - VLOG(2) << "Function marking: " << call_def.op(); - if (depth > kMaxRecursionDepth) { - VLOG(2) << "Function depth limit exceeded"; + VLOG(2) << "Rejecting " << call_def.op() + << ": function depth limit exceeded."; return false; } @@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def, Status status = lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle); if (!status.ok()) { - VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status; + VLOG(2) << "Rejecting " << call_def.op() + << ": could not instantiate: " << status; return false; } const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle); @@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def, // tf2xla to translate the TF graph into XLA. So we avoid this for now. // // TODO(b/36139787): Create a mechanism to set inlining hints. - VLOG(2) << "Can't compile noinline function: " << fdef.DebugString(); + VLOG(2) << "Rejecting " << call_def.op() + << ": can't compile noinline function."; return false; } @@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def, if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, depth + 1, lib_runtime)) { - VLOG(2) << "Function marking failed: unsupported op " << node->name() - << ": " << node->def().ShortDebugString(); + VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op " + << node->name() << ": " << node->def().ShortDebugString(); return false; } } - VLOG(2) << "Function is compilable: " << call_def.op(); return true; } -// Tests whether `node` has a DT_RESOURCE typed input or output. -bool HasResourceInputOrOutput(const Node& node) { - return std::find(node.input_types().begin(), node.input_types().end(), - DT_RESOURCE) != node.input_types().end() || - std::find(node.output_types().begin(), node.output_types().end(), - DT_RESOURCE) != node.output_types().end(); -} - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // @@ -357,24 +351,27 @@ Status FindCompilationCandidates( } std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); + if (fuel >= std::numeric_limits<int64>::max() / 2) { + // The assumption is that if fuel started out as INT64_MAX, it will forever + // stay greater than INT64_MAX / 2. + VLOG(2) << "Starting fuel: infinity"; + } else { + VLOG(2) << "Starting fuel: " << fuel; + } + for (Node* node : sorted_nodes) { - VLOG(2) << "Fuel: " << fuel; if (fuel <= 0) { - VLOG(2) + VLOG(1) << "Hit fuel limit; not marking any remaining ops as clusterable."; break; } - VLOG(2) << "FindCompilationCandidates(): Processing " - << node->DebugString(); - DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceToDeviceType(node->assigned_device_name(), &device_type)); if (is_compilable_fn && !is_compilable_fn(node, device_type)) { - VLOG(2) << "Compilation rejected node: not compilable " << node->name() - << ": " << node->type_string(); + // is_compilable_fn has already logged the reason if it returned false. continue; } @@ -384,14 +381,14 @@ Status FindCompilationCandidates( DeviceType jit_device_type(registration->compilation_device_name); if (!HasXLAKernel(*node, jit_device_type) && !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) { - VLOG(2) << "Compilation rejected node: unsupported op " << node->name() - << ": " << node->type_string(); + VLOG(2) << "Rejecting " << node->name() << ": unsupported op " + << node->type_string(); continue; } if (!registration->compile_resource_ops && HasResourceInputOrOutput(*node)) { - VLOG(2) << "Compilation rejected node: resource input/output " - << node->name() << ": " << node->type_string(); + VLOG(2) << "Rejecting: " << node->name() << ": resource input/output " + << node->type_string(); continue; } if (node->type_string() == "While" && @@ -401,15 +398,11 @@ Status FindCompilationCandidates( // _Arg nodes in a top-level function represent feeds. // Do not compile them. if (node->type_string() == "_Arg") { - VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " - << node->DebugString(); continue; } // _Retval nodes in a top-level function represent fetches. // Do not compile them. if (node->type_string() == "_Retval") { - VLOG(2) << "Compilation rejected node: return value " << node->name() - << ": " << node->type_string(); continue; } candidates->insert(node); @@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run( const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { + VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device."; return false; } @@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run( // If there is a _XlaCompile annotation, use its value. bool compile = false; Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") is false."; + } + return compile; + } status = fld->GetAttr(*node, kXlaCompileAttr, &compile); - if (status.ok()) return compile; + if (status.ok()) { + if (!compile) { + VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" + << kXlaCompileAttr << ") on callee is false."; + } + return compile; + } // If inputs to `node` can have conflicting deadness (i.e. some are alive // and some are dead) then don't compile it. XLA cannot represent the // deadness semantics of these nodes correctly and auto-clustering these // nodes can cause deadness to propagate to nodes that should be live. if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) { + VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness."; return false; } // Check for fusable ops only if requested. if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) { + VLOG(2) << "Rejecting " << node->name() + << ": not fusable op but fusion_only enabled."; return false; } @@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run( // Ignore enable_jit_by_default if global jit compilation for CPU // is explicitly requested via tf_xla_cpu_global_jit flag bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; - return (ignore_registration || registration->enable_jit_by_default) && - global_jit_level > 0; + bool should_compile = + (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; + if (!should_compile) { + if (global_jit_level <= 0) { + VLOG(2) << "Rejecting " << node->name() << ": global jit disabled."; + } else { + VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled."; + } + } + return should_compile; }; return RunImpl(options, is_compilable); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index e9acbfb19e..f1137af3c1 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; - // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass - // unconditionally, call RunImpl() directly. - // is_compilable_fn, if set, is a predicate that must be true for a node to - // be compiled. + private: Status RunImpl(const GraphOptimizationPassOptions& options, const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn = {}); + + friend class MarkForCompilationPassTestHelper; }; // Returns true iff 'ndef' is a call to a function that is compilable. A // function is compilable iff every operator in the function body is // compilable. bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef); - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 2c5f4fb774..a780d4a936 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -39,27 +39,6 @@ namespace { REGISTER_OP("UncompilableNullary").Output("o: float"); REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); -Status MarkForCompilation(std::unique_ptr<Graph>* graph, - FunctionLibraryDefinition* flib_def) { - // Assign all nodes to the CPU device. - static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; - for (Node* n : (*graph)->nodes()) { - n->set_assigned_device_name(kCpuDevice); - } - - GraphOptimizationPassOptions opt_options; - opt_options.graph = graph; - opt_options.flib_def = flib_def; - MarkForCompilationPass pass; - return pass.RunImpl(opt_options); -} - -Status MarkForCompilation(std::unique_ptr<Graph>* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - return MarkForCompilation(graph, &flib_def); -} - std::unordered_map<string, string> GetClusters(const Graph& graph) { std::unordered_map<string, string> ids; for (Node* node : graph.nodes()) { @@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(4, clusters.size()); EXPECT_EQ(clusters["B"], clusters["C"]); @@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); @@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); } @@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_FALSE(clusters.empty()); } @@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(3, clusters.size()); // Everything should be compiled. } @@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { ops::UnaryOp("Shape", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_EXPECT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // Nothing should be compiled. In particular, 'd' and 'c' must not be @@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A + relu(A) @@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: D = relu(A) + (A @ relu(A)) @@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); // The computation is: C = A @ relu(A) @@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) { ops::UnaryOp("Relu", d, builder.opts().WithName("E")); TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(0, clusters.size()); // Nothing should be compiled. } @@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilation(&graph); + Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); EXPECT_TRUE(str_util::StrContains(status.ToString(), "Edge from c to a would create a cycle.\n" @@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) { TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_EQ(2, clusters.size()); @@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) { auto r = ops::_Retval(root.WithOpName("R"), c, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) { auto r = ops::_Retval(root.WithOpName("R"), b, 0); } TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); EXPECT_TRUE(clusters.empty()); @@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), 0.5f); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_EQ(1, GetClusters(*graph).size()); } @@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) { auto c = ops::Const(root.WithOpName("const"), string("string")); c.node()->AddAttr(kXlaCompileAttr, true); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); EXPECT_TRUE(GetClusters(*graph).empty()); } } @@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map<string, string> clusters = GetClusters(*graph); @@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map<string, string> clusters = GetClusters(*graph); @@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) { std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(root.ToGraph(graph.get())); - TF_ASSERT_OK(MarkForCompilation(&graph)); + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); std::unordered_map<string, string> clusters = GetClusters(*graph); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc new file mode 100644 index 0000000000..a84b82e479 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" + +namespace tensorflow { +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) { + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + opt_options.flib_def = flib_def; + MarkForCompilationPass pass; + return pass.RunImpl(opt_options); +} + +/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( + std::unique_ptr<Graph>* graph) { + FunctionDefLibrary flib; + FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); + return MarkForCompilation(graph, &flib_def); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h new file mode 100644 index 0000000000..b9a0531cb0 --- /dev/null +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ +#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ + +#include "tensorflow/compiler/jit/mark_for_compilation_pass.h" + +namespace tensorflow { +class MarkForCompilationPassTestHelper { + public: + // Runs the MarkForCompilation pass on `graph` after assigning all nodes in + // `graph` to the CPU device. To make testing easier, ignores device + // registration, _XlaCompile attributes, input deadness and global jit level. + static Status MarkForCompilation(std::unique_ptr<Graph>* graph, + FunctionLibraryDefinition* flib_def); + + // Like `MarkForCompilation` but creates `flib_def` from the op registry. + static Status MarkForCompilation(std::unique_ptr<Graph>* graph); +}; +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc new file mode 100644 index 0000000000..68ead39424 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -0,0 +1,177 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/framework/memory_types.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace tensorflow { +namespace { +Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result, + gtl::ArraySlice<Node*> post_order) { + // Find nodes that have at least one user outside their cluster that expects + // hostmem output. These nodes should be cloned to outside the cluster to + // avoid the device-host copy we'd otherwise need. + + MemoryTypeVector input_mtypes, output_mtypes; + + for (Node* n : post_order) { + gtl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n); + if (!from_cluster) { + continue; + } + + // We assume the only XLA-auto-clusterable operations with side effects are + // resource variable updates. We can't execute these twice. + if (HasResourceInputOrOutput(*n)) { + continue; + } + + DeviceType device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(n->assigned_device_name(), &device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + n->def(), &input_mtypes, + &output_mtypes)); + for (const Edge* e : n->out_edges()) { + Node* dst = e->dst(); + + if (e->IsControlEdge()) { + continue; + } + + bool edge_incurs_extra_device_to_host_copy; + if (output_mtypes[e->src_output()] == DEVICE_MEMORY) { + // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then + // keep the node clustered -- XLA will also produce the output in device + // memory and we will get some benefit from clustering. + edge_incurs_extra_device_to_host_copy = false; + } else { + MemoryTypeVector dst_input_mtypes, dst_output_mtypes; + DeviceType dst_device_type(""); + TF_RETURN_IF_ERROR( + DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type)); + TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type, + dst->def(), &dst_input_mtypes, + &dst_output_mtypes)); + edge_incurs_extra_device_to_host_copy = + dst_input_mtypes[e->dst_input()] == HOST_MEMORY; + } + + if (!edge_incurs_extra_device_to_host_copy) { + continue; + } + + // Check if `dst` is in a different cluster, unclustered, or about to be + // partially declustered (here we rely on the post-order traversal order). + // If yes, decluster `n` to avoid the device-to-host memcpy. + gtl::optional<StringPiece> dst_cluster = + result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst); + if (from_cluster != dst_cluster) { + CHECK(result->insert(n).second); + break; + } + } + } + return Status::OK(); +} + +Status PartiallyDeclusterNode(Graph* graph, Node* n) { + StringPiece cluster_name = *GetXlaClusterForNode(*n); + gtl::InlinedVector<const Edge*, 6> out_edges_to_clone; + for (const Edge* out_edge : n->out_edges()) { + if (out_edge->IsControlEdge()) { + continue; + } + + Node* dst = out_edge->dst(); + gtl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst); + if (dst_cluster_name != cluster_name) { + out_edges_to_clone.push_back(out_edge); + } + } + + CHECK(!out_edges_to_clone.empty()) << n->DebugString(); + + NodeDef ndef = n->def(); + ndef.set_name(strings::StrCat(n->name(), "/declustered")); + RemoveFromXlaCluster(&ndef); + Status s; + Node* cloned_node = graph->AddNode(ndef, &s); + cloned_node->set_assigned_device_name(n->assigned_device_name()); + TF_RETURN_IF_ERROR(s); + + for (const Edge* in_edge : n->in_edges()) { + graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node, + in_edge->dst_input()); + } + + for (const Edge* out_edge_to_clone : out_edges_to_clone) { + graph->AddEdge(cloned_node, out_edge_to_clone->src_output(), + out_edge_to_clone->dst(), out_edge_to_clone->dst_input()); + graph->RemoveEdge(out_edge_to_clone); + } + + return Status::OK(); +} +} // namespace + +Status PartiallyDeclusterPass::Run( + const GraphOptimizationPassOptions& options) { + // NB! In this pass we assume the only XLA-auto-clusterable operations that + // may have side effects are resource variable operations so we don't cluster + // those. The pass will have to be updated if this assumption becomes + // invalid. + + Graph* graph = options.graph->get(); + + // When deciding whether to decluster a particular node, we base our decision + // on if we've decided that some of its consumers have to be declustered too. + // Iterating the graph in post-order guarantees that consumers have been + // visited before producers. + std::vector<Node*> post_order; + GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(), + /*edge_filter=*/[](const Edge& edge) { + return !edge.src()->IsNextIteration(); + }); + + gtl::FlatSet<Node*> nodes_to_partially_decluster; + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + + if (VLOG_IS_ON(3)) { + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + VLOG(3) << n->DebugString(); + } + } + } + + for (Node* n : post_order) { + if (nodes_to_partially_decluster.count(n)) { + TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n)); + } + } + + nodes_to_partially_decluster.clear(); + TF_RETURN_IF_ERROR(FindNodesToDecluster( + **options.graph, &nodes_to_partially_decluster, post_order)); + CHECK(nodes_to_partially_decluster.empty()); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h new file mode 100644 index 0000000000..6949b5028e --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ +#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { + +// Clones nodes from within a cluster to outside the cluster if profitable. +// +// Today this only clones to avoid device-to-host copies, but in the future we +// may consider other reasons to clone. For instance, we convert this: +// +// ..... +// | +// v +// A_Clustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// to: +// +// ..... +// | | +// | +-------------+ +// | | +// v v +// A_Clustered A_Unclustered ====> C_Unclustered +// | +// v +// B_Clustered +// +// where the ===> arrow has a hostmem source and destination and would entail a +// device to host copy if the source and destination were not in the same XLA +// cluster. +class PartiallyDeclusterPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_ diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc new file mode 100644 index 0000000000..08a956e4c6 --- /dev/null +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -0,0 +1,284 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/partially_decluster_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +REGISTER_OP("FakeNullary").Output("out: float"); + +REGISTER_OP("FakeBinary") + .Input("host_in: float") + .Input("device_in: float") + .Output("host_out: float") + .Output("device_out: float"); + +REGISTER_OP("FakeResourceVar").Output("out: resource"); + +REGISTER_OP("FakeResourceUpdate") + .Input("in: resource") + .Output("out: resource") + .Output("something_else: float"); + +class FakeBinaryOp : public OpKernel { + public: + explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +class FakeResourceVarUpdateOp : public OpKernel { + public: + explicit FakeResourceVarUpdateOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { CHECK(false); } +}; + +REGISTER_KERNEL_BUILDER(Name("FakeBinary") + .Device(DEVICE_CPU) + .HostMemory("host_in") + .HostMemory("host_out"), + FakeBinaryOp); + +REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate") + .Device(DEVICE_CPU) + .HostMemory("something_else"), + FakeResourceVarUpdateOp); + +Status PartiallyDecluster(std::unique_ptr<Graph>* graph) { + FixupSourceAndSinkEdges(graph->get()); + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : (*graph)->nodes()) { + n->set_assigned_device_name(kCpuDevice); + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = graph; + PartiallyDeclusterPass pass; + return pass.Run(opt_options); +} + +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) { + return node; + } + } + return nullptr; +} + +bool GetInputsForNode(const Graph& graph, const string& node_name, + std::vector<Node*>* inputs) { + const Node* node = FindNodeByName(graph, node_name); + if (node == nullptr) { + return false; + } + for (const Edge* e : node->in_edges()) { + inputs->push_back(e->src()); + } + std::sort(inputs->begin(), inputs->end(), NodeComparatorName()); + return true; +} + +TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector<Node*> unclustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + std::vector<Node*> clustered_consumer_inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer", + &clustered_consumer_inputs)); + ASSERT_EQ(clustered_consumer_inputs.size(), 2); + EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DifferentClusters) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", clustered_producer, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector<Node*> inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer")); + // The first input is hostmem and the second input is devicemem. + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", input, clustered_producer, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector<Node*> inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* resource_var = ops::SourceOp("FakeResourceVar", + builder.opts().WithName("ResourceVar")); + Node* clustered_producer = + ops::UnaryOp("FakeResourceUpdate", resource_var, + builder.opts().WithName("ClusteredProducer")); + Node* consumer_in_different_cluster = + ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input, + builder.opts().WithName("ConsumerInDifferentCluster")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", input, {clustered_producer, 1}, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector<Node*> inputs; + ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs)); + ASSERT_EQ(inputs.size(), 2); + EXPECT_EQ(inputs[0]->name(), "ClusteredProducer"); + EXPECT_EQ(inputs[1]->name(), "Input"); +} + +TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) { + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* input = + ops::SourceOp("FakeNullary", builder.opts().WithName("Input")); + Node* clustered_producer_0 = + ops::BinaryOp("FakeBinary", input, input, + builder.opts().WithName("ClusteredProducer0")); + Node* clustered_producer_1 = + ops::BinaryOp("FakeBinary", clustered_producer_0, input, + builder.opts().WithName("ClusteredProducer1")); + ops::BinaryOp("FakeBinary", clustered_producer_1, input, + builder.opts().WithName("UnclusteredConsumer")); + Node* clustered_consumer = + ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input, + builder.opts().WithName("ClusteredConsumer")); + clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0"); + clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0"); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(PartiallyDecluster(&graph)); + std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs; + + ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer", + &unclustered_consumer_inputs)); + ASSERT_EQ(unclustered_consumer_inputs.size(), 2); + EXPECT_EQ(unclustered_consumer_inputs[0]->name(), + "ClusteredProducer1/declustered"); + EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input"); + + ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered", + &declustered_producer_1_inputs)); + ASSERT_EQ(declustered_producer_1_inputs.size(), 2); + EXPECT_EQ(declustered_producer_1_inputs[0]->name(), + "ClusteredProducer0/declustered"); + EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index a5628b12a2..0a025a1fc0 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } +gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) { + const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); + if (attr_value == nullptr) { + return gtl::nullopt; + } + Status s = AttrValueHasType(*attr_value, "string"); + if (!s.ok()) { + return gtl::nullopt; + } + return attr_value->s(); +} + +bool HasResourceInputOrOutput(const Node& node) { + return std::find(node.input_types().begin(), node.input_types().end(), + DT_RESOURCE) != node.input_types().end() || + std::find(node.output_types().begin(), node.output_types().end(), + DT_RESOURCE) != node.output_types().end(); +} + +void RemoveFromXlaCluster(NodeDef* node_def) { + node_def->mutable_attr()->erase(kXlaClusterAttr); +} } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index bcce082aaf..bff76da6f9 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node); // the enclosing graph. Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); +// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, +// otherwise returns nullopt. +gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node); + +// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). +void RemoveFromXlaCluster(NodeDef* node_def); + +// Returns true if `node` has a DT_RESOURCE typed input or output. +bool HasResourceInputOrOutput(const Node& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index f65f89ebf5..dd84fb34c1 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -78,7 +78,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, executable->Run(launch_context.arguments(), run_options); TF_RETURN_IF_ERROR(run_result.status()); - launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie()); + TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( + ctx, result, run_result.ConsumeValueOrDie())); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 4ddeaebd3e..2a2691a6a4 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -216,6 +217,8 @@ XlaDevice::XlaDevice( transfer_as_literal_(transfer_as_literal), shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name << " " << this; + thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device", + /*num_threads=*/1)); } XlaDevice::~XlaDevice() { @@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() { Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr<se::Stream>* stream, bool* stream_was_changed) { if (!(*stream) || !(*stream)->ok()) { - TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_)); + xla::StreamPool::Ptr ptr; + TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); + *stream = std::shared_ptr<se::Stream>(std::move(ptr)); VLOG(1) << "XlaDevice " << this << " new " << name << " " << (*stream)->DebugStreamPointers(); *stream_was_changed = true; @@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_, &need_new_device_context)); - se::Stream* host_to_device_stream = stream_.get(); - se::Stream* device_to_host_stream = stream_.get(); + std::shared_ptr<se::Stream> host_to_device_stream = stream_; + std::shared_ptr<se::Stream> device_to_host_stream = stream_; if (use_multiple_streams_) { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream", &host_to_device_stream_, @@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream", &device_to_host_stream_, &need_new_device_context)); - host_to_device_stream = host_to_device_stream_.get(); - device_to_host_stream = device_to_host_stream_.get(); + host_to_device_stream = host_to_device_stream_; + device_to_host_stream = device_to_host_stream_; } if (!need_new_device_context) { @@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() { if (device_context_) { device_context_->Unref(); } + // The XlaDeviceContext keeps a reference count to the streams, and the + // XlaDeviceContext remains live for the duration of a Executor run. This + // ensures that the streams remain live for the duration of a run, even if + // an error is encountered and the streams are replaced with new ones. device_context_ = new XlaDeviceContext( - stream_.get(), host_to_device_stream, device_to_host_stream, client(), - transfer_as_literal_, shape_representation_fn_); + stream_, host_to_device_stream, device_to_host_stream, client(), + transfer_as_literal_, shape_representation_fn_, thread_pool_.get()); VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext " << device_context_; @@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, op_kernel->ComputeAsync(context, done); } +Status XlaDevice::Sync() { + VLOG(1) << "XlaDevice::Sync"; + std::shared_ptr<se::Stream> stream; + { + mutex_lock lock(mu_); + stream = stream_; + } + if (!stream) return Status::OK(); + + if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) { + return errors::Internal("XlaDevice::Sync() failed."); + } + VLOG(1) << "XlaDevice::Sync completed"; + return Status::OK(); +} + Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d8906419b0..dbf35f349f 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/stream_pool.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/allocator.h" @@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override { return Status::OK(); } + Status Sync() override; Status FillContextMap(const Graph* graph, DeviceContextMap* device_context_map) override @@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice { Allocator* GetAllocatorLocked(AllocatorAttributes attr) EXCLUSIVE_LOCKS_REQUIRED(mu_); Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, - xla::StreamPool::Ptr* stream, + std::shared_ptr<se::Stream>* stream, bool* stream_was_changed) EXCLUSIVE_LOCKS_REQUIRED(mu_); xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked() @@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice { // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and // computations enqueued by XLA. - xla::StreamPool::Ptr stream_ GUARDED_BY(mu_); + std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_); // If false, only stream_ is valid and all computation and transfers use // stream_. If true, computation is performed by stream_ and transfers are // performed by host_to_device/device_to_host_stream. const bool use_multiple_streams_; // If use_multiple_streams_, host to device transfers are performed using this // stream. - xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_); + std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_); // If use_multiple_streams_, device to host transfers are performed using this // stream. - xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_); + std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_); // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. const bool transfer_as_literal_; @@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice { // Holds extra information for GPU and TPU devices, e.g. the device context. bool use_gpu_device_info_ GUARDED_BY(mu_) = false; std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_); + + // Thread pool used for running closures + std::unique_ptr<thread::ThreadPool> thread_pool_; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 0100bf51ed..0a0c089241 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" +#include <memory> + +#include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } XlaTransferManager::XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : stream_(compute_stream), - host_to_device_stream_(host_to_device_stream), - device_to_host_stream_(device_to_host_stream), + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : stream_(std::move(compute_stream)), + host_to_device_stream_(std::move(host_to_device_stream)), + device_to_host_stream_(std::move(device_to_host_stream)), client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) { + shape_representation_fn_(std::move(shape_representation_fn)), + thread_pool_(thread_pool) { CHECK(host_to_device_stream_ != nullptr); CHECK(device_to_host_stream_ != nullptr); CHECK(stream_ != nullptr); @@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice( if (UseMultipleStreams()) { // Initially wait for the compute stream so that memory allocations are // synchronized. - host_to_device_stream_->ThenWaitFor(stream_); + host_to_device_stream_->ThenWaitFor(stream_.get()); } TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync( - host_to_device_stream_, *literal, shaped_buffer)); + host_to_device_stream_.get(), *literal, shaped_buffer)); if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - TF_RET_CHECK(event.Init()) << "Event failed to initialize!"; - host_to_device_stream_->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event)); + auto event = std::make_shared<se::Event>(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize!"; + host_to_device_stream_->ThenRecordEvent(event.get()); + xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event)); } // Unref the host tensor, and capture the literal shared_ptr too so it goes // out of scope when the lambda completes. @@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( - device_to_host_stream_, shaped_buffer, literal, + device_to_host_stream_.get(), shaped_buffer, literal, [=, &shaped_buffer, &literal](xla::Status status) { ref.Unref(); done([&]() -> Status { @@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); if (status.ok()) { xla_tensor->set_host_tensor(*cpu_tensor); - host_to_device_stream_->ThenDoHostCallback( - [done]() { done(Status::OK()); }); + host_to_device_stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback + // to avoid a deadlock. If done() is the callback that ends an + // Executor's run, the Executor may call XlaDevice::Sync() inside the + // callback. This deadlocks, because XlaDevice::Sync() waits for all + // stream activity to complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); return; } } else { @@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, if (!block_status.ok()) { status = xla::InternalError( "Failed to complete data transfer on stream %p: %s", - host_to_device_stream_, block_status.error_message().c_str()); + host_to_device_stream_.get(), block_status.error_message().c_str()); } } xla_tensor->set_host_tensor(*cpu_tensor); @@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); if (se::Event* event = - xla_tensor->GetDefinitionEvent(device_to_host_stream_)) { + xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) { device_to_host_stream_->ThenWaitFor(event); - xla_tensor->SetDefinedOn(device_to_host_stream_); + xla_tensor->SetDefinedOn(device_to_host_stream_.get()); } Status status; @@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, Status block_status = device_to_host_stream_->BlockHostUntilDone(); if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p: %s", stream_, + "Failed to complete data transfer on stream %p: %s", stream_.get(), block_status.error_message().c_str()); } } @@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, if (stream_ != device_to_device_stream) { // Initially wait for the compute stream so that memory allocations are // synchronized. - device_to_device_stream->ThenWaitFor(stream_); + device_to_device_stream->ThenWaitFor(stream_.get()); } } if (se::Event* event = - xla_src->GetDefinitionEvent(device_to_device_stream)) { + xla_src->GetDefinitionEvent(device_to_device_stream.get())) { device_to_device_stream->ThenWaitFor(event); - xla_src->SetDefinedOn(device_to_device_stream); + xla_src->SetDefinedOn(device_to_device_stream.get()); } auto from_iter = xla_src->shaped_buffer().buffers().begin(); @@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, } if (UseMultipleStreams()) { - se::Event event(stream_->parent()); - CHECK(event.Init()); - device_to_device_stream->ThenRecordEvent(&event); - xla_dst->SetDefinedOn(device_to_device_stream, std::move(event)); + auto event = std::make_shared<se::Event>(stream_->parent()); + TF_RET_CHECK(event->Init()) << "Event failed to initialize"; + device_to_device_stream->ThenRecordEvent(event.get()); + xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event)); } return Status::OK(); }(); if (!status.ok()) { return done(status); } else { - stream_->ThenDoHostCallback([=]() { done(Status::OK()); }); + stream_->ThenDoHostCallback([this, done]() { + // We must not call the done closure directly from DoHostCallback to avoid + // a deadlock. If done() is the callback that ends an Executor's run, the + // Executor may call XlaDevice::Sync() inside the callback. This + // deadlocks, because XlaDevice::Sync() waits for all stream activity to + // complete. + thread_pool_->Schedule([done]() { done(Status::OK()); }); + }); } } XlaDeviceContext::XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(compute_stream, host_to_device_stream, device_to_host_stream, - client, transfer_as_literal, - std::move(shape_representation_fn)) {} + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool) + : manager_(std::move(compute_stream), std::move(host_to_device_stream), + std::move(device_to_host_stream), client, transfer_as_literal, + std::move(shape_representation_fn), thread_pool) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 912f8d779e..2e7445340c 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator { class XlaTransferManager { public: explicit XlaTransferManager( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -61,7 +63,7 @@ class XlaTransferManager { void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); - se::Stream* stream() const { return stream_; } + se::Stream* stream() const { return stream_.get(); } private: Status TransferLiteralToDevice(const Tensor& host_tensor, @@ -73,13 +75,13 @@ class XlaTransferManager { // The main compute stream of the device, used to synchronize the transfer // streams if they are set. - se::Stream* stream_; + std::shared_ptr<se::Stream> stream_; // The stream to use for transferring data from host to device. Can be // idential to stream_, but must not be nullptr. - se::Stream* host_to_device_stream_; + std::shared_ptr<se::Stream> host_to_device_stream_; // The stream to use for transferring data from device to host. Can be // idential to stream_, but must not be nullptr. - se::Stream* device_to_host_stream_; + std::shared_ptr<se::Stream> device_to_host_stream_; // For the underlying memory allocator and XLA's TransferManager. xla::LocalClient* client_; // Transfer manager, for marshalling data to and from the device. @@ -87,6 +89,9 @@ class XlaTransferManager { // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + + // Thread pool used for running closures + thread::ThreadPool* thread_pool_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -95,10 +100,12 @@ class XlaTransferManager { class XlaDeviceContext : public DeviceContext { public: explicit XlaDeviceContext( - se::Stream* compute_stream, se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, xla::LocalClient* client, - bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + std::shared_ptr<se::Stream> compute_stream, + std::shared_ptr<se::Stream> host_to_device_stream, + std::shared_ptr<se::Stream> device_to_host_stream, + xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + thread::ThreadPool* thread_pool); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6134b8c694..4efbb2d5d7 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_launch_util.h" +#include <memory> + #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs( } } -void XlaComputationLaunchContext::PopulateOutputs( +Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel, ScopedShapedBuffer output) { se::Stream* stream = @@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs( output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator()); } + std::shared_ptr<se::Event> definition_event; + if (use_multiple_streams_) { + definition_event = std::make_shared<se::Event>(stream->parent()); + if (!definition_event->Init()) { + return errors::Internal("Failed to initialize tensor definition event."); + } + stream->ThenRecordEvent(definition_event.get()); + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { @@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs( // reallocate the device buffer later. VLOG(1) << "Constant output tensor on device"; - OP_REQUIRES_OK( - ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); Device* device = dynamic_cast<Device*>(ctx->device()); - OP_REQUIRES(ctx, device != nullptr, - errors::Internal("DeviceBase was not a Device.")); + if (device == nullptr) { + return errors::Internal("DeviceBase was not a Device."); + } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, [&](Status status) { TF_CHECK_OK(status); }); @@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs( se::DeviceMemoryBase buffer = output.buffer({output_num}); if (allocate_xla_tensors_) { Tensor* output_tensor; - OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor)); + TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor); if (xla_tensor) { xla_tensor->set_shaped_buffer(ScopedShapedBuffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_))); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } } else { // xla_tensor wasn't valid, which must mean this is a zero-element @@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs( for (int i = 0; i < kernel->resource_updates.size(); ++i) { Allocator* allocator = ctx->device()->GetAllocator({}); const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; - OP_REQUIRES(ctx, - write.input_index >= 0 && write.input_index < ctx->num_inputs(), - errors::Internal("Invalid input index for variable write.")); + if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) { + return errors::Internal("Invalid input index for variable write."); + } se::DeviceMemoryBase buffer = output.buffer({output_num}); Var* variable = nullptr; // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. - OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>( - ctx, HandleFromInput(ctx, write.input_index), - &variable, [this, ctx, &write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>( + ctx, HandleFromInput(ctx, write.input_index), &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); core::ScopedUnref s(variable); mutex_lock ml(*variable->mu()); - OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type, - errors::Internal("Mismatched type in variable write")); + if (variable->tensor()->dtype() != write.type) { + return errors::Internal("Mismatched type in variable write"); + } if (allocate_xla_tensors_) { Tensor output_tensor; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor)); + TF_RETURN_IF_ERROR( + ctx->allocate_temp(write.type, write.shape, &output_tensor)); XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor); CHECK(xla_tensor); xla_tensor->set_shaped_buffer( ExtractSubShapedBuffer(&output, output_num, xla_allocator_)); if (use_multiple_streams_) { - se::Event event(stream->parent()); - CHECK(event.Init()); - stream->ThenRecordEvent(&event); - xla_tensor->SetDefinedOn(stream, std::move(event)); + xla_tensor->SetDefinedOn(stream, definition_event); } *variable->tensor() = output_tensor; } else { @@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs( } ++output_num; } + return Status::OK(); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 1ea3fa4cf2..4232f514b3 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -93,9 +93,9 @@ class XlaComputationLaunchContext { const std::map<int, OptionalTensor>& variables); // Given the XLA output in `output`, populate all outputs of `ctx`. - void PopulateOutputs(OpKernelContext* ctx, - const XlaCompiler::CompilationResult* kernel, - xla::ScopedShapedBuffer output); + Status PopulateOutputs(OpKernelContext* ctx, + const XlaCompiler::CompilationResult* kernel, + xla::ScopedShapedBuffer output); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index d777dfa5a3..92ba7de1b7 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { mutex_lock lock(mu_); - if (!definition_event_.has_value()) { + if (!definition_event_) { return nullptr; } @@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) { return nullptr; } - return &*definition_event_; + return definition_event_.get(); } -void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) { +void XlaTensor::SetDefinedOn(se::Stream* stream, + std::shared_ptr<se::Event> event) { mutex_lock lock(mu_); definition_event_ = std::move(event); streams_defined_on_ = {stream}; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index f7e401c731..8d36d0fa0a 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ #define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_ +#include <memory> + #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/core/framework/allocator.h" @@ -94,7 +96,7 @@ class XlaTensor { // Assert that the tensor's content is defined on 'stream' by the time 'event' // triggers. - void SetDefinedOn(se::Stream* stream, se::Event event); + void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event); // Assert that the tensor's content is defined on 'stream'. This version does // not provide an event, and must be called *after* SetDefinedOn(Stream, @@ -116,7 +118,7 @@ class XlaTensor { // An optional event that is triggered when the tensor's content has been // defined. If this event is nullptr, it is assumed that the tensor's content // is always defined. - gtl::optional<se::Event> definition_event_; + std::shared_ptr<se::Event> definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f42fb92359..1bf8948ef6 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); flags->set_xla_llvm_enable_invariant_load_metadata(true); @@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { // the heuristics needed to decide when to run on multiple streams. See // b/77879207. flags->set_xla_gpu_disable_multi_streaming(true); + + // TODO(jlebar): Disable fastmath once doing so is not a performance + // regression. + flags->set_xla_cpu_enable_fast_math(true); + flags->set_xla_gpu_enable_fast_math(true); } // Allocates flag_values and flag_objects; this function must not be called more @@ -150,10 +154,16 @@ void AllocateFlags() { flag_values->mutable_xla_generate_hlo_text_to(), "Dump all HLO modules as text into the provided directory path."), tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " + "xla_cpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the CPU compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_gpu_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math), + flag_values->xla_cpu_enable_fast_math(), + "Enable unsafe fast-math optimizations in the GPU compiler; " "this may produce faster code at the expense of some accuracy."), tensorflow::Flag( "xla_llvm_enable_alias_scope_metadata", diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7d315fa0d3..7331d2b54c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1234,6 +1234,20 @@ cc_library( ], ) +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + ":hlo", + ":hlo_creation_utils", + ":hlo_pass", + ":while_util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 37834e1cc2..f7812d9661 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { reshape, HloInstruction::CreateReshape(reshape->shape(), operand->mutable_operand(0))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = reshape->shape(); + return ReplaceInstruction(reshape, operand); + } if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) { auto opt_dims = ReshapeLeavesDimensionsUnmodified( @@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { transpose->dimensions()))); } + if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) { + *operand->mutable_shape() = transpose->shape(); + return ReplaceInstruction(transpose, operand); + } + if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) { ReplaceWithBitcast(transpose); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 862cbeeba6..5837391d75 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } +// Test transforming reshapes and transposes of rng. +TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) { + HloComputation::Builder builder(TestName()); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f))); + HloInstruction* one = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f))); + HloInstruction* rng0 = builder.AddInstruction( + HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}), + RandomDistribution::RNG_UNIFORM, {zero, one})); + + HloInstruction* transpose = builder.AddInstruction( + HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0})); + Shape reshape_shape = builder + .AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(F32, {4}), transpose)) + ->shape(); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + // Verify that that reshape(transpose(rng)) is replace by a single rng of the + // same shape as the reshape. + EXPECT_THAT(computation->root_instruction(), op::Rng()); + EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(), + reshape_shape)); +} + // Test transforming reshapes to bitcasts under various conditions. TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 118a11c8de..cfd26fc778 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType( case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a23427f00c..985ff30e80 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kFusion: return CallContext::kParallel; diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 36fb9b43aa..3e39c1bab1 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// We add copies for all the indices of the true and false computaiton roots, +// We add copies for all the indices of the true and false computation roots, // in order to resolve interference. We later rely on the CopyRemover to drop // the unnecessary ones. Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, @@ -648,7 +648,12 @@ class CopyRemover { // We can only perform copy elision if the resulting merged values have // totally ordered live ranges; otherwise the merged buffer would have // live range interference. - if (IsHead(*dest)) { + if (src->next == dest) { + // In the process of eliding copies, its possible for a copy to have the + // same source and destination buffer. In this case, the copy can be + // safely removed. + VLOG(2) << copy->name() << " source and destination buffers are same."; + } else if (IsHead(*dest)) { // The copy copies an arbitrary value in the source buffer (call it s_x) // and defines d_0, the first value in the destination buffer. After // merging, the values in the combined buffer must be strictly ordered diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cd735256b8..892d0d7b54 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2007,5 +2007,46 @@ ENTRY TestComputation { InsertCopies(module.get()); } +TEST_F(CopyInsertionTest, NestedWhiles) { + // Verify that only no unnecessary copies remain after copy insertion for + // trivial nested whiles (b/112472605). + const string& hlo_string = R"( +HloModule TestModule + +cond.inner { + ROOT param.cond.inner = pred[] parameter(0) +} + +body.inner { + param.body.inner = pred[] parameter(0) + ROOT neg = pred[] negate(param.body.inner) +} + +cond.outer { + ROOT param.cond.outer = pred[] parameter(0) +} + +body.outer { + param.cond.outer = pred[] parameter(0) + ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner +} + +ENTRY TestComputation { + entry_param = pred[] parameter(0) + ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<HloModule> module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + InsertCopies(module.get()); + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 1); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::While(op::Copy(op::Parameter()))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 3efe3e2f93..84779c60b0 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") load( "//third_party/mkl:build_defs.bzl", - "if_mkl", + "mkl_deps", ) # Filegroup used to collect source files for dependency checking. @@ -86,6 +86,7 @@ cc_library( ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/tf2xla:cpu_function_runtime", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", @@ -497,10 +498,7 @@ cc_library( "//tensorflow/core:framework_lite", "//tensorflow/core/kernels:eigen_helpers", "//third_party/eigen3", - ] + if_mkl([ - "@mkl_dnn", - "//third_party/mkl:intel_binary_blob", - ]), + ] + mkl_deps(), ) cc_library( @@ -554,10 +552,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", "//third_party/eigen3", - ] + if_mkl([ - "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ]), + ] + mkl_deps(), ) cc_library( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 2df959c4dc..35154af048 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -88,6 +88,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -299,6 +300,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false); pipeline.AddPass<CpuInstructionFusion>(); + pipeline.AddPass<ScatterExpander>(); + ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -356,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions( llvm::TargetOptions target_options; llvm_ir::SetTargetOptions( /*fast_math_enabled=*/module_config.debug_options() - .xla_enable_fast_math(), + .xla_cpu_enable_fast_math(), &target_options); return target_options; } @@ -523,7 +526,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_hook, post_optimization_ir_hook); llvm_module->setDataLayout(jit->data_layout()); @@ -653,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, // so we bail if the configs have conflicting flags. At the moment, the only // flag that needs to be consistent is fast-math. const bool fast_math_enabled = - modules[0]->config().debug_options().xla_enable_fast_math(); + modules[0]->config().debug_options().xla_cpu_enable_fast_math(); for (const auto& module : modules) { - if (module->config().debug_options().xla_enable_fast_math() != + if (module->config().debug_options().xla_cpu_enable_fast_math() != fast_math_enabled) { return InvalidArgument( "All HLO module configs must have the same value for " @@ -832,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules, CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_enable_fast_math(), + module->config().debug_options().xla_cpu_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); std::unique_ptr<llvm::MemoryBuffer> object_file = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 946f5124b8..c376864c3e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -249,24 +249,11 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, HloExecutionProfile* hlo_execution_profile) { - if (GetRootPointsToSet().IsAmbiguous()) { - return Unimplemented("Points-to set of root instruction is ambiguous"); - } - - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - - std::vector<OwningDeviceMemory> owning_buffers; - std::vector<se::DeviceMemoryBase> unowning_buffers; TF_ASSIGN_OR_RETURN( - std::tie(unowning_buffers, owning_buffers), - CreateTempArray(memory_allocator, stream->parent()->device_ordinal(), - arguments)); - - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), unowning_buffers, hlo_execution_profile)); - - return CreateResultShapedBuffer(run_options, &owning_buffers); + auto result, + ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile)); + TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone()); + return std::move(result); } StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( @@ -277,6 +264,16 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( "Asynchronous execution on stream with hlo profiling is not yet " "supported on CPU."); } + return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr); +} + +StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile) { + if (GetRootPointsToSet().IsAmbiguous()) { + return Unimplemented("Points-to set of root instruction is ambiguous"); + } auto* host_stream = dynamic_cast<se::host::HostStream*>( run_options->stream()->implementation()); @@ -310,19 +307,20 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream( ServiceExecutableRunOptions run_options; std::vector<se::DeviceMemoryBase> unowning_buffers; std::shared_ptr<std::vector<OwningDeviceMemory>> buffers; + HloExecutionProfile* hlo_execution_profile; void operator()() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(executable->ExecuteComputeFunction( - &run_options.run_options(), unowning_buffers, - /*hlo_execution_profile=*/nullptr)); + &run_options.run_options(), unowning_buffers, hlo_execution_profile)); } }; host_stream->EnqueueTask( AsyncRunTask{this, *run_options, std::move(unowning_buffers), std::make_shared<std::vector<OwningDeviceMemory>>( - std::move(owning_buffers))}); + std::move(owning_buffers)), + hlo_execution_profile}); return std::move(result); } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 8af8a5dfec..96e53de57e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -85,6 +85,16 @@ class CpuExecutable : public Executable { const BufferAssignment& buffer_assignment() const { return *assignment_; } private: + // This is for sharing the code between ExecuteOnStream and + // ExecuteAsyncOnStream. + // + // Notice that it's tricky to use correctly, as the profile object (when it + // exists) must out-live the task. + StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments, + HloExecutionProfile* hlo_execution_profile); + // Creates an array suitable for passing as the "temps" argument to the JIT // compiled function pointer. // diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 645888de78..f2ac742b6e 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( << config.GetCacheKey(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); @@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); const bool enable_fast_math = - hlo_module_config_.debug_options().xla_enable_fast_math(); + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(); const bool optimize_for_size = options::OptimizeForSizeRequested(hlo_module_config_); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 09909b62ba..6f433b4f30 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -99,7 +99,7 @@ IrEmitter::IrEmitter( target_machine_features_(*target_machine_features) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() - .xla_enable_fast_math())); + .xla_cpu_enable_fast_math())); } StatusOr<llvm::Function*> IrEmitter::EmitComputation( @@ -158,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; // Create and initialize new IrFunction. - compute_function_.reset( - new IrFunction(function_name, linkage, - options::OptimizeForSizeRequested(hlo_module_config_), - hlo_module_config_.debug_options().xla_enable_fast_math(), - module_, &b_, num_dynamic_loop_bounds_)); + compute_function_.reset(new IrFunction( + function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_, + &b_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -577,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{reduce_window->operand(0)}, - /*supported_types=*/{F32, BF16, S32})); + /*supported_types=*/{F32, BF16, S32, F16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(reduce_window->window())) { diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc index d938f3a2c4..48e4471499 100644 --- a/tensorflow/compiler/xla/service/despecializer.cc +++ b/tensorflow/compiler/xla/service/despecializer.cc @@ -21,8 +21,33 @@ limitations under the License. namespace xla { +namespace { + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloPassInterface { + public: + ControlDepRemover() = default; + tensorflow::StringPiece name() const override { + return "control-dep-remover"; + } + + StatusOr<bool> Run(HloModule* module) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed = changed || !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace + Despecializer::Despecializer() : pipeline_("despecializer") { // TODO(b/70588125): Also deal with window reversal in a fast way. + pipeline_.AddPass<ControlDepRemover>(); pipeline_.AddPass<Defuser>(); pipeline_.AddPass<ImplicitBroadcastRemover>(); pipeline_.AddPass<BFloat16MixedPrecisionRemoval>(); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a3f6e8d989..19575c7905 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,7 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") licenses(["notice"]) # Apache 2.0 @@ -365,6 +366,7 @@ cc_library( ":gpu_executable", ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", @@ -652,6 +654,7 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", @@ -852,3 +855,35 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "buffer_comparator", + srcs = ["buffer_comparator.cc"], + hdrs = ["buffer_comparator.h"], + deps = [ + ":gpu_executable", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +xla_test( + name = "buffer_comparator_test", + srcs = ["buffer_comparator_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":buffer_comparator", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc new file mode 100644 index 0000000000..6a285a6b98 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include <cmath> +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace xla { +namespace gpu { + +static constexpr float kTolerance = 0.1f; + +static string GetCompHloText(size_t num_elements) { + // Implements the textual format of the comparison routine, as it's more + // readable. + static constexpr char kF16CompHloText[] = R"( +HloModule CompareF16 + +MaxF32 { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %max = f32[] maximum(%lhs, %rhs) +} + +Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] { + %min_constant = f32[] constant(-65505) + %max_constant = f32[] constant(65505) + %large_constant = f32[] constant(1048576) + %min_values = f32[SIZE] broadcast(%min_constant), dimensions={} + %max_values = f32[SIZE] broadcast(%max_constant), dimensions={} + %large_values = f32[SIZE] broadcast(%large_constant), dimensions={} + + %a = f16[SIZE] parameter(0) + %converted = f32[SIZE] convert(%a) + %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values) + + // Since the clamp() above already took care of infs, only NaNs will cause + // is-finite() to return false. + %is_finite = pred[SIZE] is-finite(%clamped) + ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values) +} + +ENTRY MaxDifference { + %one_constant = f32[] constant(1.0) + %zero_constant = f32[] constant(0.0) + + %ones = f32[SIZE] broadcast(%one_constant), dimensions={} + + %lhs = f16[SIZE] parameter(0) + %rhs = f16[SIZE] parameter(1) + %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize + %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize + %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical) + %sub_abs = f32[SIZE] abs(%sub) + %lhs_abs = f32[SIZE] abs(%lhs_canonical) + %rhs_abs = f32[SIZE] abs(%rhs_canonical) + %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs) + %denominator = f32[SIZE] add(%max, %ones) + %error = f32[SIZE] divide(%sub_abs, %denominator) + ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32 +})"; + auto size_string = std::to_string(num_elements); + return tensorflow::str_util::StringReplace( + kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true); +} + +StatusOr<F16BufferComparator> F16BufferComparator::Create( + se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream) { + auto stream_exec = stream->parent(); + int64 num_elements = ref_buffer.ElementCount(); + + // One may consider using hlo_runner to do all the compilation and execution. + // However, as of the time hlo_runner doesn't support injection for Compiler*, + // Stream*, or even the allocator. We may revisit this in the future if it + // proves to be a maintenance burden. + TF_ASSIGN_OR_RETURN( + auto exec, ([&]() -> StatusOr<std::unique_ptr<Executable>> { + HloModuleConfig config; + DebugOptions debug_options; + debug_options.set_xla_backend_optimization_level(2); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN( + auto module, ParseHloString(GetCompHloText(num_elements), config)); + TF_ASSIGN_OR_RETURN( + module, + compiler->RunHloPasses(std::move(module), stream_exec, nullptr)); + return compiler->RunBackend(std::move(module), stream_exec, nullptr); + }())); + + TF_ASSIGN_OR_RETURN( + auto shaped_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> { + auto device_ordinal = stream_exec->device_ordinal(); + TF_ASSIGN_OR_RETURN( + auto owning_buffer, + allocator->Allocate(device_ordinal, ref_buffer.size())); + se::DeviceMemory<Eigen::half> buffer( + owning_buffer.AsDeviceMemoryBase()); + stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size()); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal); + ret.set_buffer(std::move(owning_buffer), {}); + return std::move(ret); + }())); + + return F16BufferComparator(stream, allocator, std::move(exec), + std::move(shaped_buffer)); +} + +StatusOr<bool> F16BufferComparator::CompareEqualImpl( + se::DeviceMemory<Eigen::half> test_buffer) { + if (ref_buffer_.root_buffer().size() != test_buffer.size()) { + return InternalError("Mismatched buffer size: %lld vs %lld", + ref_buffer_.root_buffer().size(), test_buffer.size()); + } + + int64 num_elements = test_buffer.ElementCount(); + + TF_ASSIGN_OR_RETURN( + auto result_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> { + auto stream_exec = stream_->parent(); + Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements}); + auto device_ordinal = stream_exec->device_ordinal(); + ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(), + device_ordinal); + shaped_test_buffer.set_buffer(test_buffer, {}); + ExecutableRunOptions run_options; + run_options.set_device_ordinal(stream_exec->device_ordinal()); + run_options.set_stream(stream_); + run_options.set_allocator(allocator_); + ServiceExecutableRunOptions service_run_options(run_options); + return exec_->ExecuteOnStream( + &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr); + }())); + + float result; + CHECK(result_buffer.root_buffer().size() == sizeof(result)); + stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result)); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + return result < kTolerance; +} + +StatusOr<bool> F16BufferComparator::CompareEqual( + se::DeviceMemory<Eigen::half> test_buffer) { + TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer)); + if (result) { + return true; + } + // Host side code that does the same thing, but report some of the + // differences as well. + int64 n = test_buffer.ElementCount(); + std::vector<half> host_ref_buffer(n), host_test_buffer(n); + stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(), + ref_buffer_.root_buffer().size()); + stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size()); + TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone()); + + const auto canonicalize = [](float a) -> float { + constexpr float kBigNumer = 1048576.; + constexpr float kMaxFp16Value = 65504.; + if (std::isnan(a)) { + return kBigNumer; + } + if (std::isinf(a)) { + if (a < 0) { + return -(kMaxFp16Value + 1); + } + return kMaxFp16Value + 1; + } + return a; + }; + int differences_seen = 0; + for (int64 i = 0; i < n && differences_seen < 10; i++) { + float original_ref = static_cast<float>(host_ref_buffer[i]); + float original_test = static_cast<float>(host_test_buffer[i]); + float ref = canonicalize(original_ref); + float test = canonicalize(original_test); + if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) < + kTolerance)) { + differences_seen++; + LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs " + << original_test; + } + } + + return false; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h new file mode 100644 index 0000000000..bf2ba78cea --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ + +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A fp16 comparator that internally keeps a reference buffer, and compares it +// against other test buffers. +class F16BufferComparator { + public: + F16BufferComparator(const F16BufferComparator&) = delete; + F16BufferComparator(F16BufferComparator&&) = default; + + // Creates a new comparator. It internally allocates a buffer initialized by + // ref_buffer. + static StatusOr<F16BufferComparator> Create( + se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler, + DeviceMemoryAllocator* allocator, se::Stream* stream); + + // Returns true if the internally allocated buffer "compares equal" to + // test_buffer. The definition of "equal" is: + // * All NaNs equal. + // * All infs are treated as 65505 or -65505, so that this checker is tolerant + // to fp16 overflows. + // * With NaNs and infs taken care of, a and b compare equal iff: + // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance + // + // See the implementation for the tolerance value. + StatusOr<bool> CompareEqual(se::DeviceMemory<Eigen::half> test_buffer); + + private: + F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator, + std::unique_ptr<Executable> exec, + ScopedShapedBuffer ref_buffer) + : stream_(stream), + allocator_(allocator), + exec_(std::move(exec)), + ref_buffer_(std::move(ref_buffer)) {} + + StatusOr<bool> CompareEqualImpl(se::DeviceMemory<Eigen::half> test_buffer); + + se::Stream* stream_; + DeviceMemoryAllocator* allocator_; + std::unique_ptr<Executable> exec_; + ScopedShapedBuffer ref_buffer_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_ diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc new file mode 100644 index 0000000000..33761d1bd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" + +#include <limits> +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class BufferComparatorTest : public testing::Test { + protected: + BufferComparatorTest() + : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()), + stream_exec_(backend_->default_stream_executor()), + allocator_(stream_exec_->platform(), {stream_exec_}), + compiler_(Compiler::GetForPlatform(stream_exec_->platform()) + .ConsumeValueOrDie()) {} + + // Take floats only for convenience. Still uses half internally. + bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float, + const std::vector<float>& rhs_float) { + std::vector<half> lhs(lhs_float.begin(), lhs_float.end()); + std::vector<half> rhs(rhs_float.begin(), rhs_float.end()); + se::Stream stream(stream_exec_); + stream.Init(); + + auto owning_lhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto owning_rhs_buffer = + allocator_ + .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half)) + .ConsumeValueOrDie(); + + auto lhs_buffer = + se::DeviceMemory<Eigen::half>(owning_lhs_buffer.AsDeviceMemoryBase()); + auto rhs_buffer = + se::DeviceMemory<Eigen::half>(owning_rhs_buffer.AsDeviceMemoryBase()); + + stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size()); + stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size()); + + TF_CHECK_OK(stream.BlockHostUntilDone()); + + return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_, + &stream) + .ConsumeValueOrDie() + .CompareEqual(rhs_buffer) + .ConsumeValueOrDie(); + } + + std::unique_ptr<Backend> backend_; + se::StreamExecutor* stream_exec_; + StreamExecutorMemoryAllocator allocator_; + Compiler* compiler_; +}; + +TEST_F(BufferComparatorTest, TestNaNs) { + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")})); + // NaN values with different bit patterns should compare equal. + EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")})); + EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.})); +} + +TEST_F(BufferComparatorTest, TestInfs) { + const auto inf = std::numeric_limits<float>::infinity(); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf})); + EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504})); + EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504})); + + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20})); + EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20})); +} + +TEST_F(BufferComparatorTest, TestNumbers) { + EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1})); + EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1})); + EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10})); + EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9})); +} + +TEST_F(BufferComparatorTest, TestMultiple) { + EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60}, + {20.1, 30.1, 40.1, 50.1, 60.1})); + std::vector<float> lhs(200); + std::vector<float> rhs(200); + for (int i = 0; i < 200; i++) { + EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the same at index " << i; + lhs[i] = 3; + rhs[i] = 5; + EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs)) + << "should be the different at index " << i; + lhs[i] = 0; + rhs[i] = 0; + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 7348307ec8..7d93bdfc8b 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -30,7 +30,6 @@ namespace { using se::DeviceMemoryBase; using se::dnn::AlgorithmConfig; using se::dnn::AlgorithmDesc; -using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; class ScratchAllocator : public se::ScratchAllocator { @@ -173,7 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional<std::tuple<int64, bool, int64>> +StatusOr<std::tuple<int64, bool, int64>> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, @@ -206,45 +205,25 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - // - // We don't put any data in these buffers, because (in theory, anyway) the - // speed of a conv isn't affected by the data being convolved. ScratchAllocator input_output_allocator(device_ordinal, allocator); - StatusOr<DeviceMemoryBase> maybe_input_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(input_shape)); - StatusOr<DeviceMemoryBase> maybe_filter_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(filter_shape)); - StatusOr<DeviceMemoryBase> maybe_output_buf = - input_output_allocator.AllocateBytes(&stream, - ShapeUtil::ByteSizeOf(output_shape)); - if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() || - !maybe_output_buf.ok()) { - LOG(WARNING) - << "Couldn't allocate space for input/filter/output of convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } - - DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie(); - DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie(); - DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie(); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(input_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(filter_shape))); + TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf, + input_output_allocator.AllocateBytes( + &stream, ShapeUtil::ByteSizeOf(output_shape))); // Although we don't have evidence this matters, zero out the buffers before // autotuning. It's conceivable that using uninitialized memory as the inputs // might affect performance if e.g. the inputs contain denormals, and this is // easy enough. - if (!stream.ThenMemZero(&input_buf, input_buf.size()) - .ThenMemZero(&filter_buf, filter_buf.size()) - .ThenMemZero(&output_buf, output_buf.size()) - .BlockHostUntilDone() - .ok()) { - LOG(WARNING) - << "Couldn't zero out input/filter/output buffer for convolution " - << instr->ToString() << ". Falling back to default algorithm."; - return nullopt; - } + TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size()) + .ThenMemZero(&filter_buf, filter_buf.size()) + .ThenMemZero(&output_buf, output_buf.size()) + .BlockHostUntilDone()); const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo( input_shape, output_shape, dnums, stream_exec_); @@ -292,9 +271,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( best_result_bytes_used); } - LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() - << " failed. Falling back to default algorithm."; - return nullopt; + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr->ToString().c_str()); } StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( @@ -305,12 +285,13 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( const auto& lhs_shape = instr->operand(0)->shape(); const auto& rhs_shape = instr->operand(1)->shape(); const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc; + StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc; if (call_target == kCudnnConvForwardCallTarget) { - alg_scratch_and_tc = PickBestAlgorithm( - CudnnConvKind::kForward, /*input_shape=*/lhs_shape, - /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, - instr->window(), instr->convolution_dimension_numbers(), instr); + alg_scratch_and_tc = + PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, @@ -326,7 +307,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( << instr->ToString(); } - if (!alg_scratch_and_tc.has_value()) { + if (!alg_scratch_and_tc.ok()) { + LOG(ERROR) << alg_scratch_and_tc.status(); return false; } @@ -334,7 +316,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction( bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = + alg_scratch_and_tc.ConsumeValueOrDie(); VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index bc5d1ce94a..8b7749628a 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* allocator) - : stream_exec_(stream_exec), allocator_(allocator) {} + DeviceMemoryAllocator* allocator, + Compiler* compiler) + : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {} tensorflow::StringPiece name() const override { return "cudnn-convolution-algorithm-picker"; @@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm( + StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); se::StreamExecutor* stream_exec_; // never null DeviceMemoryAllocator* allocator_; // may be null + Compiler* compiler_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 69ba91793d..9b6de115ad 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -210,11 +210,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp( return make_sqrt(); } - if (hlo_module_config_.debug_options().xla_enable_fast_math() && - IsFPLiteralWithValue(rhs, -.5)) { + if (IsFPLiteralWithValue(rhs, -.5)) { VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString(); // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. + // + // TODO(jlebar): Does this happen with fastmath disabled? If not, should + // we force-enable it? TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } @@ -274,16 +276,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2( StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh( PrimitiveType prim_type, llvm::Value* value) const { - // If we don't care much about precision, emit a fast approximation of - // tanh. - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - // Upcast F16 to F32 if necessary. - llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); - llvm::Value* input = b_->CreateFPCast(value, type); - llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); - return b_->CreateFPCast(fast_tanh, value->getType()); - } - return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type); + // Emit a fast approximation of tanh instead of calling __nv_tanh. + // __nv_tanh is particularly bad because it contains branches, thus + // preventing LLVM's load-store vectorizer from working its magic across a + // function which contains tanh calls. + // + // This routine isn't numerically precise, but it's good enough for ML. + + // Upcast F16 to F32 if necessary. + llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType(); + llvm::Value* input = b_->CreateFPCast(value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, value->getType()); } llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 66aeb4efef..6675dbd3f9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -64,7 +64,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, hlo_module_config_(hlo_module_config) { b_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math())); + .xla_gpu_enable_fast_math())); } Status IrEmitter::DefaultAction(HloInstruction* hlo) { diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc index cf44458a2e..ff4ae1f9ef 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc @@ -180,7 +180,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine( TargetOptions target_options = InitTargetOptionsFromCodeGenFlags(); llvm_ir::SetTargetOptions( /*fast_math_enabled=*/hlo_module_config.debug_options() - .xla_enable_fast_math(), + .xla_gpu_enable_fast_math(), &target_options); // Enable FMA synthesis. diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 76c9b6ab33..d937123357 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -72,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -130,8 +131,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. +// +// It takes a compiler pointer, as passes may compile and execute HLOs on the +// fly for cuDNN verification or other purposes. Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, - DeviceMemoryAllocator* device_allocator) { + DeviceMemoryAllocator* device_allocator, + Compiler* compiler) { { HloPassPipeline pipeline("optimization"); pipeline.AddInvariantChecker<HloVerifier>(); @@ -167,6 +172,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass<ZeroSizedHloElimination>(); + pipeline.AddPass<ScatterExpander>(); + pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -245,8 +252,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec, - device_allocator); + pipeline.AddPass<CudnnConvolutionAlgorithmPicker>( + stream_exec, device_allocator, compiler); // Clean up new_tuple described above. pipeline.AddPass<TupleSimplifier>(); @@ -492,11 +499,15 @@ NVPTXCompiler::NVPTXCompiler() StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses( std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec, DeviceMemoryAllocator* device_allocator) { + // We dump the post-optimization HLO in RunBackend so no need to dump it here. + VLOG(2) << "*** HLO Before Optimization"; + XLA_VLOG_LINES(2, module->ToString()); + XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses"); tracing::ScopedActivity activity("HLO Transforms", module->name(), /*is_expensive=*/true); TF_RETURN_IF_ERROR( - OptimizeHloModule(module.get(), stream_exec, device_allocator)); + OptimizeHloModule(module.get(), stream_exec, device_allocator, this)); return std::move(module); } @@ -548,6 +559,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend( // include headers, so no need for us to print them ourselves. XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); + VLOG(2) << "*** HLO After Optimization"; XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_optimized_hlo_proto_to = module->config().debug_options().xla_dump_optimized_hlo_proto_to(); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 90d2be118d..858992a326 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); } +StatusOr<HloInstruction*> MakeMapHlo( + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* map_computation) { + CHECK(!operands.empty()) << "Map Hlo requires at least one operand."; + HloComputation* computation = operands.front()->parent(); + std::vector<const Shape*> operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : operands) { + CHECK_EQ(computation, operand->parent()); + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + std::vector<int64> map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + TF_ASSIGN_OR_RETURN( + Shape map_shape, + ShapeInference::InferMapShape( + operand_shapes, map_computation->ComputeProgramShape(), map_dims)); + return computation->AddInstruction( + HloInstruction::CreateMap(map_shape, operands, map_computation)); +} + StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) { CHECK_GT(n, 0); @@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand, return MakeReshapeHlo(output_shape, operand); } +StatusOr<HloInstruction*> InsertDegenerateDims( + HloInstruction* operand, ArraySlice<int64> dims_to_insert) { + CHECK(c_is_sorted(dims_to_insert)); + + const Shape& operand_shape = operand->shape(); + int64 output_shape_rank = + operand_shape.dimensions_size() + dims_to_insert.size(); + for (auto dim_to_insert : dims_to_insert) { + CHECK_LT(dim_to_insert, output_shape_rank); + } + + std::vector<int64> output_shape_dim_bounds; + output_shape_dim_bounds.reserve(output_shape_rank); + int64 operand_dims_idx = 0; + int64 dims_to_insert_idx = 0; + for (int64 i = 0; i < output_shape_rank; ++i) { + if (dims_to_insert_idx < dims_to_insert.size() && + i == dims_to_insert[dims_to_insert_idx]) { + output_shape_dim_bounds.push_back(1); + ++dims_to_insert_idx; + } else { + output_shape_dim_bounds.push_back( + operand_shape.dimensions(operand_dims_idx)); + ++operand_dims_idx; + } + } + + Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(), + output_shape_dim_bounds); + return MakeReshapeHlo(output_shape, operand); +} + StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand, int64 zeros_to_prepend, int64 zeros_to_append) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 49b1402d68..5ff8946fb0 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -102,6 +102,12 @@ StatusOr<HloInstruction*> MakeConcatHlo( StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dim_numbers); +// Creates a Map HLO instruction and adds it to the computation containing the +// operands. All operands must be in the same computation. +StatusOr<HloInstruction*> MakeMapHlo( + tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* map_computation); + // ----------------------------------------------------------------------------- // Some other miscellaneous helpers to generate common HLO patterns. All of // these add all the instructions they generate into the computation containing @@ -144,6 +150,16 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims( StatusOr<HloInstruction*> ElideDegenerateDims( HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide); +// Inserts (via reshape) a set of degenerate dimensions (dimensions containing +// exactly one element), `dims_to_insert` into `operand`. The dimensions in +// `dims_to_insert` refer to the dimensions in the result, and hence should be +// less than the rank of the result. Also, `dims_to_insert` must be sorted. +// +// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is +// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34]. +StatusOr<HloInstruction*> InsertDegenerateDims( + HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert); + // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the // front and `zeros_to_append` zeros in the back. StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc index 71b44507cc..8e0d38b6a6 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; - case '/': - return LexComment(); + case '/': { + if (PeekCurrentChar() == '*') { + // This is the start of a /*...*/ delimited comment. Save the current + // location in case the comment is unterminated so the error message + // will point to the beginning of the comment. + const char* comment_start = current_ptr_; + current_ptr_++; + // Advance until '*/' is found. + while (true) { + int current = GetNextChar(); + if (current == '*' && PeekCurrentChar() == '/') { + // End of comment. + current_ptr_++; + break; + } + if (current == kEOF) { + // Unterminated comment. + current_ptr_ = comment_start; + return TokKind::kError; + } + } + // Return no token for the comment. Keep lexing. + continue; + } else if (PeekCurrentChar() == '/') { + // This is the start of a '//' delimited comment. Throw away + // everything until end of line or file. The end-of-line character(s) + // are left unlexed in the buffer which is harmless because these are + // skipped later by the lexer. This approach enables support for + // different end-of-line encodings. + while (true) { + int current = PeekCurrentChar(); + if (current == kEOF || current == '\n' || current == '\r') { + break; + } + current_ptr_++; + } + continue; + } + // A lone '/' is an error. + return TokKind::kError; + } case '"': return LexString(); } @@ -357,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { return StringPieceFromPointers(start, end); } -TokKind HloLexer::LexComment() { - auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); - static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; - if (RE2::Consume(&consumable, *comment_pattern)) { - current_ptr_ = consumable.begin(); - return TokKind::kComment; - } - return TokKind::kError; -} - // Lexes quoted string with escaping characters. If matched, the quoted string // will be unescaped and stored to str_val_. TokKind HloLexer::LexString() { @@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) { return "kRparen"; case TokKind::kArrow: return "kArrow"; - case TokKind::kComment: - return "kComment"; case TokKind::kw_HloModule: return "kw_HloModule"; case TokKind::kw_ENTRY: diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h index ceb674f25e..003ac34ace 100644 --- a/tensorflow/compiler/xla/service/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -105,7 +105,6 @@ class HloLexer { TokKind LexShape(); TokKind LexConstant(); TokKind LexNumberOrPattern(); - TokKind LexComment(); TokKind LexString(); const tensorflow::StringPiece buf_; diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index b57c940238..c577b4359a 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -231,6 +231,7 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); +HLO_MATCHER(TupleSelect); HLO_MATCHER(While); // The special cases below let you check additional information about the diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 84f2d3f5fb..1b256cd00e 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -166,7 +166,7 @@ class HloModuleGroupMetadata { // // Precondition: IsCompanionWhile(instruction) is true. const std::unordered_set<HloInstruction*>& Companions( - HloInstruction* instruction) const { + const HloInstruction* instruction) const { CHECK_EQ(companion_set_index_.count(instruction), 1); return companion_set(companion_set_index_.at(instruction)); } @@ -243,7 +243,7 @@ class HloModuleGroupMetadata { companion_sets_; // Map from each companion while instruction to the index into companion_set_. - tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_; + tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_; // Map from computation to the instruction using it (a kWhile, kConditional). tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction> diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc index 9fd0ade153..0dc5676148 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -37,24 +38,38 @@ namespace xla { std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( HloInstruction* instruction) { - std::vector<HloInstruction*> predecessors; - - // Adds to the unique predecessors list and also add companion instructions - // if the given predecessor has those. + std::vector<HloInstruction*> + predecessors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet<HloInstruction*> unique; + + // Adds to the unique predecessors list; if the predecessors is a companion + // instruction, also add companion instructions; if the predecessors is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_predecessor = [&](HloInstruction* predecessor) { - if (std::find(predecessors.begin(), predecessors.end(), predecessor) != - predecessors.end()) { + if (unique.find(predecessor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(predecessor)) { - predecessors.push_back(predecessor); + if (metadata_.IsCompanionInstruction(predecessor)) { + for (HloInstruction* instr : metadata_.Companions(predecessor)) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(predecessor)) { - predecessors.push_back(companion); + if (predecessor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) { + if (unique.insert(instr).second) { + predecessors.push_back(instr); + } + } + return; } + unique.insert(predecessor); + predecessors.push_back(predecessor); }; - // If the given instruction is a companion instruction, we need to find the // predecessors of all of its companion instructions. If the instruction is an // all-reduce, we need to find the predecessors of all the peer all-reduce @@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors( std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors( HloInstruction* instruction) { - std::vector<HloInstruction*> successors; - - // Adds to the unique successors list and also add companion instructions - // if the given successor has those. + std::vector<HloInstruction*> + successors; // Use a vector to avoid non-determinism. + tensorflow::gtl::FlatSet<HloInstruction*> unique; + + // Adds to the unique successors list; if the successor is a companion + // instruction, also add companion instructions; if the successor is a + // cross-module all-reduce, also add the all-reduce instructions in the same + // group. auto add_unique_successor = [&](HloInstruction* successor) { - if (std::find(successors.begin(), successors.end(), successor) != - successors.end()) { + if (unique.find(successor) != unique.end()) { return; } - if (!metadata_.IsCompanionInstruction(successor)) { - successors.push_back(successor); + if (metadata_.IsCompanionInstruction(successor)) { + for (HloInstruction* instr : metadata_.Companions(successor)) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } return; } - for (HloInstruction* companion : metadata_.Companions(successor)) { - successors.push_back(companion); + if (successor->IsCrossModuleAllReduce()) { + for (HloInstruction* instr : + metadata_.GetAllReduceGroup(*successor->all_reduce_id())) { + if (unique.insert(instr).second) { + successors.push_back(instr); + } + } + return; } + unique.insert(successor); + successors.push_back(successor); }; // If the given instruction is a companion instruction, we need to find the diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2a8c6ecd92..4b3cd99dc0 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1824,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal, break; } case TokKind::kComma: - case TokKind::kComment: // Skip. lexer_.Lex(); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 4cd21841f4..5990a3d478 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1560,6 +1560,81 @@ ENTRY consts { "last"); } +TEST_F(HloParserTest, Comments) { + const string original = R"(/* module description. */ +HloModule comments: + +ENTRY /*comment*/ c1 { + /* blah */ + ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/}) + /* comment */ +} + +/* something else */ + +)"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, MultilineComments) { + const string original = R"(HloModule multiline_comment: +ENTRY c1 { + /* + ROOT foo = f32[1]{0} constant({12345}) + */ + ROOT const1 = f32[1]{0} constant({12345}) +/* +a +b +c +d + +*/ +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, UnterminatedComment) { + const string original = R"(HloModule unterminated_comment: +ENTRY c1 { +/* unterminated + ROOT const1 = f32[1]{0} constant({12345}) +})"; + // Verify that the error message points to the beginning of the unterminated + // comment. + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "/* unterminated\n^"); +} + +TEST_F(HloParserTest, SlashSlashComments) { + const string original = R"(HloModule slash_slash_comment: +// Garbage +ENTRY c1 { + // Foo bar + ROOT const1 = f32[1]{0} constant({12345}) // Something else +})"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) { + const string original = + "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo " + "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + +TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) { + const string original = + "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo " + "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}"; + auto module = ParseHloString(original); + TF_ASSERT_OK(module.status()); +} + TEST_F(HloParserTest, MultipleEntries) { const string original = R"(HloModule multiple_entries: ENTRY c1 { diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h index 533429608b..4458c251de 100644 --- a/tensorflow/compiler/xla/service/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -44,7 +44,6 @@ enum class TokKind { kRparen, // ( ) kArrow, // -> - kComment, // /*xxx*/ // Keywords kw_HloModule, diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 9b109022fb..db6b910b32 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { } // No "synchronize all activity" implemented for this platform at the moment. - bool SynchronizeAllActivity() override { return false; } + bool SynchronizeAllActivity() override { return true; } bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override { return false; } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc new file mode 100644 index 0000000000..45ca731153 --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -0,0 +1,350 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/scatter_expander.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_creation_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +using tensorflow::gtl::ArraySlice; + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64 index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + + if (scatter_indices_shape.dimensions_size() == index_vector_dim) { + return scatter_indices; + } + + if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector<int64> permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +static StatusOr<HloInstruction*> CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64 index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } else { + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); + } +} + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +static StatusOr<HloInstruction*> PermuteScatterAndWindowDims( + HloInstruction* updates, ArraySlice<int64> update_window_dims) { + std::vector<int64> permutation; + const int64 updates_rank = ShapeUtil::Rank(updates->shape()); + permutation.reserve(updates_rank); + + for (int64 i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (auto window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +static StatusOr<HloInstruction*> AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64 index_vector_dim) { + int64 num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +// Expands an index vector from the scatter_indices tensor into a vector that +// can be used to dynamic-update-slice to perform the scatter update. +static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace( + HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, + int64 operand_rank) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector<HloInstruction*> expanded_index_components; + + for (int i = 0; i < operand_rank; i++) { + int64 index_vector_dim_index = + FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); + if (index_vector_dim_index != + dim_numbers.scatter_dims_to_operand_dims_size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + expanded_index_components.push_back(zero); + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + +// Body of the while loop that performs the scatter operation using other HLOs. +static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody( + HloInstruction* scatter, HloInstruction* induction_var, + const std::vector<HloInstruction*>& loop_state) { + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK_EQ(loop_state.size(), 3); + HloInstruction* operand = loop_state[0]; + HloInstruction* scatter_indices = loop_state[1]; + HloInstruction* updates = loop_state[2]; + + bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1; + CHECK_EQ(has_scalar_indices, + dim_numbers.index_vector_dim() == + scatter->operand(1)->shape().dimensions_size()); + + // Build a vector form of the induction variable of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * induction_var_as_vector, + MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + + // Pick the index to scatter from scatter_indices based on the induction_var + // and transform that to an index into the `operand` space. + HloInstruction* index_vector; + if (has_scalar_indices) { + TF_ASSIGN_OR_RETURN( + index_vector, + MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1})); + } else { + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_scatter_indices, + PadVectorWithZeros(induction_var_as_vector, + /*zeros_to_prepend=*/0, /*zeros_to_append=*/1)); + int index_vector_size = scatter_indices->shape().dimensions(1); + TF_ASSIGN_OR_RETURN( + HloInstruction * index_vector_2d, + MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices, + {1, index_vector_size})); + TF_ASSIGN_OR_RETURN(index_vector, + ElideDegenerateDims(index_vector_2d, {0})); + } + TF_ASSIGN_OR_RETURN( + HloInstruction * scatter_slice_start, + ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, + operand->shape().dimensions_size())); + + // Extract the slice to be used to update from `updates` tensor for the + // induction_var corresponding to this iteration of the while loop. + TF_ASSIGN_OR_RETURN( + HloInstruction * index_into_updates, + PadVectorWithZeros( + induction_var_as_vector, /*zeros_to_prepend=*/0, + /*zeros_to_append=*/updates->shape().dimensions_size() - 1)); + std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(), + updates->shape().dimensions().end()); + update_slice_bounds[0] = 1; + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice, + MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds)); + TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter, + ElideDegenerateDims(update_slice, {0})); + TF_ASSIGN_OR_RETURN( + HloInstruction * update_slice_with_dims_inserted, + InsertDegenerateDims(update_slice_for_scatter, + AsInt64Slice(dim_numbers.inserted_window_dims()))); + + // Extact the slice to update from `operand` tensor. + const Shape& update_slice_shape = update_slice_with_dims_inserted->shape(); + TF_ASSIGN_OR_RETURN( + HloInstruction * operand_slice_to_update, + MakeDynamicSliceHlo(operand, scatter_slice_start, + AsInt64Slice(update_slice_shape.dimensions()))); + + // Compute the new value for the slice to be updated in `operand` tensor by + // combining the existing value and the update value using the update + // computation. + TF_ASSIGN_OR_RETURN( + HloInstruction * updated_operand_slice, + MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted}, + scatter->to_apply())); + + // Write the updated value of the slice into `operand` tensor. + TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand, + MakeDynamicUpdateSliceHlo(operand, updated_operand_slice, + scatter_slice_start)); + + return StatusOr<std::vector<HloInstruction*>>{ + {updated_operand, scatter_indices, updates}}; +} + +// High Level Algorithm. +// +// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where +// each row is an index into the operand. +// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1` +// and the scatter dim is the most-major dimension. +// 3. Iterate over the set of indices in the canonicalized scatter_indices +// tensor using a while loop, updating the operand for each such index. Each +// iteration of this while loop performs the following: +// a. Pick the index from scatter_indices for this iteration. +// b. Transfrom this index into an index into the operand space. +// c. Extract the slice to be used to update from the updates tensor. +// d. Extract the slice to update from the operand tensor. +// e. Compute the new value for the slice to update by combining the slices +// from c. and d. using the update_computation of scatter. +// f. Write the updated value of the slice into the operand tensor. + +StatusOr<HloInstruction*> ScatterExpander::ExpandScatter( + HloInstruction* scatter) { + HloInstruction* operand = scatter->mutable_operand(0); + HloInstruction* scatter_indices = scatter->mutable_operand(1); + HloInstruction* updates = scatter->mutable_operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensor is empty, there is no need to update the operand. We + // can return the operand as is. + if (ShapeUtil::IsZeroElementArray(updates->shape())) { + return operand; + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const Shape& scatter_indices_shape = scatter_indices->shape(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + if (!IsInt32(scatter_loop_trip_count)) { + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString().c_str()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + canonical_scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of its most-major dimension + // must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_updates, + PermuteScatterAndWindowDims( + updates, AsInt64Slice(dim_numbers.update_window_dims()))); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_canonical_updates, + AdjustScatterDims(scatter_indices->shape(), canonical_updates, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, + adjusted_canonical_updates->shape().dimensions(0)); + + // The while loop that implements the scatter operation. + StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status = + WhileUtil::MakeCountedLoop( + scatter->parent(), scatter_loop_trip_count, + {operand, canonical_scatter_indices, adjusted_canonical_updates}, + [&](HloInstruction* induction_var, + const std::vector<HloInstruction*>& loop_state) { + return ScatterLoopBody(scatter, induction_var, loop_state); + }); + TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result, + scatter_loop_result_status); + return scatter_loop_result.front(); +} + +StatusOr<bool> ScatterExpander::Run(HloModule* module) { + std::vector<HloInstruction*> scatter_instrs; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + for (HloInstruction* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kScatter) { + scatter_instrs.push_back(instr); + } + } + } + + for (auto instr : scatter_instrs) { + TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); + TF_RETURN_IF_ERROR( + instr->parent()->ReplaceInstruction(instr, expanded_root)); + } + + return !scatter_instrs.empty(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h new file mode 100644 index 0000000000..8f735e877d --- /dev/null +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +class ScatterExpander : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "scatter_expander"; } + StatusOr<bool> Run(HloModule* module) override; + + private: + StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 34869cc507..b69c346f1e 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { } /* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + if (!IsTuple(shape)) { + return 1; + } int64 count = 0; - ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { - if (IsLeafIndex(shape, index)) { - ++count; - } - }); + for (const Shape& subshape : shape.tuple_shapes()) { + count += GetLeafCount(subshape); + } return count; } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 42d52aee78..0f8cffd466 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -709,6 +709,21 @@ xla_test( ], ) +xla_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":client_library_test_base", + ":hlo_test_base", + "//tensorflow/compiler/xla:execution_options_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + # Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", @@ -2061,6 +2076,8 @@ tf_cc_test( xla_test( name = "test_utils_test", srcs = ["test_utils_test.cc"], + # There is nothing backend specific in this test, so just pick an arbitrary backend. + backends = ["cpu"], deps = [ ":local_client_test_base", ":test_utils", diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 4a6e8a3124..b04a3b105c 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test { string TestName() const; void SetFastMathDisabled(bool disabled) { - execution_options_.mutable_debug_options()->set_xla_enable_fast_math( - !disabled); + auto* opts = execution_options_.mutable_debug_options(); + opts->set_xla_cpu_enable_fast_math(!disabled); + opts->set_xla_gpu_enable_fast_math(!disabled); } void SetSeed(uint64 seed) { execution_options_.set_seed(seed); } diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 73edad89dc..92c93f08b2 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -1464,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] { EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); } +XLA_TEST_F(HloTestBase, ReduceWindowF16) { + const string hlo_string = R"( +HloModule reduce-window + +%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] { + %param0 = f16[] parameter(0) + ROOT %param1 = f16[] parameter(1) +} + +ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] { + %parameter.0 = f16[81,8]{1,0} parameter(0) + %parameter.1 = f16[] parameter(1) + ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window +} + +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc new file mode 100644 index 0000000000..922d70b752 --- /dev/null +++ b/tensorflow/compiler/xla/tests/scatter_test.cc @@ -0,0 +1,615 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class ScatterTest : public HloTestBase { + protected: + void RunTest(const string& hlo_text, Literal* operand, + Literal* scatter_indices, Literal* updates) { + RunTest(hlo_text, {operand, scatter_indices, updates}); + } + + void RunTest(const string& hlo_text, + tensorflow::gtl::ArraySlice<Literal*> args) { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(hlo_text, config)); + EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); + } +}; + +XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) { + const char* hlo_text = R"( +HloModule TensorFlowScatterV2 + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Add + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) { + const string hlo_text = R"( +HloModule TensorFlowScatter_Mul + +mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=mul_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) { + const string hlo_text = R"( +HloModule TensorFlowScatter_F32 + +add_f32 (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(f32[] lhs, f32[] rhs) +} + +ENTRY main { + operand = f32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = f32[2,3] parameter(2) + ROOT scatter = f32[3,3] scatter(operand, indices, updates), + to_apply=add_f32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>( + {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({2, 1}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) { + const char* hlo_text = R"( +HloModule TensorFlowScatter + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({1, 1}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) { + const char* hlo_text = R"( +HloModule TensorFlowScatterMultipleBatchDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,3,2] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={1}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=2 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNd + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) { + const char* hlo_text = R"( +HloModule TensorFlowScatterNdNonDefaultIndexVectorDim + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3,2] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3,3,2] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0,1}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, // + {{-4, 4}, {-5, 5}, {-6, 6}}, // + {{-7, 7}, {-8, 8}, {-9, 9}}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, DynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule DynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={0,1}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({1, 1}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) { + const char* hlo_text = R"( +HloModule BatchDynamicUpdateSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3] parameter(0) + indices = s32[2,2] parameter(1) + updates = s32[2,1,1] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=0 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR3<int32>({{{10}}, {{20}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ZeroDimBounds) { + const char* hlo_text = R"( +HloModule TensorFlowScatter_ZeroDimBounds + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,0] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,0] parameter(2) + ROOT scatter = s32[3,0] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR1<int32>({0, 2}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NoUpdateWindowDims) { + const string hlo_text = R"( +HloModule Scatter_NoUpdateWindowDims + +add_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT add = s32[] add(s32[] lhs, s32[] rhs) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[2,2,1] parameter(1) + updates = s32[2,2] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=add_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=2 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2}); + std::unique_ptr<Literal> scatter_indices = + LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}}); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = u32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>( + {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, NegativeIndex) { + const string hlo_text = R"( +HloModule BatchDynamicSlice + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3,3]{1,0} parameter(0) + indices = s32[6,2]{1,0} parameter(1) + updates = s32[6,1,1]{2,1,0} parameter(2) + ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = + LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>( + {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>( + {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, OneScalarIndex) { + const char* hlo_text = R"( +HloModule OneScalarIndex + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[2,3,2]{2,1,0} parameter(0) + index = s32[] parameter(1) + updates = s32[1,3,2]{2,1,0} parameter(2) + ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={0,1,2}, + inserted_window_dims={}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>( + {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1); + std::unique_ptr<Literal> updates = + LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, ScalarUpdate) { + const char* hlo_text = R"( +HloModule ScalarUpdate + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[4]{0} parameter(0) + index = s32[] parameter(1) + updates = s32[] parameter(2) + ROOT scatter = s32[4]{0} scatter(operand, index, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=0 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +XLA_TEST_F(ScatterTest, EmptyIndices) { + const string hlo_text = R"( +HloModule EmptyIndices + +update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + ROOT rhs = s32[] parameter(1) +} + +ENTRY main { + operand = s32[3] parameter(0) + indices = s32[0] parameter(1) + updates = s32[0] parameter(2) + ROOT scatter = s32[3] scatter(operand, indices, updates), + to_apply=update_s32, + update_window_dims={}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 +} +)"; + std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3}); + std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({}); + std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({}); + RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 2647937013..faeec657b6 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) { // Generate random values that are constrained to the input_shape minus the // output_shape so as not to produce wrapping slices, for instance. -std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex( - const Shape& input_shape, const Shape& slice_shape, - std::minstd_rand0* engine) { - const int64 rank = ShapeUtil::Rank(input_shape); - std::vector<int32> start_indices(rank); +std::unique_ptr<Literal> MakeRandomIndex( + tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) { + std::vector<int32> start_indices(index_space.size()); if (engine != nullptr) { - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution<int32> generator(0, upper_bound); + for (int i = 0; i < index_space.size(); ++i) { + std::uniform_int_distribution<int32> generator(0, index_space[i]); start_indices[i] = generator(*engine); } } @@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses( StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses, const HloInstruction& param, std::minstd_rand0* engine) { - HloInstruction* needs_index = nullptr; - HloInstruction* needs_constant = nullptr; + std::vector<int64> index_space; + bool needs_constant = false; ConstantType constant_type = ConstantType::kUnknown; for (HloInstruction* use : constrained_uses) { switch (use->opcode()) { case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - if (needs_index != nullptr) { - auto needs_index_shape = needs_index->shape(); - auto use_shape = use->shape(); - if (needs_index->opcode() == HloOpcode::kDynamicSlice) { - needs_index_shape = needs_index->operand(0)->shape(); - } - if (use->opcode() == HloOpcode::kDynamicSlice) { - use_shape = use->operand(0)->shape(); + case HloOpcode::kDynamicUpdateSlice: { + const Shape& indexed_shape = use->operand(0)->shape(); + const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice + ? use->shape() + : use->operand(1)->shape(); + const int64 rank = ShapeUtil::Rank(indexed_shape); + if (!index_space.empty()) { + TF_RET_CHECK(rank == index_space.size()); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = std::min( + index_space[i], ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i)); } - if (!ShapeUtil::Equal(needs_index_shape, use_shape)) { - return Unimplemented( - "Conflicting operand generation slice index constraints\n"); + } else { + index_space.resize(rank); + for (int64 i = 0; i < rank; ++i) { + index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); } } - needs_index = use; break; + } case HloOpcode::kReduce: case HloOpcode::kReduceWindow: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->to_apply()); break; case HloOpcode::kSelectAndScatter: - needs_constant = use; + needs_constant = true; constant_type = GetInitValue(*use->scatter()); break; @@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses( use->ToString().c_str()); } } - if (needs_index != nullptr && needs_constant != nullptr) { + if (!index_space.empty() && needs_constant) { return Unimplemented( - "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " - "constant: %s\n", - needs_index->ToString().c_str(), needs_constant->ToString().c_str()); + "Conflicting operand generation constraints. Dynamically indexes a " + "shape and is the init value of a reduction."); } - if (needs_index != nullptr) { - return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), - needs_index->shape(), engine); - } else if (needs_constant != nullptr) { + if (!index_space.empty()) { + return MakeRandomIndex(index_space, engine); + } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique(); @@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments( auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr; std::vector<std::unique_ptr<Literal>> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( - *dataflow, *params[i], engine.get())); + arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get()) + .ValueOrDie(); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc index a2f0338e25..64d9e2031e 100644 --- a/tensorflow/compiler/xla/tests/test_utils_test.cc +++ b/tensorflow/compiler/xla/tests/test_utils_test.cc @@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) { TF_ASSERT_OK(MakeFakeArguments(module.get()).status()); } +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3} + ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2} + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 3); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get<int32>({0}), 0); + + EXPECT_GE(index_arg.Get<int32>({1}), 0); + EXPECT_LE(index_arg.Get<int32>({1}), 2); + + EXPECT_GE(index_arg.Get<int32>({2}), 0); + EXPECT_LE(index_arg.Get<int32>({2}), 3); +} + +XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) { + auto module = ParseHloString( + R"(HloModule index_space_module + + ENTRY IndexSpace { + index_param = s32[3]{0} parameter(0) + array_param.1 = f32[123,4,789]{0,1,2} parameter(1) + array_param.2 = f32[3,3000,5]{0,1,2} parameter(2) + update_param.1 = f32[1,2,3]{0,1,2} parameter(3) + update_param.2 = f32[3,2,2]{0,1,2} parameter(4) + + dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param) + ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param) + })") + .ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args, + MakeFakeArguments(module.get())); + ASSERT_EQ(args.size(), 5); + const Literal& index_arg = *args[0]; + + EXPECT_EQ(index_arg.Get<int32>({0}), 0); + + EXPECT_GE(index_arg.Get<int32>({1}), 0); + EXPECT_LE(index_arg.Get<int32>({1}), 2); + + EXPECT_GE(index_arg.Get<int32>({2}), 0); + EXPECT_LE(index_arg.Get<int32>({2}), 3); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 10c0adc670..3b72eb17c6 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -104,15 +104,6 @@ message DebugOptions { // interpretation of this value is left to the backends. int32 xla_backend_optimization_level = 31; - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf. - // - Assuming that +0 and -0 are indistinguishable. - bool xla_enable_fast_math = 32; - // Embed the compiler IR as a string in the executable. bool xla_embed_ir_in_executable = 33; @@ -194,6 +185,16 @@ message DebugOptions { // Maximum kernel unroll factor for the GPU backend. int32 xla_gpu_max_kernel_unroll_factor = 98; + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf. + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + bool xla_gpu_enable_fast_math = 100; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map<string, string> xla_backend_extra_options = 500; |