diff options
-rw-r--r-- | tensorflow/c/c_test_util.cc | 18 | ||||
-rw-r--r-- | tensorflow/c/c_test_util.h | 5 | ||||
-rw-r--r-- | tensorflow/c/while_loop_test.cc | 70 | ||||
-rw-r--r-- | tensorflow/cc/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop.cc | 14 | ||||
-rw-r--r-- | tensorflow/cc/ops/while_loop_test.cc | 179 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 2 |
8 files changed, 299 insertions, 6 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 9cd978c97e..d1f99fe1ef 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -27,6 +27,10 @@ static void Int32Deallocator(void* data, size_t, void* arg) { delete[] static_cast<int32_t*>(data); } +static void DoubleDeallocator(void* data, size_t, void* arg) { + delete[] static_cast<double*>(data); +} + TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; for (int i = 0; i < num_dims; ++i) { @@ -63,6 +67,14 @@ TF_Tensor* Int32Tensor(int32_t v) { &Int32Deallocator, nullptr); } +TF_Tensor* DoubleTensor(double v) { + const int num_bytes = sizeof(double); + double* values = new double[1]; + values[0] = v; + return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes, + &DoubleDeallocator, nullptr); +} + // All the *Helper methods are used as a workaround for the restrictions that // one cannot call ASSERT_* methods in non-void-returning functions (when // exceptions are disabled during compilation) @@ -105,6 +117,12 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } +TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, + const char* name) { + unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor); + return Const(tensor.get(), graph, s, name); +} + void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op, bool check) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index a927739d46..91f96b0e5d 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -42,6 +42,8 @@ TF_Tensor* Int32Tensor(const std::vector<int32_t>& values); TF_Tensor* Int32Tensor(int32_t v); +TF_Tensor* DoubleTensor(double v); + TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name = "feed"); @@ -51,6 +53,9 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s, TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s, const char* name = "scalar"); +TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar"); + TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "add"); diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index ce4f86bb25..27be5d787f 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -288,20 +288,86 @@ TEST_F(CApiWhileLoopTest, NestedLoop) { ExpectOutputValue(1, 3); } -TEST_F(CApiWhileLoopTest, BadCondOutput) { +TEST_F(CApiWhileLoopTest, UnsetCondOutput) { Init(1); params_->body_outputs[0] = params_->body_inputs[0]; ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `cond_output` field isn't set"); } -TEST_F(CApiWhileLoopTest, BadBodyOutput) { +TEST_F(CApiWhileLoopTest, WrongCondOutputType) { + Init(1); + params_->cond_output = params_->cond_inputs[0]; + params_->body_outputs[0] = params_->body_inputs[0]; + ExpectError(TF_INVALID_ARGUMENT, + "BuildWhileLoop: 'cond' argument must return a boolean output, " + "got int32"); +} + +TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { + Init(1); + // Try to reuse node from parent graph + params_->cond_output = inputs_[0]; + params_->body_outputs[0] = params_->body_inputs[0]; + // TODO(skyewm): this error message could be more informative. Add explicit + // checks for this case in the while loop implementation? + ExpectError(TF_INVALID_ARGUMENT, + "Requested return node 'p0' not found in graph def"); +} + +TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { + Init(1); + CreateCondGraph(); + params_->cond_output.index = 100; + params_->body_outputs[0] = params_->body_inputs[0]; + ExpectError(TF_INVALID_ARGUMENT, + "Invalid return output 100 of node 'less_than', which has 1 " + "output(s)"); +} + +// TODO(skyewm): test bad cond output shape + +TEST_F(CApiWhileLoopTest, UnsetBodyOutput) { Init(1); CreateCondGraph(); ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `body_outputs[0]` field isn't set"); } +// TODO(skyewm): enable this when it works (currently doesn't error) +// TEST_F(CApiWhileLoopTest, WrongBodyOutputType) { +// Init(1); +// CreateCondGraph(); +// TF_Operation* double_scalar = +// ScalarConst(1.0, params_->body_graph, s_, "double_scalar"); +// params_->body_outputs[0] = {double_scalar, 0}; +// ExpectError(TF_INVALID_ARGUMENT, "bad body output type"); +// } + +TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { + Init(1); + CreateCondGraph(); + // Try to reuse node from parent graph + params_->body_outputs[0] = inputs_[0]; + // TODO(skyewm): this error message could be more informative. Add explicit + // checks for this case in the while loop implementation? + ExpectError(TF_INVALID_ARGUMENT, + "Requested return node 'p0' not found in graph def"); +} + +// TODO(skyewm): enable this when it works (currently segfaults!) +// TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) { +// Init(1); +// CreateCondGraph(); +// params_->body_outputs[0] = params_->body_inputs[0]; +// params_->body_outputs[0].index = 100; +// ExpectError(TF_INVALID_ARGUMENT, +// "Invalid return output 100 of node 'less_than', which has 1 " +// "output(s)"); +// } + +// TODO(skyewm): test bad body output shape + TEST_F(CApiWhileLoopTest, NullName) { Init(1); CreateCondGraph(); diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index d9071ba6e4..c6d5792f49 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -248,6 +248,21 @@ cc_library_with_android_deps( ], ) +tf_cc_test( + name = "ops_while_loop_test", + size = "small", + srcs = ["ops/while_loop_test.cc"], + deps = [ + ":cc_ops", + ":client_session", + ":testutil", + ":while_loop", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + cc_library( name = "grad_op_registry", srcs = ["framework/grad_op_registry.cc"], diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index 27da77bbe0..e3e39da85e 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -102,11 +102,16 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, scope.NewSubScope("cond").WithControlDependencies(inputs[0]); Output raw_cond_out; TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); + + TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(), + raw_cond_out.index())); if (raw_cond_out.type() != DT_BOOL) { return errors::InvalidArgument( "BuildWhileLoop: 'cond' argument must return a boolean output, got ", DataTypeString(raw_cond_out.type())); } + // TODO(skyewm): check that raw_cond_out is scalar + *output = LoopCond(scope, raw_cond_out).output; return Status::OK(); } @@ -123,13 +128,18 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, Scope body_scope = scope.NewSubScope("body").WithControlDependencies(inputs[0]); TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); + const size_t num_loop_vars = inputs.size(); if (outputs->size() != num_loop_vars) { return errors::InvalidArgument( "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars, - "outputs, got ", outputs->size()); + " output(s), got ", outputs->size()); + } + for (const Output& output : *outputs) { + TF_RETURN_IF_ERROR( + scope.graph()->IsValidOutputTensor(output.node(), output.index())); + // TODO(skyewm): check output types/shapes } - // TODO(skyewm): check output types/shapes return Status::OK(); } diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc new file mode 100644 index 0000000000..77028b5c41 --- /dev/null +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class WhileLoopTest : public ::testing::Test { + protected: + WhileLoopTest() : scope_(Scope::NewRootScope()) {} + + void Init(int num_inputs, DataType dtype = DT_INT32) { + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(ops::Placeholder(scope_, dtype)); + } + } + + void CreateLoop(const ops::CondGraphBuilderFn& cond, + const ops::BodyGraphBuilderFn& body, + error::Code error_code = error::OK, + const string& error_msg = "") { + Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop", + &outputs_); + EXPECT_EQ(s.code(), error_code); + EXPECT_EQ(s.error_message(), error_msg); + } + + template <typename T> + void Run(const std::vector<Input::Initializer>& input_values, + const std::vector<T>& expected_output_values) { + ClientSession session(scope_); + + DCHECK_EQ(input_values.size(), inputs_.size()); + ClientSession::FeedType feeds; + for (int i = 0; i < inputs_.size(); ++i) { + feeds.emplace(inputs_[i], input_values[i]); + } + + std::vector<Tensor> out_tensors; + TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors)); + ASSERT_EQ(out_tensors.size(), outputs_.size()); + + DCHECK_EQ(expected_output_values.size(), out_tensors.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + test::ExpectTensorEqual<T>( + out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {})); + } + } + + Scope scope_; + std::vector<Output> inputs_; + std::vector<Output> outputs_; +}; + +Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs, + Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); +} + +Status AddOneBody(const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back(ops::Add(s, inputs[0], 1)); + return s.status(); +} + +TEST_F(WhileLoopTest, Basic) { + // Create loop: while (i < 10) i += 1 + Init(1); + CreateLoop(LessThanTenCond, AddOneBody); + Run<int>({1}, {10}); + Run<int>({11}, {11}); +} + +TEST_F(WhileLoopTest, WrongCondOutputType) { + Init(1); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = ops::Placeholder(s, DT_FLOAT); + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, + "BuildWhileLoop: 'cond' argument must return a boolean output, got " + "float"); +} + +// TODO(skyewm): test bad cond output shape + +TEST_F(WhileLoopTest, NullCondOutputNode) { + Init(1); + // TODO(skyewm): improve error message + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + *output = {nullptr, 0}; + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, "Node is null"); +} + +TEST_F(WhileLoopTest, InvalidCondOutputIndex) { + Init(1); + CreateLoop( + [](const Scope& s, const std::vector<Output>& inputs, Output* output) { + auto less = ops::Less(s, inputs[0], 10); + *output = {less.node(), 100}; + return s.status(); + }, + AddOneBody, error::INVALID_ARGUMENT, + "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output " + "100"); +} + +TEST_F(WhileLoopTest, UnsetCondOutput) { + Init(1); + CreateLoop([](const Scope& s, const std::vector<Output>& inputs, + Output* output) { return s.status(); }, + AddOneBody, error::INVALID_ARGUMENT, "Node is null"); +} + +// TODO(skyewm): test bad body output type +// TODO(skyewm): test bad body output shape + +TEST_F(WhileLoopTest, NullBodyOutputNode) { + Init(1); + // TODO(skyewm): improve error message + CreateLoop(LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + outputs->push_back({nullptr, 0}); + return s.status(); + }, + error::INVALID_ARGUMENT, "Node is null"); +} + +TEST_F(WhileLoopTest, InvalidBodyOutputIndex) { + Init(1); + CreateLoop(LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { + auto add = ops::Add(s, inputs[0], 1); + outputs->emplace_back(add.node(), 100); + return s.status(); + }, + error::INVALID_ARGUMENT, + "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have " + "output 100"); +} + +TEST_F(WhileLoopTest, UnsetBodyOutputs) { + Init(1); + CreateLoop( + LessThanTenCond, + [](const Scope& s, const std::vector<Output>& inputs, + std::vector<Output>* outputs) { return s.status(); }, + error::INVALID_ARGUMENT, + "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index eb1a023ab2..8dcb6798c1 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -901,7 +901,7 @@ Status GraphConstructor::PopulateReturnTensors() { id.second != Graph::kControlSlot) { return errors::InvalidArgument("Invalid return output ", id.second, " of node '", id.first, "', which has ", - num_outputs, " outputs"); + num_outputs, " output(s)"); } return_tensors_->push_back({iter->second.node, id.second}); } else { diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index e448ce4927..1739fb554d 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -1607,7 +1607,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) { opts.return_tensors.push_back({"new_input", 2}); ExpectError("node { name: 'new_input' op: 'TestInput' }", opts, {"Invalid return output 2 of node 'new_input', which has 2 " - "outputs"}, + "output(s)"}, nullptr, &return_tensors); } |