aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c/c_api_experimental_test.cc
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-03-26 10:51:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 10:53:46 -0700
commitbe917027e37c5e8f21f6ba07f24bdbf072cf6dfd (patch)
tree3df042f7c47099bac18f446f5a5719ab2424e268 /tensorflow/c/c_api_experimental_test.cc
parentcc6b2ae837e9c0ce3678671ff5bd59f0f8e53e06 (diff)
Added experimental C APIs to build a stack of dataset + iterator nodes that
reads imagenet TFRecord files. PiperOrigin-RevId: 190488817
Diffstat (limited to 'tensorflow/c/c_api_experimental_test.cc')
-rw-r--r--tensorflow/c/c_api_experimental_test.cc84
1 files changed, 69 insertions, 15 deletions
diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc
index 9ddd65f0c5..49d64d18bf 100644
--- a/tensorflow/c/c_api_experimental_test.cc
+++ b/tensorflow/c/c_api_experimental_test.cc
@@ -15,38 +15,36 @@ limitations under the License.
#include "tensorflow/c/c_api_experimental.h"
#include "tensorflow/c/c_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
-void TestIteratorStack() {
+void TestFakeIteratorStack() {
TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph();
- TF_Function* dataset_func = nullptr;
-
- TF_Operation* get_next =
- TF_MakeIteratorGetNextWithDatasets(graph, "dummy_path", &dataset_func, s);
+ TF_Operation* get_next = TF_MakeFakeIteratorGetNextWithDatasets(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- ASSERT_NE(dataset_func, nullptr);
- TF_DeleteFunction(dataset_func);
-
CSession csession(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Run the graph.
- for (int i = 0; i < 1; ++i) {
+ const float base_value = 42.0;
+ for (int i = 0; i < 3; ++i) {
csession.SetOutputs({get_next});
csession.Run(s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* out = csession.output_tensor(0);
ASSERT_TRUE(out != nullptr);
- EXPECT_EQ(TF_INT32, TF_TensorType(out));
- EXPECT_EQ(0, TF_NumDims(out)); // scalar
- ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out));
- int32* output_contents = static_cast<int32*>(TF_TensorData(out));
- EXPECT_EQ(1, *output_contents);
+ ASSERT_EQ(TF_FLOAT, TF_TensorType(out));
+ ASSERT_EQ(0, TF_NumDims(out)); // scalar
+ ASSERT_EQ(sizeof(float), TF_TensorByteSize(out));
+ float* output_contents = static_cast<float*>(TF_TensorData(out));
+ ASSERT_EQ(base_value + i, *output_contents);
}
// This should error out since we've exhausted the iterator.
@@ -60,7 +58,63 @@ void TestIteratorStack() {
TF_DeleteStatus(s);
}
-TEST(CAPI_EXPERIMENTAL, IteratorGetNext) { TestIteratorStack(); }
+TEST(CAPI_EXPERIMENTAL, FakeIteratorGetNext) { TestFakeIteratorStack(); }
+
+TEST(CAPI_EXPERIMENTAL, ImagenetIteratorGetNext) {
+ TF_Status* s = TF_NewStatus();
+ TF_Graph* graph = TF_NewGraph();
+
+ const string file_path = tensorflow::io::JoinPath(
+ tensorflow::testing::TensorFlowSrcRoot(), "c/testdata/tf_record");
+ VLOG(1) << "data file path is " << file_path;
+ const int batch_size = 64;
+ TF_Operation* get_next = TF_MakeImagenetIteratorGetNextWithDatasets(
+ graph, file_path.c_str(), batch_size, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ CSession csession(graph, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ // Run the graph.
+ // The two output tensors should look like:
+ // Tensor("IteratorGetNext:0", shape=(batch_size, 224, 224, 3), dtype=float32)
+ // Tensor("IteratorGetNext:1", shape=(batch_size, ), dtype=int32)
+ for (int i = 0; i < 3; ++i) {
+ LOG(INFO) << "Running iter " << i;
+ csession.SetOutputs({{get_next, 0}, {get_next, 1}});
+ csession.Run(s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+
+ {
+ TF_Tensor* image = csession.output_tensor(0);
+ ASSERT_TRUE(image != nullptr);
+ ASSERT_EQ(TF_FLOAT, TF_TensorType(image));
+ // Confirm shape is 224 X 224 X 3
+ ASSERT_EQ(4, TF_NumDims(image));
+ ASSERT_EQ(batch_size, TF_Dim(image, 0));
+ ASSERT_EQ(224, TF_Dim(image, 1));
+ ASSERT_EQ(224, TF_Dim(image, 2));
+ ASSERT_EQ(3, TF_Dim(image, 3));
+ ASSERT_EQ(sizeof(float) * batch_size * 224 * 224 * 3,
+ TF_TensorByteSize(image));
+ }
+
+ {
+ TF_Tensor* label = csession.output_tensor(1);
+ ASSERT_TRUE(label != nullptr);
+ ASSERT_EQ(TF_INT32, TF_TensorType(label));
+ ASSERT_EQ(1, TF_NumDims(label));
+ ASSERT_EQ(batch_size, TF_Dim(label, 0));
+ ASSERT_EQ(sizeof(int32) * batch_size, TF_TensorByteSize(label));
+ }
+ }
+
+ // Clean up
+ csession.CloseAndDelete(s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ TF_DeleteGraph(graph);
+ TF_DeleteStatus(s);
+}
} // namespace
} // namespace tensorflow