diff options
-rw-r--r-- | tensorflow/core/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/core/example/feature_util.cc | 93 | ||||
-rw-r--r-- | tensorflow/core/example/feature_util.h | 213 | ||||
-rw-r--r-- | tensorflow/core/example/feature_util_test.cc | 214 |
4 files changed, 525 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index de1fd9f7c3..73f4e91fa8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -15,6 +15,7 @@ # ":framework" - exports the public non-test headers for: # util/: General low-level TensorFlow-specific libraries # framework/: Support for adding new ops & kernels +# example/: Wrappers to simplify access to Example proto # ":ops" - defines TensorFlow ops, but no implementations / kernels # ops/: Standard ops # user_ops/: User-supplied ops @@ -241,6 +242,7 @@ cc_library( tf_cuda_library( name = "framework", hdrs = [ + "example/feature_util.h", "framework/allocator.h", "framework/attr_value_util.h", "framework/bfloat16.h", @@ -805,6 +807,8 @@ tf_cuda_library( name = "framework_internal", srcs = glob( [ + "example/**/*.h", + "example/**/*.cc", "framework/**/*.h", "framework/**/*.cc", "public/version.h", @@ -1209,6 +1213,7 @@ tf_cc_tests( [ "client/**/*_test.cc", "common_runtime/**/*_test.cc", + "example/**/*_test.cc", "framework/**/*_test.cc", "graph/**/*_test.cc", "util/**/*_test.cc", diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc new file mode 100644 index 0000000000..863a56ba40 --- /dev/null +++ b/tensorflow/core/example/feature_util.cc @@ -0,0 +1,93 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/example/feature_util.h" + +namespace tensorflow { + +namespace internal { + +::tensorflow::Feature& ExampleFeature(const string& name, + ::tensorflow::Example* example) { + ::tensorflow::Features* features = example->mutable_features(); + return (*features->mutable_feature())[name]; +} + +} // namespace internal + +template <> +bool ExampleHasFeature<int64>(const string& name, const Example& example) { + auto it = example.features().feature().find(name); + return (it != example.features().feature().end()) && + (it->second.kind_case() == Feature::KindCase::kInt64List); +} + +template <> +bool ExampleHasFeature<float>(const string& name, const Example& example) { + auto it = example.features().feature().find(name); + return (it != example.features().feature().end()) && + (it->second.kind_case() == Feature::KindCase::kFloatList); +} + +template <> +bool ExampleHasFeature<string>(const string& name, const Example& example) { + auto it = example.features().feature().find(name); + return (it != example.features().feature().end()) && + (it->second.kind_case() == Feature::KindCase::kBytesList); +} + +template <> +const protobuf::RepeatedField<int64>& GetFeatureValues<int64>( + const string& name, const Example& example) { + return example.features().feature().at(name).int64_list().value(); +} + +template <> +protobuf::RepeatedField<int64>* GetFeatureValues<int64>(const string& name, + Example* example) { + return internal::ExampleFeature(name, example) + .mutable_int64_list() + ->mutable_value(); +} + +template <> +const protobuf::RepeatedField<float>& GetFeatureValues<float>( + const string& name, const Example& example) { + return example.features().feature().at(name).float_list().value(); +} + +template <> +protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name, + Example* example) { + return internal::ExampleFeature(name, example) + .mutable_float_list() + ->mutable_value(); +} + +template <> +const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>( + const string& name, const Example& example) { + return example.features().feature().at(name).bytes_list().value(); +} + +template <> +protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name, + Example* example) { + return internal::ExampleFeature(name, example) + .mutable_bytes_list() + ->mutable_value(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h new file mode 100644 index 0000000000..972bf4e885 --- /dev/null +++ b/tensorflow/core/example/feature_util.h @@ -0,0 +1,213 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A set of lightweight wrappers which simplify access to Example features. +// +// Tensorflow Example proto uses associative maps on top of oneof fields. +// So accessing feature values is not very convenient. +// +// For example, to read a first value of integer feature "tag": +// int id = example.features().feature().at("tag").int64_list().value(0) +// +// to add a value: +// auto features = example->mutable_features(); +// (*features->mutable_feature())["tag"].mutable_int64_list()->add_value(id) +// +// For float features you have to use float_list, for string - bytes_list. +// +// To do the same with this library: +// int id = GetFeatureValues<int64>("tag", example).Get(0); +// GetFeatureValues<int64>("tag", &example)->Add(id); +// +// Modification of bytes features is slightly different: +// auto tag = GetFeatureValues<string>("tag", example); +// *tag->Add() = "lorem ipsum"; +// +// To copy multiple values into a feature: +// AppendFeatureValues({1,2,3}, "tag", &example); +// +// GetFeatureValues gives you access to underlying data - RepeatedField object +// (RepeatedPtrField for byte list). So refer to its documentation of +// RepeatedField for full list of supported methods. +// +// NOTE: It is also important to mention that due to the nature of oneof proto +// fields setting a feature of one type automatically clears all values stored +// as another type with the same feature name. + +#ifndef TENSORFLOW_EXAMPLE_FEATURE_H_ +#define TENSORFLOW_EXAMPLE_FEATURE_H_ + +#include <iterator> +#include <type_traits> + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +namespace internal { + +// Returns a reference to a feature corresponding to the name. +// Note: it will create a new Feature if it is missing in the example. +::tensorflow::Feature& ExampleFeature(const string& name, + ::tensorflow::Example* example); + +// Specializations of RepeatedFieldTrait define a type of RepeatedField +// corresponding to a selected feature type. +template <typename FeatureType> +struct RepeatedFieldTrait; + +template <> +struct RepeatedFieldTrait<int64> { + using Type = protobuf::RepeatedField<int64>; +}; + +template <> +struct RepeatedFieldTrait<float> { + using Type = protobuf::RepeatedField<float>; +}; + +template <> +struct RepeatedFieldTrait<string> { + using Type = protobuf::RepeatedPtrField<string>; +}; + +// Specializations of FeatureTrait define a type of feature corresponding to a +// selected value type. +template <typename ValueType, class Enable = void> +struct FeatureTrait; + +template <typename ValueType> +struct FeatureTrait<ValueType, typename std::enable_if< + std::is_integral<ValueType>::value>::type> { + using Type = int64; +}; + +template <typename ValueType> +struct FeatureTrait< + ValueType, + typename std::enable_if<std::is_floating_point<ValueType>::value>::type> { + using Type = float; +}; + +template <typename T> +struct is_string + : public std::integral_constant< + bool, + std::is_same<char*, typename std::decay<T>::type>::value || + std::is_same<const char*, typename std::decay<T>::type>::value> { +}; + +template <> +struct is_string<std::string> : std::true_type {}; + +template <> +struct is_string<::tensorflow::StringPiece> : std::true_type {}; + +template <typename ValueType> +struct FeatureTrait< + ValueType, typename std::enable_if<is_string<ValueType>::value>::type> { + using Type = string; +}; + +} // namespace internal + +// Returns true if feature with the specified name belongs to the example proto. +// Doesn't check feature type. Note that specialized versions return false if +// the feature has a wrong type. +template <typename FeatureType = void> +bool ExampleHasFeature(const string& name, const Example& example) { + return example.features().feature().find(name) != + example.features().feature().end(); +} + +// Base declaration of a family of template functions to return a read only +// repeated field corresponding to a feature with the specified name. +template <typename FeatureType> +const typename internal::RepeatedFieldTrait<FeatureType>::Type& +GetFeatureValues(const string& name, const Example& example); + +// Base declaration of a family of template functions to return a mutable +// repeated field corresponding to a feature with the specified name. +template <typename FeatureType> +typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues( + const string& name, Example* example); + +// Copies elements from the range, defined by [first, last) into a feature. +template <typename IteratorType> +void AppendFeatureValues(IteratorType first, IteratorType last, + const string& name, Example* example) { + using FeatureType = typename internal::FeatureTrait< + typename std::iterator_traits<IteratorType>::value_type>::Type; + std::copy(first, last, protobuf::RepeatedFieldBackInserter( + GetFeatureValues<FeatureType>(name, example))); +} + +// Copies all elements from the container into a feature. +template <typename ContainerType> +void AppendFeatureValues(const ContainerType& container, const string& name, + Example* example) { + using IteratorType = typename ContainerType::const_iterator; + AppendFeatureValues<IteratorType>(container.begin(), container.end(), name, + example); +} + +// Copies all elements from the initializer list into a feature. +template <typename ValueType> +void AppendFeatureValues(std::initializer_list<ValueType> container, + const string& name, Example* example) { + using IteratorType = + typename std::initializer_list<ValueType>::const_iterator; + AppendFeatureValues<IteratorType>(container.begin(), container.end(), name, + example); +} + +template <> +bool ExampleHasFeature<int64>(const string& name, const Example& example); + +template <> +bool ExampleHasFeature<float>(const string& name, const Example& example); + +template <> +bool ExampleHasFeature<string>(const string& name, const Example& example); + +template <> +const protobuf::RepeatedField<int64>& GetFeatureValues<int64>( + const string& name, const Example& example); + +template <> +protobuf::RepeatedField<int64>* GetFeatureValues<int64>(const string& name, + Example* example); + +template <> +const protobuf::RepeatedField<float>& GetFeatureValues<float>( + const string& name, const Example& example); + +template <> +protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name, + Example* example); + +template <> +const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>( + const string& name, const Example& example); + +template <> +protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name, + Example* example); + +} // namespace tensorflow +#endif // TENSORFLOW_EXAMPLE_FEATURE_H_ diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc new file mode 100644 index 0000000000..bcee0fb587 --- /dev/null +++ b/tensorflow/core/example/feature_util_test.cc @@ -0,0 +1,214 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/example/feature_util.h" + +#include <vector> + +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +const float kTolerance = 1e-5; + +TEST(GetFeatureValuesInt64Test, ReadsASingleValue) { + Example example; + (*example.mutable_features()->mutable_feature())["tag"] + .mutable_int64_list() + ->add_value(42); + + auto tag = GetFeatureValues<int64>("tag", example); + + ASSERT_EQ(1, tag.size()); + EXPECT_EQ(42, tag.Get(0)); +} + +TEST(GetFeatureValuesInt64Test, WritesASingleValue) { + Example example; + + GetFeatureValues<int64>("tag", &example)->Add(42); + + ASSERT_EQ(1, + example.features().feature().at("tag").int64_list().value_size()); + EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0)); +} + +TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) { + Example example; + + EXPECT_FALSE(ExampleHasFeature("tag", example)); + + GetFeatureValues<int64>("tag", &example)->Add(0); + + EXPECT_TRUE(ExampleHasFeature("tag", example)); +} + +TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) { + Example example; + + GetFeatureValues<float>("tag", &example)->Add(3.14); + ASSERT_FALSE(ExampleHasFeature<int64>("tag", example)); + + GetFeatureValues<int64>("tag", &example)->Add(42); + + EXPECT_TRUE(ExampleHasFeature<int64>("tag", example)); + auto tag_ro = GetFeatureValues<int64>("tag", example); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_EQ(42, tag_ro.Get(0)); +} + +TEST(GetFeatureValuesInt64Test, CopyIterableToAField) { + Example example; + std::vector<int> values{1, 2, 3}; + + std::copy(values.begin(), values.end(), + protobuf::RepeatedFieldBackInserter( + GetFeatureValues<int64>("tag", &example))); + + auto tag_ro = GetFeatureValues<int64>("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_EQ(1, tag_ro.Get(0)); + EXPECT_EQ(2, tag_ro.Get(1)); + EXPECT_EQ(3, tag_ro.Get(2)); +} + +TEST(GetFeatureValuesFloatTest, ReadsASingleValue) { + Example example; + (*example.mutable_features()->mutable_feature())["tag"] + .mutable_float_list() + ->add_value(3.14); + + auto tag = GetFeatureValues<float>("tag", example); + + ASSERT_EQ(1, tag.size()); + EXPECT_NEAR(3.14, tag.Get(0), kTolerance); +} + +TEST(GetFeatureValuesFloatTest, WritesASingleValue) { + Example example; + + GetFeatureValues<float>("tag", &example)->Add(3.14); + + ASSERT_EQ(1, + example.features().feature().at("tag").float_list().value_size()); + EXPECT_NEAR(3.14, + example.features().feature().at("tag").float_list().value(0), + kTolerance); +} + +TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) { + Example example; + + GetFeatureValues<int64>("tag", &example)->Add(42); + ASSERT_FALSE(ExampleHasFeature<float>("tag", example)); + + GetFeatureValues<float>("tag", &example)->Add(3.14); + + EXPECT_TRUE(ExampleHasFeature<float>("tag", example)); + auto tag_ro = GetFeatureValues<float>("tag", example); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance); +} + +TEST(GetFeatureValuesStringTest, ReadsASingleValue) { + Example example; + (*example.mutable_features()->mutable_feature())["tag"] + .mutable_bytes_list() + ->add_value("FOO"); + + auto tag = GetFeatureValues<string>("tag", example); + + ASSERT_EQ(1, tag.size()); + EXPECT_EQ("FOO", tag.Get(0)); +} + +TEST(GetFeatureValuesStringTest, WritesASingleValue) { + Example example; + + *GetFeatureValues<string>("tag", &example)->Add() = "FOO"; + + ASSERT_EQ(1, + example.features().feature().at("tag").bytes_list().value_size()); + EXPECT_EQ("FOO", + example.features().feature().at("tag").bytes_list().value(0)); +} + +TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) { + Example example; + + GetFeatureValues<int64>("tag", &example)->Add(42); + ASSERT_FALSE(ExampleHasFeature<string>("tag", example)); + + *GetFeatureValues<string>("tag", &example)->Add() = "FOO"; + + EXPECT_TRUE(ExampleHasFeature<string>("tag", example)); + auto tag_ro = GetFeatureValues<string>("tag", example); + ASSERT_EQ(1, tag_ro.size()); + EXPECT_EQ("FOO", tag_ro.Get(0)); +} + +TEST(AppendFeatureValuesTest, FloatValuesFromContainer) { + Example example; + + std::vector<double> values{1.1, 2.2, 3.3}; + AppendFeatureValues(values, "tag", &example); + + auto tag_ro = GetFeatureValues<float>("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance); + EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance); + EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance); +} + +TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) { + Example example; + + AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example); + + auto tag_ro = GetFeatureValues<float>("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance); + EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance); + EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance); +} + +TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) { + Example example; + + AppendFeatureValues({1, 2, 3}, "tag", &example); + + auto tag_ro = GetFeatureValues<int64>("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_EQ(1, tag_ro.Get(0)); + EXPECT_EQ(2, tag_ro.Get(1)); + EXPECT_EQ(3, tag_ro.Get(2)); +} + +TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) { + Example example; + + AppendFeatureValues({"FOO", "BAR", "BAZ"}, "tag", &example); + + auto tag_ro = GetFeatureValues<string>("tag", example); + ASSERT_EQ(3, tag_ro.size()); + EXPECT_EQ("FOO", tag_ro.Get(0)); + EXPECT_EQ("BAR", tag_ro.Get(1)); + EXPECT_EQ("BAZ", tag_ro.Get(2)); +} + +} // namespace +} // namespace tensorflow |