aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_test.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2017-07-17 08:29:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 08:33:52 -0700
commitbed7cbe389420abe760d86cd345e88d45eb624ab (patch)
treec4d5c9d789b06e7596b6473902c140638fbb5de7 /tensorflow/c/c_api_test.cc
parenta925c14596135b19bedb73e0f6ae3cd170180106 (diff)
Split out new while_loop_test.cc from c_api_test.cc
This change also separates shared functionality into c_test_util.h/cc. This brings c_api_test.cc to a mere 1715 LOC (further splits can be more easily done now too). PiperOrigin-RevId: 162216399
Diffstat (limited to 'tensorflow/c/c_api_test.cc')
-rw-r--r--tensorflow/c/c_api_test.cc604
1 files changed, 1 insertions, 603 deletions
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 123ed7aeca..41f7622668 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <iterator>
#include <memory>
#include <vector>
+#include "tensorflow/c/c_test_util.h"
#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/example/example.pb.h"
@@ -56,9 +57,6 @@ TF_Tensor* TF_Tensor_EncodeStrings(const Tensor& src);
namespace {
-typedef std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)>
- unique_tensor_ptr;
-
TEST(CAPI, Version) { EXPECT_NE("", string(TF_Version())); }
TEST(CAPI, Status) {
@@ -276,194 +274,6 @@ TEST(CAPI, GetAllOpList) {
TF_DeleteBuffer(buf);
}
-static void Int32Deallocator(void* data, size_t, void* arg) {
- delete[] static_cast<int32*>(data);
-}
-
-// Create a tensor with values of type TF_INT8 provided by `values`.
-static TF_Tensor* Int8Tensor(const int64_t* dims, int num_dims,
- const char* values) {
- int64_t num_values = 1;
- for (int i = 0; i < num_dims; ++i) {
- num_values *= dims[i];
- }
- TF_Tensor* t =
- TF_AllocateTensor(TF_INT8, dims, num_dims, sizeof(char) * num_values);
- memcpy(TF_TensorData(t), values, sizeof(char) * num_values);
- return t;
-}
-
-static TF_Tensor* Int32Tensor(int32 v) {
- const int num_bytes = sizeof(int32);
- int32* values = new int32[1];
- values[0] = v;
- return TF_NewTensor(TF_INT32, nullptr, 0, values, num_bytes,
- &Int32Deallocator, nullptr);
-}
-
-TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s,
- const char* name = "feed") {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
- TF_SetAttrType(desc, "dtype", TF_INT32);
- return TF_FinishOperation(desc, s);
-}
-
-TF_Operation* Const(TF_Tensor* t, TF_Graph* graph, TF_Status* s,
- const char* name = "const") {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Const", name);
- TF_SetAttrTensor(desc, "value", t, s);
- if (TF_GetCode(s) != TF_OK) return nullptr;
- TF_SetAttrType(desc, "dtype", TF_TensorType(t));
- return TF_FinishOperation(desc, s);
-}
-
-TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s,
- const char* name = "scalar") {
- unique_tensor_ptr tensor(Int32Tensor(v), TF_DeleteTensor);
- return Const(tensor.get(), graph, s, name);
-}
-
-TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name = "add") {
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
- TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
- TF_AddInputList(desc, add_inputs, 2);
- return TF_FinishOperation(desc, s);
-}
-
-TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s,
- const char* name = "add") {
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
- TF_Output inputs[2] = {l, r};
- TF_AddInputList(desc, inputs, 2);
- return TF_FinishOperation(desc, s);
-}
-
-TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
- TF_Output neg_input = {n, 0};
- TF_AddInput(desc, neg_input);
- return TF_FinishOperation(desc, s);
-}
-
-TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph,
- TF_Status* s) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Less", "less_than");
- TF_AddInput(desc, l);
- TF_AddInput(desc, r);
- return TF_FinishOperation(desc, s);
-}
-
-bool IsPlaceholder(const NodeDef& node_def) {
- if (node_def.op() != "Placeholder" || node_def.name() != "feed") {
- return false;
- }
- bool found_dtype = false;
- bool found_shape = false;
- for (const auto& attr : node_def.attr()) {
- if (attr.first == "dtype") {
- if (attr.second.type() == tensorflow::DT_INT32) {
- found_dtype = true;
- } else {
- return false;
- }
- } else if (attr.first == "shape") {
- found_shape = true;
- }
- }
- return found_dtype && found_shape;
-}
-
-bool IsScalarConst(const NodeDef& node_def, int v) {
- if (node_def.op() != "Const" || node_def.name() != "scalar") {
- return false;
- }
- bool found_dtype = false;
- bool found_value = false;
- for (const auto& attr : node_def.attr()) {
- if (attr.first == "dtype") {
- if (attr.second.type() == tensorflow::DT_INT32) {
- found_dtype = true;
- } else {
- return false;
- }
- } else if (attr.first == "value") {
- if (attr.second.has_tensor() &&
- attr.second.tensor().int_val_size() == 1 &&
- attr.second.tensor().int_val(0) == v) {
- found_value = true;
- } else {
- return false;
- }
- }
- }
- return found_dtype && found_value;
-}
-
-bool IsAddN(const NodeDef& node_def, int n) {
- if (node_def.op() != "AddN" || node_def.name() != "add" ||
- node_def.input_size() != n) {
- return false;
- }
- bool found_t = false;
- bool found_n = false;
- for (const auto& attr : node_def.attr()) {
- if (attr.first == "T") {
- if (attr.second.type() == tensorflow::DT_INT32) {
- found_t = true;
- } else {
- return false;
- }
- } else if (attr.first == "N") {
- if (attr.second.i() == n) {
- found_n = true;
- } else {
- return false;
- }
- }
- }
- return found_t && found_n;
-}
-
-bool IsNeg(const NodeDef& node_def, const string& input) {
- return node_def.op() == "Neg" && node_def.name() == "neg" &&
- node_def.input_size() == 1 && node_def.input(0) == input;
-}
-
-bool GetGraphDef(TF_Graph* graph, GraphDef* graph_def) {
- TF_Status* s = TF_NewStatus();
- TF_Buffer* buffer = TF_NewBuffer();
- TF_GraphToGraphDef(graph, buffer, s);
- bool ret = TF_GetCode(s) == TF_OK;
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
- TF_DeleteBuffer(buffer);
- TF_DeleteStatus(s);
- return ret;
-}
-
-bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) {
- TF_Status* s = TF_NewStatus();
- TF_Buffer* buffer = TF_NewBuffer();
- TF_OperationToNodeDef(oper, buffer, s);
- bool ret = TF_GetCode(s) == TF_OK;
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
- TF_DeleteBuffer(buffer);
- TF_DeleteStatus(s);
- return ret;
-}
-
-bool GetAttrValue(TF_Operation* oper, const char* attr_name,
- tensorflow::AttrValue* attr_value, TF_Status* s) {
- TF_Buffer* buffer = TF_NewBuffer();
- TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
- bool ret = TF_GetCode(s) == TF_OK;
- if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
- TF_DeleteBuffer(buffer);
- return ret;
-}
-
TEST(CAPI, SetShape) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@@ -880,116 +690,6 @@ TEST(CAPI, ImportGraphDef) {
TF_DeleteStatus(s);
}
-class CSession {
- public:
- CSession(TF_Graph* graph, TF_Status* s) {
- TF_SessionOptions* opts = TF_NewSessionOptions();
- session_ = TF_NewSession(graph, opts, s);
- TF_DeleteSessionOptions(opts);
- }
-
- explicit CSession(TF_Session* session) : session_(session) {}
-
- ~CSession() {
- TF_Status* s = TF_NewStatus();
- CloseAndDelete(s);
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_DeleteStatus(s);
- }
-
- void SetInputs(std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
- DeleteInputValues();
- inputs_.clear();
- for (const auto& p : inputs) {
- inputs_.emplace_back(TF_Output{p.first, 0});
- input_values_.emplace_back(p.second);
- }
- }
-
- void SetOutputs(std::initializer_list<TF_Operation*> outputs) {
- ResetOutputValues();
- outputs_.clear();
- for (TF_Operation* o : outputs) {
- outputs_.emplace_back(TF_Output{o, 0});
- }
- output_values_.resize(outputs_.size());
- }
-
- void SetOutputs(const std::vector<TF_Output>& outputs) {
- ResetOutputValues();
- outputs_ = outputs;
- output_values_.resize(outputs_.size());
- }
-
- void SetTargets(std::initializer_list<TF_Operation*> targets) {
- targets_.clear();
- for (TF_Operation* t : targets) {
- targets_.emplace_back(t);
- }
- }
-
- void Run(TF_Status* s) {
- if (inputs_.size() != input_values_.size()) {
- ADD_FAILURE() << "Call SetInputs() before Run()";
- return;
- }
- ResetOutputValues();
- output_values_.resize(outputs_.size(), nullptr);
-
- const TF_Output* inputs_ptr = inputs_.empty() ? nullptr : &inputs_[0];
- TF_Tensor* const* input_values_ptr =
- input_values_.empty() ? nullptr : &input_values_[0];
-
- const TF_Output* outputs_ptr = outputs_.empty() ? nullptr : &outputs_[0];
- TF_Tensor** output_values_ptr =
- output_values_.empty() ? nullptr : &output_values_[0];
-
- TF_Operation* const* targets_ptr =
- targets_.empty() ? nullptr : &targets_[0];
-
- TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr,
- inputs_.size(), outputs_ptr, output_values_ptr,
- outputs_.size(), targets_ptr, targets_.size(), nullptr, s);
-
- DeleteInputValues();
- }
-
- void CloseAndDelete(TF_Status* s) {
- DeleteInputValues();
- ResetOutputValues();
- if (session_ != nullptr) {
- TF_CloseSession(session_, s);
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_DeleteSession(session_, s);
- session_ = nullptr;
- }
- }
-
- TF_Tensor* output_tensor(int i) { return output_values_[i]; }
-
- private:
- void DeleteInputValues() {
- for (int i = 0; i < input_values_.size(); ++i) {
- TF_DeleteTensor(input_values_[i]);
- }
- input_values_.clear();
- }
-
- void ResetOutputValues() {
- for (int i = 0; i < output_values_.size(); ++i) {
- if (output_values_[i] != nullptr) TF_DeleteTensor(output_values_[i]);
- }
- output_values_.clear();
- }
-
- TF_Session* session_;
- std::vector<TF_Output> inputs_;
- std::vector<TF_Tensor*> input_values_;
- std::vector<TF_Output> outputs_;
- std::vector<TF_Tensor*> output_values_;
- std::vector<TF_Operation*> targets_;
-};
-
TEST(CAPI, Session) {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
@@ -1275,308 +975,6 @@ TEST(CAPI, SavedModelNullArgsAreValid) {
TF_DeleteStatus(s);
}
-class CApiWhileLoopTest : public ::testing::Test {
- protected:
- CApiWhileLoopTest() : s_(TF_NewStatus()), graph_(TF_NewGraph()) {}
-
- ~CApiWhileLoopTest() override {
- TF_DeleteGraph(graph_);
- TF_DeleteStatus(s_);
- }
-
- void Init(int ninputs) {
- DCHECK(inputs_.empty());
- DCHECK_GT(ninputs, 0);
-
- for (int i = 0; i < ninputs; ++i) {
- TF_Operation* placeholder = Placeholder(
- graph_, s_, ::tensorflow::strings::StrCat("p", i).c_str());
- DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- inputs_.push_back({placeholder, 0});
- }
-
- original_graph_description_ = GraphDebugString();
-
- params_.reset(new TF_WhileParams(
- TF_NewWhile(graph_, &inputs_[0], inputs_.size(), s_)));
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- ASSERT_EQ(original_graph_description_, GraphDebugString())
- << "TF_NewWhile() altered graph";
-
- params_->name = "test_loop";
-
- // Initialize outputs_ so we can easily detect errors/bugs
- outputs_.resize(ninputs, {nullptr, -1});
- }
-
- void ExpectOK() {
- TF_FinishWhile(params_.get(), s_, &outputs_[0]);
- EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- }
-
- void ExpectError(TF_Code expected_code, const string& expected_msg) {
- TF_FinishWhile(params_.get(), s_, &outputs_[0]);
- EXPECT_EQ(expected_code, TF_GetCode(s_));
- EXPECT_EQ(expected_msg, TF_Message(s_));
- // TODO(skyewm): this assert is currently broken. Fix or remove guarantee.
- // ASSERT_EQ(original_graph_description_, GraphDebugString()) <<
- // "TF_FinishWhile() altered graph on error";
- }
-
- void Run(std::initializer_list<int> input_values) {
- DCHECK_EQ(inputs_.size(), input_values.size());
- std::vector<std::pair<TF_Operation*, TF_Tensor*>> inputs(inputs_.size());
- int i = 0;
- for (int v : input_values) {
- inputs[i] = {inputs_[i].oper, Int32Tensor(v)};
- ++i;
- }
- csession_.reset(new CSession(graph_, s_));
- csession_->SetInputs(inputs);
- csession_->SetOutputs(outputs_);
- csession_->Run(s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- }
-
- void ExpectOutputValue(int idx, int expected_value) {
- TF_Tensor* out = csession_->output_tensor(idx);
- ASSERT_TRUE(out != nullptr);
- EXPECT_EQ(TF_INT32, TF_TensorType(out));
- EXPECT_EQ(0, TF_NumDims(out));
- ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
- int32* data = static_cast<int32*>(TF_TensorData(out));
- EXPECT_EQ(expected_value, *data);
- }
-
- // Create a valid conditional graph. Useful for testing unrelated errors.
- void CreateCondGraph() {
- TF_Operation* one = ScalarConst(1, params_->cond_graph, s_);
- TF_Operation* less_than =
- LessThan(params_->cond_inputs[0], {one, 0}, params_->cond_graph, s_);
- DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- params_->cond_output = {less_than, 0};
- }
-
- string GraphDebugString() const {
- TF_Buffer* buf = TF_NewBuffer();
- TF_GraphToGraphDef(graph_, buf, s_);
- DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- GraphDef def;
- bool success = def.ParseFromArray(buf->data, buf->length);
- DCHECK(success);
- TF_DeleteBuffer(buf);
- return def.DebugString();
- }
-
- TF_Status* s_;
- TF_Graph* graph_;
- std::vector<TF_Output> inputs_; // The inputs to the while loop
- std::vector<TF_Output> outputs_; // The final outputs of the while loop
- std::unique_ptr<TF_WhileParams> params_;
- std::unique_ptr<CSession> csession_;
-
- private:
- // Used to verify that errors don't change graph_
- string original_graph_description_;
-};
-
-TEST_F(CApiWhileLoopTest, BasicLoop) {
- Init(2);
-
- // Validate TF_WhileParams returned by TF_NewWhile()
- EXPECT_TRUE(params_->body_graph != nullptr);
- EXPECT_TRUE(params_->cond_graph != nullptr);
-
- EXPECT_EQ(params_->ninputs, 2);
-
- ASSERT_TRUE(params_->cond_inputs != nullptr);
- ASSERT_TRUE(params_->cond_inputs[0].oper != nullptr);
- EXPECT_TRUE(params_->cond_inputs[1].oper != nullptr);
-
- ASSERT_TRUE(params_->body_inputs != nullptr);
- EXPECT_TRUE(params_->body_inputs[0].oper != nullptr);
- EXPECT_TRUE(params_->body_inputs[1].oper != nullptr);
-
- ASSERT_TRUE(params_->body_outputs != nullptr);
-
- // Create loop: while (input1 < input2) input1 += input2 + 1
- TF_Operation* less_than =
- LessThan(params_->cond_inputs[0], params_->cond_inputs[1],
- params_->cond_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- params_->cond_output = {less_than, 0};
-
- TF_Operation* add1 = Add(params_->body_inputs[0], params_->body_inputs[1],
- params_->body_graph, s_, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_Operation* one = ScalarConst(1, params_->body_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_Operation* add2 = Add(add1, one, params_->body_graph, s_, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- params_->body_outputs[0] = {add2, 0};
- params_->body_outputs[1] = params_->body_inputs[1];
-
- // Finalize while loop
- ExpectOK();
-
- // Validate while loop outputs returned by TF_FinishWhile()
- EXPECT_TRUE(outputs_[0].oper != nullptr);
- EXPECT_GE(outputs_[0].index, 0);
- EXPECT_TRUE(outputs_[1].oper != nullptr);
- EXPECT_GE(outputs_[1].index, 0);
-
- // Run the graph
- Run({-9, 2});
- ExpectOutputValue(0, 3);
- ExpectOutputValue(1, 2);
-}
-
-TEST_F(CApiWhileLoopTest, NestedLoop) {
- Init(2);
- // Create nested loop:
- // while (input1 < 6) {
- // inner_input1 = input1
- // while (inner_input1 < 3) {
- // input2 += 1
- // inner_input1 += 2
- // }
- // input1 += input2
- // }
- //
- // Expected execution with initial values input1 = input2 = 0:
- //
- // outer inner inner_
- // step# step# input1 input2 input1
- // ------------------------------------
- // 0 0 0 0 0
- // 0 1 0 1 2
- // 0 2 0 2 4
- // 0 - 2 2 -
- // 1 0 2 2 2
- // 1 1 2 3 4
- // 1 - 5 3 -
- // 2 0 5 3 5
- // 2 - 8 3 -
-
- // Create outer cond graph
- TF_Operation* six = ScalarConst(6, params_->cond_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_Operation* less_than =
- LessThan(params_->cond_inputs[0], {six, 0}, params_->cond_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- params_->cond_output = {less_than, 0};
-
- // Create outer body graph
- // Init inner graph
- TF_Output inner_inputs[] = {params_->body_inputs[0], params_->body_inputs[1]};
- TF_WhileParams inner_params =
- TF_NewWhile(params_->body_graph, inner_inputs, 2, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- inner_params.name = "inner_loop";
-
- // Create inner cond graph
- TF_Operation* three = ScalarConst(3, inner_params.cond_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_Operation* inner_less_than = LessThan(
- inner_params.cond_inputs[0], {three, 0}, inner_params.cond_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- inner_params.cond_output = {inner_less_than, 0};
-
- // Create inner body graph
- TF_Operation* one = ScalarConst(1, inner_params.body_graph, s_, "one");
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- TF_Operation* two = ScalarConst(2, inner_params.body_graph, s_, "two");
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
- TF_Operation* input2_add =
- Add(inner_params.body_inputs[1].oper, one, inner_params.body_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- inner_params.body_outputs[1] = {input2_add, 0};
-
- TF_Operation* inner_input1_add = Add(inner_params.body_inputs[0].oper, two,
- inner_params.body_graph, s_, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- inner_params.body_outputs[0] = {inner_input1_add, 0};
-
- // Finalize inner graph
- TF_Output inner_outputs[2] = {{nullptr, -1}};
- TF_FinishWhile(&inner_params, s_, inner_outputs);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
-
- TF_Operation* input1_add =
- Add(params_->body_inputs[0], inner_outputs[1], params_->body_graph, s_);
- ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
- params_->body_outputs[0] = {input1_add, 0};
-
- params_->body_outputs[1] = inner_outputs[1];
-
- // Finalize outer graph
- ExpectOK();
-
- // Check for a few expected nodes
- const char* node_name = "test_loop/cond/scalar";
- EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
- node_name = "test_loop/body/add";
- EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
- node_name = "test_loop/body/inner_loop/body/one";
- EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
- node_name = "test_loop/body/inner_loop/cond/less_than";
- EXPECT_TRUE(TF_GraphOperationByName(graph_, node_name) != nullptr);
-
- // Run the graph
- Run({0, 0});
- ExpectOutputValue(0, 8);
- ExpectOutputValue(1, 3);
-}
-
-TEST_F(CApiWhileLoopTest, BadCondOutput) {
- Init(1);
- params_->body_outputs[0] = params_->body_inputs[0];
- ExpectError(TF_INVALID_ARGUMENT,
- "TF_WhileParams `cond_output` field isn't set");
-}
-
-TEST_F(CApiWhileLoopTest, BadBodyOutput) {
- Init(1);
- CreateCondGraph();
- ExpectError(TF_INVALID_ARGUMENT,
- "TF_WhileParams `body_outputs[0]` field isn't set");
-}
-
-TEST_F(CApiWhileLoopTest, NullName) {
- Init(1);
- CreateCondGraph();
- params_->body_outputs[0] = params_->body_inputs[0];
- params_->name = nullptr;
- ExpectError(TF_INVALID_ARGUMENT, "TF_WhileParams `name` field is null");
-}
-
-TEST_F(CApiWhileLoopTest, WrongGraph) {
- Init(1);
- CreateCondGraph();
- // Set body output to output from outer graph
- params_->body_outputs[0] = inputs_[0];
- // TODO(skyewm): improve error message
- ExpectError(TF_INVALID_ARGUMENT,
- "Requested return node 'p0' not found in graph def");
-}
-
-TEST_F(CApiWhileLoopTest, BadTypes) {
- Init(1);
- CreateCondGraph();
- // Op that has a float input + output
- TF_OperationDescription* desc = TF_NewOperation(
- params_->body_graph, "FakeQuantWithMinMaxArgs", "float_op");
- TF_AddInput(desc, params_->body_inputs[0]);
- TF_FinishOperation(desc, s_);
- ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
- string msg(TF_Message(s_));
- EXPECT_NE(msg.find("Input 'inputs' passed int32 expected float while "
- "building NodeDef 'float_op'"),
- msg.npos);
- TF_AbortWhile(params_.get());
-}
-
REGISTER_OP("TestOpWithNoGradient")
.Input("x: T")
.Output("y: T")