aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-10-01 15:41:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 15:52:53 -0700
commitdc4ac1b84c9c74655f04254779516f9968a5c385 (patch)
treefdf43c0b81c93e5b5116d0087af46087e2cea67a /tensorflow/compiler
parent55d96e8ea93407da156c156702a38fd8b5d06b2a (diff)
Clean up the build_xla_ops to use the generated C++ TF op wrappers.
This cleanup will make the future CL implementing lazy compilation simpler. Includes some supporting changes: - Teach NewInternalScope to create a scope that doesn't do shape inference. We need this because we don't have a ShapeRefiner that has been run over the entire graph available in the build_xla_ops pass. - Add a WithAssignedDevice modifier to tensorflow::Scope. - Make cc_op_gen write out an Operation field for nodes which may not necessarily have any outputs. We already did this in most cases, but we weren't doing it for nodes that have possibly-empty list outputs. - Minor change renaming ops/xla_jit_op.cc to ops/xla_jit_ops.cc, now that we have more than one XLA JIT op. PiperOrigin-RevId: 215293817
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc180
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc32
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD7
5 files changed, 118 insertions, 107 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 29b60d1dbe..f20270931f 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -385,12 +385,16 @@ cc_library(
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 9e3fd93cda..5974696b77 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -31,132 +35,108 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-
-static Status BuildXlaCompileNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaCompile");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+namespace {
+void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
+ // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
+ // the NodeDef inputs to the function call nodes.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
}
-static Status BuildXlaRunNode(const string& nodename, const string& device_name,
- const DataTypeVector& arg_dtypes,
- const DataTypeVector& result_dtypes, Graph* graph,
- Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaRun");
- def.set_device(device_name);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
+struct XlaClusterInfo {
+ std::vector<Output> constant_inputs;
+ std::vector<Output> non_constant_inputs;
+ std::vector<Output> resource_inputs;
+ NameAttrList function;
+};
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+Output IncomingEdgeAsOutput(const Edge* e) {
+ return Output(e->src(), e->src_output());
}
-static Status GetXlaAttrs(Node* node, int* num_constant_args,
- int* num_resource_args, DataTypeVector* const_dtypes,
- DataTypeVector* arg_dtypes) {
+Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
+ int num_constant_inputs, num_resource_inputs;
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+ GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
- if (*num_constant_args < 0 || *num_resource_args < 0 ||
- *num_constant_args + *num_resource_args > node->num_inputs()) {
+ if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
+ num_constant_inputs + num_resource_inputs > n->num_inputs()) {
return errors::InvalidArgument(
"Invalid number of constant/resource arguments to XLA kernel.");
}
- const int num_nonconst_args =
- node->num_inputs() - *num_constant_args - *num_resource_args;
-
- const DataTypeVector& input_types = node->input_types();
- std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
- std::back_inserter(*const_dtypes));
- std::copy(input_types.begin() + *num_constant_args,
- input_types.begin() + *num_constant_args + num_nonconst_args,
- std::back_inserter(*arg_dtypes));
- return Status::OK();
-}
-
-static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
- int prefix_to_ignore) {
- for (const Edge* edge : old_node->in_edges()) {
- if (edge->IsControlEdge()) {
- g->AddControlEdge(edge->src(), new_node);
- } else if (edge->dst_input() >= prefix_to_ignore) {
- g->AddEdge(edge->src(), edge->src_output(), new_node,
- edge->dst_input() - prefix_to_ignore);
- }
- }
-}
+ int num_non_constant_inputs =
+ n->num_inputs() - num_constant_inputs - num_resource_inputs;
-static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
- std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
- old_node->out_edges().end());
- for (const Edge* edge : out_edges) {
- // TODO(sanjoy): This does not update NodeDef inputs.
- g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
- g->RemoveEdge(edge);
- }
-}
+ std::vector<const Edge*> input_edges_vector;
+ TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
+ absl::Span<const Edge*> input_edges(input_edges_vector);
-static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
- int num_constant_args, num_resource_args;
- DataTypeVector const_dtypes;
- DataTypeVector arg_dtypes;
+ absl::c_transform(input_edges.subspan(0, num_constant_inputs),
+ std::back_inserter(result->constant_inputs),
+ IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
- &const_dtypes, &arg_dtypes));
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
+ std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
- Node *compile_node, *run_node;
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
+ num_resource_inputs),
+ std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(BuildXlaCompileNode(
- n->name(), n->type_string(), n->def().attr(), n->requested_device(),
- const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+ result->function.set_name(n->type_string());
+ *result->function.mutable_attr() = n->def().attr();
+ return Status::OK();
+}
- DataTypeVector arg_dtypes_with_resources = arg_dtypes;
- for (int i = 0; i < num_resource_args; i++) {
- arg_dtypes_with_resources.push_back(DT_RESOURCE);
+Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
+ for (const Edge* e : from->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), to);
+ }
}
- TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
- arg_dtypes_with_resources,
- n->output_types(), g, &run_node));
-
- compile_node->set_assigned_device_name(n->assigned_device_name());
- run_node->set_assigned_device_name(n->assigned_device_name());
+ return Status::OK();
+}
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
- /*prefix_to_ignore=*/0);
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
- /*prefix_to_ignore=*/num_constant_args);
+Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) {
+ Status status;
+ Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
+ .NewSubScope(n->name())
+ .WithDevice(n->requested_device())
+ .WithAssignedDevice(n->assigned_device_name());
+
+ XlaClusterInfo cluster_info;
+ TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
+
+ ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
+ /*constants=*/cluster_info.constant_inputs,
+ /*args=*/cluster_info.non_constant_inputs,
+ /*resources=*/cluster_info.resource_inputs,
+ cluster_info.function);
+ TF_RETURN_IF_ERROR(
+ CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
- // The compilation_key output.
- g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+ std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
+ absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+ ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
+ xla_compile.key, n->output_types());
- MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ MoveOutgoingEdges(g, /*old_node=*/n,
+ /*new_node=*/xla_run.operation.node());
g->RemoveNode(n);
return Status::OK();
}
+} // namespace
Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
@@ -170,7 +150,7 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
// Only compile nodes that are marked for compilation by the
// compilation-marking pass (via 'attr_name').
if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n));
}
}
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index b7cb4506b9..9d56db7b6b 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -56,18 +56,26 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
}
Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
- const string& node_name, Node** result) {
+ const string& node_name, int num_constant_args,
+ int num_resource_args, Node** result) {
NodeDef call_node;
call_node.set_name(node_name);
call_node.set_op(callee_name);
AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
- AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
- AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
Status s;
*result = graph->AddNode(call_node, &s);
return s;
}
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ return MakeXlaCompiledKernel(graph, callee_name, node_name,
+ /*num_constant_args=*/0, /*num_resource_args=*/0,
+ result);
+}
+
Node* MakeWrite(const Scope& scope, const string& id) {
Output var_handle =
ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
@@ -108,5 +116,23 @@ TEST(BuildXlaOps, ControlDepsPreserved) {
EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
}
+TEST(BuildXlaOps, CleanFailureOnBogusAttr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(
+ MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ Status failure_status = BuildXlaOps(root, &graph);
+ ASSERT_FALSE(failure_status.ok());
+ EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index 479038ac8e..22531a4ace 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..adcdb6c8f7 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -30,14 +30,15 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
- out_ops_file = "ops/xla_jit_op",
+ include_internal_ops = 1,
+ out_ops_file = "ops/xla_jit_ops",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
name = "xla_jit_ops",
- srcs = ["ops/xla_jit_op.cc"],
- hdrs = ["ops/xla_jit_op.h"],
+ srcs = ["ops/xla_jit_ops.cc"],
+ hdrs = ["ops/xla_jit_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",