diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-09-13 10:49:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-13 10:54:47 -0700 |
commit | 92362d0f0510d5bb1afa3c9cfd007cbf9cdf6d2a (patch) | |
tree | 6440266f60c78450188586892fc6bab7fa67d859 /tensorflow/cc/ops | |
parent | a4f6e7c1afd130d97759b99ba88e69138c59107c (diff) |
Add WhileContext class and add plumbing for creating them.
This change introduces WhileContext, which stores information about a
while loop and will be used in future changes to generate while loop
gradient graphs. Exit nodes in a while loop now have a pointer to
their associated WhileContext. This will be used to retrieve the
context for a given loop.
This change adds an optional parameter to BuildWhileLoop() to create a
WhileContext for the while loop (currently this is always true, but
gradients will generate while loops without associated contexts). This
change also adds a as-yet-unused option to BuildWhileLoop() to return
the predicate output.
PiperOrigin-RevId: 168562303
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/while_loop.cc | 21 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop.h | 7 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop_test.cc | 26 |
3 files changed, 49 insertions, 5 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index e3e39da85e..e0251efb2a 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -172,7 +172,8 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, const CondGraphBuilderFn& cond, const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs) { + OutputList* outputs, bool create_while_ctx, + Output* cond_output) { DCHECK(!inputs.empty()); DCHECK(outputs != nullptr); DCHECK(outputs->empty()); @@ -194,6 +195,7 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, Output cond_out; TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); + if (cond_output != nullptr) *cond_output = cond_out; std::vector<Output> switch_trues(num_loop_vars); std::vector<Output> switch_falses(num_loop_vars); @@ -226,7 +228,22 @@ Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, for (int i = 0; i < num_loop_vars; ++i) { (*outputs)[i] = internal::Exit(scope, switch_falses[i]); } - return scope.status(); + TF_RETURN_IF_ERROR(scope.status()); + + if (create_while_ctx) { + WhileContext* while_ctx; + TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext( + frame_name, ToNodes(enter_outputs), ToNodes(*outputs), + ToOutputTensor(cond_out), ToOutputTensors(switch_trues), + ToOutputTensors(body_outputs), &while_ctx)); + + // Set while_ctx for all exit nodes. We currently don't require knowing the + // while_ctx for any other nodes. + for (int i = 0; i < num_loop_vars; ++i) { + (*outputs)[i].node()->set_while_ctx(while_ctx); + } + } + return Status::OK(); } } // namespace ops diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 253d5d8935..82181516d6 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -48,6 +48,10 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, // unique name. This will be used as a prefix for created operations. // * outputs: output param that returns final loop variable outputs in non-error // case. Must be non-null and empty. +// * create_while_ctx: if true, a WhileContext is created and populated for this +// loop. See core/graph/while_context.h for more details. +// * cond_output: if non-null, the output of the predicate is returned. This +// will always be a LoopCond node. // // Returns an error if the while loop could not be fully constructed. // @@ -56,7 +60,8 @@ typedef std::function<Status(const Scope&, const std::vector<Output>& inputs, Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, const CondGraphBuilderFn& cond, const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs); + OutputList* outputs, bool create_while_ctx = true, + Output* cond_output = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index 77028b5c41..e3f6523c19 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -38,8 +39,8 @@ class WhileLoopTest : public ::testing::Test { const ops::BodyGraphBuilderFn& body, error::Code error_code = error::OK, const string& error_msg = "") { - Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop", - &outputs_); + Status s = + ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_); EXPECT_EQ(s.code(), error_code); EXPECT_EQ(s.error_message(), error_msg); } @@ -69,8 +70,12 @@ class WhileLoopTest : public ::testing::Test { Scope scope_; std::vector<Output> inputs_; std::vector<Output> outputs_; + + static const char* const kFrameName; }; +const char* const WhileLoopTest::kFrameName = "test_loop"; + Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs, Output* output) { *output = ops::Less(s, inputs[0], 10); @@ -87,6 +92,23 @@ TEST_F(WhileLoopTest, Basic) { // Create loop: while (i < 10) i += 1 Init(1); CreateLoop(LessThanTenCond, AddOneBody); + + // Verify some output invariants + WhileContext* while_ctx; + for (int i = 0; i < outputs_.size(); ++i) { + Node* node = outputs_[i].node(); + ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n" + << node->DebugString(); + ASSERT_TRUE(node->while_ctx() != nullptr) << i; + if (i == 0) { + while_ctx = node->while_ctx(); + EXPECT_EQ(while_ctx->frame_name(), kFrameName); + } else { + EXPECT_EQ(node->while_ctx(), while_ctx) << i; + } + } + + // Run the loop and test we get the expected results Run<int>({1}, {10}); Run<int>({11}, {11}); } |