From 301b14c240fe99249dc2225132a7ebe5cbecbdc4 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 27 Sep 2017 13:28:30 -0700 Subject: Basic while loop gradient functionality in C++ This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281 --- tensorflow/c/while_loop_test.cc | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) (limited to 'tensorflow/c/while_loop_test.cc') diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 27be5d787f..4698560bbe 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test { } void Run(std::initializer_list input_values) { + Run(outputs_, input_values); + } + + void Run(const std::vector& run_outputs, + std::initializer_list input_values) { DCHECK_EQ(inputs_.size(), input_values.size()); std::vector> inputs(inputs_.size()); int i = 0; @@ -82,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test { } csession_.reset(new CSession(graph_, s_)); csession_->SetInputs(inputs); - csession_->SetOutputs(outputs_); + csession_->SetOutputs(run_outputs); csession_->Run(s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -402,4 +407,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) { TF_AbortWhile(params_.get()); } +// This is a basic test to make sure the C++ gradient code can handle while +// loops created by the C API (which calls the C++ API under the hood). There +// are more while loop gradient tests in cc/framework/while_gradients_test.cc. +TEST_F(CApiWhileLoopTest, Gradients) { + Init(1); + + // Create loop: while (i < 10) i += 1 + TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + TF_Operation* add = + Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add, 0}; + + ExpectOK(); + + // Create backprop graph + TF_Output grad_output; + TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1, + nullptr, s_, &grad_output); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run gradient + Run({grad_output}, {0}); + ExpectOutputValue(0, 1); +} + } // namespace -- cgit v1.2.3