diff options
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/framework/cc_op_gen.cc | 10 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 33 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope.h | 4 | ||||
-rw-r--r-- | tensorflow/cc/framework/scope_internal.h | 5 |
4 files changed, 45 insertions, 7 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index a32d1b1eb5..39593370d1 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { } } - strings::StrAppend(&class_decl, "\n"); - - if (output_types.empty()) { - strings::StrAppend(&class_decl, " Operation operation;\n"); - } + strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); @@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const { string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); + strings::StrAppend(out, " this->operation = Operation(ret);\n"); + // No outputs. if (graph_op_def.output_arg_size() == 0) { - strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); + strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 7f6ac4cae7..6abc9e268e 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, refiner_(refiner), scope_used_(nullptr), colocation_constraints_(), - disable_shape_inference_(false) {} + disable_shape_inference_(refiner_ == nullptr) {} Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); @@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError) exit_on_error_(true), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(kernel_label), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_(other.impl()->colocation_constraints_), disable_shape_inference_(other.impl()->disable_shape_inference_) {} @@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate, exit_on_error_(other.impl()->exit_on_error_), kernel_label_(other.impl()->kernel_label_), device_(other.impl()->device_), + assigned_device_(other.impl()->assigned_device_), colocation_constraints_( clear_colocations ? std::unordered_set<string>() : other.impl()->GetColocationConstraints(colocate_with_op)), disable_shape_inference_(other.impl()->disable_shape_inference_) {} +Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice, + const string& assigned_device) + : graph_(other.impl()->graph_), + status_(other.impl()->status_), + name_map_(other.impl()->name_map_), + refiner_(other.impl()->refiner_), + scope_used_(other.impl()->scope_used_), + control_deps_(other.impl()->control_deps_), + name_(other.impl()->name_), + op_name_(other.impl()->op_name_), + exit_on_error_(other.impl()->exit_on_error_), + kernel_label_(other.impl()->kernel_label_), + device_(other.impl()->device_), + assigned_device_(assigned_device), + colocation_constraints_(other.impl()->colocation_constraints_), + disable_shape_inference_(other.impl()->disable_shape_inference_) {} + std::unordered_set<string> Scope::Impl::GetColocationConstraints( const Operation& colocate_with_op) const { std::unordered_set<string> current_constraints(colocation_constraints_); @@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const { if (!impl()->device_.empty()) { builder->Device(impl()->device_); } + if (!impl()->assigned_device_.empty()) { + builder->AssignedDevice(impl()->assigned_device_); + } } string Scope::Impl::GetUniqueName(const string& prefix, @@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const { return Scope(new Impl(*this, Impl::Tags::Device(), device)); } +Scope Scope::WithAssignedDevice(const string& assigned_device) const { + return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device)); +} + Scope Scope::ColocateWith(const Operation& op) const { return Scope(new Impl(*this, Impl::Tags::Colocate(), op, /* clear_colocations */ false)); diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 30c32bd44b..e307d8989b 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -133,6 +133,10 @@ class Scope { /// the device field set to 'device'. Scope WithDevice(const string& device) const; + /// Returns a new scope. All ops created within the returned scope will have + /// their assigned device set to `assigned_device`. + Scope WithAssignedDevice(const string& assigned_device) const; + /// Return a new scope. All ops created within the returned scope will be /// co-located on the device where op is placed. /// NOTE: This function is intended to be use internal libraries only for diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 58adaef2e9..514e02e841 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -26,6 +26,8 @@ class ShapeRefiner; // graph, status, name_map, and refiner. // This is intended to enable the C API (which are used by other language // bindings) to create a Scope and access C++ functionality (i.e. gradients). +// +// Shape inference is disabled if `refiner` is nullptr. Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner); class Scope::Impl { @@ -58,6 +60,7 @@ class Scope::Impl { enum class ExitOnError; enum class KernelLabel; enum class Colocate; + enum class AssignedDevice; }; Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner, @@ -74,6 +77,7 @@ class Scope::Impl { Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label); Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op, bool clear_colocations); + Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device); std::unordered_set<string> GetColocationConstraints( const Operation& colocate_with_op) const; @@ -107,6 +111,7 @@ class Scope::Impl { const bool exit_on_error_ = false; const string kernel_label_ = ""; const string device_ = ""; + const string assigned_device_ = ""; const std::unordered_set<string> colocation_constraints_; // If true, Scope::DoShapeInference() always returns Status:OK(). |