diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-01 22:16:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-01 22:19:58 -0700 |
commit | 7d5cbd78a54319eeb45bca2e239ec037997dad20 (patch) | |
tree | b71b84aa688a9cd1a9360552037a7b1c4bbe9520 /tensorflow/core/example | |
parent | 2935132bce62cbc785da68997248e449d63f5ff2 (diff) |
Update feature_util to support SequenceExample proto.
PiperOrigin-RevId: 167359339
Diffstat (limited to 'tensorflow/core/example')
-rw-r--r-- | tensorflow/core/example/feature_util.cc | 124 | ||||
-rw-r--r-- | tensorflow/core/example/feature_util.h | 228 | ||||
-rw-r--r-- | tensorflow/core/example/feature_util_test.cc | 228 |
3 files changed, 470 insertions, 110 deletions
diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc index 6f3cc6c6c5..f0593ede82 100644 --- a/tensorflow/core/example/feature_util.cc +++ b/tensorflow/core/example/feature_util.cc @@ -18,77 +18,129 @@ limitations under the License. namespace tensorflow { namespace internal { - -::tensorflow::Feature& ExampleFeature(const string& name, - ::tensorflow::Example* example) { - ::tensorflow::Features* features = example->mutable_features(); - return (*features->mutable_feature())[name]; +Feature& ExampleFeature(const string& name, Example* example) { + return *GetFeature(name, example); } -} // namespace internal +} // namespace internal template <> -bool ExampleHasFeature<protobuf_int64>(const string& name, - const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature<>(const string& key, const Features& features) { + return (features.feature().find(key) != features.feature().end()); +} + +template <> +bool HasFeature<protobuf_int64>(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kInt64List); } template <> -bool ExampleHasFeature<float>(const string& name, const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature<float>(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kFloatList); } template <> -bool ExampleHasFeature<string>(const string& name, const Example& example) { - auto it = example.features().feature().find(name); - return (it != example.features().feature().end()) && +bool HasFeature<string>(const string& key, const Features& features) { + auto it = features.feature().find(key); + return (it != features.feature().end()) && (it->second.kind_case() == Feature::KindCase::kBytesList); } +bool HasFeatureList(const string& key, + const SequenceExample& sequence_example) { + auto& feature_list = sequence_example.feature_lists().feature_list(); + return (feature_list.find(key) != feature_list.end()); +} + template <> const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>( - const string& name, const Example& example) { - return example.features().feature().at(name).int64_list().value(); + const Feature& feature) { + return feature.int64_list().value(); } template <> protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>( - const string& name, Example* example) { - return internal::ExampleFeature(name, example) - .mutable_int64_list() - ->mutable_value(); + Feature* feature) { + return feature->mutable_int64_list()->mutable_value(); } template <> const protobuf::RepeatedField<float>& GetFeatureValues<float>( - const string& name, const Example& example) { - return example.features().feature().at(name).float_list().value(); + const Feature& feature) { + return feature.float_list().value(); } template <> -protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name, - Example* example) { - return internal::ExampleFeature(name, example) - .mutable_float_list() - ->mutable_value(); +protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature) { + return feature->mutable_float_list()->mutable_value(); } template <> const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>( - const string& name, const Example& example) { - return example.features().feature().at(name).bytes_list().value(); + const Feature& feature) { + return feature.bytes_list().value(); +} + +template <> +protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature) { + return feature->mutable_bytes_list()->mutable_value(); +} + +const protobuf::RepeatedPtrField<Feature>& GetFeatureList( + const string& key, const SequenceExample& sequence_example) { + return sequence_example.feature_lists().feature_list().at(key).feature(); +} + +protobuf::RepeatedPtrField<Feature>* GetFeatureList( + const string& feature_list_key, SequenceExample* sequence_example) { + return (*sequence_example->mutable_feature_lists() + ->mutable_feature_list())[feature_list_key] + .mutable_feature(); +} + +template <> +Features* GetFeatures<Features>(Features* proto) { + return proto; +} + +template <> +Features* GetFeatures<Example>(Example* proto) { + return proto->mutable_features(); } template <> -protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name, - Example* example) { - return internal::ExampleFeature(name, example) - .mutable_bytes_list() - ->mutable_value(); +const Features& GetFeatures<Features>(const Features& proto) { + return proto; } +template <> +const Features& GetFeatures<Example>(const Example& proto) { + return proto.features(); +} + +template <> +const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>( + const Feature& feature); + +template <> +protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>( + Feature* feature); + +template <> +const protobuf::RepeatedField<float>& GetFeatureValues<float>( + const Feature& feature); + +template <> +protobuf::RepeatedField<float>* GetFeatureValues<float>(Feature* feature); + +template <> +const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>( + const Feature& feature); + +template <> +protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(Feature* feature); } // namespace tensorflow diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 4004411cb1..a87c2c9a57 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// A set of lightweight wrappers which simplify access to Example features. +// A set of lightweight wrappers which simplify access to Feature protos. // // TensorFlow Example proto uses associative maps on top of oneof fields. +// SequenceExample proto uses associative map of FeatureList. // So accessing feature values is not very convenient. // // For example, to read a first value of integer feature "tag": @@ -42,9 +43,59 @@ limitations under the License. // (RepeatedPtrField for byte list). So refer to its documentation of // RepeatedField for full list of supported methods. // -// NOTE: It is also important to mention that due to the nature of oneof proto -// fields setting a feature of one type automatically clears all values stored -// as another type with the same feature name. +// NOTE: Due to the nature of oneof proto fields setting a feature of one type +// automatically clears all values stored as another type with the same feature +// key. +// +// This library also has tools to work with SequenceExample protos. +// +// To get a value from SequenceExample.context: +// int id = GetFeatureValues<protobuf_int64>("tag", se.context()).Get(0); +// To add a value to the context: +// GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42); +// +// To add values to feature_lists: +// AppendFeatureValues({4.0}, +// GetFeatureList("movie_ratings", &se)->Add()); +// AppendFeatureValues({5.0, 3.0}, +// GetFeatureList("movie_ratings", &se)->Add()); +// This will create a feature list keyed as "images" with two features: +// feature_lists { +// feature_list { +// key: "images" +// value { +// feature { float_list { value: [4.0] } } +// feature { float_list { value: [5.0, 3.0] } } +// } +// } } +// +// Functions exposed by this library: +// HasFeature<[FeatureType]>(key, proto) -> bool +// Returns true if a feature with the specified key, and optionally +// FeatureType, belongs to the Features or Example proto. +// HasFeatureList(key, sequence_example) -> bool +// Returns true if SequenceExample has a feature_list with the key. +// GetFeatureValues<FeatureType>(key, proto) -> RepeatedField<FeatureType> +// Returns values for the specified key and the FeatureType. +// Supported types for the proto: Example, Features. +// GetFeatureList(key, sequence_example) -> RepeatedPtrField<Feature> +// Returns Feature protos associated with a key. +// AppendFeatureValues(begin, end, feature) +// AppendFeatureValues(container or initializer_list, feature) +// Copies values into a Feature. +// AppendFeatureValues(begin, end, key, proto) +// AppendFeatureValues(container or initializer_list, key, proto) +// Copies values into Features and Example protos with the specified key. +// +// Auxiliary functions, it is unlikely you'll need to use them directly: +// GetFeatures(proto) -> Features +// A convenience function to get Features proto. +// Supported types for the proto: Example, Features. +// GetFeature(key, proto) -> Feature* +// Returns a Feature proto for the specified key, creates a new if +// necessary. Supported types for the proto: Example, Features. +// GetFeatureValues<FeatureType>(feature) -> RepeatedField<FeatureType> +// Returns values of the feature for the FeatureType. #ifndef TENSORFLOW_EXAMPLE_FEATURE_H_ #define TENSORFLOW_EXAMPLE_FEATURE_H_ @@ -62,10 +113,11 @@ namespace tensorflow { namespace internal { +// DEPRECATED: Use GetFeature instead. +// TODO(gorban): Update all clients in a followup CL. // Returns a reference to a feature corresponding to the name. // Note: it will create a new Feature if it is missing in the example. -::tensorflow::Feature& ExampleFeature(const string& name, - ::tensorflow::Example* example); +Feature& ExampleFeature(const string& name, Example* example); // Specializations of RepeatedFieldTrait define a type of RepeatedField // corresponding to a selected feature type. @@ -127,89 +179,135 @@ struct FeatureTrait< } // namespace internal -// Returns true if feature with the specified name belongs to the example proto. -// Doesn't check feature type. Note that specialized versions return false if -// the feature has a wrong type. -template <typename FeatureType = void> -bool ExampleHasFeature(const string& name, const Example& example) { - return example.features().feature().find(name) != - example.features().feature().end(); -} +// Returns true if sequence_example has a feature_list with the specified key. +bool HasFeatureList(const string& key, const SequenceExample& sequence_example); + +// A family of template functions to return mutable Features proto from a +// container proto. Supported ProtoTypes: Example, Features. +template <typename ProtoType> +Features* GetFeatures(ProtoType* proto); + +template <typename ProtoType> +const Features& GetFeatures(const ProtoType& proto); // Base declaration of a family of template functions to return a read only -// repeated field corresponding to a feature with the specified name. +// repeated field of feature values. template <typename FeatureType> const typename internal::RepeatedFieldTrait<FeatureType>::Type& -GetFeatureValues(const string& name, const Example& example); +GetFeatureValues(const Feature& feature); + +// Returns a read only repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: Example, Features. +template <typename FeatureType, typename ProtoType> +const typename internal::RepeatedFieldTrait<FeatureType>::Type& +GetFeatureValues(const string& key, const ProtoType& proto) { + return GetFeatureValues<FeatureType>(GetFeatures(proto).feature().at(key)); +} -// Base declaration of a family of template functions to return a mutable -// repeated field corresponding to a feature with the specified name. +// Returns a mutable repeated field of a feature values. template <typename FeatureType> typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues( - const string& name, Example* example); + Feature* feature); + +// Returns a mutable repeated field corresponding to a feature with the +// specified name and FeatureType. Supported ProtoTypes: Example, Features. +template <typename FeatureType, typename ProtoType> +typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues( + const string& key, ProtoType* proto) { + ::tensorflow::Feature& feature = + (*GetFeatures(proto)->mutable_feature())[key]; + return GetFeatureValues<FeatureType>(&feature); +} + +// Returns a Feature proto for the specified key, creates a new if necessary. +// Supported types for the proto: Example, Features. +template <typename ProtoType> +Feature* GetFeature(const string& key, ProtoType* proto) { + return &(*GetFeatures(proto)->mutable_feature())[key]; +} + +// Returns a repeated field with features corresponding to a feature_list key. +const protobuf::RepeatedPtrField<Feature>& GetFeatureList( + const string& key, const SequenceExample& sequence_example); + +// Returns a mutable repeated field with features corresponding to a +// feature_list key. It will create a new FeatureList if necessary. +protobuf::RepeatedPtrField<Feature>* GetFeatureList( + const string& feature_list_key, SequenceExample* sequence_example); -// Copies elements from the range, defined by [first, last) into a feature. template <typename IteratorType> void AppendFeatureValues(IteratorType first, IteratorType last, - const string& name, Example* example) { + Feature* feature) { using FeatureType = typename internal::FeatureTrait< typename std::iterator_traits<IteratorType>::value_type>::Type; - std::copy(first, last, protobuf::RepeatedFieldBackInserter( - GetFeatureValues<FeatureType>(name, example))); + std::copy(first, last, + protobuf::RepeatedFieldBackInserter( + GetFeatureValues<FeatureType>(feature))); +} + +template <typename ValueType> +void AppendFeatureValues(std::initializer_list<ValueType> container, + Feature* feature) { + AppendFeatureValues(container.begin(), container.end(), feature); } -// Copies all elements from the container into a feature. template <typename ContainerType> -void AppendFeatureValues(const ContainerType& container, const string& name, - Example* example) { +void AppendFeatureValues(const ContainerType& container, Feature* feature) { using IteratorType = typename ContainerType::const_iterator; - AppendFeatureValues<IteratorType>(container.begin(), container.end(), name, - example); + AppendFeatureValues<IteratorType>(container.begin(), container.end(), + feature); } -// Copies all elements from the initializer list into a feature. -template <typename ValueType> +// Copies elements from the range, defined by [first, last) into the feature +// obtainable from the (proto, key) combination. +template <typename IteratorType, typename ProtoType> +void AppendFeatureValues(IteratorType first, IteratorType last, + const string& key, ProtoType* proto) { + AppendFeatureValues(first, last, GetFeature(key, GetFeatures(proto))); +} + +// Copies all elements from the container into a feature. +template <typename ContainerType, typename ProtoType> +void AppendFeatureValues(const ContainerType& container, const string& key, + ProtoType* proto) { + using IteratorType = typename ContainerType::const_iterator; + AppendFeatureValues<IteratorType>(container.begin(), container.end(), key, + proto); +} + +// Copies all elements from the initializer list into a Feature contained by +// Features or Example proto. +template <typename ValueType, typename ProtoType> void AppendFeatureValues(std::initializer_list<ValueType> container, - const string& name, Example* example) { + const string& key, ProtoType* proto) { using IteratorType = typename std::initializer_list<ValueType>::const_iterator; - AppendFeatureValues<IteratorType>(container.begin(), container.end(), name, - example); + AppendFeatureValues<IteratorType>(container.begin(), container.end(), key, + proto); } -template <> -bool ExampleHasFeature<protobuf_int64>(const string& name, - const Example& example); - -template <> -bool ExampleHasFeature<float>(const string& name, const Example& example); - -template <> -bool ExampleHasFeature<string>(const string& name, const Example& example); - -template <> -const protobuf::RepeatedField<protobuf_int64>& GetFeatureValues<protobuf_int64>( - const string& name, const Example& example); - -template <> -protobuf::RepeatedField<protobuf_int64>* GetFeatureValues<protobuf_int64>( - const string& name, Example* example); - -template <> -const protobuf::RepeatedField<float>& GetFeatureValues<float>( - const string& name, const Example& example); - -template <> -protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name, - Example* example); - -template <> -const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>( - const string& name, const Example& example); +// Returns true if a feature with the specified key belongs to the Features. +// The template parameter pack accepts zero or one template argument - which +// is FeatureType. If the FeatureType not specified (zero template arguments) +// the function will not check the feature type. Otherwise it will return false +// if the feature has a wrong type. +template <typename... FeatureType> +bool HasFeature(const string& key, const Features& features); + +// Returns true if a feature with the specified key belongs to the Example. +// Doesn't check feature type if used without FeatureType, otherwise the +// specialized versions return false if the feature has a wrong type. +template <typename... FeatureType> +bool HasFeature(const string& key, const Example& example) { + return HasFeature<FeatureType...>(key, GetFeatures(example)); +}; -template <> -protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name, - Example* example); +// DEPRECATED: use HasFeature instead. +// TODO(gorban): update all clients in a followup CL. +template <typename... FeatureType> +bool ExampleHasFeature(const string& key, const Example& example) { + return HasFeature<FeatureType...>(key, example); +} } // namespace tensorflow #endif // TENSORFLOW_EXAMPLE_FEATURE_H_ diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc index eb7b90af1b..cd32dee306 100644 --- a/tensorflow/core/example/feature_util_test.cc +++ b/tensorflow/core/example/feature_util_test.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - #include "tensorflow/core/example/feature_util.h" #include <vector> @@ -38,6 +37,16 @@ TEST(GetFeatureValuesInt64Test, ReadsASingleValue) { EXPECT_EQ(42, tag.Get(0)); } +TEST(GetFeatureValuesInt64Test, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_int64_list()->add_value(42); + + auto values = GetFeatureValues<protobuf_int64>(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ(42, values.Get(0)); +} + TEST(GetFeatureValuesInt64Test, WritesASingleValue) { Example example; @@ -48,25 +57,33 @@ TEST(GetFeatureValuesInt64Test, WritesASingleValue) { EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0)); } +TEST(GetFeatureValuesInt64Test, WritesASingleValueToFeature) { + Feature feature; + + GetFeatureValues<protobuf_int64>(&feature)->Add(42); + + ASSERT_EQ(1, feature.int64_list().value_size()); + EXPECT_EQ(42, feature.int64_list().value(0)); +} + TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) { Example example; - - EXPECT_FALSE(ExampleHasFeature("tag", example)); + ASSERT_FALSE(HasFeature("tag", example)); GetFeatureValues<protobuf_int64>("tag", &example)->Add(0); - EXPECT_TRUE(ExampleHasFeature("tag", example)); + EXPECT_TRUE(HasFeature("tag", example)); } TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) { Example example; GetFeatureValues<float>("tag", &example)->Add(3.14); - ASSERT_FALSE(ExampleHasFeature<protobuf_int64>("tag", example)); + ASSERT_FALSE(HasFeature<protobuf_int64>("tag", example)); GetFeatureValues<protobuf_int64>("tag", &example)->Add(42); - EXPECT_TRUE(ExampleHasFeature<protobuf_int64>("tag", example)); + EXPECT_TRUE(HasFeature<protobuf_int64>("tag", example)); auto tag_ro = GetFeatureValues<protobuf_int64>("tag", example); ASSERT_EQ(1, tag_ro.size()); EXPECT_EQ(42, tag_ro.Get(0)); @@ -87,6 +104,16 @@ TEST(GetFeatureValuesInt64Test, CopyIterableToAField) { EXPECT_EQ(3, tag_ro.Get(2)); } +TEST(GetFeatureValuesFloatTest, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_float_list()->add_value(3.14); + + auto values = GetFeatureValues<float>(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_NEAR(3.14, values.Get(0), kTolerance); +} + TEST(GetFeatureValuesFloatTest, ReadsASingleValue) { Example example; (*example.mutable_features()->mutable_feature())["tag"] @@ -99,6 +126,15 @@ TEST(GetFeatureValuesFloatTest, ReadsASingleValue) { EXPECT_NEAR(3.14, tag.Get(0), kTolerance); } +TEST(GetFeatureValuesFloatTest, WritesASingleValueToFeature) { + Feature feature; + + GetFeatureValues<float>(&feature)->Add(3.14); + + ASSERT_EQ(1, feature.float_list().value_size()); + EXPECT_NEAR(3.14, feature.float_list().value(0), kTolerance); +} + TEST(GetFeatureValuesFloatTest, WritesASingleValue) { Example example; @@ -115,6 +151,20 @@ TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) { Example example; GetFeatureValues<protobuf_int64>("tag", &example)->Add(42); + ASSERT_FALSE(HasFeature<float>("tag", example)); + + GetFeatureValues<float>("tag", &example)->Add(3.14); + + EXPECT_TRUE(HasFeature<float>("tag", example)); + auto tag_ro = GetFeatureValues<float>("tag", example); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance); +} + +TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistenceForDeprecatedMethod) { + Example example; + + GetFeatureValues<protobuf_int64>("tag", &example)->Add(42); ASSERT_FALSE(ExampleHasFeature<float>("tag", example)); GetFeatureValues<float>("tag", &example)->Add(3.14); @@ -125,6 +175,16 @@ TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) { EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance); } +TEST(GetFeatureValuesStringTest, ReadsASingleValueFromFeature) { + Feature feature; + feature.mutable_bytes_list()->add_value("FOO"); + + auto values = GetFeatureValues<string>(feature); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ("FOO", values.Get(0)); +} + TEST(GetFeatureValuesStringTest, ReadsASingleValue) { Example example; (*example.mutable_features()->mutable_feature())["tag"] @@ -137,6 +197,15 @@ TEST(GetFeatureValuesStringTest, ReadsASingleValue) { EXPECT_EQ("FOO", tag.Get(0)); } +TEST(GetFeatureValuesStringTest, WritesASingleValueToFeature) { + Feature feature; + + *GetFeatureValues<string>(&feature)->Add() = "FOO"; + + ASSERT_EQ(1, feature.bytes_list().value_size()); + EXPECT_EQ("FOO", feature.bytes_list().value(0)); +} + TEST(GetFeatureValuesStringTest, WritesASingleValue) { Example example; @@ -148,15 +217,15 @@ TEST(GetFeatureValuesStringTest, WritesASingleValue) { example.features().feature().at("tag").bytes_list().value(0)); } -TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) { +TEST(GetFeatureValuesStringTest, CheckTypedFieldExistence) { Example example; GetFeatureValues<protobuf_int64>("tag", &example)->Add(42); - ASSERT_FALSE(ExampleHasFeature<string>("tag", example)); + ASSERT_FALSE(HasFeature<string>("tag", example)); *GetFeatureValues<string>("tag", &example)->Add() = "FOO"; - EXPECT_TRUE(ExampleHasFeature<string>("tag", example)); + EXPECT_TRUE(HasFeature<string>("tag", example)); auto tag_ro = GetFeatureValues<string>("tag", example); ASSERT_EQ(1, tag_ro.size()); EXPECT_EQ("FOO", tag_ro.Get(0)); @@ -228,5 +297,146 @@ TEST(AppendFeatureValuesTest, StringVariablesUsingInitializerList) { EXPECT_EQ("BAZ", tag_ro.Get(2)); } +TEST(SequenceExampleTest, ReadsASingleValueFromContext) { + SequenceExample se; + (*se.mutable_context()->mutable_feature())["tag"] + .mutable_int64_list() + ->add_value(42); + + auto values = GetFeatureValues<protobuf_int64>("tag", se.context()); + + ASSERT_EQ(1, values.size()); + EXPECT_EQ(42, values.Get(0)); +} + +TEST(SequenceExampleTest, WritesASingleValueToContext) { + SequenceExample se; + + GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42); + + ASSERT_EQ(1, se.context().feature().at("tag").int64_list().value_size()); + EXPECT_EQ(42, se.context().feature().at("tag").int64_list().value(0)); +} + +TEST(SequenceExampleTest, AppendFeatureValuesToContextSingleArg) { + SequenceExample se; + + AppendFeatureValues({1.1, 2.2, 3.3}, "tag", se.mutable_context()); + + auto tag_ro = GetFeatureValues<float>("tag", se.context()); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance); + EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance); + EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance); +} + +TEST(SequenceExampleTest, CheckTypedFieldExistence) { + SequenceExample se; + + GetFeatureValues<float>("tag", se.mutable_context())->Add(3.14); + ASSERT_FALSE(HasFeature<protobuf_int64>("tag", se.context())); + + GetFeatureValues<protobuf_int64>("tag", se.mutable_context())->Add(42); + + EXPECT_TRUE(HasFeature<protobuf_int64>("tag", se.context())); + auto tag_ro = GetFeatureValues<protobuf_int64>("tag", se.context()); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_EQ(42, tag_ro.Get(0)); +} + +TEST(SequenceExampleTest, ReturnsExistingFeatureLists) { + SequenceExample se; + (*se.mutable_feature_lists()->mutable_feature_list())["tag"] + .mutable_feature() + ->Add(); + + auto feature = GetFeatureList("tag", se); + + ASSERT_EQ(1, feature.size()); +} + +TEST(SequenceExampleTest, CreatesNewFeatureLists) { + SequenceExample se; + + GetFeatureList("tag", &se)->Add(); + + EXPECT_EQ(1, se.feature_lists().feature_list().at("tag").feature_size()); +} + +TEST(SequenceExampleTest, CheckFeatureListExistence) { + SequenceExample se; + ASSERT_FALSE(HasFeatureList("tag", se)); + + GetFeatureList("tag", &se)->Add(); + + ASSERT_TRUE(HasFeatureList("tag", se)); +} + +TEST(SequenceExampleTest, AppendFeatureValuesWithInitializerList) { + SequenceExample se; + + AppendFeatureValues({1, 2, 3}, "ids", se.mutable_context()); + AppendFeatureValues({"cam1-0", "cam2-0"}, + GetFeatureList("images", &se)->Add()); + AppendFeatureValues({"cam1-1", "cam2-2"}, + GetFeatureList("images", &se)->Add()); + + EXPECT_EQ(se.DebugString(), + "context {\n" + " feature {\n" + " key: \"ids\"\n" + " value {\n" + " int64_list {\n" + " value: 1\n" + " value: 2\n" + " value: 3\n" + " }\n" + " }\n" + " }\n" + "}\n" + "feature_lists {\n" + " feature_list {\n" + " key: \"images\"\n" + " value {\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-0\"\n" + " value: \"cam2-0\"\n" + " }\n" + " }\n" + " feature {\n" + " bytes_list {\n" + " value: \"cam1-1\"\n" + " value: \"cam2-2\"\n" + " }\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + +TEST(SequenceExampleTest, AppendFeatureValuesWithVectors) { + SequenceExample se; + + std::vector<float> readings{1.0, 2.5, 5.0}; + AppendFeatureValues(readings, GetFeatureList("movie_ratings", &se)->Add()); + + EXPECT_EQ(se.DebugString(), + "feature_lists {\n" + " feature_list {\n" + " key: \"movie_ratings\"\n" + " value {\n" + " feature {\n" + " float_list {\n" + " value: 1\n" + " value: 2.5\n" + " value: 5\n" + " }\n" + " }\n" + " }\n" + " }\n" + "}\n"); +} + } // namespace } // namespace tensorflow |