diff options
author | 2017-04-13 11:42:52 -0800 | |
---|---|---|
committer | 2017-04-13 13:05:08 -0700 | |
commit | 908d5b6ede6ae829dff138a873eec397ef434cd6 (patch) | |
tree | de7898ec319637d2f6d4a78067715bd02808fb02 /tensorflow/cc/framework/scope.cc | |
parent | 59ccf014e89ff625dc3d9779e1fb54a980c4b6ac (diff) |
Add C++ gradients to c_api.
#6268
This CL does the following:
(1) Adds TF_AddGradients function to C_API which adds gradient nodes for the specified inputs.
(2) Adds internal constructor for Scope, need to create a scope from an existing graph in the c_api.
(3) Adds constructor for AddSymbolicGradients that assumes OnesLike when grad_inputs aren't provided.
(4) Improves error message when gradients aren't provided.
Change: 153092774
Diffstat (limited to 'tensorflow/cc/framework/scope.cc')
-rw-r--r-- | tensorflow/cc/framework/scope.cc | 59 |
1 files changed, 49 insertions, 10 deletions
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc index 571c6e1e57..8b7fc1406f 100644 --- a/tensorflow/cc/framework/scope.cc +++ b/tensorflow/cc/framework/scope.cc @@ -16,7 +16,7 @@ limitations under the License. #include <algorithm> #include <vector> -#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -25,6 +25,20 @@ limitations under the License. namespace tensorflow { class Scope::Impl { + public: + // A NameMap is used to keep track of suffixes for names used in a scope. A + // name that has not been used so far in a scope will get no suffix. Later + // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes + // can share the same NameMap. For instance, a new scope created using + // WithControlDependencies() should would share the same NameMap with the + // parent. + typedef std::unordered_map<string, int> NameMap; + + Impl(const std::shared_ptr<Graph>& graph, + const std::shared_ptr<Status>& status, + const std::shared_ptr<NameMap>& name_map, + const std::shared_ptr<ShapeRefiner>& refiner); + private: friend class Scope; @@ -40,14 +54,6 @@ class Scope::Impl { enum class Colocate; }; - // A NameMap is used to keep track of suffixes for names used in a scope. A - // name that has not been used so far in a scope will get no suffix. Later - // uses of the same name will get suffixes _1, _2, _3, etc. Multiple scopes - // can share the same NameMap. For instance, a new scope created using - // WithControlDependencies() should would share the same NameMap with the - // parent. - typedef std::unordered_map<string, int> NameMap; - Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner); Impl(const Scope& other, Tags::ScopeName, const string& name, bool copy_names); @@ -116,6 +122,17 @@ Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map, scope_used_(nullptr), colocation_constraints_() {} +Scope::Impl::Impl(const std::shared_ptr<Graph>& graph, + const std::shared_ptr<Status>& status, + const std::shared_ptr<NameMap>& name_map, + const std::shared_ptr<ShapeRefiner>& refiner) + : graph_(graph), + status_(status), + name_map_(name_map), + refiner_(refiner), + scope_used_(nullptr), + colocation_constraints_() {} + Scope Scope::NewRootScope() { Graph* graph = new Graph(OpRegistry::Global()); ShapeRefiner* refiner = @@ -277,7 +294,7 @@ std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const { return impl()->graph_; } -Status Scope::status() const { return *impl()->status_; }; +Status Scope::status() const { return *impl()->status_; } const std::vector<Operation>& Scope::control_deps() const { return impl()->control_deps_; @@ -464,4 +481,26 @@ CompositeOpScopes Scope::GetCompositeOpScopes( } } +class InternalScope { + public: + // NewScope doesn't take ownership of the inputs. + static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap; + for (const Node* node : graph->nodes()) { + (*name_map)[node->name()] = 0; + } + // We provide null destructors for these shared ptrs (except for name_map) + // since the caller owns them and doesn't want the scope to destroy them. + return Scope(new Scope::Impl( + std::shared_ptr<Graph>(graph, [](Graph*) {}), + std::shared_ptr<Status>(status, [](Status*) {}), + std::shared_ptr<Scope::Impl::NameMap>(name_map), + std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {}))); + } +}; + +Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) { + return InternalScope::NewScope(graph, status, refiner); +} + } // namespace tensorflow |