aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-13 10:49:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-13 10:54:47 -0700
commit92362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2a (patch)
tree6440266f60c78450188586892fc6bab7fa67d859 /tensorflow/cc/ops
parenta4f6e7c1afd130d97759b99ba88e69138c59107c (diff)
Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a while loop and will be used in future changes to generate while loop gradient graphs. Exit nodes in a while loop now have a pointer to their associated WhileContext. This will be used to retrieve the context for a given loop. This change adds an optional parameter to BuildWhileLoop() to create a WhileContext for the while loop (currently this is always true, but gradients will generate while loops without associated contexts). This change also adds a as-yet-unused option to BuildWhileLoop() to return the predicate output. PiperOrigin-RevId: 168562303
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/while_loop.cc21
-rw-r--r--tensorflow/cc/ops/while_loop.h7
-rw-r--r--tensorflow/cc/ops/while_loop_test.cc26
3 files changed, 49 insertions, 5 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc
index e3e39da85e..e0251efb2a 100644
--- a/tensorflow/cc/ops/while_loop.cc
+++ b/tensorflow/cc/ops/while_loop.cc
@@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
const CondGraphBuilderFn& cond,
const BodyGraphBuilderFn& body, const string& frame_name,
- OutputList* outputs) {
+ OutputList* outputs, bool create_while_ctx,
+ Output* cond_output) {
DCHECK(!inputs.empty());
DCHECK(outputs != nullptr);
DCHECK(outputs->empty());
@@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
Output cond_out;
TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
+ if (cond_output != nullptr) *cond_output = cond_out;
std::vector<Output> switch_trues(num_loop_vars);
std::vector<Output> switch_falses(num_loop_vars);
@@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
for (int i = 0; i < num_loop_vars; ++i) {
(*outputs)[i] = internal::Exit(scope, switch_falses[i]);
}
- return scope.status();
+ TF_RETURN_IF_ERROR(scope.status());
+
+ if (create_while_ctx) {
+ WhileContext* while_ctx;
+ TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
+ frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
+ ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
+ ToOutputTensors(body_outputs), &while_ctx));
+
+ // Set while_ctx for all exit nodes. We currently don't require knowing the
+ // while_ctx for any other nodes.
+ for (int i = 0; i < num_loop_vars; ++i) {
+ (*outputs)[i].node()->set_while_ctx(while_ctx);
+ }
+ }
+ return Status::OK();
}
} // namespace ops
diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h
index 253d5d8935..82181516d6 100644
--- a/tensorflow/cc/ops/while_loop.h
+++ b/tensorflow/cc/ops/while_loop.h
@@ -48,6 +48,10 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
// unique name. This will be used as a prefix for created operations.
// * outputs: output param that returns final loop variable outputs in non-error
// case. Must be non-null and empty.
+// * create_while_ctx: if true, a WhileContext is created and populated for this
+// loop. See core/graph/while_context.h for more details.
+// * cond_output: if non-null, the output of the predicate is returned. This
+// will always be a LoopCond node.
//
// Returns an error if the while loop could not be fully constructed.
//
@@ -56,7 +60,8 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs,
Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
const CondGraphBuilderFn& cond,
const BodyGraphBuilderFn& body, const string& frame_name,
- OutputList* outputs);
+ OutputList* outputs, bool create_while_ctx = true,
+ Output* cond_output = nullptr);
} // namespace ops
} // namespace tensorflow
diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc
index 77028b5c41..e3f6523c19 100644
--- a/tensorflow/cc/ops/while_loop_test.cc
+++ b/tensorflow/cc/ops/while_loop_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/while_context.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test {
const ops::BodyGraphBuilderFn& body,
error::Code error_code = error::OK,
const string& error_msg = "") {
- Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
- &outputs_);
+ Status s =
+ ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
EXPECT_EQ(s.code(), error_code);
EXPECT_EQ(s.error_message(), error_msg);
}
@@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test {
Scope scope_;
std::vector<Output> inputs_;
std::vector<Output> outputs_;
+
+ static const char* const kFrameName;
};
+const char* const WhileLoopTest::kFrameName = "test_loop";
+
Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
Output* output) {
*output = ops::Less(s, inputs[0], 10);
@@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) {
// Create loop: while (i < 10) i += 1
Init(1);
CreateLoop(LessThanTenCond, AddOneBody);
+
+ // Verify some output invariants
+ WhileContext* while_ctx;
+ for (int i = 0; i < outputs_.size(); ++i) {
+ Node* node = outputs_[i].node();
+ ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
+ << node->DebugString();
+ ASSERT_TRUE(node->while_ctx() != nullptr) << i;
+ if (i == 0) {
+ while_ctx = node->while_ctx();
+ EXPECT_EQ(while_ctx->frame_name(), kFrameName);
+ } else {
+ EXPECT_EQ(node->while_ctx(), while_ctx) << i;
+ }
+ }
+
+ // Run the loop and test we get the expected results
Run<int>({1}, {10});
Run<int>({11}, {11});
}