diff options
author | Manjunath Kudlur <keveman@google.com> | 2016-07-15 14:28:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-15 15:33:32 -0700 |
commit | 25ac3dabfa3af7a313eb46b03690117c85030cc2 (patch) | |
tree | 06010c7cc7d25a538880c5f0d53df079e27093fd /tensorflow/cc/framework/scope.cc | |
parent | 194efde51895e0251d39c72c969dff1a50b67d35 (diff) |
Improvements to the C++ graph building API.
TESTED:
- passed opensource_build: http://ci.tensorflow.org/job/tensorflow-cl-presubmit-multijob/2780/
Change: 127585603
Diffstat (limited to 'tensorflow/cc/framework/scope.cc')
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 347 |
1 files changed, 347 insertions, 0 deletions
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc new file mode 100644 index 0000000000..e828796980 --- /dev/null +++ b/tensorflow/cc/framework/scope.cc @@ -0,0 +1,347 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <algorithm> +#include <vector> + +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { + +Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map) + : graph_(graph), + status_(status), + name_map_(name_map), + scope_used_(nullptr) {} + +Scope Scope::NewRootScope() { + return Scope(new Graph(OpRegistry::Global()), new Status, new Scope::NameMap); +} + +Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name, + bool copy_names) + : graph_(other.graph_), + status_(other.status_), + name_map_(copy_names ? other.name_map_ + : std::shared_ptr<NameMap>(new NameMap)), + scope_used_(nullptr), + control_deps_(other.control_deps_), + name_(name), + op_name_(""), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::OpName, const string& name, + const string& op_name) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(other.control_deps_), + name_(name), + op_name_(op_name), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::ControlDeps, + std::vector<ops::Operation> control_deps, bool clear_control_deps) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(clear_control_deps + ? std::vector<ops::Operation>() + : (control_deps.insert(control_deps.begin(), + other.control_deps_.begin(), + other.control_deps_.end()), + control_deps)), + name_(other.name_), + op_name_(other.op_name_), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::Device, const string& device) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(other.control_deps_), + name_(other.name_), + op_name_(other.op_name_), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(device), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::SingleUseScope, + const string& op_name) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(new bool(false)), + control_deps_(other.control_deps_), + name_(other.name_), + op_name_(op_name), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::ExitOnError) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(other.control_deps_), + name_(other.name_), + op_name_(other.op_name_), + exit_on_error_(true), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::KernelLabel, + const string& kernel_label) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(other.control_deps_), + name_(other.name_), + op_name_(other.op_name_), + exit_on_error_(other.exit_on_error_), + kernel_label_(kernel_label), + device_(other.device_), + colocation_constraints_(other.colocation_constraints_) {} + +Scope::Scope(const Scope& other, Scope::Tags::Colocate, + const ops::Operation& colocate_with_op, bool clear_colocations) + : graph_(other.graph_), + status_(other.status_), + name_map_(other.name_map_), + scope_used_(other.scope_used_), + control_deps_(other.control_deps_), + name_(other.name_), + op_name_(other.op_name_), + exit_on_error_(other.exit_on_error_), + kernel_label_(other.kernel_label_), + device_(other.device_), + colocation_constraints_( + clear_colocations + ? std::unordered_set<string>() + : other.GetColocationConstraints(colocate_with_op)) {} + +std::unordered_set<string> Scope::GetColocationConstraints( + const ops::Operation& colocate_with_op) const { + std::unordered_set<string> current_constraints(colocation_constraints_); + const NodeDef& node_def = colocate_with_op.node()->def(); + if (node_def.attr().find("_class") != node_def.attr().end()) { + const AttrValue& loc = node_def.attr().find("_class")->second; + if (loc.value_case() == AttrValue::kList && loc.list().s_size() > 0) { + for (int i = 0; i < loc.list().s_size(); ++i) { + // Filter out the ones that don't have "loc:@" prefix + if (loc.list().s(i).find("loc:@") == 0) { + // Skip the "loc:@" prefix + current_constraints.insert(loc.list().s(i).substr(5)); + } + } + } + } else { + current_constraints.insert(colocate_with_op.node()->name()); + } + return current_constraints; +} + +void Scope::UpdateStatus(const Status s) const { + status_->Update(s); + if (exit_on_error_ && !status_->ok()) { + LOG(FATAL) << status_; + } +} + +Status Scope::ToGraphDef(GraphDef* gdef) const { + if (!status_->ok()) { + return *status_; + } + graph()->ToGraphDef(gdef); + return Status::OK(); +} + +Status Scope::ToGraph(Graph* g) const { + if (status_->ok()) { + GraphDef graph_def; + graph()->ToGraphDef(&graph_def); + GraphConstructorOptions opts; + UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g)); + } + return *status_; +} + +void Scope::UpdateBuilder(NodeBuilder* builder) const { + std::vector<Node*> control_inputs; + for (const auto& op : control_deps_) { + control_inputs.push_back(op.node()); + } + builder->ControlInputs(control_inputs); + + if (!kernel_label_.empty()) { + builder->Attr("_kernel", kernel_label_); + } + + if (!colocation_constraints_.empty()) { + std::vector<string> constraints(colocation_constraints_.begin(), + colocation_constraints_.end()); + // Sort the set. + std::sort(constraints.begin(), constraints.end()); + // Add loc:@ prefix + std::transform(constraints.begin(), constraints.end(), constraints.begin(), + [](const string& s) { return strings::StrCat("loc:@", s); }); + builder->Attr("_class", constraints); + } + if (!device_.empty()) { + builder->Device(device_); + } +} + +string Scope::GetUniqueName(const string& prefix, bool check_single_use) const { + if (check_single_use && single_use_scope()) { + if (*scope_used_) { + *status_ = + errors::AlreadyExists(prefix, " already exists in the current scope"); + return ""; + } + *scope_used_ = true; + return prefix; + } + auto entry = name_map_->find(prefix); + string unique_name = prefix; + if (entry == name_map_->end()) { + name_map_->insert({prefix, 0}); + } else { + unique_name = strings::StrCat(unique_name, "_", ++entry->second); + } + return unique_name; +} + +string Scope::GetNameForOp(const string& default_name) const { + const string unique_name = + GetUniqueName(default_name, true /* check_single_use */); + const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + return strings::StrCat(name_, sep, unique_name); +} + +string Scope::GetUniqueNameForOp(const string& default_name) const { + if (single_use_scope()) { + if (op_name_.empty() || *scope_used_) { + *status_ = + errors::InvalidArgument("Cannot get a unique name in this scope"); + return ""; + } + *scope_used_ = true; + return op_name_; + } + return op_name_.empty() ? GetNameForOp(default_name) : GetNameForOp(op_name_); +} + +Scope Scope::NewSubScope(const string& child_scope_name) const { + if (child_scope_name.empty()) { + return Scope(*this, Scope::Tags::ScopeName(), name_, true /* copy_names */); + } + const string unique_name = + GetUniqueName(child_scope_name, false /* check_single_use */); + const string sep = name_.empty() || unique_name.empty() ? "" : "/"; + return Scope(*this, Scope::Tags::ScopeName(), + strings::StrCat(name_, sep, unique_name), + false /* copy_names */); +} + +Scope Scope::WithOpName(const string& op_name) const { + if (single_use_scope()) { + UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name, + " on this scope")); + return *this; + } + return Scope(*this, Scope::Tags::OpName(), name_, op_name); +} + +Scope Scope::WithControlDependencies( + const gtl::ArraySlice<ops::Operation>& control_deps) const { + return Scope( + *this, Scope::Tags::ControlDeps(), + std::vector<ops::Operation>(control_deps.begin(), control_deps.end()), + /* clear_control_deps */ false); +} + +Scope Scope::WithControlDependencies(const ops::Output& control_dep) const { + return Scope(*this, Scope::Tags::ControlDeps(), + std::vector<ops::Operation>(1, control_dep.op()), + /* clear_control_deps */ false); +} + +Scope Scope::WithNoControlDependencies() const { + return Scope(*this, Scope::Tags::ControlDeps(), std::vector<ops::Operation>(), + /* clear_control_deps */ true); +} + +Scope Scope::WithDevice(const string& device) const { + return Scope(*this, Scope::Tags::Device(), device); +} + +Scope Scope::ColocateWith(const ops::Operation& op) const { + return Scope(*this, Scope::Tags::Colocate(), op, + /* clear_colocations */ false); +} + +Scope Scope::ClearColocation() const { + return Scope(*this, Scope::Tags::Colocate(), ops::Operation(), + /* clear_colocations */ true); +} + +Scope Scope::ExitOnError() const { + return Scope(*this, Scope::Tags::ExitOnError()); +} + +Scope Scope::WithKernelLabel(const string& kernel_label) const { + return Scope(*this, Scope::Tags::KernelLabel(), kernel_label); +} + +CompositeOpScopes Scope::GetCompositeOpScopes( + const string& composite_op_name) const { + if (op_name_.empty() && composite_op_name.empty()) { + UpdateStatus(errors::InvalidArgument( + "Cannot create composite op scopes with empty name")); + return {*this, *this}; + } + if (!single_use_scope()) { + Scope child = NewSubScope(op_name_.empty() ? composite_op_name : op_name_); + const string child_op_sep = name_.empty() ? "" : "_"; + return {child, Scope(child, Scope::Tags::SingleUseScope(), + strings::StrCat(name_, child_op_sep, child.name_))}; + } else { + return { + Scope(*this, Scope::Tags::ScopeName(), op_name_, true /* copy_names */), + *this}; + } +} + +} // namespace tensorflow |