aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-09-07 16:02:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 16:06:27 -0700
commit83ba41e0a38d211fcdb5e3b4e212ef296dc96490 (patch)
tree48526ff884ad5ecfa72c85902ad0290c00c0bc1f /tensorflow/cc/ops
parentbf1c826a7a3dadd2c971d82e811253b0430f590d (diff)
More C++ while loop validation
With this change, we call IsValidOutputTensor() on the returned outputs from the condition and body functions. This will return a bad status if no output or a null output is set, or if an output has a bad index. This also adds unit tests for related error cases to the C and C++ unit tests. They often produce different errors because the C implementation goes through the graph constructor. PiperOrigin-RevId: 167925641
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r--tensorflow/cc/ops/while_loop.cc14
-rw-r--r--tensorflow/cc/ops/while_loop_test.cc179
2 files changed, 191 insertions, 2 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc
index 27da77bbe0..e3e39da85e 100644
--- a/tensorflow/cc/ops/while_loop.cc
+++ b/tensorflow/cc/ops/while_loop.cc
@@ -102,11 +102,16 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
Output raw_cond_out;
TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
+
+ TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(),
+ raw_cond_out.index()));
if (raw_cond_out.type() != DT_BOOL) {
return errors::InvalidArgument(
"BuildWhileLoop: 'cond' argument must return a boolean output, got ",
DataTypeString(raw_cond_out.type()));
}
+ // TODO(skyewm): check that raw_cond_out is scalar
+
*output = LoopCond(scope, raw_cond_out).output;
return Status::OK();
}
@@ -123,13 +128,18 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
Scope body_scope =
scope.NewSubScope("body").WithControlDependencies(inputs[0]);
TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
+
const size_t num_loop_vars = inputs.size();
if (outputs->size() != num_loop_vars) {
return errors::InvalidArgument(
"BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
- "outputs, got ", outputs->size());
+ " output(s), got ", outputs->size());
+ }
+ for (const Output& output : *outputs) {
+ TF_RETURN_IF_ERROR(
+ scope.graph()->IsValidOutputTensor(output.node(), output.index()));
+ // TODO(skyewm): check output types/shapes
}
- // TODO(skyewm): check output types/shapes
return Status::OK();
}
diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc
new file mode 100644
index 0000000000..77028b5c41
--- /dev/null
+++ b/tensorflow/cc/ops/while_loop_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+class WhileLoopTest : public ::testing::Test {
+ protected:
+ WhileLoopTest() : scope_(Scope::NewRootScope()) {}
+
+ void Init(int num_inputs, DataType dtype = DT_INT32) {
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(ops::Placeholder(scope_, dtype));
+ }
+ }
+
+ void CreateLoop(const ops::CondGraphBuilderFn& cond,
+ const ops::BodyGraphBuilderFn& body,
+ error::Code error_code = error::OK,
+ const string& error_msg = "") {
+ Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
+ &outputs_);
+ EXPECT_EQ(s.code(), error_code);
+ EXPECT_EQ(s.error_message(), error_msg);
+ }
+
+ template <typename T>
+ void Run(const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_output_values) {
+ ClientSession session(scope_);
+
+ DCHECK_EQ(input_values.size(), inputs_.size());
+ ClientSession::FeedType feeds;
+ for (int i = 0; i < inputs_.size(); ++i) {
+ feeds.emplace(inputs_[i], input_values[i]);
+ }
+
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), outputs_.size());
+
+ DCHECK_EQ(expected_output_values.size(), out_tensors.size());
+ for (int i = 0; i < out_tensors.size(); ++i) {
+ test::ExpectTensorEqual<T>(
+ out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {}));
+ }
+ }
+
+ Scope scope_;
+ std::vector<Output> inputs_;
+ std::vector<Output> outputs_;
+};
+
+Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+}
+
+Status AddOneBody(const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::Add(s, inputs[0], 1));
+ return s.status();
+}
+
+TEST_F(WhileLoopTest, Basic) {
+ // Create loop: while (i < 10) i += 1
+ Init(1);
+ CreateLoop(LessThanTenCond, AddOneBody);
+ Run<int>({1}, {10});
+ Run<int>({11}, {11});
+}
+
+TEST_F(WhileLoopTest, WrongCondOutputType) {
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Placeholder(s, DT_FLOAT);
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT,
+ "BuildWhileLoop: 'cond' argument must return a boolean output, got "
+ "float");
+}
+
+// TODO(skyewm): test bad cond output shape
+
+TEST_F(WhileLoopTest, NullCondOutputNode) {
+ Init(1);
+ // TODO(skyewm): improve error message
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = {nullptr, 0};
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT, "Node is null");
+}
+
+TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ auto less = ops::Less(s, inputs[0], 10);
+ *output = {less.node(), 100};
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT,
+ "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
+ "100");
+}
+
+TEST_F(WhileLoopTest, UnsetCondOutput) {
+ Init(1);
+ CreateLoop([](const Scope& s, const std::vector<Output>& inputs,
+ Output* output) { return s.status(); },
+ AddOneBody, error::INVALID_ARGUMENT, "Node is null");
+}
+
+// TODO(skyewm): test bad body output type
+// TODO(skyewm): test bad body output shape
+
+TEST_F(WhileLoopTest, NullBodyOutputNode) {
+ Init(1);
+ // TODO(skyewm): improve error message
+ CreateLoop(LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back({nullptr, 0});
+ return s.status();
+ },
+ error::INVALID_ARGUMENT, "Node is null");
+}
+
+TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
+ Init(1);
+ CreateLoop(LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ auto add = ops::Add(s, inputs[0], 1);
+ outputs->emplace_back(add.node(), 100);
+ return s.status();
+ },
+ error::INVALID_ARGUMENT,
+ "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
+ "output 100");
+}
+
+TEST_F(WhileLoopTest, UnsetBodyOutputs) {
+ Init(1);
+ CreateLoop(
+ LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) { return s.status(); },
+ error::INVALID_ARGUMENT,
+ "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0");
+}
+
+} // namespace
+} // namespace tensorflow