diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2017-09-07 16:02:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-07 16:06:27 -0700 |
commit | 83ba41e0a38d211fcdb5e3b4e212ef296dc96490 (patch) | |
tree | 48526ff884ad5ecfa72c85902ad0290c00c0bc1f /tensorflow/cc/ops | |
parent | bf1c826a7a3dadd2c971d82e811253b0430f590d (diff) |
More C++ while loop validation
With this change, we call IsValidOutputTensor() on the returned
outputs from the condition and body functions. This will return a bad
status if no output or a null output is set, or if an output has a bad
index.
This also adds unit tests for related error cases to the C and C++
unit tests. They often produce different errors because the C
implementation goes through the graph constructor.
PiperOrigin-RevId: 167925641
Diffstat (limited to 'tensorflow/cc/ops')
-rw-r--r-- | tensorflow/cc/ops/while_loop.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop_test.cc | 179 |
2 files changed, 191 insertions, 2 deletions
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index 27da77bbe0..e3e39da85e 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -102,11 +102,16 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, scope.NewSubScope("cond").WithControlDependencies(inputs[0]); Output raw_cond_out; TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); + + TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(), + raw_cond_out.index())); if (raw_cond_out.type() != DT_BOOL) { return errors::InvalidArgument( "BuildWhileLoop: 'cond' argument must return a boolean output, got ", DataTypeString(raw_cond_out.type())); } + // TODO(skyewm): check that raw_cond_out is scalar + *output = LoopCond(scope, raw_cond_out).output; return Status::OK(); } @@ -123,13 +128,18 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, Scope body_scope = scope.NewSubScope("body").WithControlDependencies(inputs[0]); TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); + const size_t num_loop_vars = inputs.size(); if (outputs->size() != num_loop_vars) { return errors::InvalidArgument( "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars, - "outputs, got ", outputs->size()); + " output(s), got ", outputs->size()); + } + for (const Output& output : *outputs) { + TF_RETURN_IF_ERROR( + scope.graph()->IsValidOutputTensor(output.node(), output.index())); + // TODO(skyewm): check output types/shapes } - // TODO(skyewm): check output types/shapes return Status::OK(); } diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc new file mode 100644 index 0000000000..77028b5c41 --- /dev/null +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class WhileLoopTest : public ::testing::Test { + protected: + WhileLoopTest() : scope_(Scope::NewRootScope()) {} + + void Init(int num_inputs, DataType dtype = DT_INT32) { + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(ops::Placeholder(scope_, dtype)); + } + } + + void CreateLoop(const ops::CondGraphBuilderFn& cond, + const ops::BodyGraphBuilderFn& body, + error::Code error_code = error::OK, + const string& error_msg = "") { + Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop", + &outputs_); + EXPECT_EQ(s.code(), error_code); + EXPECT_EQ(s.error_message(), error_msg); + } + + template <typename T> + void Run(const std::vector<Input::Initializer>& input_values, + const std::vector<T>& expected_output_values) { + ClientSession session(scope_); + + DCHECK_EQ(input_values.size(), inputs_.size()); + ClientSession::FeedType feeds; + for (int i = 0; i < inputs_.size(); ++i) { + feeds.emplace(inputs_[i], input_values[i]); + } + + std::vector<Tensor> out_tensors; + TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors)); + ASSERT_EQ(out_tensors.size(), outputs_.size()); + + DCHECK_EQ(expected_output_values.size(), out_tensors.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + test::ExpectTensorEqual<T>( + out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {})); + } + } + + Scope scope_; + std::vector<Output> inputs_; + std::vector<Output> outputs_; +}; + +Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs, + Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); +} + +Status AddOneBody(const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back(ops::Add(s, inputs[0], 1)); + return s.status(); +} + +TEST_F(WhileLoopTest, Basic) { + // Create loop: while (i < 10) i += 1 + Init(1); + CreateLoop(LessThanTenCond, AddOneBody); + Run<int>({1}, {10}); + Run<int>({11}, {11}); +} + +TEST_F(WhileLoopTest, WrongCondOutputType) { + Init(1); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Placeholder(s, DT_FLOAT); + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, + "BuildWhileLoop: 'cond' argument must return a boolean output, got " + "float"); +} + +// TODO(skyewm): test bad cond output shape + +TEST_F(WhileLoopTest, NullCondOutputNode) { + Init(1); + // TODO(skyewm): improve error message + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = {nullptr, 0}; + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, "Node is null"); +} + +TEST_F(WhileLoopTest, InvalidCondOutputIndex) { + Init(1); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + auto less = ops::Less(s, inputs[0], 10); + *output = {less.node(), 100}; + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, + "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output " + "100"); +} + +TEST_F(WhileLoopTest, UnsetCondOutput) { + Init(1); + CreateLoop([](const Scope& s, const std::vector<Output>& inputs, + Output* output) { return s.status(); }, + AddOneBody, error::INVALID_ARGUMENT, "Node is null"); +} + +// TODO(skyewm): test bad body output type +// TODO(skyewm): test bad body output shape + +TEST_F(WhileLoopTest, NullBodyOutputNode) { + Init(1); + // TODO(skyewm): improve error message + CreateLoop(LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back({nullptr, 0}); + return s.status(); + }, + error::INVALID_ARGUMENT, "Node is null"); +} + +TEST_F(WhileLoopTest, InvalidBodyOutputIndex) { + Init(1); + CreateLoop(LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + auto add = ops::Add(s, inputs[0], 1); + outputs->emplace_back(add.node(), 100); + return s.status(); + }, + error::INVALID_ARGUMENT, + "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have " + "output 100"); +} + +TEST_F(WhileLoopTest, UnsetBodyOutputs) { + Init(1); + CreateLoop( + LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { return s.status(); }, + error::INVALID_ARGUMENT, + "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0"); +} + +} // namespace +} // namespace tensorflow |