aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_test_util.cc18
-rw-r--r--tensorflow/c/c_test_util.h5
-rw-r--r--tensorflow/c/while_loop_test.cc70
-rw-r--r--tensorflow/cc/BUILD15
-rw-r--r--tensorflow/cc/ops/while_loop.cc14
-rw-r--r--tensorflow/cc/ops/while_loop_test.cc179
-rw-r--r--tensorflow/core/graph/graph_constructor.cc2
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc2
8 files changed, 299 insertions, 6 deletions
diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc
index 9cd978c97e..d1f99fe1ef 100644
--- a/tensorflow/c/c_test_util.cc
+++ b/tensorflow/c/c_test_util.cc
@@ -27,6 +27,10 @@ static void Int32Deallocator(void* data, size_t, void* arg) {
delete[] static_cast<int32_t*>(data);
}
+static void DoubleDeallocator(void* data, size_t, void* arg) {
+ delete[] static_cast<double*>(data);
+}
+
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;
for (int i = 0; i < num_dims; ++i) {
@@ -63,6 +67,14 @@ TF_Tensor* Int32Tensor(int32_t v) {
&Int32Deallocator, nullptr);
}
+TF_Tensor* DoubleTensor(double v) {
+ const int num_bytes = sizeof(double);
+ double* values = new double[1];
+ values[0] = v;
+ return TF_NewTensor(TF_DOUBLE, nullptr, 0, values, num_bytes,
+ &DoubleDeallocator, nullptr);
+}
+
// All the *Helper methods are used as a workaround for the restrictions that
// one cannot call ASSERT_* methods in non-void-returning functions (when
// exceptions are disabled during compilation)
@@ -105,6 +117,12 @@ TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
return Const(tensor.get(), graph, s, name);
}
+TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
+ const char* name) {
+ unique_tensor_ptr tensor(DoubleTensor(v), TF_DeleteTensor);
+ return Const(tensor.get(), graph, s, name);
+}
+
void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s,
const char* name, TF_Operation** op, bool check) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h
index a927739d46..91f96b0e5d 100644
--- a/tensorflow/c/c_test_util.h
+++ b/tensorflow/c/c_test_util.h
@@ -42,6 +42,8 @@ TF_Tensor* Int32Tensor(const std::vector<int32_t>& values);
TF_Tensor* Int32Tensor(int32_t v);
+TF_Tensor* DoubleTensor(double v);
+
TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
const char* name = "feed");
@@ -51,6 +53,9 @@ TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
TF_Operation* ScalarConst(int32_t v, TF_Graph* graph, TF_Status* s,
const char* name = "scalar");
+TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s,
+ const char* name = "scalar");
+
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name = "add");
diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc
index ce4f86bb25..27be5d787f 100644
--- a/tensorflow/c/while_loop_test.cc
+++ b/tensorflow/c/while_loop_test.cc
@@ -288,20 +288,86 @@ TEST_F(CApiWhileLoopTest, NestedLoop) {
ExpectOutputValue(1, 3);
}
-TEST_F(CApiWhileLoopTest, BadCondOutput) {
+TEST_F(CApiWhileLoopTest, UnsetCondOutput) {
Init(1);
params_->body_outputs[0] = params_->body_inputs[0];
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `cond_output` field isn't set");
}
-TEST_F(CApiWhileLoopTest, BadBodyOutput) {
+TEST_F(CApiWhileLoopTest, WrongCondOutputType) {
+ Init(1);
+ params_->cond_output = params_->cond_inputs[0];
+ params_->body_outputs[0] = params_->body_inputs[0];
+ ExpectError(TF_INVALID_ARGUMENT,
+ "BuildWhileLoop: 'cond' argument must return a boolean output, "
+ "got int32");
+}
+
+TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) {
+ Init(1);
+ // Try to reuse node from parent graph
+ params_->cond_output = inputs_[0];
+ params_->body_outputs[0] = params_->body_inputs[0];
+ // TODO(skyewm): this error message could be more informative. Add explicit
+ // checks for this case in the while loop implementation?
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Requested return node 'p0' not found in graph def");
+}
+
+TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) {
+ Init(1);
+ CreateCondGraph();
+ params_->cond_output.index = 100;
+ params_->body_outputs[0] = params_->body_inputs[0];
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Invalid return output 100 of node 'less_than', which has 1 "
+ "output(s)");
+}
+
+// TODO(skyewm): test bad cond output shape
+
+TEST_F(CApiWhileLoopTest, UnsetBodyOutput) {
Init(1);
CreateCondGraph();
ExpectError(TF_INVALID_ARGUMENT,
"TF_WhileParams `body_outputs[0]` field isn't set");
}
+// TODO(skyewm): enable this when it works (currently doesn't error)
+// TEST_F(CApiWhileLoopTest, WrongBodyOutputType) {
+// Init(1);
+// CreateCondGraph();
+// TF_Operation* double_scalar =
+// ScalarConst(1.0, params_->body_graph, s_, "double_scalar");
+// params_->body_outputs[0] = {double_scalar, 0};
+// ExpectError(TF_INVALID_ARGUMENT, "bad body output type");
+// }
+
+TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) {
+ Init(1);
+ CreateCondGraph();
+ // Try to reuse node from parent graph
+ params_->body_outputs[0] = inputs_[0];
+ // TODO(skyewm): this error message could be more informative. Add explicit
+ // checks for this case in the while loop implementation?
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Requested return node 'p0' not found in graph def");
+}
+
+// TODO(skyewm): enable this when it works (currently segfaults!)
+// TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) {
+// Init(1);
+// CreateCondGraph();
+// params_->body_outputs[0] = params_->body_inputs[0];
+// params_->body_outputs[0].index = 100;
+// ExpectError(TF_INVALID_ARGUMENT,
+// "Invalid return output 100 of node 'less_than', which has 1 "
+// "output(s)");
+// }
+
+// TODO(skyewm): test bad body output shape
+
TEST_F(CApiWhileLoopTest, NullName) {
Init(1);
CreateCondGraph();
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index d9071ba6e4..c6d5792f49 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -248,6 +248,21 @@ cc_library_with_android_deps(
],
)
+tf_cc_test(
+ name = "ops_while_loop_test",
+ size = "small",
+ srcs = ["ops/while_loop_test.cc"],
+ deps = [
+ ":cc_ops",
+ ":client_session",
+ ":testutil",
+ ":while_loop",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
cc_library(
name = "grad_op_registry",
srcs = ["framework/grad_op_registry.cc"],
diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc
index 27da77bbe0..e3e39da85e 100644
--- a/tensorflow/cc/ops/while_loop.cc
+++ b/tensorflow/cc/ops/while_loop.cc
@@ -102,11 +102,16 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
Output raw_cond_out;
TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
+
+ TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(),
+ raw_cond_out.index()));
if (raw_cond_out.type() != DT_BOOL) {
return errors::InvalidArgument(
"BuildWhileLoop: 'cond' argument must return a boolean output, got ",
DataTypeString(raw_cond_out.type()));
}
+ // TODO(skyewm): check that raw_cond_out is scalar
+
*output = LoopCond(scope, raw_cond_out).output;
return Status::OK();
}
@@ -123,13 +128,18 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
Scope body_scope =
scope.NewSubScope("body").WithControlDependencies(inputs[0]);
TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
+
const size_t num_loop_vars = inputs.size();
if (outputs->size() != num_loop_vars) {
return errors::InvalidArgument(
"BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
- "outputs, got ", outputs->size());
+ " output(s), got ", outputs->size());
+ }
+ for (const Output& output : *outputs) {
+ TF_RETURN_IF_ERROR(
+ scope.graph()->IsValidOutputTensor(output.node(), output.index()));
+ // TODO(skyewm): check output types/shapes
}
- // TODO(skyewm): check output types/shapes
return Status::OK();
}
diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc
new file mode 100644
index 0000000000..77028b5c41
--- /dev/null
+++ b/tensorflow/cc/ops/while_loop_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/while_loop.h"
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+class WhileLoopTest : public ::testing::Test {
+ protected:
+ WhileLoopTest() : scope_(Scope::NewRootScope()) {}
+
+ void Init(int num_inputs, DataType dtype = DT_INT32) {
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(ops::Placeholder(scope_, dtype));
+ }
+ }
+
+ void CreateLoop(const ops::CondGraphBuilderFn& cond,
+ const ops::BodyGraphBuilderFn& body,
+ error::Code error_code = error::OK,
+ const string& error_msg = "") {
+ Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, "test_loop",
+ &outputs_);
+ EXPECT_EQ(s.code(), error_code);
+ EXPECT_EQ(s.error_message(), error_msg);
+ }
+
+ template <typename T>
+ void Run(const std::vector<Input::Initializer>& input_values,
+ const std::vector<T>& expected_output_values) {
+ ClientSession session(scope_);
+
+ DCHECK_EQ(input_values.size(), inputs_.size());
+ ClientSession::FeedType feeds;
+ for (int i = 0; i < inputs_.size(); ++i) {
+ feeds.emplace(inputs_[i], input_values[i]);
+ }
+
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), outputs_.size());
+
+ DCHECK_EQ(expected_output_values.size(), out_tensors.size());
+ for (int i = 0; i < out_tensors.size(); ++i) {
+ test::ExpectTensorEqual<T>(
+ out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {}));
+ }
+ }
+
+ Scope scope_;
+ std::vector<Output> inputs_;
+ std::vector<Output> outputs_;
+};
+
+Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
+ Output* output) {
+ *output = ops::Less(s, inputs[0], 10);
+ return s.status();
+}
+
+Status AddOneBody(const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back(ops::Add(s, inputs[0], 1));
+ return s.status();
+}
+
+TEST_F(WhileLoopTest, Basic) {
+ // Create loop: while (i < 10) i += 1
+ Init(1);
+ CreateLoop(LessThanTenCond, AddOneBody);
+ Run<int>({1}, {10});
+ Run<int>({11}, {11});
+}
+
+TEST_F(WhileLoopTest, WrongCondOutputType) {
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = ops::Placeholder(s, DT_FLOAT);
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT,
+ "BuildWhileLoop: 'cond' argument must return a boolean output, got "
+ "float");
+}
+
+// TODO(skyewm): test bad cond output shape
+
+TEST_F(WhileLoopTest, NullCondOutputNode) {
+ Init(1);
+ // TODO(skyewm): improve error message
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ *output = {nullptr, 0};
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT, "Node is null");
+}
+
+TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
+ Init(1);
+ CreateLoop(
+ [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
+ auto less = ops::Less(s, inputs[0], 10);
+ *output = {less.node(), 100};
+ return s.status();
+ },
+ AddOneBody, error::INVALID_ARGUMENT,
+ "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
+ "100");
+}
+
+TEST_F(WhileLoopTest, UnsetCondOutput) {
+ Init(1);
+ CreateLoop([](const Scope& s, const std::vector<Output>& inputs,
+ Output* output) { return s.status(); },
+ AddOneBody, error::INVALID_ARGUMENT, "Node is null");
+}
+
+// TODO(skyewm): test bad body output type
+// TODO(skyewm): test bad body output shape
+
+TEST_F(WhileLoopTest, NullBodyOutputNode) {
+ Init(1);
+ // TODO(skyewm): improve error message
+ CreateLoop(LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ outputs->push_back({nullptr, 0});
+ return s.status();
+ },
+ error::INVALID_ARGUMENT, "Node is null");
+}
+
+TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
+ Init(1);
+ CreateLoop(LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) {
+ auto add = ops::Add(s, inputs[0], 1);
+ outputs->emplace_back(add.node(), 100);
+ return s.status();
+ },
+ error::INVALID_ARGUMENT,
+ "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
+ "output 100");
+}
+
+TEST_F(WhileLoopTest, UnsetBodyOutputs) {
+ Init(1);
+ CreateLoop(
+ LessThanTenCond,
+ [](const Scope& s, const std::vector<Output>& inputs,
+ std::vector<Output>* outputs) { return s.status(); },
+ error::INVALID_ARGUMENT,
+ "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index eb1a023ab2..8dcb6798c1 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -901,7 +901,7 @@ Status GraphConstructor::PopulateReturnTensors() {
id.second != Graph::kControlSlot) {
return errors::InvalidArgument("Invalid return output ", id.second,
" of node '", id.first, "', which has ",
- num_outputs, " outputs");
+ num_outputs, " output(s)");
}
return_tensors_->push_back({iter->second.node, id.second});
} else {
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e448ce4927..1739fb554d 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -1607,7 +1607,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ReturnTensorsErrors) {
opts.return_tensors.push_back({"new_input", 2});
ExpectError("node { name: 'new_input' op: 'TestInput' }", opts,
{"Invalid return output 2 of node 'new_input', which has 2 "
- "outputs"},
+ "output(s)"},
nullptr, &return_tensors);
}