aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/jit/BUILD7
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc8
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc3
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc102
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.h8
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc70
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc40
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h35
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc177
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.h58
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass_test.cc284
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.cc22
-rw-r--r--tensorflow/compiler/jit/xla_cluster_util.h11
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc3
-rw-r--r--tensorflow/compiler/jit/xla_device.cc41
-rw-r--r--tensorflow/compiler/jit/xla_device.h14
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc89
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h31
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc62
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h6
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc7
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h6
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc20
-rw-r--r--tensorflow/compiler/xla/service/BUILD14
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc1
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc9
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc41
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc38
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h10
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc14
-rw-r--r--tensorflow/compiler/xla/service/despecializer.cc25
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD35
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.cc205
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator.h71
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc126
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc73
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc28
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_creation_utils.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.cc55
-rw-r--r--tensorflow/compiler/xla/service/hlo_lexer.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc75
-rw-r--r--tensorflow/compiler/xla/service/hlo_token.h1
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executor.h2
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.cc350
-rw-r--r--tensorflow/compiler/xla/service/scatter_expander.h34
-rw-r--r--tensorflow/compiler/xla/shape_util.cc11
-rw-r--r--tensorflow/compiler/xla/tests/BUILD17
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h5
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc615
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc71
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc55
-rw-r--r--tensorflow/compiler/xla/xla.proto19
68 files changed, 2977 insertions, 400 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 55b98da472..e059f77563 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -314,12 +314,16 @@ cc_library(
"deadness_analysis_internal.h",
"encapsulate_subgraphs_pass.cc",
"mark_for_compilation_pass.cc",
+ "mark_for_compilation_pass_test_helper.cc",
+ "partially_decluster_pass.cc",
],
hdrs = [
"build_xla_launch_ops_pass.h",
"deadness_analysis.h",
"encapsulate_subgraphs_pass.h",
"mark_for_compilation_pass.h",
+ "mark_for_compilation_pass_test_helper.h",
+ "partially_decluster_pass.h",
],
deps = [
":common",
@@ -354,6 +358,7 @@ cc_library(
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
],
@@ -418,10 +423,12 @@ tf_cc_test(
srcs = [
"encapsulate_subgraphs_pass_test.cc",
"mark_for_compilation_pass_test.cc",
+ "partially_decluster_pass_test.cc",
],
deps = [
":common",
":compilation_passes",
+ ":xla_cluster_util",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
index 4d49a14b24..c37b6112cc 100644
--- a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
namespace tensorflow {
@@ -23,15 +24,18 @@ namespace tensorflow {
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
MarkForCompilationPass);
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
+ PartiallyDeclusterPass);
+
// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
// also need to run it after the graph been rewritten to have _Send nodes added
// for fetches. Before the _Send nodes are added, fetch nodes are identified by
// name, and encapsulation might remove that node from the graph.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
EncapsulateSubgraphsPass);
// Must run after EncapsulateSubgraphsPass.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 40,
BuildXlaLaunchOpsPass);
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 37a2f3b5ac..7f4370b5b0 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -210,7 +210,8 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
- launch_context.PopulateOutputs(ctx, kernel, run_result.ConsumeValueOrDie());
+ OP_REQUIRES_OK(ctx, launch_context.PopulateOutputs(
+ ctx, kernel, run_result.ConsumeValueOrDie()));
VLOG(1) << "Done";
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 45d422943c..90d5d56998 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -65,6 +65,7 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) {
// XLA cluster so it can't implement the forward-tensor-ref semantic. Leave
// such nodes out of XLA clusters.
if (HasForwardedRefInput(node)) {
+ VLOG(2) << "Rejecting " << node.name() << ": Identity with unsafe cast.";
return false;
}
@@ -84,14 +85,13 @@ bool IsCompilableCall(const NodeDef& call_def,
bool IsCompilableWhile(const Node& while_node,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Loop marking: " << while_node.type_string();
-
const NameAttrList* name_attr;
NodeDef call;
Status status;
status = GetNodeAttr(while_node.attrs(), "cond", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'cond' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'cond' attribute on While node.";
return false;
}
const string cond_func = name_attr->name();
@@ -99,12 +99,14 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(cond_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop condition: " << cond_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop condition: " << cond_func;
return false;
}
status = GetNodeAttr(while_node.attrs(), "body", &name_attr);
if (!status.ok()) {
- VLOG(2) << "Missing 'body' attribute on While node.";
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": missing 'body' attribute on While node.";
return false;
}
const string body_func = name_attr->name();
@@ -112,10 +114,10 @@ bool IsCompilableWhile(const Node& while_node,
call.set_op(body_func);
*call.mutable_attr() = name_attr->attr();
if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
- VLOG(2) << "Can't compile loop body: " << body_func;
+ VLOG(2) << "Rejecting While " << while_node.name()
+ << ": can't compile loop body: " << body_func;
return false;
}
- VLOG(2) << "Loop is compilable.";
return true;
}
@@ -125,10 +127,9 @@ bool IsCompilableWhile(const Node& while_node,
bool IsCompilableCall(const NodeDef& call_def,
const DeviceType& jit_device_type, int depth,
FunctionLibraryRuntime* lib_runtime) {
- VLOG(2) << "Function marking: " << call_def.op();
-
if (depth > kMaxRecursionDepth) {
- VLOG(2) << "Function depth limit exceeded";
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": function depth limit exceeded.";
return false;
}
@@ -136,7 +137,8 @@ bool IsCompilableCall(const NodeDef& call_def,
Status status =
lib_runtime->Instantiate(call_def.op(), AttrSlice(call_def), &handle);
if (!status.ok()) {
- VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": could not instantiate: " << status;
return false;
}
const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
@@ -150,7 +152,8 @@ bool IsCompilableCall(const NodeDef& call_def,
// tf2xla to translate the TF graph into XLA. So we avoid this for now.
//
// TODO(b/36139787): Create a mechanism to set inlining hints.
- VLOG(2) << "Can't compile noinline function: " << fdef.DebugString();
+ VLOG(2) << "Rejecting " << call_def.op()
+ << ": can't compile noinline function.";
return false;
}
@@ -164,23 +167,14 @@ bool IsCompilableCall(const NodeDef& call_def,
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, depth + 1,
lib_runtime)) {
- VLOG(2) << "Function marking failed: unsupported op " << node->name()
- << ": " << node->def().ShortDebugString();
+ VLOG(2) << "Rejecting " << call_def.op() << ": unsupported op "
+ << node->name() << ": " << node->def().ShortDebugString();
return false;
}
}
- VLOG(2) << "Function is compilable: " << call_def.op();
return true;
}
-// Tests whether `node` has a DT_RESOURCE typed input or output.
-bool HasResourceInputOrOutput(const Node& node) {
- return std::find(node.input_types().begin(), node.input_types().end(),
- DT_RESOURCE) != node.input_types().end() ||
- std::find(node.output_types().begin(), node.output_types().end(),
- DT_RESOURCE) != node.output_types().end();
-}
-
// Returns true if the op can be decomposed into XLA ops for which
// there are fusable elemental implementations.
//
@@ -357,24 +351,27 @@ Status FindCompilationCandidates(
}
std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID());
+ if (fuel >= std::numeric_limits<int64>::max() / 2) {
+ // The assumption is that if fuel started out as INT64_MAX, it will forever
+ // stay greater than INT64_MAX / 2.
+ VLOG(2) << "Starting fuel: infinity";
+ } else {
+ VLOG(2) << "Starting fuel: " << fuel;
+ }
+
for (Node* node : sorted_nodes) {
- VLOG(2) << "Fuel: " << fuel;
if (fuel <= 0) {
- VLOG(2)
+ VLOG(1)
<< "Hit fuel limit; not marking any remaining ops as clusterable.";
break;
}
- VLOG(2) << "FindCompilationCandidates(): Processing "
- << node->DebugString();
-
DeviceType device_type("");
TF_RETURN_IF_ERROR(
DeviceToDeviceType(node->assigned_device_name(), &device_type));
if (is_compilable_fn && !is_compilable_fn(node, device_type)) {
- VLOG(2) << "Compilation rejected node: not compilable " << node->name()
- << ": " << node->type_string();
+ // is_compilable_fn has already logged the reason if it returned false.
continue;
}
@@ -384,14 +381,14 @@ Status FindCompilationCandidates(
DeviceType jit_device_type(registration->compilation_device_name);
if (!HasXLAKernel(*node, jit_device_type) &&
!IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime)) {
- VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
- << ": " << node->type_string();
+ VLOG(2) << "Rejecting " << node->name() << ": unsupported op "
+ << node->type_string();
continue;
}
if (!registration->compile_resource_ops &&
HasResourceInputOrOutput(*node)) {
- VLOG(2) << "Compilation rejected node: resource input/output "
- << node->name() << ": " << node->type_string();
+ VLOG(2) << "Rejecting: " << node->name() << ": resource input/output "
+ << node->type_string();
continue;
}
if (node->type_string() == "While" &&
@@ -401,15 +398,11 @@ Status FindCompilationCandidates(
// _Arg nodes in a top-level function represent feeds.
// Do not compile them.
if (node->type_string() == "_Arg") {
- VLOG(2) << "Skipping jit compilation for '_Arg'-typed node "
- << node->DebugString();
continue;
}
// _Retval nodes in a top-level function represent fetches.
// Do not compile them.
if (node->type_string() == "_Retval") {
- VLOG(2) << "Compilation rejected node: return value " << node->name()
- << ": " << node->type_string();
continue;
}
candidates->insert(node);
@@ -475,6 +468,7 @@ Status MarkForCompilationPass::Run(
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(device_type.type(),
&registration)) {
+ VLOG(2) << "Rejecting " << node->name() << ": could not find JIT device.";
return false;
}
@@ -484,21 +478,36 @@ Status MarkForCompilationPass::Run(
// If there is a _XlaCompile annotation, use its value.
bool compile = false;
Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") is false.";
+ }
+ return compile;
+ }
status = fld->GetAttr(*node, kXlaCompileAttr, &compile);
- if (status.ok()) return compile;
+ if (status.ok()) {
+ if (!compile) {
+ VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr("
+ << kXlaCompileAttr << ") on callee is false.";
+ }
+ return compile;
+ }
// If inputs to `node` can have conflicting deadness (i.e. some are alive
// and some are dead) then don't compile it. XLA cannot represent the
// deadness semantics of these nodes correctly and auto-clustering these
// nodes can cause deadness to propagate to nodes that should be live.
if (node->IsMerge() || deadness->HasInputsWithMismatchingDeadness(*node)) {
+ VLOG(2) << "Rejecting " << node->name() << ": mismatching deadness.";
return false;
}
// Check for fusable ops only if requested.
if (global_jit_level > 0 && fusion_only && !IsXlaFusable(node->def())) {
+ VLOG(2) << "Rejecting " << node->name()
+ << ": not fusable op but fusion_only enabled.";
return false;
}
@@ -506,8 +515,17 @@ Status MarkForCompilationPass::Run(
// Ignore enable_jit_by_default if global jit compilation for CPU
// is explicitly requested via tf_xla_cpu_global_jit flag
bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU;
- return (ignore_registration || registration->enable_jit_by_default) &&
- global_jit_level > 0;
+ bool should_compile =
+ (ignore_registration || registration->enable_jit_by_default) &&
+ global_jit_level > 0;
+ if (!should_compile) {
+ if (global_jit_level <= 0) {
+ VLOG(2) << "Rejecting " << node->name() << ": global jit disabled.";
+ } else {
+ VLOG(2) << "Rejecting " << node->name() << ": JIT for device disabled.";
+ }
+ }
+ return should_compile;
};
return RunImpl(options, is_compilable);
}
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h
index e9acbfb19e..f1137af3c1 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.h
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h
@@ -40,20 +40,18 @@ class MarkForCompilationPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
- // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
- // unconditionally, call RunImpl() directly.
- // is_compilable_fn, if set, is a predicate that must be true for a node to
- // be compiled.
+ private:
Status RunImpl(const GraphOptimizationPassOptions& options,
const std::function<bool(const Node*, const DeviceType&)>&
is_compilable_fn = {});
+
+ friend class MarkForCompilationPassTestHelper;
};
// Returns true iff 'ndef' is a call to a function that is compilable. A
// function is compilable iff every operator in the function body is
// compilable.
bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 2c5f4fb774..a780d4a936 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@@ -39,27 +39,6 @@ namespace {
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
-Status MarkForCompilation(std::unique_ptr<Graph>* graph,
- FunctionLibraryDefinition* flib_def) {
- // Assign all nodes to the CPU device.
- static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
- for (Node* n : (*graph)->nodes()) {
- n->set_assigned_device_name(kCpuDevice);
- }
-
- GraphOptimizationPassOptions opt_options;
- opt_options.graph = graph;
- opt_options.flib_def = flib_def;
- MarkForCompilationPass pass;
- return pass.RunImpl(opt_options);
-}
-
-Status MarkForCompilation(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
- return MarkForCompilation(graph, &flib_def);
-}
-
std::unordered_map<string, string> GetClusters(const Graph& graph) {
std::unordered_map<string, string> ids;
for (Node* node : graph.nodes()) {
@@ -88,7 +67,7 @@ TEST(XlaCompilationTest, Chains) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(4, clusters.size());
EXPECT_EQ(clusters["B"], clusters["C"]);
@@ -113,7 +92,7 @@ TEST(XlaCompilationTest, UncompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -133,7 +112,7 @@ TEST(XlaCompilationTest, CompilableCycles) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size());
@@ -156,7 +135,7 @@ TEST(XlaCompilationTest, Complex128Unsupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
}
@@ -177,7 +156,7 @@ TEST(XlaCompilationTest, HalfSupported) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_FALSE(clusters.empty());
}
@@ -206,7 +185,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
}
@@ -241,7 +220,8 @@ TEST(XlaCompilationTest, FunctionCalls) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def));
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, &flib_def));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -272,7 +252,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
ops::UnaryOp("Shape", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -359,7 +339,7 @@ TEST(XlaCompilationTest, SymbolicGradients) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -384,7 +364,7 @@ TEST(XlaCompilationTest, Loops) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// Nothing should be compiled. In particular, 'd' and 'c' must not be
@@ -411,7 +391,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A + relu(A)
@@ -442,7 +422,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: D = relu(A) + (A @ relu(A))
@@ -472,7 +452,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) {
TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
// The computation is: C = A @ relu(A)
@@ -512,7 +492,7 @@ TEST(XlaCompilationTest, Resources) {
ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
}
@@ -542,7 +522,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
TF_EXPECT_OK(root.ToGraph(graph.get()));
- Status status = MarkForCompilation(&graph);
+ Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(str_util::StrContains(status.ToString(),
"Edge from c to a would create a cycle.\n"
@@ -570,7 +550,7 @@ TEST(XlaCompilationTest, Retval) {
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_EQ(2, clusters.size());
@@ -588,7 +568,7 @@ TEST(XlaCompilationTest, DontCountIdentityOps) {
auto r = ops::_Retval(root.WithOpName("R"), c, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -604,7 +584,7 @@ TEST(XlaCompilationTest, DontCountIdentityOpsWithLocalJit) {
auto r = ops::_Retval(root.WithOpName("R"), b, 0);
}
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.empty());
@@ -618,7 +598,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), 0.5f);
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_EQ(1, GetClusters(*graph).size());
}
@@ -629,7 +609,7 @@ TEST(XlaCompilationTest, ConstOp) {
auto c = ops::Const(root.WithOpName("const"), string("string"));
c.node()->AddAttr(kXlaCompileAttr, true);
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
EXPECT_TRUE(GetClusters(*graph).empty());
}
}
@@ -644,7 +624,7 @@ TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -667,7 +647,7 @@ TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
@@ -699,7 +679,7 @@ TEST(XlaCompilationTest, ClusterControlTrigger) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(root.ToGraph(graph.get()));
- TF_ASSERT_OK(MarkForCompilation(&graph));
+ TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::unordered_map<string, string> clusters = GetClusters(*graph);
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
new file mode 100644
index 0000000000..a84b82e479
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+
+namespace tensorflow {
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def) {
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : (*graph)->nodes()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = flib_def;
+ MarkForCompilationPass pass;
+ return pass.RunImpl(opt_options);
+}
+
+/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation(
+ std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+ return MarkForCompilation(graph, &flib_def);
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
new file mode 100644
index 0000000000..b9a0531cb0
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h
@@ -0,0 +1,35 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
+#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+
+namespace tensorflow {
+class MarkForCompilationPassTestHelper {
+ public:
+ // Runs the MarkForCompilation pass on `graph` after assigning all nodes in
+ // `graph` to the CPU device. To make testing easier, ignores device
+ // registration, _XlaCompile attributes, input deadness and global jit level.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def);
+
+ // Like `MarkForCompilation` but creates `flib_def` from the op registry.
+ static Status MarkForCompilation(std::unique_ptr<Graph>* graph);
+};
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_TEST_HELPER_H_
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
new file mode 100644
index 0000000000..68ead39424
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -0,0 +1,177 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/core/framework/memory_types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace tensorflow {
+namespace {
+Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+ gtl::ArraySlice<Node*> post_order) {
+ // Find nodes that have at least one user outside their cluster that expects
+ // hostmem output. These nodes should be cloned to outside the cluster to
+ // avoid the device-host copy we'd otherwise need.
+
+ MemoryTypeVector input_mtypes, output_mtypes;
+
+ for (Node* n : post_order) {
+ gtl::optional<StringPiece> from_cluster = GetXlaClusterForNode(*n);
+ if (!from_cluster) {
+ continue;
+ }
+
+ // We assume the only XLA-auto-clusterable operations with side effects are
+ // resource variable updates. We can't execute these twice.
+ if (HasResourceInputOrOutput(*n)) {
+ continue;
+ }
+
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(n->assigned_device_name(), &device_type));
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
+ n->def(), &input_mtypes,
+ &output_mtypes));
+ for (const Edge* e : n->out_edges()) {
+ Node* dst = e->dst();
+
+ if (e->IsControlEdge()) {
+ continue;
+ }
+
+ bool edge_incurs_extra_device_to_host_copy;
+ if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
+ // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
+ // keep the node clustered -- XLA will also produce the output in device
+ // memory and we will get some benefit from clustering.
+ edge_incurs_extra_device_to_host_copy = false;
+ } else {
+ MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
+ DeviceType dst_device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceToDeviceType(dst->assigned_device_name(), &dst_device_type));
+ TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
+ dst->def(), &dst_input_mtypes,
+ &dst_output_mtypes));
+ edge_incurs_extra_device_to_host_copy =
+ dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
+ }
+
+ if (!edge_incurs_extra_device_to_host_copy) {
+ continue;
+ }
+
+ // Check if `dst` is in a different cluster, unclustered, or about to be
+ // partially declustered (here we rely on the post-order traversal order).
+ // If yes, decluster `n` to avoid the device-to-host memcpy.
+ gtl::optional<StringPiece> dst_cluster =
+ result->count(dst) ? gtl::nullopt : GetXlaClusterForNode(*dst);
+ if (from_cluster != dst_cluster) {
+ CHECK(result->insert(n).second);
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status PartiallyDeclusterNode(Graph* graph, Node* n) {
+ StringPiece cluster_name = *GetXlaClusterForNode(*n);
+ gtl::InlinedVector<const Edge*, 6> out_edges_to_clone;
+ for (const Edge* out_edge : n->out_edges()) {
+ if (out_edge->IsControlEdge()) {
+ continue;
+ }
+
+ Node* dst = out_edge->dst();
+ gtl::optional<StringPiece> dst_cluster_name = GetXlaClusterForNode(*dst);
+ if (dst_cluster_name != cluster_name) {
+ out_edges_to_clone.push_back(out_edge);
+ }
+ }
+
+ CHECK(!out_edges_to_clone.empty()) << n->DebugString();
+
+ NodeDef ndef = n->def();
+ ndef.set_name(strings::StrCat(n->name(), "/declustered"));
+ RemoveFromXlaCluster(&ndef);
+ Status s;
+ Node* cloned_node = graph->AddNode(ndef, &s);
+ cloned_node->set_assigned_device_name(n->assigned_device_name());
+ TF_RETURN_IF_ERROR(s);
+
+ for (const Edge* in_edge : n->in_edges()) {
+ graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
+ in_edge->dst_input());
+ }
+
+ for (const Edge* out_edge_to_clone : out_edges_to_clone) {
+ graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
+ out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
+ graph->RemoveEdge(out_edge_to_clone);
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status PartiallyDeclusterPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // NB! In this pass we assume the only XLA-auto-clusterable operations that
+ // may have side effects are resource variable operations so we don't cluster
+ // those. The pass will have to be updated if this assumption becomes
+ // invalid.
+
+ Graph* graph = options.graph->get();
+
+ // When deciding whether to decluster a particular node, we base our decision
+ // on if we've decided that some of its consumers have to be declustered too.
+ // Iterating the graph in post-order guarantees that consumers have been
+ // visited before producers.
+ std::vector<Node*> post_order;
+ GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
+ /*edge_filter=*/[](const Edge& edge) {
+ return !edge.src()->IsNextIteration();
+ });
+
+ gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ TF_RETURN_IF_ERROR(FindNodesToDecluster(
+ **options.graph, &nodes_to_partially_decluster, post_order));
+
+ if (VLOG_IS_ON(3)) {
+ for (Node* n : post_order) {
+ if (nodes_to_partially_decluster.count(n)) {
+ VLOG(3) << n->DebugString();
+ }
+ }
+ }
+
+ for (Node* n : post_order) {
+ if (nodes_to_partially_decluster.count(n)) {
+ TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
+ }
+ }
+
+ nodes_to_partially_decluster.clear();
+ TF_RETURN_IF_ERROR(FindNodesToDecluster(
+ **options.graph, &nodes_to_partially_decluster, post_order));
+ CHECK(nodes_to_partially_decluster.empty());
+
+ return Status::OK();
+}
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h
new file mode 100644
index 0000000000..6949b5028e
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+namespace tensorflow {
+
+// Clones nodes from within a cluster to outside the cluster if profitable.
+//
+// Today this only clones to avoid device-to-host copies, but in the future we
+// may consider other reasons to clone. For instance, we convert this:
+//
+// .....
+// |
+// v
+// A_Clustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// to:
+//
+// .....
+// | |
+// | +-------------+
+// | |
+// v v
+// A_Clustered A_Unclustered ====> C_Unclustered
+// |
+// v
+// B_Clustered
+//
+// where the ===> arrow has a hostmem source and destination and would entail a
+// device to host copy if the source and destination were not in the same XLA
+// cluster.
+class PartiallyDeclusterPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_PARTIALLY_DECLUSTER_PASS_H_
diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
new file mode 100644
index 0000000000..08a956e4c6
--- /dev/null
+++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc
@@ -0,0 +1,284 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/partially_decluster_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/sendrecv_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+REGISTER_OP("FakeNullary").Output("out: float");
+
+REGISTER_OP("FakeBinary")
+ .Input("host_in: float")
+ .Input("device_in: float")
+ .Output("host_out: float")
+ .Output("device_out: float");
+
+REGISTER_OP("FakeResourceVar").Output("out: resource");
+
+REGISTER_OP("FakeResourceUpdate")
+ .Input("in: resource")
+ .Output("out: resource")
+ .Output("something_else: float");
+
+class FakeBinaryOp : public OpKernel {
+ public:
+ explicit FakeBinaryOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override { CHECK(false); }
+};
+
+class FakeResourceVarUpdateOp : public OpKernel {
+ public:
+ explicit FakeResourceVarUpdateOp(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override { CHECK(false); }
+};
+
+REGISTER_KERNEL_BUILDER(Name("FakeBinary")
+ .Device(DEVICE_CPU)
+ .HostMemory("host_in")
+ .HostMemory("host_out"),
+ FakeBinaryOp);
+
+REGISTER_KERNEL_BUILDER(Name("FakeResourceVarUpdate")
+ .Device(DEVICE_CPU)
+ .HostMemory("something_else"),
+ FakeResourceVarUpdateOp);
+
+Status PartiallyDecluster(std::unique_ptr<Graph>* graph) {
+ FixupSourceAndSinkEdges(graph->get());
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : (*graph)->nodes()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ PartiallyDeclusterPass pass;
+ return pass.Run(opt_options);
+}
+
+const Node* FindNodeByName(const Graph& graph, const string& name) {
+ for (const Node* node : graph.nodes()) {
+ if (node->name() == name) {
+ return node;
+ }
+ }
+ return nullptr;
+}
+
+bool GetInputsForNode(const Graph& graph, const string& node_name,
+ std::vector<Node*>* inputs) {
+ const Node* node = FindNodeByName(graph, node_name);
+ if (node == nullptr) {
+ return false;
+ }
+ for (const Edge* e : node->in_edges()) {
+ inputs->push_back(e->src());
+ }
+ std::sort(inputs->begin(), inputs->end(), NodeComparatorName());
+ return true;
+}
+
+TEST(PartiallyDeclusterPassTest, ClusteredAndUnclustered) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ ops::BinaryOp("FakeBinary", clustered_producer, input,
+ builder.opts().WithName("UnclusteredConsumer"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> unclustered_consumer_inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
+ &unclustered_consumer_inputs));
+ ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
+ "ClusteredProducer/declustered");
+ EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
+
+ std::vector<Node*> clustered_consumer_inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredConsumer",
+ &clustered_consumer_inputs));
+ ASSERT_EQ(clustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(clustered_consumer_inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(clustered_consumer_inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DifferentClusters) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", clustered_producer, input,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer/declustered");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDeclusterIfUserIsDeviceMem) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer"));
+ // The first input is hostmem and the second input is devicemem.
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", input, clustered_producer,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DontDuplicateResourceVarOps) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* resource_var = ops::SourceOp("FakeResourceVar",
+ builder.opts().WithName("ResourceVar"));
+ Node* clustered_producer =
+ ops::UnaryOp("FakeResourceUpdate", resource_var,
+ builder.opts().WithName("ClusteredProducer"));
+ Node* consumer_in_different_cluster =
+ ops::BinaryOp("FakeBinary", {clustered_producer, 1}, input,
+ builder.opts().WithName("ConsumerInDifferentCluster"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", input, {clustered_producer, 1},
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ consumer_in_different_cluster->AddAttr(kXlaClusterAttr, "cluster_1");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> inputs;
+ ASSERT_TRUE(GetInputsForNode(*graph, "ConsumerInDifferentCluster", &inputs));
+ ASSERT_EQ(inputs.size(), 2);
+ EXPECT_EQ(inputs[0]->name(), "ClusteredProducer");
+ EXPECT_EQ(inputs[1]->name(), "Input");
+}
+
+TEST(PartiallyDeclusterPassTest, DeclusterDependentNodes) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* input =
+ ops::SourceOp("FakeNullary", builder.opts().WithName("Input"));
+ Node* clustered_producer_0 =
+ ops::BinaryOp("FakeBinary", input, input,
+ builder.opts().WithName("ClusteredProducer0"));
+ Node* clustered_producer_1 =
+ ops::BinaryOp("FakeBinary", clustered_producer_0, input,
+ builder.opts().WithName("ClusteredProducer1"));
+ ops::BinaryOp("FakeBinary", clustered_producer_1, input,
+ builder.opts().WithName("UnclusteredConsumer"));
+ Node* clustered_consumer =
+ ops::BinaryOp("FakeBinary", {clustered_producer_1, 1}, input,
+ builder.opts().WithName("ClusteredConsumer"));
+ clustered_producer_0->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_producer_1->AddAttr(kXlaClusterAttr, "cluster_0");
+ clustered_consumer->AddAttr(kXlaClusterAttr, "cluster_0");
+ TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
+ }
+
+ TF_ASSERT_OK(PartiallyDecluster(&graph));
+ std::vector<Node*> unclustered_consumer_inputs, declustered_producer_1_inputs;
+
+ ASSERT_TRUE(GetInputsForNode(*graph, "UnclusteredConsumer",
+ &unclustered_consumer_inputs));
+ ASSERT_EQ(unclustered_consumer_inputs.size(), 2);
+ EXPECT_EQ(unclustered_consumer_inputs[0]->name(),
+ "ClusteredProducer1/declustered");
+ EXPECT_EQ(unclustered_consumer_inputs[1]->name(), "Input");
+
+ ASSERT_TRUE(GetInputsForNode(*graph, "ClusteredProducer1/declustered",
+ &declustered_producer_1_inputs));
+ ASSERT_EQ(declustered_producer_1_inputs.size(), 2);
+ EXPECT_EQ(declustered_producer_1_inputs[0]->name(),
+ "ClusteredProducer0/declustered");
+ EXPECT_EQ(declustered_producer_1_inputs[1]->name(), "Input");
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc
index a5628b12a2..0a025a1fc0 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.cc
+++ b/tensorflow/compiler/jit/xla_cluster_util.cc
@@ -185,4 +185,26 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) {
return Status::OK();
}
+gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node) {
+ const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr);
+ if (attr_value == nullptr) {
+ return gtl::nullopt;
+ }
+ Status s = AttrValueHasType(*attr_value, "string");
+ if (!s.ok()) {
+ return gtl::nullopt;
+ }
+ return attr_value->s();
+}
+
+bool HasResourceInputOrOutput(const Node& node) {
+ return std::find(node.input_types().begin(), node.input_types().end(),
+ DT_RESOURCE) != node.input_types().end() ||
+ std::find(node.output_types().begin(), node.output_types().end(),
+ DT_RESOURCE) != node.output_types().end();
+}
+
+void RemoveFromXlaCluster(NodeDef* node_def) {
+ node_def->mutable_attr()->erase(kXlaClusterAttr);
+}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h
index bcce082aaf..bff76da6f9 100644
--- a/tensorflow/compiler/jit/xla_cluster_util.h
+++ b/tensorflow/compiler/jit/xla_cluster_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/gtl/optional.h"
namespace tensorflow {
@@ -44,6 +45,16 @@ bool HasForwardedRefInput(const Node& node);
// the enclosing graph.
Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles);
+// Returns the XLA cluster in which `node` is placed if it is in an XLA cluster,
+// otherwise returns nullopt.
+gtl::optional<StringPiece> GetXlaClusterForNode(const Node& node);
+
+// Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute).
+void RemoveFromXlaCluster(NodeDef* node_def);
+
+// Returns true if `node` has a DT_RESOURCE typed input or output.
+bool HasResourceInputOrOutput(const Node& node);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index f65f89ebf5..dd84fb34c1 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -78,7 +78,8 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
executable->Run(launch_context.arguments(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
- launch_context.PopulateOutputs(ctx, result, run_result.ConsumeValueOrDie());
+ TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
+ ctx, result, run_result.ConsumeValueOrDie()));
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 4ddeaebd3e..2a2691a6a4 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -216,6 +217,8 @@ XlaDevice::XlaDevice(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name << " " << this;
+ thread_pool_.reset(new thread::ThreadPool(options.env, "xla_device",
+ /*num_threads=*/1));
}
XlaDevice::~XlaDevice() {
@@ -262,10 +265,12 @@ Status XlaDevice::EnsureDeviceContextOk() {
Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend,
const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed) {
if (!(*stream) || !(*stream)->ok()) {
- TF_ASSIGN_OR_RETURN(*stream, backend->BorrowStream(device_ordinal_));
+ xla::StreamPool::Ptr ptr;
+ TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_));
+ *stream = std::shared_ptr<se::Stream>(std::move(ptr));
VLOG(1) << "XlaDevice " << this << " new " << name << " "
<< (*stream)->DebugStreamPointers();
*stream_was_changed = true;
@@ -281,8 +286,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "stream", &stream_,
&need_new_device_context));
- se::Stream* host_to_device_stream = stream_.get();
- se::Stream* device_to_host_stream = stream_.get();
+ std::shared_ptr<se::Stream> host_to_device_stream = stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream = stream_;
if (use_multiple_streams_) {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "host_to_device_stream",
&host_to_device_stream_,
@@ -290,8 +295,8 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
TF_RETURN_IF_ERROR(EnsureStreamOkLocked(backend, "device_to_host_stream",
&device_to_host_stream_,
&need_new_device_context));
- host_to_device_stream = host_to_device_stream_.get();
- device_to_host_stream = device_to_host_stream_.get();
+ host_to_device_stream = host_to_device_stream_;
+ device_to_host_stream = device_to_host_stream_;
}
if (!need_new_device_context) {
@@ -304,9 +309,13 @@ xla::StatusOr<XlaDeviceContext*> XlaDevice::GetDeviceContextLocked() {
if (device_context_) {
device_context_->Unref();
}
+ // The XlaDeviceContext keeps a reference count to the streams, and the
+ // XlaDeviceContext remains live for the duration of a Executor run. This
+ // ensures that the streams remain live for the duration of a run, even if
+ // an error is encountered and the streams are replaced with new ones.
device_context_ = new XlaDeviceContext(
- stream_.get(), host_to_device_stream, device_to_host_stream, client(),
- transfer_as_literal_, shape_representation_fn_);
+ stream_, host_to_device_stream, device_to_host_stream, client(),
+ transfer_as_literal_, shape_representation_fn_, thread_pool_.get());
VLOG(1) << "XlaDevice " << this << " new XlaDeviceContext "
<< device_context_;
@@ -371,6 +380,22 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
op_kernel->ComputeAsync(context, done);
}
+Status XlaDevice::Sync() {
+ VLOG(1) << "XlaDevice::Sync";
+ std::shared_ptr<se::Stream> stream;
+ {
+ mutex_lock lock(mu_);
+ stream = stream_;
+ }
+ if (!stream) return Status::OK();
+
+ if (!stream->parent()->SynchronizeAllActivity() || !stream->ok()) {
+ return errors::Internal("XlaDevice::Sync() failed.");
+ }
+ VLOG(1) << "XlaDevice::Sync completed";
+ return Status::OK();
+}
+
Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) {
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d8906419b0..dbf35f349f 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
-#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/allocator.h"
@@ -124,7 +123,7 @@ class XlaDevice : public LocalDevice {
void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
- Status Sync() override { return Status::OK(); }
+ Status Sync() override;
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override
@@ -153,7 +152,7 @@ class XlaDevice : public LocalDevice {
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status EnsureStreamOkLocked(xla::Backend* backend, const string& name,
- xla::StreamPool::Ptr* stream,
+ std::shared_ptr<se::Stream>* stream,
bool* stream_was_changed)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
xla::StatusOr<XlaDeviceContext*> GetDeviceContextLocked()
@@ -174,17 +173,17 @@ class XlaDevice : public LocalDevice {
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
// computations enqueued by XLA.
- xla::StreamPool::Ptr stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> stream_ GUARDED_BY(mu_);
// If false, only stream_ is valid and all computation and transfers use
// stream_. If true, computation is performed by stream_ and transfers are
// performed by host_to_device/device_to_host_stream.
const bool use_multiple_streams_;
// If use_multiple_streams_, host to device transfers are performed using this
// stream.
- xla::StreamPool::Ptr host_to_device_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> host_to_device_stream_ GUARDED_BY(mu_);
// If use_multiple_streams_, device to host transfers are performed using this
// stream.
- xla::StreamPool::Ptr device_to_host_stream_ GUARDED_BY(mu_);
+ std::shared_ptr<se::Stream> device_to_host_stream_ GUARDED_BY(mu_);
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
const bool transfer_as_literal_;
@@ -198,6 +197,9 @@ class XlaDevice : public LocalDevice {
// Holds extra information for GPU and TPU devices, e.g. the device context.
bool use_gpu_device_info_ GUARDED_BY(mu_) = false;
std::unique_ptr<GpuDeviceInfo> gpu_device_info_ GUARDED_BY(mu_);
+
+ // Thread pool used for running closures
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 0100bf51ed..0a0c089241 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
+#include <memory>
+
+#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -48,17 +51,20 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : stream_(compute_stream),
- host_to_device_stream_(host_to_device_stream),
- device_to_host_stream_(device_to_host_stream),
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : stream_(std::move(compute_stream)),
+ host_to_device_stream_(std::move(host_to_device_stream)),
+ device_to_host_stream_(std::move(device_to_host_stream)),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ thread_pool_(thread_pool) {
CHECK(host_to_device_stream_ != nullptr);
CHECK(device_to_host_stream_ != nullptr);
CHECK(stream_ != nullptr);
@@ -88,15 +94,15 @@ Status XlaTransferManager::TransferLiteralToDevice(
if (UseMultipleStreams()) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- host_to_device_stream_->ThenWaitFor(stream_);
+ host_to_device_stream_->ThenWaitFor(stream_.get());
}
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_, *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- TF_RET_CHECK(event.Init()) << "Event failed to initialize!";
- host_to_device_stream_->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(host_to_device_stream_, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
+ host_to_device_stream_->ThenRecordEvent(event.get());
+ xla_tensor->SetDefinedOn(host_to_device_stream_.get(), std::move(event));
}
// Unref the host tensor, and capture the literal shared_ptr too so it goes
// out of scope when the lambda completes.
@@ -116,7 +122,7 @@ void XlaTransferManager::TransferLiteralFromDevice(
TensorReference ref(device_tensor);
transfer_manager_->TransferLiteralFromDevice(
- device_to_host_stream_, shaped_buffer, literal,
+ device_to_host_stream_.get(), shaped_buffer, literal,
[=, &shaped_buffer, &literal](xla::Status status) {
ref.Unref();
done([&]() -> Status {
@@ -179,8 +185,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
if (status.ok()) {
xla_tensor->set_host_tensor(*cpu_tensor);
- host_to_device_stream_->ThenDoHostCallback(
- [done]() { done(Status::OK()); });
+ host_to_device_stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback
+ // to avoid a deadlock. If done() is the callback that ends an
+ // Executor's run, the Executor may call XlaDevice::Sync() inside the
+ // callback. This deadlocks, because XlaDevice::Sync() waits for all
+ // stream activity to complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
return;
}
} else {
@@ -192,7 +204,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
if (!block_status.ok()) {
status = xla::InternalError(
"Failed to complete data transfer on stream %p: %s",
- host_to_device_stream_, block_status.error_message().c_str());
+ host_to_device_stream_.get(), block_status.error_message().c_str());
}
}
xla_tensor->set_host_tensor(*cpu_tensor);
@@ -225,9 +237,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
if (se::Event* event =
- xla_tensor->GetDefinitionEvent(device_to_host_stream_)) {
+ xla_tensor->GetDefinitionEvent(device_to_host_stream_.get())) {
device_to_host_stream_->ThenWaitFor(event);
- xla_tensor->SetDefinedOn(device_to_host_stream_);
+ xla_tensor->SetDefinedOn(device_to_host_stream_.get());
}
Status status;
@@ -240,7 +252,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
Status block_status = device_to_host_stream_->BlockHostUntilDone();
if (!block_status.ok()) {
status = xla::InternalError(
- "Failed to complete data transfer on stream %p: %s", stream_,
+ "Failed to complete data transfer on stream %p: %s", stream_.get(),
block_status.error_message().c_str());
}
}
@@ -278,14 +290,14 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
if (stream_ != device_to_device_stream) {
// Initially wait for the compute stream so that memory allocations are
// synchronized.
- device_to_device_stream->ThenWaitFor(stream_);
+ device_to_device_stream->ThenWaitFor(stream_.get());
}
}
if (se::Event* event =
- xla_src->GetDefinitionEvent(device_to_device_stream)) {
+ xla_src->GetDefinitionEvent(device_to_device_stream.get())) {
device_to_device_stream->ThenWaitFor(event);
- xla_src->SetDefinedOn(device_to_device_stream);
+ xla_src->SetDefinedOn(device_to_device_stream.get());
}
auto from_iter = xla_src->shaped_buffer().buffers().begin();
@@ -297,28 +309,37 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
}
if (UseMultipleStreams()) {
- se::Event event(stream_->parent());
- CHECK(event.Init());
- device_to_device_stream->ThenRecordEvent(&event);
- xla_dst->SetDefinedOn(device_to_device_stream, std::move(event));
+ auto event = std::make_shared<se::Event>(stream_->parent());
+ TF_RET_CHECK(event->Init()) << "Event failed to initialize";
+ device_to_device_stream->ThenRecordEvent(event.get());
+ xla_dst->SetDefinedOn(device_to_device_stream.get(), std::move(event));
}
return Status::OK();
}();
if (!status.ok()) {
return done(status);
} else {
- stream_->ThenDoHostCallback([=]() { done(Status::OK()); });
+ stream_->ThenDoHostCallback([this, done]() {
+ // We must not call the done closure directly from DoHostCallback to avoid
+ // a deadlock. If done() is the callback that ends an Executor's run, the
+ // Executor may call XlaDevice::Sync() inside the callback. This
+ // deadlocks, because XlaDevice::Sync() waits for all stream activity to
+ // complete.
+ thread_pool_->Schedule([done]() { done(Status::OK()); });
+ });
}
}
XlaDeviceContext::XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream, xla::LocalClient* client,
bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(compute_stream, host_to_device_stream, device_to_host_stream,
- client, transfer_as_literal,
- std::move(shape_representation_fn)) {}
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool)
+ : manager_(std::move(compute_stream), std::move(host_to_device_stream),
+ std::move(device_to_host_stream), client, transfer_as_literal,
+ std::move(shape_representation_fn), thread_pool) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 912f8d779e..2e7445340c 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -47,10 +47,12 @@ class XlaDeviceAllocator : public Allocator {
class XlaTransferManager {
public:
explicit XlaTransferManager(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -61,7 +63,7 @@ class XlaTransferManager {
void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor,
const StatusCallback& done);
- se::Stream* stream() const { return stream_; }
+ se::Stream* stream() const { return stream_.get(); }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -73,13 +75,13 @@ class XlaTransferManager {
// The main compute stream of the device, used to synchronize the transfer
// streams if they are set.
- se::Stream* stream_;
+ std::shared_ptr<se::Stream> stream_;
// The stream to use for transferring data from host to device. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* host_to_device_stream_;
+ std::shared_ptr<se::Stream> host_to_device_stream_;
// The stream to use for transferring data from device to host. Can be
// idential to stream_, but must not be nullptr.
- se::Stream* device_to_host_stream_;
+ std::shared_ptr<se::Stream> device_to_host_stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -87,6 +89,9 @@ class XlaTransferManager {
// True if we must use XLA's TransferManager for correct device transfers.
const bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+
+ // Thread pool used for running closures
+ thread::ThreadPool* thread_pool_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -95,10 +100,12 @@ class XlaTransferManager {
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(
- se::Stream* compute_stream, se::Stream* host_to_device_stream,
- se::Stream* device_to_host_stream, xla::LocalClient* client,
- bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ std::shared_ptr<se::Stream> compute_stream,
+ std::shared_ptr<se::Stream> host_to_device_stream,
+ std::shared_ptr<se::Stream> device_to_host_stream,
+ xla::LocalClient* client, bool transfer_as_literal,
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ thread::ThreadPool* thread_pool);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 6134b8c694..4efbb2d5d7 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_launch_util.h"
+#include <memory>
+
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@@ -182,7 +184,7 @@ void XlaComputationLaunchContext::PopulateInputs(
}
}
-void XlaComputationLaunchContext::PopulateOutputs(
+Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
se::Stream* stream =
@@ -211,6 +213,15 @@ void XlaComputationLaunchContext::PopulateOutputs(
output = ScopedShapedBuffer(std::move(buffer), output.memory_allocator());
}
+ std::shared_ptr<se::Event> definition_event;
+ if (use_multiple_streams_) {
+ definition_event = std::make_shared<se::Event>(stream->parent());
+ if (!definition_event->Init()) {
+ return errors::Internal("Failed to initialize tensor definition event.");
+ }
+ stream->ThenRecordEvent(definition_event.get());
+ }
+
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
@@ -228,12 +239,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
- OP_REQUIRES(ctx, device != nullptr,
- errors::Internal("DeviceBase was not a Device."));
+ if (device == nullptr) {
+ return errors::Internal("DeviceBase was not a Device.");
+ }
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
@@ -263,16 +275,13 @@ void XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
+ TF_RETURN_IF_ERROR(ctx->allocate_output(i, shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor);
if (xla_tensor) {
xla_tensor->set_shaped_buffer(ScopedShapedBuffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_)));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
} else {
// xla_tensor wasn't valid, which must mean this is a zero-element
@@ -298,41 +307,39 @@ void XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
- OP_REQUIRES(ctx,
- write.input_index >= 0 && write.input_index < ctx->num_inputs(),
- errors::Internal("Invalid input index for variable write."));
+ if (write.input_index < 0 || write.input_index >= ctx->num_inputs()) {
+ return errors::Internal("Invalid input index for variable write.");
+ }
se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
- OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
- ctx, HandleFromInput(ctx, write.input_index),
- &variable, [this, ctx, &write](Var** ptr) {
- *ptr = new Var(write.type);
- return Status::OK();
- }));
+ TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, write.input_index), &variable,
+ [&write](Var** ptr) {
+ *ptr = new Var(write.type);
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
mutex_lock ml(*variable->mu());
- OP_REQUIRES(ctx, variable->tensor()->dtype() == write.type,
- errors::Internal("Mismatched type in variable write"));
+ if (variable->tensor()->dtype() != write.type) {
+ return errors::Internal("Mismatched type in variable write");
+ }
if (allocate_xla_tensors_) {
Tensor output_tensor;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(write.type, write.shape, &output_tensor));
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_temp(write.type, write.shape, &output_tensor));
XlaTensor* xla_tensor = XlaTensor::FromTensor(&output_tensor);
CHECK(xla_tensor);
xla_tensor->set_shaped_buffer(
ExtractSubShapedBuffer(&output, output_num, xla_allocator_));
if (use_multiple_streams_) {
- se::Event event(stream->parent());
- CHECK(event.Init());
- stream->ThenRecordEvent(&event);
- xla_tensor->SetDefinedOn(stream, std::move(event));
+ xla_tensor->SetDefinedOn(stream, definition_event);
}
*variable->tensor() = output_tensor;
} else {
@@ -343,6 +350,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
}
++output_num;
}
+ return Status::OK();
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 1ea3fa4cf2..4232f514b3 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -93,9 +93,9 @@ class XlaComputationLaunchContext {
const std::map<int, OptionalTensor>& variables);
// Given the XLA output in `output`, populate all outputs of `ctx`.
- void PopulateOutputs(OpKernelContext* ctx,
- const XlaCompiler::CompilationResult* kernel,
- xla::ScopedShapedBuffer output);
+ Status PopulateOutputs(OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult* kernel,
+ xla::ScopedShapedBuffer output);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index d777dfa5a3..92ba7de1b7 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -75,7 +75,7 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape,
se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
mutex_lock lock(mu_);
- if (!definition_event_.has_value()) {
+ if (!definition_event_) {
return nullptr;
}
@@ -87,10 +87,11 @@ se::Event* XlaTensor::GetDefinitionEvent(se::Stream* stream) {
return nullptr;
}
- return &*definition_event_;
+ return definition_event_.get();
}
-void XlaTensor::SetDefinedOn(se::Stream* stream, se::Event event) {
+void XlaTensor::SetDefinedOn(se::Stream* stream,
+ std::shared_ptr<se::Event> event) {
mutex_lock lock(mu_);
definition_event_ = std::move(event);
streams_defined_on_ = {stream};
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index f7e401c731..8d36d0fa0a 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
#define TENSORFLOW_COMPILER_JIT_XLA_TENSOR_H_
+#include <memory>
+
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/core/framework/allocator.h"
@@ -94,7 +96,7 @@ class XlaTensor {
// Assert that the tensor's content is defined on 'stream' by the time 'event'
// triggers.
- void SetDefinedOn(se::Stream* stream, se::Event event);
+ void SetDefinedOn(se::Stream* stream, std::shared_ptr<se::Event> event);
// Assert that the tensor's content is defined on 'stream'. This version does
// not provide an event, and must be called *after* SetDefinedOn(Stream,
@@ -116,7 +118,7 @@ class XlaTensor {
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.
- gtl::optional<se::Event> definition_event_;
+ std::shared_ptr<se::Event> definition_event_;
// A list of all streams for which the tensor's content is defined for any
// newly enqueued command.
gtl::InlinedVector<se::Stream*, 2> streams_defined_on_ GUARDED_BY(mu_);
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index f42fb92359..1bf8948ef6 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;
void SetDebugOptionsDefaults(DebugOptions* flags) {
- flags->set_xla_enable_fast_math(true);
flags->set_xla_llvm_enable_alias_scope_metadata(true);
flags->set_xla_llvm_enable_noalias_metadata(true);
flags->set_xla_llvm_enable_invariant_load_metadata(true);
@@ -53,6 +52,11 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
// the heuristics needed to decide when to run on multiple streams. See
// b/77879207.
flags->set_xla_gpu_disable_multi_streaming(true);
+
+ // TODO(jlebar): Disable fastmath once doing so is not a performance
+ // regression.
+ flags->set_xla_cpu_enable_fast_math(true);
+ flags->set_xla_gpu_enable_fast_math(true);
}
// Allocates flag_values and flag_objects; this function must not be called more
@@ -150,10 +154,16 @@ void AllocateFlags() {
flag_values->mutable_xla_generate_hlo_text_to(),
"Dump all HLO modules as text into the provided directory path."),
tensorflow::Flag(
- "xla_enable_fast_math",
- bool_setter_for(&DebugOptions::set_xla_enable_fast_math),
- flag_values->xla_enable_fast_math(),
- "Enable unsafe fast-math optimizations in the compiler; "
+ "xla_cpu_enable_fast_math",
+ bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
+ flag_values->xla_cpu_enable_fast_math(),
+ "Enable unsafe fast-math optimizations in the CPU compiler; "
+ "this may produce faster code at the expense of some accuracy."),
+ tensorflow::Flag(
+ "xla_gpu_enable_fast_math",
+ bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
+ flag_values->xla_cpu_enable_fast_math(),
+ "Enable unsafe fast-math optimizations in the GPU compiler; "
"this may produce faster code at the expense of some accuracy."),
tensorflow::Flag(
"xla_llvm_enable_alias_scope_metadata",
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7d315fa0d3..7331d2b54c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1234,6 +1234,20 @@ cc_library(
],
)
+cc_library(
+ name = "scatter_expander",
+ srcs = ["scatter_expander.cc"],
+ hdrs = ["scatter_expander.h"],
+ deps = [
+ ":hlo",
+ ":hlo_creation_utils",
+ ":hlo_pass",
+ ":while_util",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ ],
+)
+
tf_cc_test(
name = "batchnorm_expander_test",
size = "small",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 37834e1cc2..f7812d9661 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -1705,6 +1705,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
reshape, HloInstruction::CreateReshape(reshape->shape(),
operand->mutable_operand(0)));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = reshape->shape();
+ return ReplaceInstruction(reshape, operand);
+ }
if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
auto opt_dims = ReshapeLeavesDimensionsUnmodified(
@@ -2144,6 +2148,11 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
transpose->dimensions())));
}
+ if (operand->opcode() == HloOpcode::kRng && operand->user_count() == 1) {
+ *operand->mutable_shape() = transpose->shape();
+ return ReplaceInstruction(transpose, operand);
+ }
+
if (is_layout_sensitive_ && TransposeIsBitcast(transpose)) {
ReplaceWithBitcast(transpose);
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 862cbeeba6..5837391d75 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1428,6 +1428,37 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
EXPECT_THAT(computation->root_instruction(), op::Reshape(param0));
}
+// Test transforming reshapes and transposes of rng.
+TEST_F(AlgebraicSimplifierTest, ReshapeOfTransposeOfRngToRng) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* rng0 = builder.AddInstruction(
+ HloInstruction::CreateRng(ShapeUtil::MakeShape(F32, {2, 2}),
+ RandomDistribution::RNG_UNIFORM, {zero, one}));
+
+ HloInstruction* transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(rng0->shape(), rng0, {1, 0}));
+ Shape reshape_shape = builder
+ .AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {4}), transpose))
+ ->shape();
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ bitcasting_callback());
+ EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ // Verify that that reshape(transpose(rng)) is replace by a single rng of the
+ // same shape as the reshape.
+ EXPECT_THAT(computation->root_instruction(), op::Rng());
+ EXPECT_TRUE(ShapeUtil::Equal(computation->root_instruction()->shape(),
+ reshape_shape));
+}
+
// Test transforming reshapes to bitcasts under various conditions.
TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 118a11c8de..cfd26fc778 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -139,6 +139,7 @@ Status GatherComputationsByAllocationType(
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
// Map/reduce etc computations are always thread-local.
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index a23427f00c..985ff30e80 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -61,6 +61,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kFusion:
return CallContext::kParallel;
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 36fb9b43aa..3e39c1bab1 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -312,7 +312,7 @@ Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
-// We add copies for all the indices of the true and false computaiton roots,
+// We add copies for all the indices of the true and false computation roots,
// in order to resolve interference. We later rely on the CopyRemover to drop
// the unnecessary ones.
Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
@@ -648,7 +648,12 @@ class CopyRemover {
// We can only perform copy elision if the resulting merged values have
// totally ordered live ranges; otherwise the merged buffer would have
// live range interference.
- if (IsHead(*dest)) {
+ if (src->next == dest) {
+ // In the process of eliding copies, its possible for a copy to have the
+ // same source and destination buffer. In this case, the copy can be
+ // safely removed.
+ VLOG(2) << copy->name() << " source and destination buffers are same.";
+ } else if (IsHead(*dest)) {
// The copy copies an arbitrary value in the source buffer (call it s_x)
// and defines d_0, the first value in the destination buffer. After
// merging, the values in the combined buffer must be strictly ordered
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index cd735256b8..892d0d7b54 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -2007,5 +2007,46 @@ ENTRY TestComputation {
InsertCopies(module.get());
}
+TEST_F(CopyInsertionTest, NestedWhiles) {
+ // Verify that only no unnecessary copies remain after copy insertion for
+ // trivial nested whiles (b/112472605).
+ const string& hlo_string = R"(
+HloModule TestModule
+
+cond.inner {
+ ROOT param.cond.inner = pred[] parameter(0)
+}
+
+body.inner {
+ param.body.inner = pred[] parameter(0)
+ ROOT neg = pred[] negate(param.body.inner)
+}
+
+cond.outer {
+ ROOT param.cond.outer = pred[] parameter(0)
+}
+
+body.outer {
+ param.cond.outer = pred[] parameter(0)
+ ROOT while = pred[] while(param.cond.outer), condition=cond.inner, body=body.inner
+}
+
+ENTRY TestComputation {
+ entry_param = pred[] parameter(0)
+ ROOT while = pred[] while(entry_param), condition=cond.outer, body=body.outer
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<HloModule> module,
+ HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
+ InsertCopies(module.get());
+
+ // There should only be a single copy inserted, and it's in the entry
+ // computation.
+ EXPECT_EQ(CountCopies(*module), 1);
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::While(op::Copy(op::Parameter())));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 3efe3e2f93..84779c60b0 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -20,7 +20,7 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS")
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",
+ "mkl_deps",
)
# Filegroup used to collect source files for dependency checking.
@@ -86,6 +86,7 @@ cc_library(
":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
+ "//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
@@ -497,10 +498,7 @@ cc_library(
"//tensorflow/core:framework_lite",
"//tensorflow/core/kernels:eigen_helpers",
"//third_party/eigen3",
- ] + if_mkl([
- "@mkl_dnn",
- "//third_party/mkl:intel_binary_blob",
- ]),
+ ] + mkl_deps(),
)
cc_library(
@@ -554,10 +552,7 @@ cc_library(
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
- ] + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]),
+ ] + mkl_deps(),
)
cc_library(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 2df959c4dc..35154af048 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -88,6 +88,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@@ -299,6 +300,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
pipeline.AddPass<CpuInstructionFusion>();
+ pipeline.AddPass<ScatterExpander>();
+
ReducePrecisionInsertion::AddPasses(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -356,7 +359,7 @@ llvm::TargetOptions CompilerTargetOptions(
llvm::TargetOptions target_options;
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/module_config.debug_options()
- .xla_enable_fast_math(),
+ .xla_cpu_enable_fast_math(),
&target_options);
return target_options;
}
@@ -523,7 +526,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
- module->config().debug_options().xla_enable_fast_math(),
+ module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_hook, post_optimization_ir_hook);
llvm_module->setDataLayout(jit->data_layout());
@@ -653,9 +656,9 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// so we bail if the configs have conflicting flags. At the moment, the only
// flag that needs to be consistent is fast-math.
const bool fast_math_enabled =
- modules[0]->config().debug_options().xla_enable_fast_math();
+ modules[0]->config().debug_options().xla_cpu_enable_fast_math();
for (const auto& module : modules) {
- if (module->config().debug_options().xla_enable_fast_math() !=
+ if (module->config().debug_options().xla_cpu_enable_fast_math() !=
fast_math_enabled) {
return InvalidArgument(
"All HLO module configs must have the same value for "
@@ -832,7 +835,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
CompilerFunctor compiler_functor(
target_machine.get(), &disassembler, opt_level,
options::OptimizeForSizeRequested(module->config()),
- module->config().debug_options().xla_enable_fast_math(),
+ module->config().debug_options().xla_cpu_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
std::unique_ptr<llvm::MemoryBuffer> object_file =
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 946f5124b8..c376864c3e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -249,24 +249,11 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
- if (GetRootPointsToSet().IsAmbiguous()) {
- return Unimplemented("Points-to set of root instruction is ambiguous");
- }
-
- se::Stream* stream = run_options->stream();
- DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-
- std::vector<OwningDeviceMemory> owning_buffers;
- std::vector<se::DeviceMemoryBase> unowning_buffers;
TF_ASSIGN_OR_RETURN(
- std::tie(unowning_buffers, owning_buffers),
- CreateTempArray(memory_allocator, stream->parent()->device_ordinal(),
- arguments));
-
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(
- &run_options->run_options(), unowning_buffers, hlo_execution_profile));
-
- return CreateResultShapedBuffer(run_options, &owning_buffers);
+ auto result,
+ ExecuteAsyncOnStreamImpl(run_options, arguments, hlo_execution_profile));
+ TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
+ return std::move(result);
}
StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
@@ -277,6 +264,16 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
"Asynchronous execution on stream with hlo profiling is not yet "
"supported on CPU.");
}
+ return ExecuteAsyncOnStreamImpl(run_options, arguments, nullptr);
+}
+
+StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStreamImpl(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented("Points-to set of root instruction is ambiguous");
+ }
auto* host_stream = dynamic_cast<se::host::HostStream*>(
run_options->stream()->implementation());
@@ -310,19 +307,20 @@ StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
ServiceExecutableRunOptions run_options;
std::vector<se::DeviceMemoryBase> unowning_buffers;
std::shared_ptr<std::vector<OwningDeviceMemory>> buffers;
+ HloExecutionProfile* hlo_execution_profile;
void operator()() {
// Failing a CHECK here is not great, but I don't see an obvious way to
// return a failed Status asynchronously.
TF_CHECK_OK(executable->ExecuteComputeFunction(
- &run_options.run_options(), unowning_buffers,
- /*hlo_execution_profile=*/nullptr));
+ &run_options.run_options(), unowning_buffers, hlo_execution_profile));
}
};
host_stream->EnqueueTask(
AsyncRunTask{this, *run_options, std::move(unowning_buffers),
std::make_shared<std::vector<OwningDeviceMemory>>(
- std::move(owning_buffers))});
+ std::move(owning_buffers)),
+ hlo_execution_profile});
return std::move(result);
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 8af8a5dfec..96e53de57e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -85,6 +85,16 @@ class CpuExecutable : public Executable {
const BufferAssignment& buffer_assignment() const { return *assignment_; }
private:
+ // This is for sharing the code between ExecuteOnStream and
+ // ExecuteAsyncOnStream.
+ //
+ // Notice that it's tricky to use correctly, as the profile object (when it
+ // exists) must out-live the task.
+ StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStreamImpl(
+ const ServiceExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile);
+
// Creates an array suitable for passing as the "temps" argument to the JIT
// compiled function pointer.
//
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 645888de78..f2ac742b6e 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -1066,7 +1066,7 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled(
<< config.GetCacheKey();
const bool enable_fast_math =
- hlo_module_config_.debug_options().xla_enable_fast_math();
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
@@ -1149,7 +1149,7 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
const bool enable_fast_math =
- hlo_module_config_.debug_options().xla_enable_fast_math();
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math();
const bool optimize_for_size =
options::OptimizeForSizeRequested(hlo_module_config_);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 09909b62ba..6f433b4f30 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -99,7 +99,7 @@ IrEmitter::IrEmitter(
target_machine_features_(*target_machine_features) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
- .xla_enable_fast_math()));
+ .xla_cpu_enable_fast_math()));
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
@@ -158,11 +158,11 @@ void IrEmitter::InitializeIrFunction(const string& function_name) {
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
// Create and initialize new IrFunction.
- compute_function_.reset(
- new IrFunction(function_name, linkage,
- options::OptimizeForSizeRequested(hlo_module_config_),
- hlo_module_config_.debug_options().xla_enable_fast_math(),
- module_, &b_, num_dynamic_loop_bounds_));
+ compute_function_.reset(new IrFunction(
+ function_name, linkage,
+ options::OptimizeForSizeRequested(hlo_module_config_),
+ hlo_module_config_.debug_options().xla_cpu_enable_fast_math(), module_,
+ &b_, num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@@ -577,7 +577,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window,
/*operands=*/{reduce_window->operand(0)},
- /*supported_types=*/{F32, BF16, S32}));
+ /*supported_types=*/{F32, BF16, S32, F16}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(reduce_window->window())) {
diff --git a/tensorflow/compiler/xla/service/despecializer.cc b/tensorflow/compiler/xla/service/despecializer.cc
index d938f3a2c4..48e4471499 100644
--- a/tensorflow/compiler/xla/service/despecializer.cc
+++ b/tensorflow/compiler/xla/service/despecializer.cc
@@ -21,8 +21,33 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Pass which strips control dependencies from all instructions in the module.
+class ControlDepRemover : public HloPassInterface {
+ public:
+ ControlDepRemover() = default;
+ tensorflow::StringPiece name() const override {
+ return "control-dep-remover";
+ }
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ changed = changed || !instruction->control_predecessors().empty();
+ TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
+ }
+ }
+ return changed;
+ }
+};
+
+} // namespace
+
Despecializer::Despecializer() : pipeline_("despecializer") {
// TODO(b/70588125): Also deal with window reversal in a fast way.
+ pipeline_.AddPass<ControlDepRemover>();
pipeline_.AddPass<Defuser>();
pipeline_.AddPass<ImplicitBroadcastRemover>();
pipeline_.AddPass<BFloat16MixedPrecisionRemoval>();
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index a3f6e8d989..19575c7905 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1,6 +1,7 @@
# Description:
# GPU-specific components in XLA service implementation.
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
licenses(["notice"]) # Apache 2.0
@@ -365,6 +366,7 @@ cc_library(
":gpu_executable",
":ir_emission_utils",
"//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
@@ -652,6 +654,7 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
+ "//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
@@ -852,3 +855,35 @@ tf_cc_test(
"//tensorflow/core:test",
],
)
+
+cc_library(
+ name = "buffer_comparator",
+ srcs = ["buffer_comparator.cc"],
+ hdrs = ["buffer_comparator.h"],
+ deps = [
+ ":gpu_executable",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/service:hlo_runner",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+xla_test(
+ name = "buffer_comparator_test",
+ srcs = ["buffer_comparator_test.cc"],
+ backends = [
+ "cpu",
+ "gpu",
+ ],
+ deps = [
+ ":buffer_comparator",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
new file mode 100644
index 0000000000..6a285a6b98
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc
@@ -0,0 +1,205 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
+
+#include <cmath>
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace gpu {
+
+static constexpr float kTolerance = 0.1f;
+
+static string GetCompHloText(size_t num_elements) {
+ // Implements the textual format of the comparison routine, as it's more
+ // readable.
+ static constexpr char kF16CompHloText[] = R"(
+HloModule CompareF16
+
+MaxF32 {
+ %lhs = f32[] parameter(0)
+ %rhs = f32[] parameter(1)
+ ROOT %max = f32[] maximum(%lhs, %rhs)
+}
+
+Canonicalize (aparam: f16[SIZE]) -> f32[SIZE] {
+ %min_constant = f32[] constant(-65505)
+ %max_constant = f32[] constant(65505)
+ %large_constant = f32[] constant(1048576)
+ %min_values = f32[SIZE] broadcast(%min_constant), dimensions={}
+ %max_values = f32[SIZE] broadcast(%max_constant), dimensions={}
+ %large_values = f32[SIZE] broadcast(%large_constant), dimensions={}
+
+ %a = f16[SIZE] parameter(0)
+ %converted = f32[SIZE] convert(%a)
+ %clamped = f32[SIZE] clamp(%min_values, %converted, %max_values)
+
+ // Since the clamp() above already took care of infs, only NaNs will cause
+ // is-finite() to return false.
+ %is_finite = pred[SIZE] is-finite(%clamped)
+ ROOT %result = f32[SIZE] select(%is_finite, %clamped, %large_values)
+}
+
+ENTRY MaxDifference {
+ %one_constant = f32[] constant(1.0)
+ %zero_constant = f32[] constant(0.0)
+
+ %ones = f32[SIZE] broadcast(%one_constant), dimensions={}
+
+ %lhs = f16[SIZE] parameter(0)
+ %rhs = f16[SIZE] parameter(1)
+ %lhs_canonical = f32[SIZE] call(%lhs), to_apply=Canonicalize
+ %rhs_canonical = f32[SIZE] call(%rhs), to_apply=Canonicalize
+ %sub = f32[SIZE] subtract(%lhs_canonical, %rhs_canonical)
+ %sub_abs = f32[SIZE] abs(%sub)
+ %lhs_abs = f32[SIZE] abs(%lhs_canonical)
+ %rhs_abs = f32[SIZE] abs(%rhs_canonical)
+ %max = f32[SIZE] maximum(%lhs_abs, %rhs_abs)
+ %denominator = f32[SIZE] add(%max, %ones)
+ %error = f32[SIZE] divide(%sub_abs, %denominator)
+ ROOT %max_diff = f32[] reduce(%error, %zero_constant), dimensions={0}, to_apply=MaxF32
+})";
+ auto size_string = std::to_string(num_elements);
+ return tensorflow::str_util::StringReplace(
+ kF16CompHloText, "SIZE", {size_string.data(), size_string.size()}, true);
+}
+
+StatusOr<F16BufferComparator> F16BufferComparator::Create(
+ se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
+ DeviceMemoryAllocator* allocator, se::Stream* stream) {
+ auto stream_exec = stream->parent();
+ int64 num_elements = ref_buffer.ElementCount();
+
+ // One may consider using hlo_runner to do all the compilation and execution.
+ // However, as of the time hlo_runner doesn't support injection for Compiler*,
+ // Stream*, or even the allocator. We may revisit this in the future if it
+ // proves to be a maintenance burden.
+ TF_ASSIGN_OR_RETURN(
+ auto exec, ([&]() -> StatusOr<std::unique_ptr<Executable>> {
+ HloModuleConfig config;
+ DebugOptions debug_options;
+ debug_options.set_xla_backend_optimization_level(2);
+ config.set_debug_options(debug_options);
+ TF_ASSIGN_OR_RETURN(
+ auto module, ParseHloString(GetCompHloText(num_elements), config));
+ TF_ASSIGN_OR_RETURN(
+ module,
+ compiler->RunHloPasses(std::move(module), stream_exec, nullptr));
+ return compiler->RunBackend(std::move(module), stream_exec, nullptr);
+ }()));
+
+ TF_ASSIGN_OR_RETURN(
+ auto shaped_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
+ auto device_ordinal = stream_exec->device_ordinal();
+ TF_ASSIGN_OR_RETURN(
+ auto owning_buffer,
+ allocator->Allocate(device_ordinal, ref_buffer.size()));
+ se::DeviceMemory<Eigen::half> buffer(
+ owning_buffer.AsDeviceMemoryBase());
+ stream->ThenMemcpy(&buffer, ref_buffer, ref_buffer.size());
+ Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
+ ScopedShapedBuffer ret(shape, shape, allocator, device_ordinal);
+ ret.set_buffer(std::move(owning_buffer), {});
+ return std::move(ret);
+ }()));
+
+ return F16BufferComparator(stream, allocator, std::move(exec),
+ std::move(shaped_buffer));
+}
+
+StatusOr<bool> F16BufferComparator::CompareEqualImpl(
+ se::DeviceMemory<Eigen::half> test_buffer) {
+ if (ref_buffer_.root_buffer().size() != test_buffer.size()) {
+ return InternalError("Mismatched buffer size: %lld vs %lld",
+ ref_buffer_.root_buffer().size(), test_buffer.size());
+ }
+
+ int64 num_elements = test_buffer.ElementCount();
+
+ TF_ASSIGN_OR_RETURN(
+ auto result_buffer, ([&]() -> StatusOr<ScopedShapedBuffer> {
+ auto stream_exec = stream_->parent();
+ Shape shape = ShapeUtil::MakeShape(xla::F16, {num_elements});
+ auto device_ordinal = stream_exec->device_ordinal();
+ ShapedBuffer shaped_test_buffer(shape, shape, stream_exec->platform(),
+ device_ordinal);
+ shaped_test_buffer.set_buffer(test_buffer, {});
+ ExecutableRunOptions run_options;
+ run_options.set_device_ordinal(stream_exec->device_ordinal());
+ run_options.set_stream(stream_);
+ run_options.set_allocator(allocator_);
+ ServiceExecutableRunOptions service_run_options(run_options);
+ return exec_->ExecuteOnStream(
+ &service_run_options, {&ref_buffer_, &shaped_test_buffer}, nullptr);
+ }()));
+
+ float result;
+ CHECK(result_buffer.root_buffer().size() == sizeof(result));
+ stream_->ThenMemcpy(&result, result_buffer.root_buffer(), sizeof(result));
+ TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
+ return result < kTolerance;
+}
+
+StatusOr<bool> F16BufferComparator::CompareEqual(
+ se::DeviceMemory<Eigen::half> test_buffer) {
+ TF_ASSIGN_OR_RETURN(auto result, CompareEqualImpl(test_buffer));
+ if (result) {
+ return true;
+ }
+ // Host side code that does the same thing, but report some of the
+ // differences as well.
+ int64 n = test_buffer.ElementCount();
+ std::vector<half> host_ref_buffer(n), host_test_buffer(n);
+ stream_->ThenMemcpy(host_ref_buffer.data(), ref_buffer_.root_buffer(),
+ ref_buffer_.root_buffer().size());
+ stream_->ThenMemcpy(host_test_buffer.data(), test_buffer, test_buffer.size());
+ TF_RETURN_IF_ERROR(stream_->BlockHostUntilDone());
+
+ const auto canonicalize = [](float a) -> float {
+ constexpr float kBigNumer = 1048576.;
+ constexpr float kMaxFp16Value = 65504.;
+ if (std::isnan(a)) {
+ return kBigNumer;
+ }
+ if (std::isinf(a)) {
+ if (a < 0) {
+ return -(kMaxFp16Value + 1);
+ }
+ return kMaxFp16Value + 1;
+ }
+ return a;
+ };
+ int differences_seen = 0;
+ for (int64 i = 0; i < n && differences_seen < 10; i++) {
+ float original_ref = static_cast<float>(host_ref_buffer[i]);
+ float original_test = static_cast<float>(host_test_buffer[i]);
+ float ref = canonicalize(original_ref);
+ float test = canonicalize(original_test);
+ if (!(std::abs(ref - test) / (std::max(std::abs(ref), std::abs(test)) + 1) <
+ kTolerance)) {
+ differences_seen++;
+ LOG(ERROR) << "Difference at " << i << ": " << original_ref << " vs "
+ << original_test;
+ }
+ }
+
+ return false;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.h b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h
new file mode 100644
index 0000000000..bf2ba78cea
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A fp16 comparator that internally keeps a reference buffer, and compares it
+// against other test buffers.
+class F16BufferComparator {
+ public:
+ F16BufferComparator(const F16BufferComparator&) = delete;
+ F16BufferComparator(F16BufferComparator&&) = default;
+
+ // Creates a new comparator. It internally allocates a buffer initialized by
+ // ref_buffer.
+ static StatusOr<F16BufferComparator> Create(
+ se::DeviceMemory<Eigen::half> ref_buffer, Compiler* compiler,
+ DeviceMemoryAllocator* allocator, se::Stream* stream);
+
+ // Returns true if the internally allocated buffer "compares equal" to
+ // test_buffer. The definition of "equal" is:
+ // * All NaNs equal.
+ // * All infs are treated as 65505 or -65505, so that this checker is tolerant
+ // to fp16 overflows.
+ // * With NaNs and infs taken care of, a and b compare equal iff:
+ // abs(a - b) / (max(abs(a), abs(b)) + 1) < tolerance
+ //
+ // See the implementation for the tolerance value.
+ StatusOr<bool> CompareEqual(se::DeviceMemory<Eigen::half> test_buffer);
+
+ private:
+ F16BufferComparator(se::Stream* stream, DeviceMemoryAllocator* allocator,
+ std::unique_ptr<Executable> exec,
+ ScopedShapedBuffer ref_buffer)
+ : stream_(stream),
+ allocator_(allocator),
+ exec_(std::move(exec)),
+ ref_buffer_(std::move(ref_buffer)) {}
+
+ StatusOr<bool> CompareEqualImpl(se::DeviceMemory<Eigen::half> test_buffer);
+
+ se::Stream* stream_;
+ DeviceMemoryAllocator* allocator_;
+ std::unique_ptr<Executable> exec_;
+ ScopedShapedBuffer ref_buffer_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_COMPARATOR_H_
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc
new file mode 100644
index 0000000000..33761d1bd8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator_test.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
+
+#include <limits>
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+class BufferComparatorTest : public testing::Test {
+ protected:
+ BufferComparatorTest()
+ : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()),
+ stream_exec_(backend_->default_stream_executor()),
+ allocator_(stream_exec_->platform(), {stream_exec_}),
+ compiler_(Compiler::GetForPlatform(stream_exec_->platform())
+ .ConsumeValueOrDie()) {}
+
+ // Take floats only for convenience. Still uses half internally.
+ bool CompareEqualFloatBuffers(const std::vector<float>& lhs_float,
+ const std::vector<float>& rhs_float) {
+ std::vector<half> lhs(lhs_float.begin(), lhs_float.end());
+ std::vector<half> rhs(rhs_float.begin(), rhs_float.end());
+ se::Stream stream(stream_exec_);
+ stream.Init();
+
+ auto owning_lhs_buffer =
+ allocator_
+ .Allocate(stream_exec_->device_ordinal(), lhs.size() * sizeof(half))
+ .ConsumeValueOrDie();
+
+ auto owning_rhs_buffer =
+ allocator_
+ .Allocate(stream_exec_->device_ordinal(), rhs.size() * sizeof(half))
+ .ConsumeValueOrDie();
+
+ auto lhs_buffer =
+ se::DeviceMemory<Eigen::half>(owning_lhs_buffer.AsDeviceMemoryBase());
+ auto rhs_buffer =
+ se::DeviceMemory<Eigen::half>(owning_rhs_buffer.AsDeviceMemoryBase());
+
+ stream.ThenMemcpy(&lhs_buffer, lhs.data(), lhs_buffer.size());
+ stream.ThenMemcpy(&rhs_buffer, rhs.data(), rhs_buffer.size());
+
+ TF_CHECK_OK(stream.BlockHostUntilDone());
+
+ return F16BufferComparator::Create(lhs_buffer, compiler_, &allocator_,
+ &stream)
+ .ConsumeValueOrDie()
+ .CompareEqual(rhs_buffer)
+ .ConsumeValueOrDie();
+ }
+
+ std::unique_ptr<Backend> backend_;
+ se::StreamExecutor* stream_exec_;
+ StreamExecutorMemoryAllocator allocator_;
+ Compiler* compiler_;
+};
+
+TEST_F(BufferComparatorTest, TestNaNs) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("")}));
+ // NaN values with different bit patterns should compare equal.
+ EXPECT_TRUE(CompareEqualFloatBuffers({std::nanf("")}, {std::nanf("1234")}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({std::nanf("")}, {1.}));
+}
+
+TEST_F(BufferComparatorTest, TestInfs) {
+ const auto inf = std::numeric_limits<float>::infinity();
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {std::nanf("")}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {inf}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({inf}, {65504}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({-inf}, {-65504}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-65504}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {65504}));
+
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({inf}, {-20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {20}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({-inf}, {-20}));
+}
+
+TEST_F(BufferComparatorTest, TestNumbers) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({20}, {20.1}));
+ EXPECT_FALSE(CompareEqualFloatBuffers({0}, {1}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({0.9}, {1}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({9}, {10}));
+ EXPECT_TRUE(CompareEqualFloatBuffers({10}, {9}));
+}
+
+TEST_F(BufferComparatorTest, TestMultiple) {
+ EXPECT_TRUE(CompareEqualFloatBuffers({20, 30, 40, 50, 60},
+ {20.1, 30.1, 40.1, 50.1, 60.1}));
+ std::vector<float> lhs(200);
+ std::vector<float> rhs(200);
+ for (int i = 0; i < 200; i++) {
+ EXPECT_TRUE(CompareEqualFloatBuffers(lhs, rhs))
+ << "should be the same at index " << i;
+ lhs[i] = 3;
+ rhs[i] = 5;
+ EXPECT_FALSE(CompareEqualFloatBuffers(lhs, rhs))
+ << "should be the different at index " << i;
+ lhs[i] = 0;
+ rhs[i] = 0;
+ }
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 7348307ec8..7d93bdfc8b 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -30,7 +30,6 @@ namespace {
using se::DeviceMemoryBase;
using se::dnn::AlgorithmConfig;
using se::dnn::AlgorithmDesc;
-using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;
class ScratchAllocator : public se::ScratchAllocator {
@@ -173,7 +172,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-optional<std::tuple<int64, bool, int64>>
+StatusOr<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
@@ -206,45 +205,25 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
// Allocate space for the input, filter, and output of the convolution. We
// use a ScratchAllocator for this instead of calling allocator_ directly so
// that our allocations don't leak.
- //
- // We don't put any data in these buffers, because (in theory, anyway) the
- // speed of a conv isn't affected by the data being convolved.
ScratchAllocator input_output_allocator(device_ordinal, allocator);
- StatusOr<DeviceMemoryBase> maybe_input_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(input_shape));
- StatusOr<DeviceMemoryBase> maybe_filter_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(filter_shape));
- StatusOr<DeviceMemoryBase> maybe_output_buf =
- input_output_allocator.AllocateBytes(&stream,
- ShapeUtil::ByteSizeOf(output_shape));
- if (!maybe_input_buf.ok() || !maybe_filter_buf.ok() ||
- !maybe_output_buf.ok()) {
- LOG(WARNING)
- << "Couldn't allocate space for input/filter/output of convolution "
- << instr->ToString() << ". Falling back to default algorithm.";
- return nullopt;
- }
-
- DeviceMemoryBase input_buf = maybe_input_buf.ValueOrDie();
- DeviceMemoryBase filter_buf = maybe_filter_buf.ValueOrDie();
- DeviceMemoryBase output_buf = maybe_output_buf.ValueOrDie();
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase input_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(input_shape)));
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase filter_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(filter_shape)));
+ TF_ASSIGN_OR_RETURN(DeviceMemoryBase output_buf,
+ input_output_allocator.AllocateBytes(
+ &stream, ShapeUtil::ByteSizeOf(output_shape)));
// Although we don't have evidence this matters, zero out the buffers before
// autotuning. It's conceivable that using uninitialized memory as the inputs
// might affect performance if e.g. the inputs contain denormals, and this is
// easy enough.
- if (!stream.ThenMemZero(&input_buf, input_buf.size())
- .ThenMemZero(&filter_buf, filter_buf.size())
- .ThenMemZero(&output_buf, output_buf.size())
- .BlockHostUntilDone()
- .ok()) {
- LOG(WARNING)
- << "Couldn't zero out input/filter/output buffer for convolution "
- << instr->ToString() << ". Falling back to default algorithm.";
- return nullopt;
- }
+ TF_RETURN_IF_ERROR(stream.ThenMemZero(&input_buf, input_buf.size())
+ .ThenMemZero(&filter_buf, filter_buf.size())
+ .ThenMemZero(&output_buf, output_buf.size())
+ .BlockHostUntilDone());
const bool use_winograd_nonfused = ShouldIncludeWinogradNonfusedAlgo(
input_shape, output_shape, dnums, stream_exec_);
@@ -292,9 +271,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
best_result_bytes_used);
}
- LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
- << " failed. Falling back to default algorithm.";
- return nullopt;
+ return InternalError(
+ "All algorithms tried for convolution %s failed. Falling back to "
+ "default algorithm.",
+ instr->ToString().c_str());
}
StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
@@ -305,12 +285,13 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
const auto& lhs_shape = instr->operand(0)->shape();
const auto& rhs_shape = instr->operand(1)->shape();
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
+ StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
if (call_target == kCudnnConvForwardCallTarget) {
- alg_scratch_and_tc = PickBestAlgorithm(
- CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
- /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
- instr->window(), instr->convolution_dimension_numbers(), instr);
+ alg_scratch_and_tc =
+ PickBestAlgorithm(CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
+ /*filter_shape=*/rhs_shape,
+ /*output_shape=*/conv_result_shape, instr->window(),
+ instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
@@ -326,7 +307,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
<< instr->ToString();
}
- if (!alg_scratch_and_tc.has_value()) {
+ if (!alg_scratch_and_tc.ok()) {
+ LOG(ERROR) << alg_scratch_and_tc.status();
return false;
}
@@ -334,7 +316,8 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
bool tensor_ops_enabled;
int64 scratch_bytes;
- std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
+ std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
+ alg_scratch_and_tc.ConsumeValueOrDie();
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
<< NumBytesToString(scratch_bytes)
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index bc5d1ce94a..8b7749628a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -34,8 +35,9 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
// memory while timing the various convolution algorithms. If it's null,
// we'll use the default allocator on the StreamExecutor.
CudnnConvolutionAlgorithmPicker(se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* allocator)
- : stream_exec_(stream_exec), allocator_(allocator) {}
+ DeviceMemoryAllocator* allocator,
+ Compiler* compiler)
+ : stream_exec_(stream_exec), allocator_(allocator), compiler_(compiler) {}
tensorflow::StringPiece name() const override {
return "cudnn-convolution-algorithm-picker";
@@ -46,13 +48,14 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
+ StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
+ Compiler* compiler_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 69ba91793d..9b6de115ad 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -210,11 +210,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
return make_sqrt();
}
- if (hlo_module_config_.debug_options().xla_enable_fast_math() &&
- IsFPLiteralWithValue(rhs, -.5)) {
+ if (IsFPLiteralWithValue(rhs, -.5)) {
VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString();
// LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX
// rsqrt.approx instruction.
+ //
+ // TODO(jlebar): Does this happen with fastmath disabled? If not, should
+ // we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
@@ -274,16 +276,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
- // If we don't care much about precision, emit a fast approximation of
- // tanh.
- if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
- // Upcast F16 to F32 if necessary.
- llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
- llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
- }
- return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
+ // Emit a fast approximation of tanh instead of calling __nv_tanh.
+ // __nv_tanh is particularly bad because it contains branches, thus
+ // preventing LLVM's load-store vectorizer from working its magic across a
+ // function which contains tanh calls.
+ //
+ // This routine isn't numerically precise, but it's good enough for ML.
+
+ // Upcast F16 to F32 if necessary.
+ llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
+ llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ return b_->CreateFPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 66aeb4efef..6675dbd3f9 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -64,7 +64,7 @@ IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
hlo_module_config_(hlo_module_config) {
b_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config.debug_options()
- .xla_enable_fast_math()));
+ .xla_gpu_enable_fast_math()));
}
Status IrEmitter::DefaultAction(HloInstruction* hlo) {
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
index cf44458a2e..ff4ae1f9ef 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc
@@ -180,7 +180,7 @@ std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
llvm_ir::SetTargetOptions(
/*fast_math_enabled=*/hlo_module_config.debug_options()
- .xla_enable_fast_math(),
+ .xla_gpu_enable_fast_math(),
&target_options);
// Enable FMA synthesis.
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 76c9b6ab33..d937123357 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -72,6 +72,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@@ -130,8 +131,12 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) {
}
// Runs optimization passes on the given HLO module.
+//
+// It takes a compiler pointer, as passes may compile and execute HLOs on the
+// fly for cuDNN verification or other purposes.
Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
- DeviceMemoryAllocator* device_allocator) {
+ DeviceMemoryAllocator* device_allocator,
+ Compiler* compiler) {
{
HloPassPipeline pipeline("optimization");
pipeline.AddInvariantChecker<HloVerifier>();
@@ -167,6 +172,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
+ pipeline.AddPass<ScatterExpander>();
+
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
@@ -245,8 +252,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// the gte(customcall, 0) would probably already be into a fusion node. We
// can't simplify across HloComputation boundaries, so in this case we
// wouldn't be able to simplify away the new_tuple bits.
- pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(stream_exec,
- device_allocator);
+ pipeline.AddPass<CudnnConvolutionAlgorithmPicker>(
+ stream_exec, device_allocator, compiler);
// Clean up new_tuple described above.
pipeline.AddPass<TupleSimplifier>();
@@ -492,11 +499,15 @@ NVPTXCompiler::NVPTXCompiler()
StatusOr<std::unique_ptr<HloModule>> NVPTXCompiler::RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) {
+ // We dump the post-optimization HLO in RunBackend so no need to dump it here.
+ VLOG(2) << "*** HLO Before Optimization";
+ XLA_VLOG_LINES(2, module->ToString());
+
XLA_SCOPED_LOGGING_TIMER("NVPTXCompiler::RunHloPasses");
tracing::ScopedActivity activity("HLO Transforms", module->name(),
/*is_expensive=*/true);
TF_RETURN_IF_ERROR(
- OptimizeHloModule(module.get(), stream_exec, device_allocator));
+ OptimizeHloModule(module.get(), stream_exec, device_allocator, this));
return std::move(module);
}
@@ -548,6 +559,7 @@ StatusOr<std::unique_ptr<Executable>> NVPTXCompiler::RunBackend(
// include headers, so no need for us to print them ourselves.
XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString());
XLA_VLOG_LINES(2, buffer_assignment->ToString());
+ VLOG(2) << "*** HLO After Optimization";
XLA_VLOG_LINES(2, module->ToString());
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index 90d2be118d..858992a326 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -174,6 +174,29 @@ StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers));
}
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation) {
+ CHECK(!operands.empty()) << "Map Hlo requires at least one operand.";
+ HloComputation* computation = operands.front()->parent();
+ std::vector<const Shape*> operand_shapes;
+ int64 max_operand_rank = 0;
+ for (const HloInstruction* operand : operands) {
+ CHECK_EQ(computation, operand->parent());
+ operand_shapes.push_back(&operand->shape());
+ max_operand_rank =
+ std::max(max_operand_rank, ShapeUtil::Rank(operand->shape()));
+ }
+ std::vector<int64> map_dims(max_operand_rank);
+ std::iota(map_dims.begin(), map_dims.end(), 0);
+ TF_ASSIGN_OR_RETURN(
+ Shape map_shape,
+ ShapeInference::InferMapShape(
+ operand_shapes, map_computation->ComputeProgramShape(), map_dims));
+ return computation->AddInstruction(
+ HloInstruction::CreateMap(map_shape, operands, map_computation));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
@@ -251,6 +274,38 @@ StatusOr<HloInstruction*> ElideDegenerateDims(HloInstruction* operand,
return MakeReshapeHlo(output_shape, operand);
}
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, ArraySlice<int64> dims_to_insert) {
+ CHECK(c_is_sorted(dims_to_insert));
+
+ const Shape& operand_shape = operand->shape();
+ int64 output_shape_rank =
+ operand_shape.dimensions_size() + dims_to_insert.size();
+ for (auto dim_to_insert : dims_to_insert) {
+ CHECK_LT(dim_to_insert, output_shape_rank);
+ }
+
+ std::vector<int64> output_shape_dim_bounds;
+ output_shape_dim_bounds.reserve(output_shape_rank);
+ int64 operand_dims_idx = 0;
+ int64 dims_to_insert_idx = 0;
+ for (int64 i = 0; i < output_shape_rank; ++i) {
+ if (dims_to_insert_idx < dims_to_insert.size() &&
+ i == dims_to_insert[dims_to_insert_idx]) {
+ output_shape_dim_bounds.push_back(1);
+ ++dims_to_insert_idx;
+ } else {
+ output_shape_dim_bounds.push_back(
+ operand_shape.dimensions(operand_dims_idx));
+ ++operand_dims_idx;
+ }
+ }
+
+ Shape output_shape = ShapeUtil::MakeShape(operand_shape.element_type(),
+ output_shape_dim_bounds);
+ return MakeReshapeHlo(output_shape, operand);
+}
+
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
int64 zeros_to_prepend,
int64 zeros_to_append) {
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 49b1402d68..5ff8946fb0 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -102,6 +102,12 @@ StatusOr<HloInstruction*> MakeConcatHlo(
StatusOr<HloInstruction*> MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dim_numbers);
+// Creates a Map HLO instruction and adds it to the computation containing the
+// operands. All operands must be in the same computation.
+StatusOr<HloInstruction*> MakeMapHlo(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation);
+
// -----------------------------------------------------------------------------
// Some other miscellaneous helpers to generate common HLO patterns. All of
// these add all the instructions they generate into the computation containing
@@ -144,6 +150,16 @@ StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
StatusOr<HloInstruction*> ElideDegenerateDims(
HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_elide);
+// Inserts (via reshape) a set of degenerate dimensions (dimensions containing
+// exactly one element), `dims_to_insert` into `operand`. The dimensions in
+// `dims_to_insert` refer to the dimensions in the result, and hence should be
+// less than the rank of the result. Also, `dims_to_insert` must be sorted.
+//
+// For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
+// {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
+StatusOr<HloInstruction*> InsertDegenerateDims(
+ HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dims_to_insert);
+
// Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
// front and `zeros_to_append` zeros in the back.
StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 71b44507cc..8e0d38b6a6 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -143,8 +143,47 @@ TokKind HloLexer::LexToken() {
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
- case '/':
- return LexComment();
+ case '/': {
+ if (PeekCurrentChar() == '*') {
+ // This is the start of a /*...*/ delimited comment. Save the current
+ // location in case the comment is unterminated so the error message
+ // will point to the beginning of the comment.
+ const char* comment_start = current_ptr_;
+ current_ptr_++;
+ // Advance until '*/' is found.
+ while (true) {
+ int current = GetNextChar();
+ if (current == '*' && PeekCurrentChar() == '/') {
+ // End of comment.
+ current_ptr_++;
+ break;
+ }
+ if (current == kEOF) {
+ // Unterminated comment.
+ current_ptr_ = comment_start;
+ return TokKind::kError;
+ }
+ }
+ // Return no token for the comment. Keep lexing.
+ continue;
+ } else if (PeekCurrentChar() == '/') {
+ // This is the start of a '//' delimited comment. Throw away
+ // everything until end of line or file. The end-of-line character(s)
+ // are left unlexed in the buffer which is harmless because these are
+ // skipped later by the lexer. This approach enables support for
+ // different end-of-line encodings.
+ while (true) {
+ int current = PeekCurrentChar();
+ if (current == kEOF || current == '\n' || current == '\r') {
+ break;
+ }
+ current_ptr_++;
+ }
+ continue;
+ }
+ // A lone '/' is an error.
+ return TokKind::kError;
+ }
case '"':
return LexString();
}
@@ -357,16 +396,6 @@ tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const {
return StringPieceFromPointers(start, end);
}
-TokKind HloLexer::LexComment() {
- auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
- static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"};
- if (RE2::Consume(&consumable, *comment_pattern)) {
- current_ptr_ = consumable.begin();
- return TokKind::kComment;
- }
- return TokKind::kError;
-}
-
// Lexes quoted string with escaping characters. If matched, the quoted string
// will be unescaped and stored to str_val_.
TokKind HloLexer::LexString() {
@@ -412,8 +441,6 @@ string TokKindToString(TokKind kind) {
return "kRparen";
case TokKind::kArrow:
return "kArrow";
- case TokKind::kComment:
- return "kComment";
case TokKind::kw_HloModule:
return "kw_HloModule";
case TokKind::kw_ENTRY:
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index ceb674f25e..003ac34ace 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -105,7 +105,6 @@ class HloLexer {
TokKind LexShape();
TokKind LexConstant();
TokKind LexNumberOrPattern();
- TokKind LexComment();
TokKind LexString();
const tensorflow::StringPiece buf_;
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index b57c940238..c577b4359a 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -231,6 +231,7 @@ HLO_MATCHER(Tanh);
HLO_MATCHER(Trace);
HLO_MATCHER(Transpose);
HLO_MATCHER(Tuple);
+HLO_MATCHER(TupleSelect);
HLO_MATCHER(While);
// The special cases below let you check additional information about the
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 84f2d3f5fb..1b256cd00e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -166,7 +166,7 @@ class HloModuleGroupMetadata {
//
// Precondition: IsCompanionWhile(instruction) is true.
const std::unordered_set<HloInstruction*>& Companions(
- HloInstruction* instruction) const {
+ const HloInstruction* instruction) const {
CHECK_EQ(companion_set_index_.count(instruction), 1);
return companion_set(companion_set_index_.at(instruction));
}
@@ -243,7 +243,7 @@ class HloModuleGroupMetadata {
companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
- tensorflow::gtl::FlatMap<HloInstruction*, int64> companion_set_index_;
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
// Map from computation to the instruction using it (a kWhile, kConditional).
tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 9fd0ade153..0dc5676148 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -37,24 +38,38 @@ namespace xla {
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> predecessors;
-
- // Adds to the unique predecessors list and also add companion instructions
- // if the given predecessor has those.
+ std::vector<HloInstruction*>
+ predecessors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique predecessors list; if the predecessors is a companion
+ // instruction, also add companion instructions; if the predecessors is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_predecessor = [&](HloInstruction* predecessor) {
- if (std::find(predecessors.begin(), predecessors.end(), predecessor) !=
- predecessors.end()) {
+ if (unique.find(predecessor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(predecessor)) {
- predecessors.push_back(predecessor);
+ if (metadata_.IsCompanionInstruction(predecessor)) {
+ for (HloInstruction* instr : metadata_.Companions(predecessor)) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(predecessor)) {
- predecessors.push_back(companion);
+ if (predecessor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*predecessor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ predecessors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(predecessor);
+ predecessors.push_back(predecessor);
};
-
// If the given instruction is a companion instruction, we need to find the
// predecessors of all of its companion instructions. If the instruction is an
// all-reduce, we need to find the predecessors of all the peer all-reduce
@@ -98,22 +113,37 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
- std::vector<HloInstruction*> successors;
-
- // Adds to the unique successors list and also add companion instructions
- // if the given successor has those.
+ std::vector<HloInstruction*>
+ successors; // Use a vector to avoid non-determinism.
+ tensorflow::gtl::FlatSet<HloInstruction*> unique;
+
+ // Adds to the unique successors list; if the successor is a companion
+ // instruction, also add companion instructions; if the successor is a
+ // cross-module all-reduce, also add the all-reduce instructions in the same
+ // group.
auto add_unique_successor = [&](HloInstruction* successor) {
- if (std::find(successors.begin(), successors.end(), successor) !=
- successors.end()) {
+ if (unique.find(successor) != unique.end()) {
return;
}
- if (!metadata_.IsCompanionInstruction(successor)) {
- successors.push_back(successor);
+ if (metadata_.IsCompanionInstruction(successor)) {
+ for (HloInstruction* instr : metadata_.Companions(successor)) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
return;
}
- for (HloInstruction* companion : metadata_.Companions(successor)) {
- successors.push_back(companion);
+ if (successor->IsCrossModuleAllReduce()) {
+ for (HloInstruction* instr :
+ metadata_.GetAllReduceGroup(*successor->all_reduce_id())) {
+ if (unique.insert(instr).second) {
+ successors.push_back(instr);
+ }
+ }
+ return;
}
+ unique.insert(successor);
+ successors.push_back(successor);
};
// If the given instruction is a companion instruction, we need to find the
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 2a8c6ecd92..4b3cd99dc0 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1824,7 +1824,6 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
break;
}
case TokKind::kComma:
- case TokKind::kComment:
// Skip.
lexer_.Lex();
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 4cd21841f4..5990a3d478 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1560,6 +1560,81 @@ ENTRY consts {
"last");
}
+TEST_F(HloParserTest, Comments) {
+ const string original = R"(/* module description. */
+HloModule comments:
+
+ENTRY /*comment*/ c1 {
+ /* blah */
+ ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
+ /* comment */
+}
+
+/* something else */
+
+)";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, MultilineComments) {
+ const string original = R"(HloModule multiline_comment:
+ENTRY c1 {
+ /*
+ ROOT foo = f32[1]{0} constant({12345})
+ */
+ ROOT const1 = f32[1]{0} constant({12345})
+/*
+a
+b
+c
+d
+
+*/
+})";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, UnterminatedComment) {
+ const string original = R"(HloModule unterminated_comment:
+ENTRY c1 {
+/* unterminated
+ ROOT const1 = f32[1]{0} constant({12345})
+})";
+ // Verify that the error message points to the beginning of the unterminated
+ // comment.
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "/* unterminated\n^");
+}
+
+TEST_F(HloParserTest, SlashSlashComments) {
+ const string original = R"(HloModule slash_slash_comment:
+// Garbage
+ENTRY c1 {
+ // Foo bar
+ ROOT const1 = f32[1]{0} constant({12345}) // Something else
+})";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
+ const string original =
+ "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
+ "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
+TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
+ const string original =
+ "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
+ "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
+ auto module = ParseHloString(original);
+ TF_ASSERT_OK(module.status());
+}
+
TEST_F(HloParserTest, MultipleEntries) {
const string original = R"(HloModule multiple_entries:
ENTRY c1 {
diff --git a/tensorflow/compiler/xla/service/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h
index 533429608b..4458c251de 100644
--- a/tensorflow/compiler/xla/service/hlo_token.h
+++ b/tensorflow/compiler/xla/service/hlo_token.h
@@ -44,7 +44,6 @@ enum class TokKind {
kRparen, // ( )
kArrow, // ->
- kComment, // /*xxx*/
// Keywords
kw_HloModule,
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index 9b109022fb..db6b910b32 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -104,7 +104,7 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
}
// No "synchronize all activity" implemented for this platform at the moment.
- bool SynchronizeAllActivity() override { return false; }
+ bool SynchronizeAllActivity() override { return true; }
bool SynchronousMemZero(DeviceMemoryBase *location, uint64 size) override {
return false;
}
diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc
new file mode 100644
index 0000000000..45ca731153
--- /dev/null
+++ b/tensorflow/compiler/xla/service/scatter_expander.cc
@@ -0,0 +1,350 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/scatter_expander.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/while_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+using tensorflow::gtl::ArraySlice;
+
+// Transposes the given scatter_indices such that the index_vector_dim becomes
+// the most-minor dimension.
+static StatusOr<HloInstruction*> TransposeIndexVectorDimToLast(
+ HloInstruction* scatter_indices, int64 index_vector_dim) {
+ const Shape& scatter_indices_shape = scatter_indices->shape();
+
+ if (scatter_indices_shape.dimensions_size() == index_vector_dim) {
+ return scatter_indices;
+ }
+
+ if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) {
+ return scatter_indices;
+ }
+
+ std::vector<int64> permutation;
+ permutation.reserve(scatter_indices_shape.dimensions_size());
+ for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != index_vector_dim) {
+ permutation.push_back(i);
+ }
+ }
+ permutation.push_back(index_vector_dim);
+ return MakeTransposeHlo(scatter_indices, permutation);
+}
+
+// Canonicalizes the scatter_indices tensor in order to keep them uniform while
+// performing the scatter operation.
+static StatusOr<HloInstruction*> CanonicalizeScatterIndices(
+ HloInstruction* scatter_indices, int64 index_vector_dim) {
+ // Transpose the non-index-vector dimensions to the front.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * transposed_scatter_indices,
+ TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim));
+ bool indices_are_scalar =
+ index_vector_dim == scatter_indices->shape().dimensions_size();
+
+ // The number of dimensions in scatter_indices that are index dimensions.
+ const int64 index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1;
+
+ // If there is only one index (i.e. scatter_indices has rank 1 and this
+ // scatter is really just a dynamic update slice) add a leading degenerate
+ // dimension for uniformity. Otherwise create a "collapsed" leading dimension
+ // that subsumes all of the non-index-vector dimensions.
+ const Shape& shape = transposed_scatter_indices->shape();
+ if (shape.dimensions_size() == index_dims_in_scatter_indices) {
+ return PrependDegenerateDims(transposed_scatter_indices, 1);
+ } else {
+ // Collapse all but the dimensions (0 or 1) in scatter_indices containing
+ // the index vectors.
+ return CollapseFirstNDims(
+ transposed_scatter_indices,
+ shape.dimensions_size() - index_dims_in_scatter_indices);
+ }
+}
+
+// Permutes the `updates` tensor such that all the scatter dims appear in the
+// major dimensions and all the window dimensions appear in the minor
+// dimensions.
+static StatusOr<HloInstruction*> PermuteScatterAndWindowDims(
+ HloInstruction* updates, ArraySlice<int64> update_window_dims) {
+ std::vector<int64> permutation;
+ const int64 updates_rank = ShapeUtil::Rank(updates->shape());
+ permutation.reserve(updates_rank);
+
+ for (int64 i = 0; i < updates_rank; ++i) {
+ bool is_scatter_dim = !c_binary_search(update_window_dims, i);
+ if (is_scatter_dim) {
+ permutation.push_back(i);
+ }
+ }
+ for (auto window_dim : update_window_dims) {
+ permutation.push_back(window_dim);
+ }
+
+ return MakeTransposeHlo(updates, permutation);
+}
+
+// Expands or contracts the scatter indices in the updates tensor.
+static StatusOr<HloInstruction*> AdjustScatterDims(
+ const Shape& scatter_indices_shape, HloInstruction* updates,
+ int64 index_vector_dim) {
+ int64 num_scatter_dims = scatter_indices_shape.dimensions_size();
+ if (index_vector_dim < scatter_indices_shape.dimensions_size()) {
+ --num_scatter_dims;
+ }
+ if (num_scatter_dims == 0) {
+ // If there are no scatter dims, this must be a dynamic-update-slice kind of
+ // scatter. In this case, we prepend a degenerate dimension to work
+ // uniformly in the while loop.
+ return PrependDegenerateDims(updates, 1);
+ }
+ return CollapseFirstNDims(updates, num_scatter_dims);
+}
+
+// Expands an index vector from the scatter_indices tensor into a vector that
+// can be used to dynamic-update-slice to perform the scatter update.
+static StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
+ HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
+ int64 operand_rank) {
+ HloComputation* computation = index_vector->parent();
+ const Shape& index_shape = index_vector->shape();
+ HloInstruction* zero =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1})));
+
+ // We extract out individual components from the smaller index and concatenate
+ // them (interspersing zeros as needed) into the larger index.
+ std::vector<HloInstruction*> expanded_index_components;
+
+ for (int i = 0; i < operand_rank; i++) {
+ int64 index_vector_dim_index =
+ FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
+ if (index_vector_dim_index !=
+ dim_numbers.scatter_dims_to_operand_dims_size()) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * component_to_concat,
+ MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index},
+ /*limit_indices=*/{index_vector_dim_index + 1},
+ /*strides=*/{1}));
+ expanded_index_components.push_back(component_to_concat);
+ } else {
+ expanded_index_components.push_back(zero);
+ }
+ }
+
+ return MakeConcatHlo(expanded_index_components, /*dimension=*/0);
+}
+
+// Body of the while loop that performs the scatter operation using other HLOs.
+static StatusOr<std::vector<HloInstruction*>> ScatterLoopBody(
+ HloInstruction* scatter, HloInstruction* induction_var,
+ const std::vector<HloInstruction*>& loop_state) {
+ const ScatterDimensionNumbers& dim_numbers =
+ scatter->scatter_dimension_numbers();
+ CHECK_EQ(loop_state.size(), 3);
+ HloInstruction* operand = loop_state[0];
+ HloInstruction* scatter_indices = loop_state[1];
+ HloInstruction* updates = loop_state[2];
+
+ bool has_scalar_indices = scatter_indices->shape().dimensions_size() == 1;
+ CHECK_EQ(has_scalar_indices,
+ dim_numbers.index_vector_dim() ==
+ scatter->operand(1)->shape().dimensions_size());
+
+ // Build a vector form of the induction variable of the while loop.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * induction_var_as_vector,
+ MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{},
+ /*result_shape_bounds=*/{1}));
+
+ // Pick the index to scatter from scatter_indices based on the induction_var
+ // and transform that to an index into the `operand` space.
+ HloInstruction* index_vector;
+ if (has_scalar_indices) {
+ TF_ASSIGN_OR_RETURN(
+ index_vector,
+ MakeDynamicSliceHlo(scatter_indices, induction_var_as_vector, {1}));
+ } else {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_into_scatter_indices,
+ PadVectorWithZeros(induction_var_as_vector,
+ /*zeros_to_prepend=*/0, /*zeros_to_append=*/1));
+ int index_vector_size = scatter_indices->shape().dimensions(1);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_vector_2d,
+ MakeDynamicSliceHlo(scatter_indices, index_into_scatter_indices,
+ {1, index_vector_size}));
+ TF_ASSIGN_OR_RETURN(index_vector,
+ ElideDegenerateDims(index_vector_2d, {0}));
+ }
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * scatter_slice_start,
+ ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
+ operand->shape().dimensions_size()));
+
+ // Extract the slice to be used to update from `updates` tensor for the
+ // induction_var corresponding to this iteration of the while loop.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * index_into_updates,
+ PadVectorWithZeros(
+ induction_var_as_vector, /*zeros_to_prepend=*/0,
+ /*zeros_to_append=*/updates->shape().dimensions_size() - 1));
+ std::vector<int64> update_slice_bounds(updates->shape().dimensions().begin(),
+ updates->shape().dimensions().end());
+ update_slice_bounds[0] = 1;
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * update_slice,
+ MakeDynamicSliceHlo(updates, index_into_updates, update_slice_bounds));
+ TF_ASSIGN_OR_RETURN(HloInstruction * update_slice_for_scatter,
+ ElideDegenerateDims(update_slice, {0}));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * update_slice_with_dims_inserted,
+ InsertDegenerateDims(update_slice_for_scatter,
+ AsInt64Slice(dim_numbers.inserted_window_dims())));
+
+ // Extact the slice to update from `operand` tensor.
+ const Shape& update_slice_shape = update_slice_with_dims_inserted->shape();
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * operand_slice_to_update,
+ MakeDynamicSliceHlo(operand, scatter_slice_start,
+ AsInt64Slice(update_slice_shape.dimensions())));
+
+ // Compute the new value for the slice to be updated in `operand` tensor by
+ // combining the existing value and the update value using the update
+ // computation.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * updated_operand_slice,
+ MakeMapHlo({operand_slice_to_update, update_slice_with_dims_inserted},
+ scatter->to_apply()));
+
+ // Write the updated value of the slice into `operand` tensor.
+ TF_ASSIGN_OR_RETURN(HloInstruction * updated_operand,
+ MakeDynamicUpdateSliceHlo(operand, updated_operand_slice,
+ scatter_slice_start));
+
+ return StatusOr<std::vector<HloInstruction*>>{
+ {updated_operand, scatter_indices, updates}};
+}
+
+// High Level Algorithm.
+//
+// 1. Canonicalize the scatter_indices tensor such that it has rank 2, where
+// each row is an index into the operand.
+// 2. Canonicalize the updates tensor such that is has rank `num_window_dims+1`
+// and the scatter dim is the most-major dimension.
+// 3. Iterate over the set of indices in the canonicalized scatter_indices
+// tensor using a while loop, updating the operand for each such index. Each
+// iteration of this while loop performs the following:
+// a. Pick the index from scatter_indices for this iteration.
+// b. Transfrom this index into an index into the operand space.
+// c. Extract the slice to be used to update from the updates tensor.
+// d. Extract the slice to update from the operand tensor.
+// e. Compute the new value for the slice to update by combining the slices
+// from c. and d. using the update_computation of scatter.
+// f. Write the updated value of the slice into the operand tensor.
+
+StatusOr<HloInstruction*> ScatterExpander::ExpandScatter(
+ HloInstruction* scatter) {
+ HloInstruction* operand = scatter->mutable_operand(0);
+ HloInstruction* scatter_indices = scatter->mutable_operand(1);
+ HloInstruction* updates = scatter->mutable_operand(2);
+ const ScatterDimensionNumbers& dim_numbers =
+ scatter->scatter_dimension_numbers();
+
+ // If the updates tensor is empty, there is no need to update the operand. We
+ // can return the operand as is.
+ if (ShapeUtil::IsZeroElementArray(updates->shape())) {
+ return operand;
+ }
+
+ // Compute the trip count for the while loop to be used for scatter. This
+ // should be the number of indices we should scatter into the operand.
+ const Shape& scatter_indices_shape = scatter_indices->shape();
+ int64 scatter_loop_trip_count = 1;
+ for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) {
+ if (i != dim_numbers.index_vector_dim()) {
+ scatter_loop_trip_count *= scatter_indices_shape.dimensions(i);
+ }
+ }
+ if (!IsInt32(scatter_loop_trip_count)) {
+ return Unimplemented(
+ "Scatter operations with more than 2147483647 scatter indices are not "
+ "supported. This error occurred for %s.",
+ scatter->ToString().c_str());
+ }
+
+ // Canonicalize the scatter_indices, after which the size of its most-major
+ // dimension must be same as the while loop trip count.
+ TF_ASSIGN_OR_RETURN(HloInstruction * canonical_scatter_indices,
+ CanonicalizeScatterIndices(
+ scatter_indices, dim_numbers.index_vector_dim()));
+ CHECK_EQ(scatter_loop_trip_count,
+ canonical_scatter_indices->shape().dimensions(0));
+
+ // Canonicalize the updates, after which the size of its most-major dimension
+ // must be same as the while loop trip count.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * canonical_updates,
+ PermuteScatterAndWindowDims(
+ updates, AsInt64Slice(dim_numbers.update_window_dims())));
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * adjusted_canonical_updates,
+ AdjustScatterDims(scatter_indices->shape(), canonical_updates,
+ dim_numbers.index_vector_dim()));
+ CHECK_EQ(scatter_loop_trip_count,
+ adjusted_canonical_updates->shape().dimensions(0));
+
+ // The while loop that implements the scatter operation.
+ StatusOr<std::vector<HloInstruction*>> scatter_loop_result_status =
+ WhileUtil::MakeCountedLoop(
+ scatter->parent(), scatter_loop_trip_count,
+ {operand, canonical_scatter_indices, adjusted_canonical_updates},
+ [&](HloInstruction* induction_var,
+ const std::vector<HloInstruction*>& loop_state) {
+ return ScatterLoopBody(scatter, induction_var, loop_state);
+ });
+ TF_ASSIGN_OR_RETURN(std::vector<HloInstruction*> scatter_loop_result,
+ scatter_loop_result_status);
+ return scatter_loop_result.front();
+}
+
+StatusOr<bool> ScatterExpander::Run(HloModule* module) {
+ std::vector<HloInstruction*> scatter_instrs;
+ for (HloComputation* computation : module->MakeNonfusionComputations()) {
+ for (HloInstruction* instr : computation->instructions()) {
+ if (instr->opcode() == HloOpcode::kScatter) {
+ scatter_instrs.push_back(instr);
+ }
+ }
+ }
+
+ for (auto instr : scatter_instrs) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr));
+ TF_RETURN_IF_ERROR(
+ instr->parent()->ReplaceInstruction(instr, expanded_root));
+ }
+
+ return !scatter_instrs.empty();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h
new file mode 100644
index 0000000000..8f735e877d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/scatter_expander.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+class ScatterExpander : public HloPassInterface {
+ public:
+ tensorflow::StringPiece name() const override { return "scatter_expander"; }
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ StatusOr<HloInstruction*> ExpandScatter(HloInstruction* scatter);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 34869cc507..b69c346f1e 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1014,12 +1014,13 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) {
+ if (!IsTuple(shape)) {
+ return 1;
+ }
int64 count = 0;
- ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) {
- if (IsLeafIndex(shape, index)) {
- ++count;
- }
- });
+ for (const Shape& subshape : shape.tuple_shapes()) {
+ count += GetLeafCount(subshape);
+ }
return count;
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 42d52aee78..0f8cffd466 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -709,6 +709,21 @@ xla_test(
],
)
+xla_test(
+ name = "scatter_test",
+ srcs = ["scatter_test.cc"],
+ deps = [
+ ":client_library_test_base",
+ ":hlo_test_base",
+ "//tensorflow/compiler/xla:execution_options_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/service:hlo_parser",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ ],
+)
+
# Repeat dot_operation_runtime_test with single-threaded eigen.
xla_test(
name = "dot_operation_single_threaded_runtime_test",
@@ -2061,6 +2076,8 @@ tf_cc_test(
xla_test(
name = "test_utils_test",
srcs = ["test_utils_test.cc"],
+ # There is nothing backend specific in this test, so just pick an arbitrary backend.
+ backends = ["cpu"],
deps = [
":local_client_test_base",
":test_utils",
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 4a6e8a3124..b04a3b105c 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -74,8 +74,9 @@ class ClientLibraryTestBase : public ::testing::Test {
string TestName() const;
void SetFastMathDisabled(bool disabled) {
- execution_options_.mutable_debug_options()->set_xla_enable_fast_math(
- !disabled);
+ auto* opts = execution_options_.mutable_debug_options();
+ opts->set_xla_cpu_enable_fast_math(!disabled);
+ opts->set_xla_gpu_enable_fast_math(!disabled);
}
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 73edad89dc..92c93f08b2 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -1464,5 +1464,24 @@ ENTRY %reduce-window (parameter.0: s32[81,8], parameter.1: s32[]) -> s32[82,8] {
EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
}
+XLA_TEST_F(HloTestBase, ReduceWindowF16) {
+ const string hlo_string = R"(
+HloModule reduce-window
+
+%identity.pad_to_reduce_window (param0: f16[], param1: f16[]) -> f16[] {
+ %param0 = f16[] parameter(0)
+ ROOT %param1 = f16[] parameter(1)
+}
+
+ENTRY %reduce-window (parameter.0: f16[81,8], parameter.1: f16[]) -> f16[82,8] {
+ %parameter.0 = f16[81,8]{1,0} parameter(0)
+ %parameter.1 = f16[] parameter(1)
+ ROOT %reduce-window = f16[82,8]{1,0} reduce-window(f16[81,8]{1,0} %parameter.0, f16[] %parameter.1), window={size=1x1 pad=0_1x0_0}, to_apply=%identity.pad_to_reduce_window
+}
+
+)";
+ EXPECT_TRUE(RunAndCompare(hlo_string, tensorflow::gtl::nullopt));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
new file mode 100644
index 0000000000..922d70b752
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -0,0 +1,615 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ScatterTest : public HloTestBase {
+ protected:
+ void RunTest(const string& hlo_text, Literal* operand,
+ Literal* scatter_indices, Literal* updates) {
+ RunTest(hlo_text, {operand, scatter_indices, updates});
+ }
+
+ void RunTest(const string& hlo_text,
+ tensorflow::gtl::ArraySlice<Literal*> args) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
+ }
+};
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterV2
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Add
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Mul
+
+mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=mul_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_F32
+
+add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(f32[] lhs, f32[] rhs)
+}
+
+ENTRY main {
+ operand = f32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = f32[2,3] parameter(2)
+ ROOT scatter = f32[3,3] scatter(operand, indices, updates),
+ to_apply=add_f32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({2, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterMultipleBatchDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule DynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ZeroDimBounds) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_ZeroDimBounds
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,0] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,0] parameter(2)
+ ROOT scatter = s32[3,0] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
+ const string hlo_text = R"(
+HloModule Scatter_NoUpdateWindowDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[2,2,1] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NegativeIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OneScalarIndex) {
+ const char* hlo_text = R"(
+HloModule OneScalarIndex
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[2,3,2]{2,1,0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[1,3,2]{2,1,0} parameter(2)
+ ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ScalarUpdate) {
+ const char* hlo_text = R"(
+HloModule ScalarUpdate
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[4]{0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[] parameter(2)
+ ROOT scatter = s32[4]{0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, EmptyIndices) {
+ const string hlo_text = R"(
+HloModule EmptyIndices
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[0] parameter(1)
+ updates = s32[0] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 2647937013..faeec657b6 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -208,16 +208,12 @@ bool NeedsInitValue(const HloUse& use) {
// Generate random values that are constrained to the input_shape minus the
// output_shape so as not to produce wrapping slices, for instance.
-std::unique_ptr<Literal> MakeRandomNonwrappingSliceIndex(
- const Shape& input_shape, const Shape& slice_shape,
- std::minstd_rand0* engine) {
- const int64 rank = ShapeUtil::Rank(input_shape);
- std::vector<int32> start_indices(rank);
+std::unique_ptr<Literal> MakeRandomIndex(
+ tensorflow::gtl::ArraySlice<int64> index_space, std::minstd_rand0* engine) {
+ std::vector<int32> start_indices(index_space.size());
if (engine != nullptr) {
- for (int i = 0; i < rank; ++i) {
- const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) -
- ShapeUtil::GetDimension(slice_shape, i);
- std::uniform_int_distribution<int32> generator(0, upper_bound);
+ for (int i = 0; i < index_space.size(); ++i) {
+ std::uniform_int_distribution<int32> generator(0, index_space[i]);
start_indices[i] = generator(*engine);
}
}
@@ -267,37 +263,42 @@ std::vector<HloInstruction*> FindConstrainedUses(
StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
const tensorflow::gtl::ArraySlice<HloInstruction*> constrained_uses,
const HloInstruction& param, std::minstd_rand0* engine) {
- HloInstruction* needs_index = nullptr;
- HloInstruction* needs_constant = nullptr;
+ std::vector<int64> index_space;
+ bool needs_constant = false;
ConstantType constant_type = ConstantType::kUnknown;
for (HloInstruction* use : constrained_uses) {
switch (use->opcode()) {
case HloOpcode::kDynamicSlice:
- case HloOpcode::kDynamicUpdateSlice:
- if (needs_index != nullptr) {
- auto needs_index_shape = needs_index->shape();
- auto use_shape = use->shape();
- if (needs_index->opcode() == HloOpcode::kDynamicSlice) {
- needs_index_shape = needs_index->operand(0)->shape();
- }
- if (use->opcode() == HloOpcode::kDynamicSlice) {
- use_shape = use->operand(0)->shape();
+ case HloOpcode::kDynamicUpdateSlice: {
+ const Shape& indexed_shape = use->operand(0)->shape();
+ const Shape& slice_shape = use->opcode() == HloOpcode::kDynamicSlice
+ ? use->shape()
+ : use->operand(1)->shape();
+ const int64 rank = ShapeUtil::Rank(indexed_shape);
+ if (!index_space.empty()) {
+ TF_RET_CHECK(rank == index_space.size());
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = std::min(
+ index_space[i], ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i));
}
- if (!ShapeUtil::Equal(needs_index_shape, use_shape)) {
- return Unimplemented(
- "Conflicting operand generation slice index constraints\n");
+ } else {
+ index_space.resize(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ index_space[i] = ShapeUtil::GetDimension(indexed_shape, i) -
+ ShapeUtil::GetDimension(slice_shape, i);
}
}
- needs_index = use;
break;
+ }
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->to_apply());
break;
case HloOpcode::kSelectAndScatter:
- needs_constant = use;
+ needs_constant = true;
constant_type = GetInitValue(*use->scatter());
break;
@@ -307,16 +308,14 @@ StatusOr<std::unique_ptr<Literal>> CreateLiteralForConstrainedUses(
use->ToString().c_str());
}
}
- if (needs_index != nullptr && needs_constant != nullptr) {
+ if (!index_space.empty() && needs_constant) {
return Unimplemented(
- "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds "
- "constant: %s\n",
- needs_index->ToString().c_str(), needs_constant->ToString().c_str());
+ "Conflicting operand generation constraints. Dynamically indexes a "
+ "shape and is the init value of a reduction.");
}
- if (needs_index != nullptr) {
- return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(),
- needs_index->shape(), engine);
- } else if (needs_constant != nullptr) {
+ if (!index_space.empty()) {
+ return MakeRandomIndex(index_space, engine);
+ } else if (needs_constant) {
switch (constant_type) {
case ConstantType::kZero:
return LiteralUtil::Zero(param.shape().element_type()).CloneToUnique();
@@ -356,8 +355,8 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
auto engine = pseudo_random ? MakeUnique<std::minstd_rand0>() : nullptr;
std::vector<std::unique_ptr<Literal>> arguments(params.size());
for (int i = 0; i < params.size(); ++i) {
- TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument(
- *dataflow, *params[i], engine.get()));
+ arguments[i] = MakeConstrainedArgument(*dataflow, *params[i], engine.get())
+ .ValueOrDie();
}
return std::move(arguments);
}
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index a2f0338e25..64d9e2031e 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -72,5 +72,60 @@ XLA_TEST_F(TestUtilsTest, Token) {
TF_ASSERT_OK(MakeFakeArguments(module.get()).status());
}
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ dynamic-slice.1 = f32[1,2,3] dynamic-slice(array_param.1, index_param), dynamic_slice_sizes={1,2,3}
+ ROOT dynamic-slice.2 = f32[3,2,2] dynamic-slice(array_param.2, index_param), dynamic_slice_sizes={3,2,2}
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 3);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
+XLA_TEST_F(TestUtilsTest, MultipleIndexSpacesForDynamicUpdateSlices) {
+ auto module = ParseHloString(
+ R"(HloModule index_space_module
+
+ ENTRY IndexSpace {
+ index_param = s32[3]{0} parameter(0)
+ array_param.1 = f32[123,4,789]{0,1,2} parameter(1)
+ array_param.2 = f32[3,3000,5]{0,1,2} parameter(2)
+ update_param.1 = f32[1,2,3]{0,1,2} parameter(3)
+ update_param.2 = f32[3,2,2]{0,1,2} parameter(4)
+
+ dynamic-update-slice.1 = f32[123,4,789] dynamic-update-slice(array_param.1, update_param.1, index_param)
+ ROOT dynamic-update-slice.2 = f32[3,3000,5] dynamic-update-slice(array_param.2, update_param.2, index_param)
+ })")
+ .ValueOrDie();
+ TF_ASSERT_OK_AND_ASSIGN(std::vector<std::unique_ptr<Literal>> args,
+ MakeFakeArguments(module.get()));
+ ASSERT_EQ(args.size(), 5);
+ const Literal& index_arg = *args[0];
+
+ EXPECT_EQ(index_arg.Get<int32>({0}), 0);
+
+ EXPECT_GE(index_arg.Get<int32>({1}), 0);
+ EXPECT_LE(index_arg.Get<int32>({1}), 2);
+
+ EXPECT_GE(index_arg.Get<int32>({2}), 0);
+ EXPECT_LE(index_arg.Get<int32>({2}), 3);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
index 10c0adc670..3b72eb17c6 100644
--- a/tensorflow/compiler/xla/xla.proto
+++ b/tensorflow/compiler/xla/xla.proto
@@ -104,15 +104,6 @@ message DebugOptions {
// interpretation of this value is left to the backends.
int32 xla_backend_optimization_level = 31;
- // When true, "unsafe" mathematical optimizations are enabled. These
- // transformations include but are not limited to:
- //
- // - Reducing the precision of operations (e.g. using an approximate sin
- // function, or transforming x/y into x * (1/y)).
- // - Assuming that operations never produce or consume NaN or +/- Inf.
- // - Assuming that +0 and -0 are indistinguishable.
- bool xla_enable_fast_math = 32;
-
// Embed the compiler IR as a string in the executable.
bool xla_embed_ir_in_executable = 33;
@@ -194,6 +185,16 @@ message DebugOptions {
// Maximum kernel unroll factor for the GPU backend.
int32 xla_gpu_max_kernel_unroll_factor = 98;
+ // When true, "unsafe" mathematical optimizations are enabled. These
+ // transformations include but are not limited to:
+ //
+ // - Reducing the precision of operations (e.g. using an approximate sin
+ // function, or transforming x/y into x * (1/y)).
+ // - Assuming that operations never produce or consume NaN or +/- Inf.
+ // - Assuming that +0 and -0 are indistinguishable.
+ bool xla_cpu_enable_fast_math = 99;
+ bool xla_gpu_enable_fast_math = 100;
+
// Extra options to pass to the compilation backend; specific interpretation
// of these values is left to the backend.
map<string, string> xla_backend_extra_options = 500;