aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc10
-rw-r--r--tensorflow/cc/framework/scope.cc33
-rw-r--r--tensorflow/cc/framework/scope.h4
-rw-r--r--tensorflow/cc/framework/scope_internal.h5
-rw-r--r--tensorflow/compiler/jit/BUILD4
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc180
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc32
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc2
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD7
-rw-r--r--tensorflow/core/graph/node_builder.cc7
-rw-r--r--tensorflow/core/graph/node_builder.h4
11 files changed, 174 insertions, 114 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index a32d1b1eb5..39593370d1 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
}
}
- strings::StrAppend(&class_decl, "\n");
-
- if (output_types.empty()) {
- strings::StrAppend(&class_decl, " Operation operation;\n");
- }
+ strings::StrAppend(&class_decl, "\n Operation operation;\n");
for (int i = 0; i < output_types.size(); ++i) {
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
";\n");
@@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const {
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
+ strings::StrAppend(out, " this->operation = Operation(ret);\n");
+
// No outputs.
if (graph_op_def.output_arg_size() == 0) {
- strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
+ strings::StrAppend(out, " return;\n");
return;
}
if (graph_op_def.output_arg_size() == 1) {
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7f6ac4cae7..6abc9e268e 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_(),
- disable_shape_inference_(false) {}
+ disable_shape_inference_(refiner_ == nullptr) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
@@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
+ const string& assigned_device)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(assigned_device),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
@@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
+ if (!impl()->assigned_device_.empty()) {
+ builder->AssignedDevice(impl()->assigned_device_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const {
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
+Scope Scope::WithAssignedDevice(const string& assigned_device) const {
+ return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index 30c32bd44b..e307d8989b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -133,6 +133,10 @@ class Scope {
/// the device field set to 'device'.
Scope WithDevice(const string& device) const;
+ /// Returns a new scope. All ops created within the returned scope will have
+ /// their assigned device set to `assigned_device`.
+ Scope WithAssignedDevice(const string& assigned_device) const;
+
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 58adaef2e9..514e02e841 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -26,6 +26,8 @@ class ShapeRefiner;
// graph, status, name_map, and refiner.
// This is intended to enable the C API (which are used by other language
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
+//
+// Shape inference is disabled if `refiner` is nullptr.
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
class Scope::Impl {
@@ -58,6 +60,7 @@ class Scope::Impl {
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
+ enum class AssignedDevice;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@@ -74,6 +77,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
+ Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
@@ -107,6 +111,7 @@ class Scope::Impl {
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
+ const string assigned_device_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 29b60d1dbe..f20270931f 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -385,12 +385,16 @@ cc_library(
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 9e3fd93cda..5974696b77 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -31,132 +35,108 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-
-static Status BuildXlaCompileNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaCompile");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+namespace {
+void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
+ // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
+ // the NodeDef inputs to the function call nodes.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
}
-static Status BuildXlaRunNode(const string& nodename, const string& device_name,
- const DataTypeVector& arg_dtypes,
- const DataTypeVector& result_dtypes, Graph* graph,
- Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaRun");
- def.set_device(device_name);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
+struct XlaClusterInfo {
+ std::vector<Output> constant_inputs;
+ std::vector<Output> non_constant_inputs;
+ std::vector<Output> resource_inputs;
+ NameAttrList function;
+};
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+Output IncomingEdgeAsOutput(const Edge* e) {
+ return Output(e->src(), e->src_output());
}
-static Status GetXlaAttrs(Node* node, int* num_constant_args,
- int* num_resource_args, DataTypeVector* const_dtypes,
- DataTypeVector* arg_dtypes) {
+Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
+ int num_constant_inputs, num_resource_inputs;
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+ GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
- if (*num_constant_args < 0 || *num_resource_args < 0 ||
- *num_constant_args + *num_resource_args > node->num_inputs()) {
+ if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
+ num_constant_inputs + num_resource_inputs > n->num_inputs()) {
return errors::InvalidArgument(
"Invalid number of constant/resource arguments to XLA kernel.");
}
- const int num_nonconst_args =
- node->num_inputs() - *num_constant_args - *num_resource_args;
-
- const DataTypeVector& input_types = node->input_types();
- std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
- std::back_inserter(*const_dtypes));
- std::copy(input_types.begin() + *num_constant_args,
- input_types.begin() + *num_constant_args + num_nonconst_args,
- std::back_inserter(*arg_dtypes));
- return Status::OK();
-}
-
-static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
- int prefix_to_ignore) {
- for (const Edge* edge : old_node->in_edges()) {
- if (edge->IsControlEdge()) {
- g->AddControlEdge(edge->src(), new_node);
- } else if (edge->dst_input() >= prefix_to_ignore) {
- g->AddEdge(edge->src(), edge->src_output(), new_node,
- edge->dst_input() - prefix_to_ignore);
- }
- }
-}
+ int num_non_constant_inputs =
+ n->num_inputs() - num_constant_inputs - num_resource_inputs;
-static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
- std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
- old_node->out_edges().end());
- for (const Edge* edge : out_edges) {
- // TODO(sanjoy): This does not update NodeDef inputs.
- g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
- g->RemoveEdge(edge);
- }
-}
+ std::vector<const Edge*> input_edges_vector;
+ TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
+ absl::Span<const Edge*> input_edges(input_edges_vector);
-static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
- int num_constant_args, num_resource_args;
- DataTypeVector const_dtypes;
- DataTypeVector arg_dtypes;
+ absl::c_transform(input_edges.subspan(0, num_constant_inputs),
+ std::back_inserter(result->constant_inputs),
+ IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
- &const_dtypes, &arg_dtypes));
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
+ std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
- Node *compile_node, *run_node;
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
+ num_resource_inputs),
+ std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(BuildXlaCompileNode(
- n->name(), n->type_string(), n->def().attr(), n->requested_device(),
- const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+ result->function.set_name(n->type_string());
+ *result->function.mutable_attr() = n->def().attr();
+ return Status::OK();
+}
- DataTypeVector arg_dtypes_with_resources = arg_dtypes;
- for (int i = 0; i < num_resource_args; i++) {
- arg_dtypes_with_resources.push_back(DT_RESOURCE);
+Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
+ for (const Edge* e : from->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), to);
+ }
}
- TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
- arg_dtypes_with_resources,
- n->output_types(), g, &run_node));
-
- compile_node->set_assigned_device_name(n->assigned_device_name());
- run_node->set_assigned_device_name(n->assigned_device_name());
+ return Status::OK();
+}
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
- /*prefix_to_ignore=*/0);
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
- /*prefix_to_ignore=*/num_constant_args);
+Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) {
+ Status status;
+ Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
+ .NewSubScope(n->name())
+ .WithDevice(n->requested_device())
+ .WithAssignedDevice(n->assigned_device_name());
+
+ XlaClusterInfo cluster_info;
+ TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
+
+ ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
+ /*constants=*/cluster_info.constant_inputs,
+ /*args=*/cluster_info.non_constant_inputs,
+ /*resources=*/cluster_info.resource_inputs,
+ cluster_info.function);
+ TF_RETURN_IF_ERROR(
+ CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
- // The compilation_key output.
- g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+ std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
+ absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+ ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
+ xla_compile.key, n->output_types());
- MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ MoveOutgoingEdges(g, /*old_node=*/n,
+ /*new_node=*/xla_run.operation.node());
g->RemoveNode(n);
return Status::OK();
}
+} // namespace
Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
@@ -170,7 +150,7 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
// Only compile nodes that are marked for compilation by the
// compilation-marking pass (via 'attr_name').
if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n));
}
}
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index b7cb4506b9..9d56db7b6b 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -56,18 +56,26 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
}
Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
- const string& node_name, Node** result) {
+ const string& node_name, int num_constant_args,
+ int num_resource_args, Node** result) {
NodeDef call_node;
call_node.set_name(node_name);
call_node.set_op(callee_name);
AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
- AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
- AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
Status s;
*result = graph->AddNode(call_node, &s);
return s;
}
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ return MakeXlaCompiledKernel(graph, callee_name, node_name,
+ /*num_constant_args=*/0, /*num_resource_args=*/0,
+ result);
+}
+
Node* MakeWrite(const Scope& scope, const string& id) {
Output var_handle =
ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
@@ -108,5 +116,23 @@ TEST(BuildXlaOps, ControlDepsPreserved) {
EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
}
+TEST(BuildXlaOps, CleanFailureOnBogusAttr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(
+ MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ Status failure_status = BuildXlaOps(root, &graph);
+ ASSERT_FALSE(failure_status.ok());
+ EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index 479038ac8e..22531a4ace 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..adcdb6c8f7 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -30,14 +30,15 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
- out_ops_file = "ops/xla_jit_op",
+ include_internal_ops = 1,
+ out_ops_file = "ops/xla_jit_ops",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
name = "xla_jit_ops",
- srcs = ["ops/xla_jit_op.cc"],
- hdrs = ["ops/xla_jit_op.h"],
+ srcs = ["ops/xla_jit_ops.cc"],
+ hdrs = ["ops/xla_jit_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index a446e0d136..d92874909f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
return *this;
}
+NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
+ assigned_device_ = string(device);
+ return *this;
+}
+
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
@@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Node* node = graph->AddNode(node_def, &status);
if (!status.ok()) return status;
+ node->set_assigned_device_name(assigned_device_);
+
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs_[i].node != nullptr) { // Skip back edges.
graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 4727ee7b56..d576985a23 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -100,6 +100,9 @@ class NodeBuilder {
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
+ // Sets the device name in the "assigned device" field in tensorflow::Node.
+ NodeBuilder& AssignedDevice(StringPiece device);
+
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
@@ -141,6 +144,7 @@ class NodeBuilder {
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
+ string assigned_device_;
};
// IMPLEMENTATION -------------------------------------------------------------