aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-10-01 15:41:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 15:52:53 -0700
commitdc4ac1b84c9c74655f04254779516f9968a5c385 (patch)
treefdf43c0b81c93e5b5116d0087af46087e2cea67a
parent55d96e8ea93407da156c156702a38fd8b5d06b2a (diff)
Clean up the build_xla_ops to use the generated C++ TF op wrappers.
This cleanup will make the future CL implementing lazy compilation simpler. Includes some supporting changes: - Teach NewInternalScope to create a scope that doesn't do shape inference. We need this because we don't have a ShapeRefiner that has been run over the entire graph available in the build_xla_ops pass. - Add a WithAssignedDevice modifier to tensorflow::Scope. - Make cc_op_gen write out an Operation field for nodes which may not necessarily have any outputs. We already did this in most cases, but we weren't doing it for nodes that have possibly-empty list outputs. - Minor change renaming ops/xla_jit_op.cc to ops/xla_jit_ops.cc, now that we have more than one XLA JIT op. PiperOrigin-RevId: 215293817
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc10
-rw-r--r--tensorflow/cc/framework/scope.cc33
-rw-r--r--tensorflow/cc/framework/scope.h4
-rw-r--r--tensorflow/cc/framework/scope_internal.h5
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc180
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc32
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD7
-rw-r--r--tensorflow/core/graph/node_builder.cc7
-rw-r--r--tensorflow/core/graph/node_builder.h4
11 files changed, 174 insertions, 114 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index a32d1b1eb5..39593370d1 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
}
}
- strings::StrAppend(&class_decl, "\n");
-
- if (output_types.empty()) {
- strings::StrAppend(&class_decl, " Operation operation;\n");
- }
+ strings::StrAppend(&class_decl, "\n Operation operation;\n");
for (int i = 0; i < output_types.size(); ++i) {
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
";\n");
@@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const {
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
+ strings::StrAppend(out, " this->operation = Operation(ret);\n");
+
// No outputs.
if (graph_op_def.output_arg_size() == 0) {
- strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
+ strings::StrAppend(out, " return;\n");
return;
}
if (graph_op_def.output_arg_size() == 1) {
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7f6ac4cae7..6abc9e268e 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_(),
- disable_shape_inference_(false) {}
+ disable_shape_inference_(refiner_ == nullptr) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
@@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
+ const string& assigned_device)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(assigned_device),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
@@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
+ if (!impl()->assigned_device_.empty()) {
+ builder->AssignedDevice(impl()->assigned_device_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const {
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
+Scope Scope::WithAssignedDevice(const string& assigned_device) const {
+ return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index 30c32bd44b..e307d8989b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -133,6 +133,10 @@ class Scope {
/// the device field set to 'device'.
Scope WithDevice(const string& device) const;
+ /// Returns a new scope. All ops created within the returned scope will have
+ /// their assigned device set to `assigned_device`.
+ Scope WithAssignedDevice(const string& assigned_device) const;
+
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 58adaef2e9..514e02e841 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -26,6 +26,8 @@ class ShapeRefiner;
// graph, status, name_map, and refiner.
// This is intended to enable the C API (which are used by other language
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
+//
+// Shape inference is disabled if `refiner` is nullptr.
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
class Scope::Impl {
@@ -58,6 +60,7 @@ class Scope::Impl {
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
+ enum class AssignedDevice;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@@ -74,6 +77,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
+ Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
@@ -107,6 +111,7 @@ class Scope::Impl {
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
+ const string assigned_device_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 29b60d1dbe..f20270931f 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -385,12 +385,16 @@ cc_library(
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 9e3fd93cda..5974696b77 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -31,132 +35,108 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-
-static Status BuildXlaCompileNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaCompile");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+namespace {
+void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
+ // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
+ // the NodeDef inputs to the function call nodes.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
}
-static Status BuildXlaRunNode(const string& nodename, const string& device_name,
- const DataTypeVector& arg_dtypes,
- const DataTypeVector& result_dtypes, Graph* graph,
- Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaRun");
- def.set_device(device_name);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
+struct XlaClusterInfo {
+ std::vector<Output> constant_inputs;
+ std::vector<Output> non_constant_inputs;
+ std::vector<Output> resource_inputs;
+ NameAttrList function;
+};
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+Output IncomingEdgeAsOutput(const Edge* e) {
+ return Output(e->src(), e->src_output());
}
-static Status GetXlaAttrs(Node* node, int* num_constant_args,
- int* num_resource_args, DataTypeVector* const_dtypes,
- DataTypeVector* arg_dtypes) {
+Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
+ int num_constant_inputs, num_resource_inputs;
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+ GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
- if (*num_constant_args < 0 || *num_resource_args < 0 ||
- *num_constant_args + *num_resource_args > node->num_inputs()) {
+ if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
+ num_constant_inputs + num_resource_inputs > n->num_inputs()) {
return errors::InvalidArgument(
"Invalid number of constant/resource arguments to XLA kernel.");
}
- const int num_nonconst_args =
- node->num_inputs() - *num_constant_args - *num_resource_args;
-
- const DataTypeVector& input_types = node->input_types();
- std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
- std::back_inserter(*const_dtypes));
- std::copy(input_types.begin() + *num_constant_args,
- input_types.begin() + *num_constant_args + num_nonconst_args,
- std::back_inserter(*arg_dtypes));
- return Status::OK();
-}
-
-static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
- int prefix_to_ignore) {
- for (const Edge* edge : old_node->in_edges()) {
- if (edge->IsControlEdge()) {
- g->AddControlEdge(edge->src(), new_node);
- } else if (edge->dst_input() >= prefix_to_ignore) {
- g->AddEdge(edge->src(), edge->src_output(), new_node,
- edge->dst_input() - prefix_to_ignore);
- }
- }
-}
+ int num_non_constant_inputs =
+ n->num_inputs() - num_constant_inputs - num_resource_inputs;
-static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
- std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
- old_node->out_edges().end());
- for (const Edge* edge : out_edges) {
- // TODO(sanjoy): This does not update NodeDef inputs.
- g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
- g->RemoveEdge(edge);
- }
-}
+ std::vector<const Edge*> input_edges_vector;
+ TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
+ absl::Span<const Edge*> input_edges(input_edges_vector);
-static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
- int num_constant_args, num_resource_args;
- DataTypeVector const_dtypes;
- DataTypeVector arg_dtypes;
+ absl::c_transform(input_edges.subspan(0, num_constant_inputs),
+ std::back_inserter(result->constant_inputs),
+ IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
- &const_dtypes, &arg_dtypes));
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
+ std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
- Node *compile_node, *run_node;
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
+ num_resource_inputs),
+ std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(BuildXlaCompileNode(
- n->name(), n->type_string(), n->def().attr(), n->requested_device(),
- const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+ result->function.set_name(n->type_string());
+ *result->function.mutable_attr() = n->def().attr();
+ return Status::OK();
+}
- DataTypeVector arg_dtypes_with_resources = arg_dtypes;
- for (int i = 0; i < num_resource_args; i++) {
- arg_dtypes_with_resources.push_back(DT_RESOURCE);
+Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
+ for (const Edge* e : from->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), to);
+ }
}
- TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
- arg_dtypes_with_resources,
- n->output_types(), g, &run_node));
-
- compile_node->set_assigned_device_name(n->assigned_device_name());
- run_node->set_assigned_device_name(n->assigned_device_name());
+ return Status::OK();
+}
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
- /*prefix_to_ignore=*/0);
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
- /*prefix_to_ignore=*/num_constant_args);
+Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) {
+ Status status;
+ Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
+ .NewSubScope(n->name())
+ .WithDevice(n->requested_device())
+ .WithAssignedDevice(n->assigned_device_name());
+
+ XlaClusterInfo cluster_info;
+ TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
+
+ ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
+ /*constants=*/cluster_info.constant_inputs,
+ /*args=*/cluster_info.non_constant_inputs,
+ /*resources=*/cluster_info.resource_inputs,
+ cluster_info.function);
+ TF_RETURN_IF_ERROR(
+ CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
- // The compilation_key output.
- g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+ std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
+ absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+ ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
+ xla_compile.key, n->output_types());
- MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ MoveOutgoingEdges(g, /*old_node=*/n,
+ /*new_node=*/xla_run.operation.node());
g->RemoveNode(n);
return Status::OK();
}
+} // namespace
Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
@@ -170,7 +150,7 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
// Only compile nodes that are marked for compilation by the
// compilation-marking pass (via 'attr_name').
if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n));
}
}
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index b7cb4506b9..9d56db7b6b 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -56,18 +56,26 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
}
Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
- const string& node_name, Node** result) {
+ const string& node_name, int num_constant_args,
+ int num_resource_args, Node** result) {
NodeDef call_node;
call_node.set_name(node_name);
call_node.set_op(callee_name);
AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
- AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
- AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
Status s;
*result = graph->AddNode(call_node, &s);
return s;
}
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ return MakeXlaCompiledKernel(graph, callee_name, node_name,
+ /*num_constant_args=*/0, /*num_resource_args=*/0,
+ result);
+}
+
Node* MakeWrite(const Scope& scope, const string& id) {
Output var_handle =
ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
@@ -108,5 +116,23 @@ TEST(BuildXlaOps, ControlDepsPreserved) {
EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
}
+TEST(BuildXlaOps, CleanFailureOnBogusAttr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(
+ MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ Status failure_status = BuildXlaOps(root, &graph);
+ ASSERT_FALSE(failure_status.ok());
+ EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index 479038ac8e..22531a4ace 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..adcdb6c8f7 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -30,14 +30,15 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
- out_ops_file = "ops/xla_jit_op",
+ include_internal_ops = 1,
+ out_ops_file = "ops/xla_jit_ops",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
name = "xla_jit_ops",
- srcs = ["ops/xla_jit_op.cc"],
- hdrs = ["ops/xla_jit_op.h"],
+ srcs = ["ops/xla_jit_ops.cc"],
+ hdrs = ["ops/xla_jit_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index a446e0d136..d92874909f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
return *this;
}
+NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
+ assigned_device_ = string(device);
+ return *this;
+}
+
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
@@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Node* node = graph->AddNode(node_def, &status);
if (!status.ok()) return status;
+ node->set_assigned_device_name(assigned_device_);
+
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs_[i].node != nullptr) { // Skip back edges.
graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 4727ee7b56..d576985a23 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -100,6 +100,9 @@ class NodeBuilder {
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
+ // Sets the device name in the "assigned device" field in tensorflow::Node.
+ NodeBuilder& AssignedDevice(StringPiece device);
+
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
@@ -141,6 +144,7 @@ class NodeBuilder {
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
+ string assigned_device_;
};
// IMPLEMENTATION -------------------------------------------------------------