diff options
author | Mingsheng Hong <hongm@google.com> | 2018-03-26 10:51:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-26 10:53:46 -0700 |
commit | be917027e37c5e8f21f6ba07f24bdbf072cf6dfd (patch) | |
tree | 3df042f7c47099bac18f446f5a5719ab2424e268 /tensorflow/c/c_api_experimental_test.cc | |
parent | cc6b2ae837e9c0ce3678671ff5bd59f0f8e53e06 (diff) |
Added experimental C APIs to build a stack of dataset + iterator nodes that
reads imagenet TFRecord files.
PiperOrigin-RevId: 190488817
Diffstat (limited to 'tensorflow/c/c_api_experimental_test.cc')
-rw-r--r-- | tensorflow/c/c_api_experimental_test.cc | 84 |
1 files changed, 69 insertions, 15 deletions
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 9ddd65f0c5..49d64d18bf 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -15,38 +15,36 @@ limitations under the License. #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" namespace tensorflow { namespace { -void TestIteratorStack() { +void TestFakeIteratorStack() { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); - TF_Function* dataset_func = nullptr; - - TF_Operation* get_next = - TF_MakeIteratorGetNextWithDatasets(graph, "dummy_path", &dataset_func, s); + TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - ASSERT_NE(dataset_func, nullptr); - TF_DeleteFunction(dataset_func); - CSession csession(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Run the graph. - for (int i = 0; i < 1; ++i) { + const float base_value = 42.0; + for (int i = 0; i < 3; ++i) { csession.SetOutputs({get_next}); csession.Run(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_Tensor* out = csession.output_tensor(0); ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* output_contents = static_cast<int32*>(TF_TensorData(out)); - EXPECT_EQ(1, *output_contents); + ASSERT_EQ(TF_FLOAT, TF_TensorType(out)); + ASSERT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(float), TF_TensorByteSize(out)); + float* output_contents = static_cast<float*>(TF_TensorData(out)); + ASSERT_EQ(base_value + i, *output_contents); } // This should error out since we've exhausted the iterator. @@ -60,7 +58,63 @@ void TestIteratorStack() { TF_DeleteStatus(s); } -TEST(CAPI_EXPERIMENTAL, IteratorGetNext) { TestIteratorStack(); } +TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); } + +TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + const string file_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record"); + VLOG(1) << "data file path is " << file_path; + const int batch_size = 64; + TF_Operation* get_next = TF_MakeImagenetIteratorGetNextWithDatasets( + graph, file_path.c_str(), batch_size, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + CSession csession(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + // The two output tensors should look like: + // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32) + // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32) + for (int i = 0; i < 3; ++i) { + LOG(INFO) << "Running iter " << i; + csession.SetOutputs({{get_next, 0}, {get_next, 1}}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + { + TF_Tensor* image = csession.output_tensor(0); + ASSERT_TRUE(image != nullptr); + ASSERT_EQ(TF_FLOAT, TF_TensorType(image)); + // Confirm shape is 224 X 224 X 3 + ASSERT_EQ(4, TF_NumDims(image)); + ASSERT_EQ(batch_size, TF_Dim(image, 0)); + ASSERT_EQ(224, TF_Dim(image, 1)); + ASSERT_EQ(224, TF_Dim(image, 2)); + ASSERT_EQ(3, TF_Dim(image, 3)); + ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3, + TF_TensorByteSize(image)); + } + + { + TF_Tensor* label = csession.output_tensor(1); + ASSERT_TRUE(label != nullptr); + ASSERT_EQ(TF_INT32, TF_TensorType(label)); + ASSERT_EQ(1, TF_NumDims(label)); + ASSERT_EQ(batch_size, TF_Dim(label, 0)); + ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label)); + } + } + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} } // namespace } // namespace tensorflow |