aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/example
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-01 22:16:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-01 22:19:58 -0700
commit7d5cbd78a54319eeb45bca2e239ec037997dad20 (patch)
treeb71b84aa688a9cd1a9360552037a7b1c4bbe9520 /tensorflow/core/example
parent2935132bce62cbc785da68997248e449d63f5ff2 (diff)
Update feature_util to support SequenceExample proto.
PiperOrigin-RevId: 167359339
Diffstat (limited to 'tensorflow/core/example')
-rw-r--r--tensorflow/core/example/feature_util.cc124
-rw-r--r--tensorflow/core/example/feature_util.h228
-rw-r--r--tensorflow/core/example/feature_util_test.cc228
3 files changed, 470 insertions, 110 deletions
diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc
index 6f3cc6c6c5..f0593ede82 100644
--- a/tensorflow/core/example/feature_util.cc
+++ b/tensorflow/core/example/feature_util.cc
@@ -18,77 +18,129 @@ limitations under the License.
namespace tensorflow {
namespace internal {
-
-::tensorflow::Feature& ExampleFeature(const string& name,
- ::tensorflow::Example* example) {
- ::tensorflow::Features* features = example->mutable_features();
- return (*features->mutable_feature())[name];
+Feature& ExampleFeature(const string& name, Example* example) {
+ return *GetFeature(name, example);
}
-} // namespace internal
+} // namespace internal
template <>
-bool ExampleHasFeature<protobuf_int64>(const string& name,
- const Example& example) {
- auto it = example.features().feature().find(name);
- return (it != example.features().feature().end()) &&
+bool HasFeature<>(const string& key, const Features& features) {
+ return (features.feature().find(key) != features.feature().end());
+}
+
+template <>
+bool HasFeature<protobuf_int64>(const string& key, const Features& features) {
+ auto it = features.feature().find(key);
+ return (it != features.feature().end()) &&
(it->second.kind_case() == Feature::KindCase::kInt64List);
}
template <>
-bool ExampleHasFeature<float>(const string& name, const Example& example) {
- auto it = example.features().feature().find(name);
- return (it != example.features().feature().end()) &&
+bool HasFeature<float>(const string& key, const Features& features) {
+ auto it = features.feature().find(key);
+ return (it != features.feature().end()) &&
(it->second.kind_case() == Feature::KindCase::kFloatList);
}
template <>
-bool ExampleHasFeature<string>(const string& name, const Example& example) {
- auto it = example.features().feature().find(name);
- return (it != example.features().feature().end()) &&
+bool HasFeature<string>(const string& key, const Features& features) {
+ auto it = features.feature().find(key);
+ return (it != features.feature().end()) &&
(it->second.kind_case() == Feature::KindCase::kBytesList);
}
+bool HasFeatureList(const string& key,
+ const SequenceExample& sequence_example) {
+ auto& feature_list = sequence_example.feature_lists().feature_list();
+ return (feature_list.find(key) != feature_list.end());
+}
+
template <>
const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
- const string& name, const Example& example) {
- return example.features().feature().at(name).int64_list().value();
+ const Feature& feature) {
+ return feature.int64_list().value();
}
template <>
protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
- const string& name, Example* example) {
- return internal::ExampleFeature(name, example)
- .mutable_int64_list()
- ->mutable_value();
+ Feature* feature) {
+ return feature->mutable_int64_list()->mutable_value();
}
template <>
const protobuf::RepeatedField<float>& GetFeatureValues<float>(
- const string& name, const Example& example) {
- return example.features().feature().at(name).float_list().value();
+ const Feature& feature) {
+ return feature.float_list().value();
}
template <>
-protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name,
- Example* example) {
- return internal::ExampleFeature(name, example)
- .mutable_float_list()
- ->mutable_value();
+protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) {
+ return feature->mutable_float_list()->mutable_value();
}
template <>
const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
- const string& name, const Example& example) {
- return example.features().feature().at(name).bytes_list().value();
+ const Feature& feature) {
+ return feature.bytes_list().value();
+}
+
+template <>
+protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) {
+ return feature->mutable_bytes_list()->mutable_value();
+}
+
+const protobuf::RepeatedPtrField<Feature>& GetFeatureList(
+ const string& key, const SequenceExample& sequence_example) {
+ return sequence_example.feature_lists().feature_list().at(key).feature();
+}
+
+protobuf::RepeatedPtrField<Feature>* GetFeatureList(
+ const string& feature_list_key, SequenceExample* sequence_example) {
+ return (*sequence_example->mutable_feature_lists()
+ ->mutable_feature_list())[feature_list_key]
+ .mutable_feature();
+}
+
+template <>
+Features* GetFeatures<Features>(Features* proto) {
+ return proto;
+}
+
+template <>
+Features* GetFeatures<Example>(Example* proto) {
+ return proto->mutable_features();
}
template <>
-protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name,
- Example* example) {
- return internal::ExampleFeature(name, example)
- .mutable_bytes_list()
- ->mutable_value();
+const Features& GetFeatures<Features>(const Features& proto) {
+ return proto;
}
+template <>
+const Features& GetFeatures<Example>(const Example& proto) {
+ return proto.features();
+}
+
+template <>
+const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
+ const Feature& feature);
+
+template <>
+protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
+ Feature* feature);
+
+template <>
+const protobuf::RepeatedField<float>& GetFeatureValues<float>(
+ const Feature& feature);
+
+template <>
+protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature);
+
+template <>
+const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
+ const Feature& feature);
+
+template <>
+protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature);
} // namespace tensorflow
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index 4004411cb1..a87c2c9a57 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// A set of lightweight wrappers which simplify access to Example features.
+// A set of lightweight wrappers which simplify access to Feature protos.
//
// TensorFlow Example proto uses associative maps on top of oneof fields.
+// SequenceExample proto uses associative map of FeatureList.
// So accessing feature values is not very convenient.
//
// For example, to read a first value of integer feature "tag":
@@ -42,9 +43,59 @@ limitations under the License.
// (RepeatedPtrField for byte list). So refer to its documentation of
// RepeatedField for full list of supported methods.
//
-// NOTE: It is also important to mention that due to the nature of oneof proto
-// fields setting a feature of one type automatically clears all values stored
-// as another type with the same feature name.
+// NOTE: Due to the nature of oneof proto fields setting a feature of one type
+// automatically clears all values stored as another type with the same feature
+// key.
+//
+// This library also has tools to work with SequenceExample protos.
+//
+// To get a value from SequenceExample.context:
+// int id = GetFeatureValues<protobuf_int64>("tag", se.context()).Get(0);
+// To add a value to the context:
+// GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
+//
+// To add values to feature_lists:
+// AppendFeatureValues({4.0},
+// GetFeatureList("movie_ratings", &se)->Add());
+// AppendFeatureValues({5.0, 3.0},
+// GetFeatureList("movie_ratings", &se)->Add());
+// This will create a feature list keyed as "images" with two features:
+// feature_lists {
+// feature_list {
+// key: "images"
+// value {
+// feature { float_list { value: [4.0] } }
+// feature { float_list { value: [5.0, 3.0] } }
+// }
+// } }
+//
+// Functions exposed by this library:
+// HasFeature<[FeatureType]>(key, proto) -> bool
+// Returns true if a feature with the specified key, and optionally
+// FeatureType, belongs to the Features or Example proto.
+// HasFeatureList(key, sequence_example) -> bool
+// Returns true if SequenceExample has a feature_list with the key.
+// GetFeatureValues<FeatureType>(key, proto) -> RepeatedField<FeatureType>
+// Returns values for the specified key and the FeatureType.
+// Supported types for the proto: Example, Features.
+// GetFeatureList(key, sequence_example) -> RepeatedPtrField<Feature>
+// Returns Feature protos associated with a key.
+// AppendFeatureValues(begin, end, feature)
+// AppendFeatureValues(container or initializer_list, feature)
+// Copies values into a Feature.
+// AppendFeatureValues(begin, end, key, proto)
+// AppendFeatureValues(container or initializer_list, key, proto)
+// Copies values into Features and Example protos with the specified key.
+//
+// Auxiliary functions, it is unlikely you'll need to use them directly:
+// GetFeatures(proto) -> Features
+// A convenience function to get Features proto.
+// Supported types for the proto: Example, Features.
+// GetFeature(key, proto) -> Feature*
+// Returns a Feature proto for the specified key, creates a new if
+// necessary. Supported types for the proto: Example, Features.
+// GetFeatureValues<FeatureType>(feature) -> RepeatedField<FeatureType>
+// Returns values of the feature for the FeatureType.
#ifndef TENSORFLOW_EXAMPLE_FEATURE_H_
#define TENSORFLOW_EXAMPLE_FEATURE_H_
@@ -62,10 +113,11 @@ namespace tensorflow {
namespace internal {
+// DEPRECATED: Use GetFeature instead.
+// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
-::tensorflow::Feature& ExampleFeature(const string& name,
- ::tensorflow::Example* example);
+Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
// corresponding to a selected feature type.
@@ -127,89 +179,135 @@ struct FeatureTrait<
} // namespace internal
-// Returns true if feature with the specified name belongs to the example proto.
-// Doesn't check feature type. Note that specialized versions return false if
-// the feature has a wrong type.
-template <typename FeatureType = void>
-bool ExampleHasFeature(const string& name, const Example& example) {
- return example.features().feature().find(name) !=
- example.features().feature().end();
-}
+// Returns true if sequence_example has a feature_list with the specified key.
+bool HasFeatureList(const string& key, const SequenceExample& sequence_example);
+
+// A family of template functions to return mutable Features proto from a
+// container proto. Supported ProtoTypes: Example, Features.
+template <typename ProtoType>
+Features* GetFeatures(ProtoType* proto);
+
+template <typename ProtoType>
+const Features& GetFeatures(const ProtoType& proto);
// Base declaration of a family of template functions to return a read only
-// repeated field corresponding to a feature with the specified name.
+// repeated field of feature values.
template <typename FeatureType>
const typename internal::RepeatedFieldTrait<FeatureType>::Type&
-GetFeatureValues(const string& name, const Example& example);
+GetFeatureValues(const Feature& feature);
+
+// Returns a read only repeated field corresponding to a feature with the
+// specified name and FeatureType. Supported ProtoTypes: Example, Features.
+template <typename FeatureType, typename ProtoType>
+const typename internal::RepeatedFieldTrait<FeatureType>::Type&
+GetFeatureValues(const string& key, const ProtoType& proto) {
+ return GetFeatureValues<FeatureType>(GetFeatures(proto).feature().at(key));
+}
-// Base declaration of a family of template functions to return a mutable
-// repeated field corresponding to a feature with the specified name.
+// Returns a mutable repeated field of a feature values.
template <typename FeatureType>
typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues(
- const string& name, Example* example);
+ Feature* feature);
+
+// Returns a mutable repeated field corresponding to a feature with the
+// specified name and FeatureType. Supported ProtoTypes: Example, Features.
+template <typename FeatureType, typename ProtoType>
+typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues(
+ const string& key, ProtoType* proto) {
+ ::tensorflow::Feature& feature =
+ (*GetFeatures(proto)->mutable_feature())[key];
+ return GetFeatureValues<FeatureType>(&feature);
+}
+
+// Returns a Feature proto for the specified key, creates a new if necessary.
+// Supported types for the proto: Example, Features.
+template <typename ProtoType>
+Feature* GetFeature(const string& key, ProtoType* proto) {
+ return &(*GetFeatures(proto)->mutable_feature())[key];
+}
+
+// Returns a repeated field with features corresponding to a feature_list key.
+const protobuf::RepeatedPtrField<Feature>& GetFeatureList(
+ const string& key, const SequenceExample& sequence_example);
+
+// Returns a mutable repeated field with features corresponding to a
+// feature_list key. It will create a new FeatureList if necessary.
+protobuf::RepeatedPtrField<Feature>* GetFeatureList(
+ const string& feature_list_key, SequenceExample* sequence_example);
-// Copies elements from the range, defined by [first, last) into a feature.
template <typename IteratorType>
void AppendFeatureValues(IteratorType first, IteratorType last,
- const string& name, Example* example) {
+ Feature* feature) {
using FeatureType = typename internal::FeatureTrait<
typename std::iterator_traits<IteratorType>::value_type>::Type;
- std::copy(first, last, protobuf::RepeatedFieldBackInserter(
- GetFeatureValues<FeatureType>(name, example)));
+ std::copy(first, last,
+ protobuf::RepeatedFieldBackInserter(
+ GetFeatureValues<FeatureType>(feature)));
+}
+
+template <typename ValueType>
+void AppendFeatureValues(std::initializer_list<ValueType> container,
+ Feature* feature) {
+ AppendFeatureValues(container.begin(), container.end(), feature);
}
-// Copies all elements from the container into a feature.
template <typename ContainerType>
-void AppendFeatureValues(const ContainerType& container, const string& name,
- Example* example) {
+void AppendFeatureValues(const ContainerType& container, Feature* feature) {
using IteratorType = typename ContainerType::const_iterator;
- AppendFeatureValues<IteratorType>(container.begin(), container.end(), name,
- example);
+ AppendFeatureValues<IteratorType>(container.begin(), container.end(),
+ feature);
}
-// Copies all elements from the initializer list into a feature.
-template <typename ValueType>
+// Copies elements from the range, defined by [first, last) into the feature
+// obtainable from the (proto, key) combination.
+template <typename IteratorType, typename ProtoType>
+void AppendFeatureValues(IteratorType first, IteratorType last,
+ const string& key, ProtoType* proto) {
+ AppendFeatureValues(first, last, GetFeature(key, GetFeatures(proto)));
+}
+
+// Copies all elements from the container into a feature.
+template <typename ContainerType, typename ProtoType>
+void AppendFeatureValues(const ContainerType& container, const string& key,
+ ProtoType* proto) {
+ using IteratorType = typename ContainerType::const_iterator;
+ AppendFeatureValues<IteratorType>(container.begin(), container.end(), key,
+ proto);
+}
+
+// Copies all elements from the initializer list into a Feature contained by
+// Features or Example proto.
+template <typename ValueType, typename ProtoType>
void AppendFeatureValues(std::initializer_list<ValueType> container,
- const string& name, Example* example) {
+ const string& key, ProtoType* proto) {
using IteratorType =
typename std::initializer_list<ValueType>::const_iterator;
- AppendFeatureValues<IteratorType>(container.begin(), container.end(), name,
- example);
+ AppendFeatureValues<IteratorType>(container.begin(), container.end(), key,
+ proto);
}
-template <>
-bool ExampleHasFeature<protobuf_int64>(const string& name,
- const Example& example);
-
-template <>
-bool ExampleHasFeature<float>(const string& name, const Example& example);
-
-template <>
-bool ExampleHasFeature<string>(const string& name, const Example& example);
-
-template <>
-const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>(
- const string& name, const Example& example);
-
-template <>
-protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>(
- const string& name, Example* example);
-
-template <>
-const protobuf::RepeatedField<float>& GetFeatureValues<float>(
- const string& name, const Example& example);
-
-template <>
-protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name,
- Example* example);
-
-template <>
-const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
- const string& name, const Example& example);
+// Returns true if a feature with the specified key belongs to the Features.
+// The template parameter pack accepts zero or one template argument - which
+// is FeatureType. If the FeatureType not specified (zero template arguments)
+// the function will not check the feature type. Otherwise it will return false
+// if the feature has a wrong type.
+template <typename... FeatureType>
+bool HasFeature(const string& key, const Features& features);
+
+// Returns true if a feature with the specified key belongs to the Example.
+// Doesn't check feature type if used without FeatureType, otherwise the
+// specialized versions return false if the feature has a wrong type.
+template <typename... FeatureType>
+bool HasFeature(const string& key, const Example& example) {
+ return HasFeature<FeatureType...>(key, GetFeatures(example));
+};
-template <>
-protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name,
- Example* example);
+// DEPRECATED: use HasFeature instead.
+// TODO(gorban): update all clients in a followup CL.
+template <typename... FeatureType>
+bool ExampleHasFeature(const string& key, const Example& example) {
+ return HasFeature<FeatureType...>(key, example);
+}
} // namespace tensorflow
#endif // TENSORFLOW_EXAMPLE_FEATURE_H_
diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc
index eb7b90af1b..cd32dee306 100644
--- a/tensorflow/core/example/feature_util_test.cc
+++ b/tensorflow/core/example/feature_util_test.cc
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
#include "tensorflow/core/example/feature_util.h"
#include <vector>
@@ -38,6 +37,16 @@ TEST(GetFeatureValuesInt64Test, ReadsASingleValue) {
EXPECT_EQ(42, tag.Get(0));
}
+TEST(GetFeatureValuesInt64Test, ReadsASingleValueFromFeature) {
+ Feature feature;
+ feature.mutable_int64_list()->add_value(42);
+
+ auto values = GetFeatureValues<protobuf_int64>(feature);
+
+ ASSERT_EQ(1, values.size());
+ EXPECT_EQ(42, values.Get(0));
+}
+
TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
Example example;
@@ -48,25 +57,33 @@ TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0));
}
+TEST(GetFeatureValuesInt64Test, WritesASingleValueToFeature) {
+ Feature feature;
+
+ GetFeatureValues<protobuf_int64>(&feature)->Add(42);
+
+ ASSERT_EQ(1, feature.int64_list().value_size());
+ EXPECT_EQ(42, feature.int64_list().value(0));
+}
+
TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) {
Example example;
-
- EXPECT_FALSE(ExampleHasFeature("tag", example));
+ ASSERT_FALSE(HasFeature("tag", example));
GetFeatureValues<protobuf_int64>("tag", &example)->Add(0);
- EXPECT_TRUE(ExampleHasFeature("tag", example));
+ EXPECT_TRUE(HasFeature("tag", example));
}
TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) {
Example example;
GetFeatureValues<float>("tag", &example)->Add(3.14);
- ASSERT_FALSE(ExampleHasFeature<protobuf_int64>("tag", example));
+ ASSERT_FALSE(HasFeature<protobuf_int64>("tag", example));
GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
- EXPECT_TRUE(ExampleHasFeature<protobuf_int64>("tag", example));
+ EXPECT_TRUE(HasFeature<protobuf_int64>("tag", example));
auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example);
ASSERT_EQ(1, tag_ro.size());
EXPECT_EQ(42, tag_ro.Get(0));
@@ -87,6 +104,16 @@ TEST(GetFeatureValuesInt64Test, CopyIterableToAField) {
EXPECT_EQ(3, tag_ro.Get(2));
}
+TEST(GetFeatureValuesFloatTest, ReadsASingleValueFromFeature) {
+ Feature feature;
+ feature.mutable_float_list()->add_value(3.14);
+
+ auto values = GetFeatureValues<float>(feature);
+
+ ASSERT_EQ(1, values.size());
+ EXPECT_NEAR(3.14, values.Get(0), kTolerance);
+}
+
TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
Example example;
(*example.mutable_features()->mutable_feature())["tag"]
@@ -99,6 +126,15 @@ TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
EXPECT_NEAR(3.14, tag.Get(0), kTolerance);
}
+TEST(GetFeatureValuesFloatTest, WritesASingleValueToFeature) {
+ Feature feature;
+
+ GetFeatureValues<float>(&feature)->Add(3.14);
+
+ ASSERT_EQ(1, feature.float_list().value_size());
+ EXPECT_NEAR(3.14, feature.float_list().value(0), kTolerance);
+}
+
TEST(GetFeatureValuesFloatTest, WritesASingleValue) {
Example example;
@@ -115,6 +151,20 @@ TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
Example example;
GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
+ ASSERT_FALSE(HasFeature<float>("tag", example));
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+
+ EXPECT_TRUE(HasFeature<float>("tag", example));
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
+}
+
+TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistenceForDeprecatedMethod) {
+ Example example;
+
+ GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
ASSERT_FALSE(ExampleHasFeature<float>("tag", example));
GetFeatureValues<float>("tag", &example)->Add(3.14);
@@ -125,6 +175,16 @@ TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
}
+TEST(GetFeatureValuesStringTest, ReadsASingleValueFromFeature) {
+ Feature feature;
+ feature.mutable_bytes_list()->add_value("FOO");
+
+ auto values = GetFeatureValues<string>(feature);
+
+ ASSERT_EQ(1, values.size());
+ EXPECT_EQ("FOO", values.Get(0));
+}
+
TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
Example example;
(*example.mutable_features()->mutable_feature())["tag"]
@@ -137,6 +197,15 @@ TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
EXPECT_EQ("FOO", tag.Get(0));
}
+TEST(GetFeatureValuesStringTest, WritesASingleValueToFeature) {
+ Feature feature;
+
+ *GetFeatureValues<string>(&feature)->Add() = "FOO";
+
+ ASSERT_EQ(1, feature.bytes_list().value_size());
+ EXPECT_EQ("FOO", feature.bytes_list().value(0));
+}
+
TEST(GetFeatureValuesStringTest, WritesASingleValue) {
Example example;
@@ -148,15 +217,15 @@ TEST(GetFeatureValuesStringTest, WritesASingleValue) {
example.features().feature().at("tag").bytes_list().value(0));
}
-TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) {
+TEST(GetFeatureValuesStringTest, CheckTypedFieldExistence) {
Example example;
GetFeatureValues<protobuf_int64>("tag", &example)->Add(42);
- ASSERT_FALSE(ExampleHasFeature<string>("tag", example));
+ ASSERT_FALSE(HasFeature<string>("tag", example));
*GetFeatureValues<string>("tag", &example)->Add() = "FOO";
- EXPECT_TRUE(ExampleHasFeature<string>("tag", example));
+ EXPECT_TRUE(HasFeature<string>("tag", example));
auto tag_ro = GetFeatureValues<string>("tag", example);
ASSERT_EQ(1, tag_ro.size());
EXPECT_EQ("FOO", tag_ro.Get(0));
@@ -228,5 +297,146 @@ TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) {
EXPECT_EQ("BAZ", tag_ro.Get(2));
}
+TEST(SequenceExampleTest, ReadsASingleValueFromContext) {
+ SequenceExample se;
+ (*se.mutable_context()->mutable_feature())["tag"]
+ .mutable_int64_list()
+ ->add_value(42);
+
+ auto values = GetFeatureValues<protobuf_int64>("tag", se.context());
+
+ ASSERT_EQ(1, values.size());
+ EXPECT_EQ(42, values.Get(0));
+}
+
+TEST(SequenceExampleTest, WritesASingleValueToContext) {
+ SequenceExample se;
+
+ GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
+
+ ASSERT_EQ(1, se.context().feature().at("tag").int64_list().value_size());
+ EXPECT_EQ(42, se.context().feature().at("tag").int64_list().value(0));
+}
+
+TEST(SequenceExampleTest, AppendFeatureValuesToContextSingleArg) {
+ SequenceExample se;
+
+ AppendFeatureValues({1.1, 2.2, 3.3}, "tag", se.mutable_context());
+
+ auto tag_ro = GetFeatureValues<float>("tag", se.context());
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
+ EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
+ EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
+}
+
+TEST(SequenceExampleTest, CheckTypedFieldExistence) {
+ SequenceExample se;
+
+ GetFeatureValues<float>("tag", se.mutable_context())->Add(3.14);
+ ASSERT_FALSE(HasFeature<protobuf_int64>("tag", se.context()));
+
+ GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42);
+
+ EXPECT_TRUE(HasFeature<protobuf_int64>("tag", se.context()));
+ auto tag_ro = GetFeatureValues<protobuf_int64>("tag", se.context());
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_EQ(42, tag_ro.Get(0));
+}
+
+TEST(SequenceExampleTest, ReturnsExistingFeatureLists) {
+ SequenceExample se;
+ (*se.mutable_feature_lists()->mutable_feature_list())["tag"]
+ .mutable_feature()
+ ->Add();
+
+ auto feature = GetFeatureList("tag", se);
+
+ ASSERT_EQ(1, feature.size());
+}
+
+TEST(SequenceExampleTest, CreatesNewFeatureLists) {
+ SequenceExample se;
+
+ GetFeatureList("tag", &se)->Add();
+
+ EXPECT_EQ(1, se.feature_lists().feature_list().at("tag").feature_size());
+}
+
+TEST(SequenceExampleTest, CheckFeatureListExistence) {
+ SequenceExample se;
+ ASSERT_FALSE(HasFeatureList("tag", se));
+
+ GetFeatureList("tag", &se)->Add();
+
+ ASSERT_TRUE(HasFeatureList("tag", se));
+}
+
+TEST(SequenceExampleTest, AppendFeatureValuesWithInitializerList) {
+ SequenceExample se;
+
+ AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context());
+ AppendFeatureValues({"cam1-0", "cam2-0"},
+ GetFeatureList("images", &se)->Add());
+ AppendFeatureValues({"cam1-1", "cam2-2"},
+ GetFeatureList("images", &se)->Add());
+
+ EXPECT_EQ(se.DebugString(),
+ "context {\n"
+ " feature {\n"
+ " key: \"ids\"\n"
+ " value {\n"
+ " int64_list {\n"
+ " value: 1\n"
+ " value: 2\n"
+ " value: 3\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n"
+ "feature_lists {\n"
+ " feature_list {\n"
+ " key: \"images\"\n"
+ " value {\n"
+ " feature {\n"
+ " bytes_list {\n"
+ " value: \"cam1-0\"\n"
+ " value: \"cam2-0\"\n"
+ " }\n"
+ " }\n"
+ " feature {\n"
+ " bytes_list {\n"
+ " value: \"cam1-1\"\n"
+ " value: \"cam2-2\"\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n");
+}
+
+TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) {
+ SequenceExample se;
+
+ std::vector<float> readings{1.0, 2.5, 5.0};
+ AppendFeatureValues(readings, GetFeatureList("movie_ratings", &se)->Add());
+
+ EXPECT_EQ(se.DebugString(),
+ "feature_lists {\n"
+ " feature_list {\n"
+ " key: \"movie_ratings\"\n"
+ " value {\n"
+ " feature {\n"
+ " float_list {\n"
+ " value: 1\n"
+ " value: 2.5\n"
+ " value: 5\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ " }\n"
+ "}\n");
+}
+
} // namespace
} // namespace tensorflow