aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/example/feature_util_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-06 01:58:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-06 03:01:31 -0700
commit357b19ee7f39d46b32639847e260dac2312b3da5 (patch)
tree48b52ef4992183a39914ed30be16e69efddd3b3c /tensorflow/core/example/feature_util_test.cc
parentceb04a1b806563615d68d4617822f057405b3bfb (diff)
A set of lightweight wrappers to simplify access to Example proto fields in C++.
Change: 121659770
Diffstat (limited to 'tensorflow/core/example/feature_util_test.cc')
-rw-r--r--tensorflow/core/example/feature_util_test.cc214
1 files changed, 214 insertions, 0 deletions
diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc
new file mode 100644
index 0000000000..bcee0fb587
--- /dev/null
+++ b/tensorflow/core/example/feature_util_test.cc
@@ -0,0 +1,214 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/example/feature_util.h"
+
+#include <vector>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+const float kTolerance = 1e-5;
+
+TEST(GetFeatureValuesInt64Test, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_int64_list()
+ ->add_value(42);
+
+ auto tag = GetFeatureValues<int64>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_EQ(42, tag.Get(0));
+}
+
+TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").int64_list().value_size());
+ EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0));
+}
+
+TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) {
+ Example example;
+
+ EXPECT_FALSE(ExampleHasFeature("tag", example));
+
+ GetFeatureValues<int64>("tag", &example)->Add(0);
+
+ EXPECT_TRUE(ExampleHasFeature("tag", example));
+}
+
+TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+ ASSERT_FALSE(ExampleHasFeature<int64>("tag", example));
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+
+ EXPECT_TRUE(ExampleHasFeature<int64>("tag", example));
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_EQ(42, tag_ro.Get(0));
+}
+
+TEST(GetFeatureValuesInt64Test, CopyIterableToAField) {
+ Example example;
+ std::vector<int> values{1, 2, 3};
+
+ std::copy(values.begin(), values.end(),
+ protobuf::RepeatedFieldBackInserter(
+ GetFeatureValues<int64>("tag", &example)));
+
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ(1, tag_ro.Get(0));
+ EXPECT_EQ(2, tag_ro.Get(1));
+ EXPECT_EQ(3, tag_ro.Get(2));
+}
+
+TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_float_list()
+ ->add_value(3.14);
+
+ auto tag = GetFeatureValues<float>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_NEAR(3.14, tag.Get(0), kTolerance);
+}
+
+TEST(GetFeatureValuesFloatTest, WritesASingleValue) {
+ Example example;
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").float_list().value_size());
+ EXPECT_NEAR(3.14,
+ example.features().feature().at("tag").float_list().value(0),
+ kTolerance);
+}
+
+TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+ ASSERT_FALSE(ExampleHasFeature<float>("tag", example));
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+
+ EXPECT_TRUE(ExampleHasFeature<float>("tag", example));
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
+}
+
+TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_bytes_list()
+ ->add_value("FOO");
+
+ auto tag = GetFeatureValues<string>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_EQ("FOO", tag.Get(0));
+}
+
+TEST(GetFeatureValuesStringTest, WritesASingleValue) {
+ Example example;
+
+ *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").bytes_list().value_size());
+ EXPECT_EQ("FOO",
+ example.features().feature().at("tag").bytes_list().value(0));
+}
+
+TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+ ASSERT_FALSE(ExampleHasFeature<string>("tag", example));
+
+ *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
+
+ EXPECT_TRUE(ExampleHasFeature<string>("tag", example));
+ auto tag_ro = GetFeatureValues<string>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_EQ("FOO", tag_ro.Get(0));
+}
+
+TEST(AppendFeatureValuesTest, FloatValuesFromContainer) {
+ Example example;
+
+ std::vector<double> values{1.1, 2.2, 3.3};
+ AppendFeatureValues(values, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
+ EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
+ EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
+}
+
+TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
+ EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
+ EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
+}
+
+TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({1, 2, 3}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ(1, tag_ro.Get(0));
+ EXPECT_EQ(2, tag_ro.Get(1));
+ EXPECT_EQ(3, tag_ro.Get(2));
+}
+
+TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({"FOO", "BAR", "BAZ"}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<string>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ("FOO", tag_ro.Get(0));
+ EXPECT_EQ("BAR", tag_ro.Get(1));
+ EXPECT_EQ("BAZ", tag_ro.Get(2));
+}
+
+} // namespace
+} // namespace tensorflow