aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-02-07 08:16:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 08:29:31 -0800
commit661058e52460b36417573e2c5a73de9a8b9e5edb (patch)
tree4e0505510d82eba8108d1f55ec867d022ee94a01 /tensorflow/c/c_api_test.cc
parent1bed3d798e79ccfea7f0e5bb782f86466bb6bffb (diff)
C API methods for creating while loops
Change: 146788176
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc362
1 files changed, 354 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 5591409d99..fac486d6eb 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -269,15 +269,17 @@ static TF_Tensor* Int32Tensor(int32 v) {
&Int32Deallocator, nullptr);
}
-TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed");
+TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
+ const char* name = "feed") {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishOperation(desc, s);
}
-TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
+TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s,
+ const char* name = "scalar") {
unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
- TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar");
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
TF_SetAttrTensor(desc, "value", tensor.get(), s);
if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_INT32);
@@ -285,13 +287,21 @@ TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
}
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add");
+ TF_Status* s, const char* name = "add") {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2);
return TF_FinishOperation(desc, s);
}
+TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
+ const char* name = "add") {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
+ TF_Output inputs[2] = {l, r};
+ TF_AddInputList(desc, inputs, 2);
+ return TF_FinishOperation(desc, s);
+}
+
TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Output neg_input = {n, 0};
@@ -299,6 +309,14 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
return TF_FinishOperation(desc, s);
}
+TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
+ TF_Status* s) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
+ TF_AddInput(desc, l);
+ TF_AddInput(desc, r);
+ return TF_FinishOperation(desc, s);
+}
+
bool IsPlaceholder(const NodeDef& node_def) {
if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
return false;
@@ -667,6 +685,28 @@ TEST(CAPI, Graph) {
TF_DeleteStatus(s);
}
+/*
+TODO(skyewm): this test currently DCHECKs, change to bad status
+
+TEST(CAPI, InputFromDifferentGraphError) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* g1 = TF_NewGraph();
+ TF_Graph* g2 = TF_NewGraph();
+
+ TF_Operation* feed = Placeholder(g1, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Attempt to create node in g2 with input from g1
+ Neg(feed, g2, s);
+ EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
+ EXPECT_STREQ("foo", TF_Message(s));
+
+ TF_DeleteGraph(g1);
+ TF_DeleteGraph(g2);
+ TF_DeleteStatus(s);
+}
+*/
+
TEST(CAPI, ImportGraphDef) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@@ -793,8 +833,7 @@ class CSession {
TF_DeleteStatus(s);
}
- void SetInputs(
- std::initializer_list<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
+ void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
DeleteInputValues();
inputs_.clear();
for (const auto& p : inputs) {
@@ -811,6 +850,11 @@ class CSession {
}
}
+ void SetOutputs(const std::vector<TF_Output>& outputs) {
+ ResetOutputValues();
+ outputs_ = outputs;
+ }
+
void SetTargets(std::initializer_list<TF_Operation*> targets) {
targets_.clear();
for (TF_Operation* t : targets) {
@@ -1068,6 +1112,308 @@ TEST(CAPI, SavedModelNullArgsAreValid) {
TF_DeleteStatus(s);
}
+class CApiWhileLoopTest : public ::testing::Test {
+ protected:
+ CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
+
+ ~CApiWhileLoopTest() override {
+ TF_DeleteGraph(graph_);
+ TF_DeleteStatus(s_);
+ }
+
+ void Init(int ninputs) {
+ DCHECK(inputs_.empty());
+ DCHECK_GT(ninputs, 0);
+
+ for (int i = 0; i < ninputs; ++i) {
+ TF_Operation* placeholder = Placeholder(
+ graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
+ DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ inputs_.push_back({placeholder, 0});
+ }
+
+ original_graph_description_ = GraphDebugString();
+
+ params_.reset(new TF_WhileParams(
+ TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ ASSERT_EQ(original_graph_description_, GraphDebugString())
+ << "TF_NewWhile() altered graph";
+
+ params_->name = "test_loop";
+
+ // Initialize outputs_ so we can easily detect errors/bugs
+ outputs_.resize(ninputs, {nullptr, -1});
+ }
+
+ void ExpectOK() {
+ TF_FinishWhile(params_.get(), s_, &outputs_[0]);
+ EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ }
+
+ void ExpectError(TF_Code expected_code, const string& expected_msg) {
+ TF_FinishWhile(params_.get(), s_, &outputs_[0]);
+ EXPECT_EQ(expected_code, TF_GetCode(s_));
+ EXPECT_EQ(expected_msg, TF_Message(s_));
+ // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
+ // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
+ // "TF_FinishWhile() altered graph on error";
+ }
+
+ void Run(std::initializer_list<int> input_values) {
+ DCHECK_EQ(inputs_.size(), input_values.size());
+ std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
+ int i = 0;
+ for (int v : input_values) {
+ inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
+ ++i;
+ }
+ csession_.reset(new CSession(graph_, s_));
+ csession_->SetInputs(inputs);
+ csession_->SetOutputs(outputs_);
+ csession_->Run(s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ }
+
+ void ExpectOutputValue(int idx, int expected_value) {
+ TF_Tensor* out = csession_->output_tensor(idx);
+ ASSERT_TRUE(out != nullptr);
+ EXPECT_EQ(TF_INT32, TF_TensorType(out));
+ EXPECT_EQ(0, TF_NumDims(out));
+ ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
+ int32* data = static_cast<int32*>(TF_TensorData(out));
+ EXPECT_EQ(expected_value, *data);
+ }
+
+ // Create a valid conditonal graph. Useful for testing unrelated errors.
+ void CreateCondGraph() {
+ TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
+ TF_Operation* less_than =
+ LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
+ DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->cond_output = {less_than, 0};
+ }
+
+ string GraphDebugString() const {
+ TF_Buffer* buf = TF_NewBuffer();
+ TF_GraphToGraphDef(graph_, buf, s_);
+ DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ GraphDef def;
+ bool success = def.ParseFromArray(buf->data, buf->length);
+ DCHECK(success);
+ TF_DeleteBuffer(buf);
+ return def.DebugString();
+ }
+
+ TF_Status* s_;
+ TF_Graph* graph_;
+ std::vector<TF_Output> inputs_; // The inputs to the while loop
+ std::vector<TF_Output> outputs_; // The final outputs of the while loop
+ std::unique_ptr<TF_WhileParams> params_;
+ std::unique_ptr<CSession> csession_;
+
+ private:
+ // Used to verify that errors don't change graph_
+ string original_graph_description_;
+};
+
+TEST_F(CApiWhileLoopTest, BasicLoop) {
+ Init(2);
+
+ // Validate TF_WhileParams returned by TF_NewWhile()
+ EXPECT_TRUE(params_->body_graph != nullptr);
+ EXPECT_TRUE(params_->cond_graph != nullptr);
+
+ EXPECT_EQ(params_->ninputs, 2);
+
+ ASSERT_TRUE(params_->cond_inputs != nullptr);
+ ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
+ EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
+
+ ASSERT_TRUE(params_->body_inputs != nullptr);
+ EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
+ EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
+
+ ASSERT_TRUE(params_->body_outputs != nullptr);
+
+ // Create loop: while (input1 < input2) input1 += input2 + 1
+ TF_Operation* less_than =
+ LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
+ params_->cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->cond_output = {less_than, 0};
+
+ TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
+ params_->body_graph, s_, "add1");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->body_outputs[0] = {add2, 0};
+ params_->body_outputs[1] = params_->body_inputs[1];
+
+ // Finalize while loop
+ ExpectOK();
+
+ // Validate while loop outputs returned by TF_FinishWhile()
+ EXPECT_TRUE(outputs_[0].oper != nullptr);
+ EXPECT_GE(outputs_[0].index, 0);
+ EXPECT_TRUE(outputs_[1].oper != nullptr);
+ EXPECT_GE(outputs_[1].index, 0);
+
+ // Run the graph
+ Run({-9, 2});
+ ExpectOutputValue(0, 3);
+ ExpectOutputValue(1, 2);
+}
+
+TEST_F(CApiWhileLoopTest, NestedLoop) {
+ Init(2);
+ // Create nested loop:
+ // while (input1 < 6) {
+ // inner_input1 = input1
+ // while (inner_input1 < 3) {
+ // input2 += 1
+ // inner_input1 += 2
+ // }
+ // input1 += input2
+ // }
+ //
+ // Expected execution with initial values input1 = input2 = 0:
+ //
+ // outer inner inner_
+ // step# step# input1 input2 input1
+ // ------------------------------------
+ // 0 0 0 0 0
+ // 0 1 0 1 2
+ // 0 2 0 2 4
+ // 0 - 2 2 -
+ // 1 0 2 2 2
+ // 1 1 2 3 4
+ // 1 - 5 3 -
+ // 2 0 5 3 5
+ // 2 - 8 3 -
+
+ // Create outer cond graph
+ TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* less_than =
+ LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->cond_output = {less_than, 0};
+
+ // Create outer body graph
+ // Init inner graph
+ TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
+ TF_WhileParams inner_params =
+ TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ inner_params.name = "inner_loop";
+
+ // Create inner cond graph
+ TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* inner_less_than = LessThan(
+ inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ inner_params.cond_output = {inner_less_than, 0};
+
+ // Create inner body graph
+ TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* input2_add =
+ Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ inner_params.body_outputs[1] = {input2_add, 0};
+
+ TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
+ inner_params.body_graph, s_, "add2");
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ inner_params.body_outputs[0] = {inner_input1_add, 0};
+
+ // Finalize inner graph
+ TF_Output inner_outputs[2] = {{nullptr, -1}};
+ TF_FinishWhile(&inner_params, s_, inner_outputs);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+
+ TF_Operation* input1_add =
+ Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
+ ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
+ params_->body_outputs[0] = {input1_add, 0};
+
+ params_->body_outputs[1] = inner_outputs[1];
+
+ // Finalize outer graph
+ ExpectOK();
+
+ // Check for a few expected nodes
+ const char* node_name = "test_loop/cond/scalar";
+ EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
+ node_name = "test_loop/body/add";
+ EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
+ node_name = "test_loop/body/inner_loop/body/one";
+ EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
+ node_name = "test_loop/body/inner_loop/cond/less_than";
+ EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
+
+ // Run the graph
+ Run({0, 0});
+ ExpectOutputValue(0, 8);
+ ExpectOutputValue(1, 3);
+}
+
+TEST_F(CApiWhileLoopTest, BadCondOutput) {
+ Init(1);
+ params_->body_outputs[0] = params_->body_inputs[0];
+ ExpectError(TF_INVALID_ARGUMENT,
+ "TF_WhileParams `cond_output` field isn't set");
+}
+
+TEST_F(CApiWhileLoopTest, BadBodyOutput) {
+ Init(1);
+ CreateCondGraph();
+ ExpectError(TF_INVALID_ARGUMENT,
+ "TF_WhileParams `body_outputs[0]` field isn't set");
+}
+
+TEST_F(CApiWhileLoopTest, NullName) {
+ Init(1);
+ CreateCondGraph();
+ params_->body_outputs[0] = params_->body_inputs[0];
+ params_->name = nullptr;
+ ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
+}
+
+TEST_F(CApiWhileLoopTest, WrongGraph) {
+ Init(1);
+ CreateCondGraph();
+ // Set body output to output from outer graph
+ params_->body_outputs[0] = inputs_[0];
+ // TODO(skyewm): improve error message
+ ExpectError(TF_INVALID_ARGUMENT,
+ "Requested return node 'p0' not found in graph def");
+}
+
+TEST_F(CApiWhileLoopTest, BadTypes) {
+ Init(1);
+ CreateCondGraph();
+ // Op that has a float input + output
+ TF_OperationDescription* desc = TF_NewOperation(
+ params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
+ TF_AddInput(desc, params_->body_inputs[0]);
+ TF_FinishOperation(desc, s_);
+ ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
+ string msg(TF_Message(s_));
+ EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
+ "building NodeDef 'float_op'"),
+ msg.npos);
+ TF_AbortWhile(params_.get());
+}
+
// Create a tensor with values of type TF_INT8 provided by `values`.
TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) {
int64_t num_values = 1;