diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-09-13 10:49:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-13 10:54:47 -0700 |
commit | 92362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2a (patch) | |
tree | 6440266f60c78450188586892fc6bab7fa67d859 /tensorflow/core/graph/graph.cc | |
parent | a4f6e7c1afd130d97759b99ba88e69138c59107c (diff) |
Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a
while loop and will be used in future changes to generate while loop
gradient graphs. Exit nodes in a while loop now have a pointer to
their associated WhileContext. This will be used to retrieve the
context for a given loop.
This change adds an optional parameter to BuildWhileLoop() to create a
WhileContext for the while loop (currently this is always true, but
gradients will generate while loops without associated contexts). This
change also adds a as-yet-unused option to BuildWhileLoop() to return
the predicate output.
PiperOrigin-RevId: 168562303
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r-- | tensorflow/core/graph/graph.cc | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a274c79970..599f802ee0 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -110,7 +111,8 @@ Node::Node() cost_id_(-1), class_(NC_UNINITIALIZED), props_(nullptr), - assigned_device_name_index_(0) {} + assigned_device_name_index_(0), + while_ctx_(nullptr) {} void Node::Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props) { @@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) { return index; } +Status Graph::AddWhileContext(StringPiece frame_name, + std::vector<Node*> enter_nodes, + std::vector<Node*> exit_nodes, + OutputTensor cond_output, + std::vector<OutputTensor> body_inputs, + std::vector<OutputTensor> body_outputs, + WhileContext** result) { + auto pair = while_ctxs_.insert(std::pair<string, WhileContext>( + frame_name.ToString(), + WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), + cond_output, std::move(body_inputs), + std::move(body_outputs)))); + if (!pair.second) { + *result = nullptr; + return errors::InvalidArgument("WhileContext with frame name '", frame_name, + "' already exists"); + } + *result = &pair.first->second; + return Status::OK(); +} + string Edge::DebugString() const { return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), src_output_, dst_->name().c_str(), dst_input_); |