diff options
author | 2017-02-07 08:16:06 -0800 | |
---|---|---|
committer | 2017-02-07 08:29:31 -0800 | |
commit | 661058e52460b36417573e2c5a73de9a8b9e5edb (patch) | |
tree | 4e0505510d82eba8108d1f55ec867d022ee94a01 /tensorflow/c/c_api_test.cc | |
parent | 1bed3d798e79ccfea7f0e5bb782f86466bb6bffb (diff) |
C API methods for creating while loops
Change: 146788176
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r-- | tensorflow/c/c_api_test.cc | 362 |
1 files changed, 354 insertions, 8 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 5591409d99..fac486d6eb 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -269,15 +269,17 @@ static TF_Tensor* Int32Tensor(int32 v) { &Int32Deallocator, nullptr); } -TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed"); +TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, + const char* name = "feed") { + TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name); TF_SetAttrType(desc, "dtype", TF_INT32); return TF_FinishOperation(desc, s); } -TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) { +TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s, + const char* name = "scalar") { unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor); - TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar"); + TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name); TF_SetAttrTensor(desc, "value", tensor.get(), s); if (TF_GetCode(s) != TF_OK) return nullptr; TF_SetAttrType(desc, "dtype", TF_INT32); @@ -285,13 +287,21 @@ TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) { } TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s) { - TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add"); + TF_Status* s, const char* name = "add") { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); return TF_FinishOperation(desc, s); } +TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, + const char* name = "add") { + TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); + TF_Output inputs[2] = {l, r}; + TF_AddInputList(desc, inputs, 2); + return TF_FinishOperation(desc, s); +} + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg"); TF_Output neg_input = {n, 0}; @@ -299,6 +309,14 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) { return TF_FinishOperation(desc, s); } +TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, + TF_Status* s) { + TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than"); + TF_AddInput(desc, l); + TF_AddInput(desc, r); + return TF_FinishOperation(desc, s); +} + bool IsPlaceholder(const NodeDef& node_def) { if (node_def.op() != "Placeholder" || node_def.name() != "feed") { return false; @@ -667,6 +685,28 @@ TEST(CAPI, Graph) { TF_DeleteStatus(s); } +/* +TODO(skyewm): this test currently DCHECKs, change to bad status + +TEST(CAPI, InputFromDifferentGraphError) { + TF_Status* s = TF_NewStatus(); + TF_Graph* g1 = TF_NewGraph(); + TF_Graph* g2 = TF_NewGraph(); + + TF_Operation* feed = Placeholder(g1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Attempt to create node in g2 with input from g1 + Neg(feed, g2, s); + EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); + EXPECT_STREQ("foo", TF_Message(s)); + + TF_DeleteGraph(g1); + TF_DeleteGraph(g2); + TF_DeleteStatus(s); +} +*/ + TEST(CAPI, ImportGraphDef) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -793,8 +833,7 @@ class CSession { TF_DeleteStatus(s); } - void SetInputs( - std::initializer_list<std::pair<TF_Operation*, TF_Tensor*>> inputs) { + void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) { DeleteInputValues(); inputs_.clear(); for (const auto& p : inputs) { @@ -811,6 +850,11 @@ class CSession { } } + void SetOutputs(const std::vector<TF_Output>& outputs) { + ResetOutputValues(); + outputs_ = outputs; + } + void SetTargets(std::initializer_list<TF_Operation*> targets) { targets_.clear(); for (TF_Operation* t : targets) { @@ -1068,6 +1112,308 @@ TEST(CAPI, SavedModelNullArgsAreValid) { TF_DeleteStatus(s); } +class CApiWhileLoopTest : public ::testing::Test { + protected: + CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} + + ~CApiWhileLoopTest() override { + TF_DeleteGraph(graph_); + TF_DeleteStatus(s_); + } + + void Init(int ninputs) { + DCHECK(inputs_.empty()); + DCHECK_GT(ninputs, 0); + + for (int i = 0; i < ninputs; ++i) { + TF_Operation* placeholder = Placeholder( + graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inputs_.push_back({placeholder, 0}); + } + + original_graph_description_ = GraphDebugString(); + + params_.reset(new TF_WhileParams( + TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + ASSERT_EQ(original_graph_description_, GraphDebugString()) + << "TF_NewWhile() altered graph"; + + params_->name = "test_loop"; + + // Initialize outputs_ so we can easily detect errors/bugs + outputs_.resize(ninputs, {nullptr, -1}); + } + + void ExpectOK() { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectError(TF_Code expected_code, const string& expected_msg) { + TF_FinishWhile(params_.get(), s_, &outputs_[0]); + EXPECT_EQ(expected_code, TF_GetCode(s_)); + EXPECT_EQ(expected_msg, TF_Message(s_)); + // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. + // ASSERT_EQ(original_graph_description_, GraphDebugString()) << + // "TF_FinishWhile() altered graph on error"; + } + + void Run(std::initializer_list<int> input_values) { + DCHECK_EQ(inputs_.size(), input_values.size()); + std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size()); + int i = 0; + for (int v : input_values) { + inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; + ++i; + } + csession_.reset(new CSession(graph_, s_)); + csession_->SetInputs(inputs); + csession_->SetOutputs(outputs_); + csession_->Run(s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + } + + void ExpectOutputValue(int idx, int expected_value) { + TF_Tensor* out = csession_->output_tensor(idx); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* data = static_cast<int32*>(TF_TensorData(out)); + EXPECT_EQ(expected_value, *data); + } + + // Create a valid conditonal graph. Useful for testing unrelated errors. + void CreateCondGraph() { + TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + } + + string GraphDebugString() const { + TF_Buffer* buf = TF_NewBuffer(); + TF_GraphToGraphDef(graph_, buf, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + GraphDef def; + bool success = def.ParseFromArray(buf->data, buf->length); + DCHECK(success); + TF_DeleteBuffer(buf); + return def.DebugString(); + } + + TF_Status* s_; + TF_Graph* graph_; + std::vector<TF_Output> inputs_; // The inputs to the while loop + std::vector<TF_Output> outputs_; // The final outputs of the while loop + std::unique_ptr<TF_WhileParams> params_; + std::unique_ptr<CSession> csession_; + + private: + // Used to verify that errors don't change graph_ + string original_graph_description_; +}; + +TEST_F(CApiWhileLoopTest, BasicLoop) { + Init(2); + + // Validate TF_WhileParams returned by TF_NewWhile() + EXPECT_TRUE(params_->body_graph != nullptr); + EXPECT_TRUE(params_->cond_graph != nullptr); + + EXPECT_EQ(params_->ninputs, 2); + + ASSERT_TRUE(params_->cond_inputs != nullptr); + ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_inputs != nullptr); + EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); + EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); + + ASSERT_TRUE(params_->body_outputs != nullptr); + + // Create loop: while (input1 < input2) input1 += input2 + 1 + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], params_->cond_inputs[1], + params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], + params_->body_graph, s_, "add1"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add2, 0}; + params_->body_outputs[1] = params_->body_inputs[1]; + + // Finalize while loop + ExpectOK(); + + // Validate while loop outputs returned by TF_FinishWhile() + EXPECT_TRUE(outputs_[0].oper != nullptr); + EXPECT_GE(outputs_[0].index, 0); + EXPECT_TRUE(outputs_[1].oper != nullptr); + EXPECT_GE(outputs_[1].index, 0); + + // Run the graph + Run({-9, 2}); + ExpectOutputValue(0, 3); + ExpectOutputValue(1, 2); +} + +TEST_F(CApiWhileLoopTest, NestedLoop) { + Init(2); + // Create nested loop: + // while (input1 < 6) { + // inner_input1 = input1 + // while (inner_input1 < 3) { + // input2 += 1 + // inner_input1 += 2 + // } + // input1 += input2 + // } + // + // Expected execution with initial values input1 = input2 = 0: + // + // outer inner inner_ + // step# step# input1 input2 input1 + // ------------------------------------ + // 0 0 0 0 0 + // 0 1 0 1 2 + // 0 2 0 2 4 + // 0 - 2 2 - + // 1 0 2 2 2 + // 1 1 2 3 4 + // 1 - 5 3 - + // 2 0 5 3 5 + // 2 - 8 3 - + + // Create outer cond graph + TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + // Create outer body graph + // Init inner graph + TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; + TF_WhileParams inner_params = + TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.name = "inner_loop"; + + // Create inner cond graph + TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* inner_less_than = LessThan( + inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.cond_output = {inner_less_than, 0}; + + // Create inner body graph + TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input2_add = + Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[1] = {input2_add, 0}; + + TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, + inner_params.body_graph, s_, "add2"); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + inner_params.body_outputs[0] = {inner_input1_add, 0}; + + // Finalize inner graph + TF_Output inner_outputs[2] = {{nullptr, -1}}; + TF_FinishWhile(&inner_params, s_, inner_outputs); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + TF_Operation* input1_add = + Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {input1_add, 0}; + + params_->body_outputs[1] = inner_outputs[1]; + + // Finalize outer graph + ExpectOK(); + + // Check for a few expected nodes + const char* node_name = "test_loop/cond/scalar"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/add"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/body/one"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + node_name = "test_loop/body/inner_loop/cond/less_than"; + EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); + + // Run the graph + Run({0, 0}); + ExpectOutputValue(0, 8); + ExpectOutputValue(1, 3); +} + +TEST_F(CApiWhileLoopTest, BadCondOutput) { + Init(1); + params_->body_outputs[0] = params_->body_inputs[0]; + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `cond_output` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, BadBodyOutput) { + Init(1); + CreateCondGraph(); + ExpectError(TF_INVALID_ARGUMENT, + "TF_WhileParams `body_outputs[0]` field isn't set"); +} + +TEST_F(CApiWhileLoopTest, NullName) { + Init(1); + CreateCondGraph(); + params_->body_outputs[0] = params_->body_inputs[0]; + params_->name = nullptr; + ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); +} + +TEST_F(CApiWhileLoopTest, WrongGraph) { + Init(1); + CreateCondGraph(); + // Set body output to output from outer graph + params_->body_outputs[0] = inputs_[0]; + // TODO(skyewm): improve error message + ExpectError(TF_INVALID_ARGUMENT, + "Requested return node 'p0' not found in graph def"); +} + +TEST_F(CApiWhileLoopTest, BadTypes) { + Init(1); + CreateCondGraph(); + // Op that has a float input + output + TF_OperationDescription* desc = TF_NewOperation( + params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); + TF_AddInput(desc, params_->body_inputs[0]); + TF_FinishOperation(desc, s_); + ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + string msg(TF_Message(s_)); + EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " + "building NodeDef 'float_op'"), + msg.npos); + TF_AbortWhile(params_.get()); +} + // Create a tensor with values of type TF_INT8 provided by `values`. TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims, const char* values) { int64_t num_values = 1; |