aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph/graph.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-13 10:49:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 10:54:47 -0700
commit92362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2a (patch)
tree6440266f60c78450188586892fc6bab7fa67d859 /tensorflow/core/graph/graph.cc
parenta4f6e7c1afd130d97759b99ba88e69138c59107c (diff)
Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a while loop and will be used in future changes to generate while loop gradient graphs. Exit nodes in a while loop now have a pointer to their associated WhileContext. This will be used to retrieve the context for a given loop. This change adds an optional parameter to BuildWhileLoop() to create a WhileContext for the while loop (currently this is always true, but gradients will generate while loops without associated contexts). This change also adds a as-yet-unused option to BuildWhileLoop() to return the predicate output. PiperOrigin-RevId: 168562303
Diffstat (limited to 'tensorflow/core/graph/graph.cc')
-rw-r--r--tensorflow/core/graph/graph.cc25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index a274c79970..599f802ee0 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -110,7 +111,8 @@ Node::Node()
cost_id_(-1),
class_(NC_UNINITIALIZED),
props_(nullptr),
- assigned_device_name_index_(0) {}
+ assigned_device_name_index_(0),
+ while_ctx_(nullptr) {}
void Node::Initialize(int id, int cost_id,
std::shared_ptr<NodeProperties> props) {
@@ -582,6 +584,27 @@ int Graph::InternDeviceName(const string& device_name) {
return index;
}
+Status Graph::AddWhileContext(StringPiece frame_name,
+ std::vector<Node*> enter_nodes,
+ std::vector<Node*> exit_nodes,
+ OutputTensor cond_output,
+ std::vector<OutputTensor> body_inputs,
+ std::vector<OutputTensor> body_outputs,
+ WhileContext** result) {
+ auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
+ frame_name.ToString(),
+ WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
+ cond_output, std::move(body_inputs),
+ std::move(body_outputs))));
+ if (!pair.second) {
+ *result = nullptr;
+ return errors::InvalidArgument("WhileContext with frame name '", frame_name,
+ "' already exists");
+ }
+ *result = &pair.first->second;
+ return Status::OK();
+}
+
string Edge::DebugString() const {
return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
src_output_, dst_->name().c_str(), dst_input_);