aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/while_loop_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-27 13:28:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-27 13:32:17 -0700
commit301b14c240fe99249dc2225132a7ebe5cbecbdc4 (patch)
tree8de10cf0180ff41211b294c05cae199208487c85 /tensorflow/c/while_loop_test.cc
parent545e3572f7d8928eeb220e8b55c71ad33a9343c6 (diff)
Basic while loop gradient functionality in C++
This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281
Diffstat (limited to 'tensorflow/c/while_loop_test.cc')
-rw-r--r--tensorflow/c/while_loop_test.cc39
1 files changed, 38 insertions, 1 deletions
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index 27be5d787f..4698560bbe 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test {
}
void Run(std::initializer_list<int> input_values) {
+ Run(outputs_, input_values);
+ }
+
+ void Run(const std::vector<TF_Output>& run_outputs,
+ std::initializer_list<int> input_values) {
DCHECK_EQ(inputs_.size(), input_values.size());
std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
int i = 0;
@@ -82,7 +87,7 @@ class CApiWhileLoopTest : public ::testing::Test {
}
csession_.reset(new CSession(graph_, s_));
csession_->SetInputs(inputs);
- csession_->SetOutputs(outputs_);
+ csession_->SetOutputs(run_outputs);
csession_->Run(s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
}
@@ -402,4 +407,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) {
TF_AbortWhile(params_.get());
}
+// This is a basic test to make sure the C++ gradient code can handle while
+// loops created by the C API (which calls the C++ API under the hood). There
+// are more while loop gradient tests in cc/framework/while_gradients_test.cc.
+TEST_F(CApiWhileLoopTest, Gradients) {
+ Init(1);
+
+ // Create loop: while (i < 10) i += 1
+ TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
+ TF_Operation* less_than =
+ LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
+ DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->cond_output = {less_than, 0};
+
+ TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
+ TF_Operation* add =
+ Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->body_outputs[0] = {add, 0};
+
+ ExpectOK();
+
+ // Create backprop graph
+ TF_Output grad_output;
+ TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1,
+ nullptr, s_, &grad_output);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ // Run gradient
+ Run({grad_output}, {0});
+ ExpectOutputValue(0, 1);
+}
+
} // namespace