diff options
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 10 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 33 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.h | 4 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope_internal.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/build_xla_ops_pass.cc | 180 | ||||
-rw-r--r-- | tensorflow/compiler/jit/build_xla_ops_pass_test.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/cc/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/graph/node_builder.h | 4 |
11 files changed, 174 insertions, 114 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a32d1b1eb5..39593370d1 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { } } - strings::StrAppend(&class_decl, "\n"); - - if (output_types.empty()) { - strings::StrAppend(&class_decl, " Operation operation;\n"); - } + strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); @@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const { string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); + strings::StrAppend(out, " this->operation = Operation(ret);\n"); + // No outputs. if (graph_op_def.output_arg_size() == 0) { - strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); + strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 7f6ac4cae7..6abc9e268e 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, refiner_(refiner), scope_used_(nullptr), colocation_constraints_(), - disable_shape_inference_(false) {} + disable_shape_inference_(refiner_ == nullptr) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); @@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_( clear_colocations ? std::unordered_set<string>() : other.impl()->GetColocationConstraints(colocate_with_op)), disable_shape_inference_(other.impl()->disable_shape_inference_) {} +Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, + const string& assigned_device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(assigned_device), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + std::unordered_set<string> Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set<string> current_constraints(colocation_constraints_); @@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->device_.empty()) { builder->Device(impl()->device_); } + if (!impl()->assigned_device_.empty()) { + builder->AssignedDevice(impl()->assigned_device_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const { return Scope(new Impl(*this, Impl::Tags::Device(), device)); } +Scope Scope::WithAssignedDevice(const string& assigned_device) const { + return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 30c32bd44b..e307d8989b 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -133,6 +133,10 @@ class Scope { /// the device field set to 'device'. Scope WithDevice(const string& device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 58adaef2e9..514e02e841 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -26,6 +26,8 @@ class ShapeRefiner; // graph, status, name_map, and refiner. // This is intended to enable the C API (which are used by other language // bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); class Scope::Impl { @@ -58,6 +60,7 @@ class Scope::Impl { enum class ExitOnError; enum class KernelLabel; enum class Colocate; + enum class AssignedDevice; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -74,6 +77,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); std::unordered_set<string> GetColocationConstraints( const Operation& colocate_with_op) const; @@ -107,6 +111,7 @@ class Scope::Impl { const bool exit_on_error_ = false; const string kernel_label_ = ""; const string device_ = ""; + const string assigned_device_ = ""; const std::unordered_set<string> colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 29b60d1dbe..f20270931f 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -385,12 +385,16 @@ cc_library( ":shape_inference_helpers", ":union_find", ":xla_cluster_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope_internal", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 9e3fd93cda..5974696b77 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -14,8 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/build_xla_ops_pass.h" +#include "absl/algorithm/container.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -31,132 +35,108 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - -static Status BuildXlaCompileNode( - const string& nodename, const string& function_name, - const AttrValueMap& function_attr, const string& device_name, - const DataTypeVector& constant_dtypes, int num_resources, - const DataTypeVector& arg_dtypes, Graph* graph, Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("_XlaCompile"); - def.set_device(device_name); - AddNodeAttr("Tconstants", constant_dtypes, &def); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Nresources", num_resources, &def); - NameAttrList function; - function.set_name(function_name); - *function.mutable_attr() = function_attr; - AddNodeAttr("function", function, &def); - - Status status; - *node = graph->AddNode(def, &status); - return status; +namespace { +void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { + std::vector<const Edge*> out_edges(old_node->out_edges().begin(), + old_node->out_edges().end()); + for (const Edge* edge : out_edges) { + // TODO(sanjoy): This does not update NodeDef inputs. To be able to update + // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up + // the NodeDef inputs to the function call nodes. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); + g->RemoveEdge(edge); + } } -static Status BuildXlaRunNode(const string& nodename, const string& device_name, - const DataTypeVector& arg_dtypes, - const DataTypeVector& result_dtypes, Graph* graph, - Node** node) { - NodeDef def; - def.set_name(graph->NewName(nodename)); - def.set_op("_XlaRun"); - def.set_device(device_name); - AddNodeAttr("Targs", arg_dtypes, &def); - AddNodeAttr("Tresults", result_dtypes, &def); +struct XlaClusterInfo { + std::vector<Output> constant_inputs; + std::vector<Output> non_constant_inputs; + std::vector<Output> resource_inputs; + NameAttrList function; +}; - Status status; - *node = graph->AddNode(def, &status); - return status; +Output IncomingEdgeAsOutput(const Edge* e) { + return Output(e->src(), e->src_output()); } -static Status GetXlaAttrs(Node* node, int* num_constant_args, - int* num_resource_args, DataTypeVector* const_dtypes, - DataTypeVector* arg_dtypes) { +Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { + int num_constant_inputs, num_resource_inputs; TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args)); + GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); TF_RETURN_IF_ERROR( - GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args)); + GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs)); - if (*num_constant_args < 0 || *num_resource_args < 0 || - *num_constant_args + *num_resource_args > node->num_inputs()) { + if (num_constant_inputs < 0 || num_resource_inputs < 0 || + num_constant_inputs + num_resource_inputs > n->num_inputs()) { return errors::InvalidArgument( "Invalid number of constant/resource arguments to XLA kernel."); } - const int num_nonconst_args = - node->num_inputs() - *num_constant_args - *num_resource_args; - - const DataTypeVector& input_types = node->input_types(); - std::copy(input_types.begin(), input_types.begin() + *num_constant_args, - std::back_inserter(*const_dtypes)); - std::copy(input_types.begin() + *num_constant_args, - input_types.begin() + *num_constant_args + num_nonconst_args, - std::back_inserter(*arg_dtypes)); - return Status::OK(); -} - -static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node, - int prefix_to_ignore) { - for (const Edge* edge : old_node->in_edges()) { - if (edge->IsControlEdge()) { - g->AddControlEdge(edge->src(), new_node); - } else if (edge->dst_input() >= prefix_to_ignore) { - g->AddEdge(edge->src(), edge->src_output(), new_node, - edge->dst_input() - prefix_to_ignore); - } - } -} + int num_non_constant_inputs = + n->num_inputs() - num_constant_inputs - num_resource_inputs; -static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { - std::vector<const Edge*> out_edges(old_node->out_edges().begin(), - old_node->out_edges().end()); - for (const Edge* edge : out_edges) { - // TODO(sanjoy): This does not update NodeDef inputs. - g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); - g->RemoveEdge(edge); - } -} + std::vector<const Edge*> input_edges_vector; + TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector)); + absl::Span<const Edge*> input_edges(input_edges_vector); -static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) { - int num_constant_args, num_resource_args; - DataTypeVector const_dtypes; - DataTypeVector arg_dtypes; + absl::c_transform(input_edges.subspan(0, num_constant_inputs), + std::back_inserter(result->constant_inputs), + IncomingEdgeAsOutput); - TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args, - &const_dtypes, &arg_dtypes)); + absl::c_transform( + input_edges.subspan(num_constant_inputs, num_non_constant_inputs), + std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput); - Node *compile_node, *run_node; + absl::c_transform( + input_edges.subspan(num_constant_inputs + num_non_constant_inputs, + num_resource_inputs), + std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput); - TF_RETURN_IF_ERROR(BuildXlaCompileNode( - n->name(), n->type_string(), n->def().attr(), n->requested_device(), - const_dtypes, num_resource_args, arg_dtypes, g, &compile_node)); + result->function.set_name(n->type_string()); + *result->function.mutable_attr() = n->def().attr(); + return Status::OK(); +} - DataTypeVector arg_dtypes_with_resources = arg_dtypes; - for (int i = 0; i < num_resource_args; i++) { - arg_dtypes_with_resources.push_back(DT_RESOURCE); +Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { + for (const Edge* e : from->in_edges()) { + if (e->IsControlEdge()) { + g->AddControlEdge(e->src(), to); + } } - TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(), - arg_dtypes_with_resources, - n->output_types(), g, &run_node)); - - compile_node->set_assigned_device_name(n->assigned_device_name()); - run_node->set_assigned_device_name(n->assigned_device_name()); + return Status::OK(); +} - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node, - /*prefix_to_ignore=*/0); - CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node, - /*prefix_to_ignore=*/num_constant_args); +Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) { + Status status; + Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) + .NewSubScope(n->name()) + .WithDevice(n->requested_device()) + .WithAssignedDevice(n->assigned_device_name()); + + XlaClusterInfo cluster_info; + TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info)); + + ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"), + /*constants=*/cluster_info.constant_inputs, + /*args=*/cluster_info.non_constant_inputs, + /*resources=*/cluster_info.resource_inputs, + cluster_info.function); + TF_RETURN_IF_ERROR( + CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node())); - // The compilation_key output. - g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args); + std::vector<Output> xla_run_args = cluster_info.non_constant_inputs; + absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args)); + ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args, + xla_compile.key, n->output_types()); - MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node); + MoveOutgoingEdges(g, /*old_node=*/n, + /*new_node=*/xla_run.operation.node()); g->RemoveNode(n); return Status::OK(); } +} // namespace Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); @@ -170,7 +150,7 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { // Only compile nodes that are marked for compilation by the // compilation-marking pass (via 'attr_name'). if (IsXlaCompiledKernel(*n)) { - TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n)); + TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n)); } } diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index b7cb4506b9..9d56db7b6b 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -56,18 +56,26 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) { } Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, Node** result) { + const string& node_name, int num_constant_args, + int num_resource_args, Node** result) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); - AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node); - AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node); Status s; *result = graph->AddNode(call_node, &s); return s; } +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + return MakeXlaCompiledKernel(graph, callee_name, node_name, + /*num_constant_args=*/0, /*num_resource_args=*/0, + result); +} + Node* MakeWrite(const Scope& scope, const string& id) { Output var_handle = ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); @@ -108,5 +116,23 @@ TEST(BuildXlaOps, ControlDepsPreserved) { EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); } +TEST(BuildXlaOps, CleanFailureOnBogusAttr) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK( + MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr<Graph> graph; + Status failure_status = BuildXlaOps(root, &graph); + ASSERT_FALSE(failure_status.ok()); + EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 479038ac8e..22531a4ace 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h" +#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/graph/graph_constructor.h" diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD index ea8d1b3d14..adcdb6c8f7 100644 --- a/tensorflow/compiler/tf2xla/cc/BUILD +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -30,14 +30,15 @@ cc_library( tf_gen_op_wrapper_cc( name = "xla_jit_op_gen", - out_ops_file = "ops/xla_jit_op", + include_internal_ops = 1, + out_ops_file = "ops/xla_jit_ops", deps = ["//tensorflow/compiler/jit/ops:xla_ops"], ) cc_library( name = "xla_jit_ops", - srcs = ["ops/xla_jit_op.cc"], - hdrs = ["ops/xla_jit_op.h"], + srcs = ["ops/xla_jit_ops.cc"], + hdrs = ["ops/xla_jit_ops.h"], deps = [ "//tensorflow/cc:const_op", "//tensorflow/cc:ops", diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index a446e0d136..d92874909f 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { return *this; } +NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) { + assigned_device_ = string(device); + return *this; +} + Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { // In case of error, set *created_node to nullptr. if (created_node != nullptr) *created_node = nullptr; @@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { Node* node = graph->AddNode(node_def, &status); if (!status.ok()) return status; + node->set_assigned_device_name(assigned_device_); + for (size_t i = 0; i < inputs_.size(); ++i) { if (inputs_[i].node != nullptr) { // Skip back edges. graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 4727ee7b56..d576985a23 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -100,6 +100,9 @@ class NodeBuilder { // "assigned device" in the Node). NodeBuilder& Device(StringPiece device_spec); + // Sets the device name in the "assigned device" field in tensorflow::Node. + NodeBuilder& AssignedDevice(StringPiece device); + // Set the value of an attr. attr_name must match the name of one of // attrs defined by the Op, and value must have the corresponding type // (see SetAttrValue() in ../framework/attr_value_util.h for legal @@ -141,6 +144,7 @@ class NodeBuilder { std::vector<NodeOut> inputs_; std::vector<Node*> control_inputs_; std::vector<string> errors_; + string assigned_device_; }; // IMPLEMENTATION ------------------------------------------------------------- |