/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" using tensorflow::GraphDef; namespace { class CApiWhileLoopTest : public ::testing::Test { protected: CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {} ~CApiWhileLoopTest() override { TF_DeleteGraph(graph_); TF_DeleteStatus(s_); } void Init(int ninputs) { DCHECK(inputs_.empty()); DCHECK_GT(ninputs, 0); for (int i = 0; i < ninputs; ++i) { TF_Operation* placeholder = Placeholder( graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str()); DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inputs_.push_back({placeholder, 0}); } original_graph_description_ = GraphDebugString(); params_.reset(new TF_WhileParams( TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_))); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); ASSERT_EQ(original_graph_description_, GraphDebugString()) << "TF_NewWhile() altered graph"; params_->name = "test_loop"; // Initialize outputs_ so we can easily detect errors/bugs outputs_.resize(ninputs, {nullptr, -1}); } void ExpectOK() { TF_FinishWhile(params_.get(), s_, &outputs_[0]); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } void ExpectError(TF_Code expected_code, const string& expected_msg) { TF_FinishWhile(params_.get(), s_, &outputs_[0]); EXPECT_EQ(expected_code, TF_GetCode(s_)); EXPECT_EQ(expected_msg, TF_Message(s_)); // TODO(skyewm): this assert is currently broken. Fix or remove guarantee. // ASSERT_EQ(original_graph_description_, GraphDebugString()) << // "TF_FinishWhile() altered graph on error"; } void Run(std::initializer_list input_values) { Run(outputs_, input_values); } void Run(const std::vector& run_outputs, std::initializer_list input_values) { DCHECK_EQ(inputs_.size(), input_values.size()); std::vector> inputs(inputs_.size()); int i = 0; for (int v : input_values) { inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; ++i; } // TODO(skyewm): use std::make_unique or absl::make_unique when possible. csession_.reset(new CSession(graph_, s_)); csession_->SetInputs(inputs); csession_->SetOutputs(run_outputs); csession_->Run(s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } void ExpectOutputValue(int idx, int expected_value) { TF_Tensor* out = csession_->output_tensor(idx); ASSERT_TRUE(out != nullptr); EXPECT_EQ(TF_INT32, TF_TensorType(out)); EXPECT_EQ(0, TF_NumDims(out)); ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out)); int32_t* data = static_cast(TF_TensorData(out)); EXPECT_EQ(expected_value, *data); } // Create a valid conditional graph. Useful for testing unrelated errors. void CreateCondGraph() { TF_Operation* one = ScalarConst(1, params_->cond_graph, s_); TF_Operation* less_than = LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_); DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->cond_output = {less_than, 0}; } string GraphDebugString() const { TF_Buffer* buf = TF_NewBuffer(); TF_GraphToGraphDef(graph_, buf, s_); DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); GraphDef def; bool success = def.ParseFromArray(buf->data, buf->length); DCHECK(success); TF_DeleteBuffer(buf); return def.DebugString(); } TF_Status* s_; TF_Graph* graph_; std::vector inputs_; // The inputs to the while loop std::vector outputs_; // The final outputs of the while loop std::unique_ptr params_; std::unique_ptr csession_; private: // Used to verify that errors don't change graph_ string original_graph_description_; }; TEST_F(CApiWhileLoopTest, BasicLoop) { Init(2); // Validate TF_WhileParams returned by TF_NewWhile() EXPECT_TRUE(params_->body_graph != nullptr); EXPECT_TRUE(params_->cond_graph != nullptr); EXPECT_EQ(params_->ninputs, 2); ASSERT_TRUE(params_->cond_inputs != nullptr); ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr); EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr); ASSERT_TRUE(params_->body_inputs != nullptr); EXPECT_TRUE(params_->body_inputs[0].oper != nullptr); EXPECT_TRUE(params_->body_inputs[1].oper != nullptr); ASSERT_TRUE(params_->body_outputs != nullptr); // Create loop: while (input1 < input2) input1 += input2 + 1 TF_Operation* less_than = LessThan(params_->cond_inputs[0], params_->cond_inputs[1], params_->cond_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->cond_output = {less_than, 0}; TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1], params_->body_graph, s_, "add1"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* one = ScalarConst(1, params_->body_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->body_outputs[0] = {add2, 0}; params_->body_outputs[1] = params_->body_inputs[1]; // Finalize while loop ExpectOK(); // Validate while loop outputs returned by TF_FinishWhile() EXPECT_TRUE(outputs_[0].oper != nullptr); EXPECT_GE(outputs_[0].index, 0); EXPECT_TRUE(outputs_[1].oper != nullptr); EXPECT_GE(outputs_[1].index, 0); // Check that cond and body inputs are not present for (int i = 0; i < params_->ninputs; ++i) { string cond_name = ::tensorflow::strings::StrCat(params_->name, "/cond/cond_input", i); string body_name = ::tensorflow::strings::StrCat(params_->name, "/body/body_input", i); EXPECT_TRUE(TF_GraphOperationByName(graph_, cond_name.c_str()) == nullptr); EXPECT_TRUE(TF_GraphOperationByName(graph_, body_name.c_str()) == nullptr); } // Run the graph Run({-9, 2}); ExpectOutputValue(0, 3); ExpectOutputValue(1, 2); } TEST_F(CApiWhileLoopTest, NestedLoop) { Init(2); // Create nested loop: // while (input1 < 6) { // inner_input1 = input1 // while (inner_input1 < 3) { // input2 += 1 // inner_input1 += 2 // } // input1 += input2 // } // // Expected execution with initial values input1 = input2 = 0: // // outer inner inner_ // step# step# input1 input2 input1 // ------------------------------------ // 0 0 0 0 0 // 0 1 0 1 2 // 0 2 0 2 4 // 0 - 2 2 - // 1 0 2 2 2 // 1 1 2 3 4 // 1 - 5 3 - // 2 0 5 3 5 // 2 - 8 3 - // Create outer cond graph TF_Operation* six = ScalarConst(6, params_->cond_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* less_than = LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->cond_output = {less_than, 0}; // Create outer body graph // Init inner graph TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]}; TF_WhileParams inner_params = TF_NewWhile(params_->body_graph, inner_inputs, 2, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inner_params.name = "inner_loop"; // Create inner cond graph TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* inner_less_than = LessThan( inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inner_params.cond_output = {inner_less_than, 0}; // Create inner body graph TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* input2_add = Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inner_params.body_outputs[1] = {input2_add, 0}; TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two, inner_params.body_graph, s_, "add2"); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); inner_params.body_outputs[0] = {inner_input1_add, 0}; // Finalize inner graph TF_Output inner_outputs[2] = {{nullptr, -1}}; TF_FinishWhile(&inner_params, s_, inner_outputs); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); TF_Operation* input1_add = Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->body_outputs[0] = {input1_add, 0}; params_->body_outputs[1] = inner_outputs[1]; // Finalize outer graph ExpectOK(); // Check for a few expected nodes const char* node_name = "test_loop/cond/scalar"; EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); node_name = "test_loop/body/add"; EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); node_name = "test_loop/body/inner_loop/body/one"; EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); node_name = "test_loop/body/inner_loop/cond/less_than"; EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr); // Run the graph Run({0, 0}); ExpectOutputValue(0, 8); ExpectOutputValue(1, 3); } TEST_F(CApiWhileLoopTest, UnsetCondOutput) { Init(1); params_->body_outputs[0] = params_->body_inputs[0]; ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `cond_output` field isn't set"); } TEST_F(CApiWhileLoopTest, WrongCondOutputType) { Init(1); params_->cond_output = params_->cond_inputs[0]; params_->body_outputs[0] = params_->body_inputs[0]; ExpectError(TF_INVALID_ARGUMENT, "BuildWhileLoop: 'cond' argument must return a boolean output, " "got int32"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { Init(1); // Try to reuse node from parent graph params_->cond_output = inputs_[0]; params_->body_outputs[0] = params_->body_inputs[0]; // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { Init(1); CreateCondGraph(); params_->cond_output.index = 100; params_->body_outputs[0] = params_->body_inputs[0]; ExpectError(TF_INVALID_ARGUMENT, "Invalid return output 100 of node 'less_than', which has 1 " "output(s)"); } // TODO(skyewm): test bad cond output shape TEST_F(CApiWhileLoopTest, UnsetBodyOutput) { Init(1); CreateCondGraph(); ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `body_outputs[0]` field isn't set"); } // TODO(skyewm): enable this when it works (currently doesn't error) // TEST_F(CApiWhileLoopTest, WrongBodyOutputType) { // Init(1); // CreateCondGraph(); // TF_Operation* double_scalar = // ScalarConst(1.0, params_->body_graph, s_, "double_scalar"); // params_->body_outputs[0] = {double_scalar, 0}; // ExpectError(TF_INVALID_ARGUMENT, "bad body output type"); // } TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { Init(1); CreateCondGraph(); // Try to reuse node from parent graph params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, "Requested return tensor 'p0:0' not found in graph def"); } // TODO(skyewm): enable this when it works (currently segfaults!) // TEST_F(CApiWhileLoopTest, InvalidBodyOutputIndex) { // Init(1); // CreateCondGraph(); // params_->body_outputs[0] = params_->body_inputs[0]; // params_->body_outputs[0].index = 100; // ExpectError(TF_INVALID_ARGUMENT, // "Invalid return output 100 of node 'less_than', which has 1 " // "output(s)"); // } // TODO(skyewm): test bad body output shape TEST_F(CApiWhileLoopTest, NullName) { Init(1); CreateCondGraph(); params_->body_outputs[0] = params_->body_inputs[0]; params_->name = nullptr; ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null"); } TEST_F(CApiWhileLoopTest, WrongGraph) { Init(1); CreateCondGraph(); // Set body output to output from outer graph params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): improve error message ExpectError(TF_INVALID_ARGUMENT, "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, BadTypes) { Init(1); CreateCondGraph(); // Op that has a float input + output TF_OperationDescription* desc = TF_NewOperation( params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op"); TF_AddInput(desc, params_->body_inputs[0]); TF_FinishOperation(desc, s_); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); string msg(TF_Message(s_)); EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while " "building NodeDef 'float_op'"), msg.npos); TF_AbortWhile(params_.get()); } // This is a basic test to make sure the C++ gradient code can handle while // loops created by the C API (which calls the C++ API under the hood). There // are more while loop gradient tests in cc/framework/while_gradients_test.cc. TEST_F(CApiWhileLoopTest, Gradients) { Init(1); // Create loop: while (i < 10) i += 1 TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_); TF_Operation* less_than = LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_); DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->cond_output = {less_than, 0}; TF_Operation* one = ScalarConst(1, params_->body_graph, s_); TF_Operation* add = Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); params_->body_outputs[0] = {add, 0}; ExpectOK(); // Create backprop graph TF_Output grad_output; TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1, nullptr, s_, &grad_output); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); // Run gradient Run({grad_output}, {0}); ExpectOutputValue(0, 1); } } // namespace