aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/while_loop_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-07 16:02:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 16:06:27 -0700
commit83ba41e0a38d211fcdb5e3b4e212ef296dc96490 (patch)
tree48526ff884ad5ecfa72c85902ad0290c00c0bc1f /tensorflow/c/while_loop_test.cc
parentbf1c826a7a3dadd2c971d82e811253b0430f590d (diff)
More C++ while loop validation
With this change, we call IsValidOutputTensor() on the returned outputs from the condition and body functions. This will return a bad status if no output or a null output is set, or if an output has a bad index. This also adds unit tests for related error cases to the C and C++ unit tests. They often produce different errors because the C implementation goes through the graph constructor. PiperOrigin-RevId: 167925641
Diffstat (limited to 'tensorflow/c/while_loop_test.cc')
-rw-r--r--tensorflow/c/while_loop_test.cc70
1 files changed, 68 insertions, 2 deletions
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index ce4f86bb25..27be5d787f 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -288,20 +288,86 @@ TEST_F(CApiWhileLoopTest, NestedLoop) {
ExpectOutputValue(1, 3);
}
-TEST_F(CApiWhileLoopTest, BadCondOutput) {
+TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
Init(1);
params_->body_outputs[0] = params_->body_inputs[0];
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `cond_output` field isn't set");
}
-TEST_F(CApiWhileLoopTest, BadBodyOutput) {
+TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
+ Init(1);
+ params_->cond_output = params_->cond_inputs[0];
+ params_->body_outputs[0] = params_->body_inputs[0];
+ ExpectError(TF_INVALID_ARGUMENT,
+ "BuildWhileLoop: 'cond' argument must return a boolean output, "
+ "got int32");
+}
+
+TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
+ Init(1);
+ // Try to reuse node from parent graph
+ params_->cond_output = inputs_[0];
+ params_->body_outputs[0] = params_->body_inputs[0];
+ // TODO(skyewm): this error message could be more informative. Add explicit
+ // checks for this case in the while loop implementation?
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Requested return node 'p0' not found in graph def");
+}
+
+TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
+ Init(1);
+ CreateCondGraph();
+ params_->cond_output.index = 100;
+ params_->body_outputs[0] = params_->body_inputs[0];
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Invalid return output 100 of node 'less_than', which has 1 "
+ "output(s)");
+}
+
+// TODO(skyewm): test bad cond output shape
+
+TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
Init(1);
CreateCondGraph();
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `body_outputs[0]` field isn't set");
}
+// TODO(skyewm): enable this when it works (currently doesn't error)
+// TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
+// Init(1);
+// CreateCondGraph();
+// TF_Operation* double_scalar =
+// ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
+// params_->body_outputs[0] = {double_scalar, 0};
+// ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
+// }
+
+TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
+ Init(1);
+ CreateCondGraph();
+ // Try to reuse node from parent graph
+ params_->body_outputs[0] = inputs_[0];
+ // TODO(skyewm): this error message could be more informative. Add explicit
+ // checks for this case in the while loop implementation?
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Requested return node 'p0' not found in graph def");
+}
+
+// TODO(skyewm): enable this when it works (currently segfaults!)
+// TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
+// Init(1);
+// CreateCondGraph();
+// params_->body_outputs[0] = params_->body_inputs[0];
+// params_->body_outputs[0].index = 100;
+// ExpectError(TF_INVALID_ARGUMENT,
+// "Invalid return output 100 of node 'less_than', which has 1 "
+// "output(s)");
+// }
+
+// TODO(skyewm): test bad body output shape
+
TEST_F(CApiWhileLoopTest, NullName) {
Init(1);
CreateCondGraph();