aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-06 01:58:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-06 03:01:31 -0700
commit357b19ee7f39d46b32639847e260dac2312b3da5 (patch)
tree48b52ef4992183a39914ed30be16e69efddd3b3c
parentceb04a1b806563615d68d4617822f057405b3bfb (diff)
A set of lightweight wrappers to simplify access to Example proto fields in C++.
Change: 121659770
-rw-r--r--tensorflow/core/BUILD5
-rw-r--r--tensorflow/core/example/feature_util.cc93
-rw-r--r--tensorflow/core/example/feature_util.h213
-rw-r--r--tensorflow/core/example/feature_util_test.cc214
4 files changed, 525 insertions, 0 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index de1fd9f7c3..73f4e91fa8 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -15,6 +15,7 @@
# ":framework" - exports the public non-test headers for:
# util/: General low-level TensorFlow-specific libraries
# framework/: Support for adding new ops & kernels
+# example/: Wrappers to simplify access to Example proto
# ":ops" - defines TensorFlow ops, but no implementations / kernels
# ops/: Standard ops
# user_ops/: User-supplied ops
@@ -241,6 +242,7 @@ cc_library(
tf_cuda_library(
name = "framework",
hdrs = [
+ "example/feature_util.h",
"framework/allocator.h",
"framework/attr_value_util.h",
"framework/bfloat16.h",
@@ -805,6 +807,8 @@ tf_cuda_library(
name = "framework_internal",
srcs = glob(
[
+ "example/**/*.h",
+ "example/**/*.cc",
"framework/**/*.h",
"framework/**/*.cc",
"public/version.h",
@@ -1209,6 +1213,7 @@ tf_cc_tests(
[
"client/**/*_test.cc",
"common_runtime/**/*_test.cc",
+ "example/**/*_test.cc",
"framework/**/*_test.cc",
"graph/**/*_test.cc",
"util/**/*_test.cc",
diff --git a/tensorflow/core/example/feature_util.cc b/tensorflow/core/example/feature_util.cc
new file mode 100644
index 0000000000..863a56ba40
--- /dev/null
+++ b/tensorflow/core/example/feature_util.cc
@@ -0,0 +1,93 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/example/feature_util.h"
+
+namespace tensorflow {
+
+namespace internal {
+
+::tensorflow::Feature& ExampleFeature(const string& name,
+ ::tensorflow::Example* example) {
+ ::tensorflow::Features* features = example->mutable_features();
+ return (*features->mutable_feature())[name];
+}
+
+} // namespace internal
+
+template <>
+bool ExampleHasFeature<int64>(const string& name, const Example& example) {
+ auto it = example.features().feature().find(name);
+ return (it != example.features().feature().end()) &&
+ (it->second.kind_case() == Feature::KindCase::kInt64List);
+}
+
+template <>
+bool ExampleHasFeature<float>(const string& name, const Example& example) {
+ auto it = example.features().feature().find(name);
+ return (it != example.features().feature().end()) &&
+ (it->second.kind_case() == Feature::KindCase::kFloatList);
+}
+
+template <>
+bool ExampleHasFeature<string>(const string& name, const Example& example) {
+ auto it = example.features().feature().find(name);
+ return (it != example.features().feature().end()) &&
+ (it->second.kind_case() == Feature::KindCase::kBytesList);
+}
+
+template <>
+const protobuf::RepeatedField<int64>& GetFeatureValues<int64>(
+ const string& name, const Example& example) {
+ return example.features().feature().at(name).int64_list().value();
+}
+
+template <>
+protobuf::RepeatedField<int64>* GetFeatureValues<int64>(const string& name,
+ Example* example) {
+ return internal::ExampleFeature(name, example)
+ .mutable_int64_list()
+ ->mutable_value();
+}
+
+template <>
+const protobuf::RepeatedField<float>& GetFeatureValues<float>(
+ const string& name, const Example& example) {
+ return example.features().feature().at(name).float_list().value();
+}
+
+template <>
+protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name,
+ Example* example) {
+ return internal::ExampleFeature(name, example)
+ .mutable_float_list()
+ ->mutable_value();
+}
+
+template <>
+const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
+ const string& name, const Example& example) {
+ return example.features().feature().at(name).bytes_list().value();
+}
+
+template <>
+protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name,
+ Example* example) {
+ return internal::ExampleFeature(name, example)
+ .mutable_bytes_list()
+ ->mutable_value();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
new file mode 100644
index 0000000000..972bf4e885
--- /dev/null
+++ b/tensorflow/core/example/feature_util.h
@@ -0,0 +1,213 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A set of lightweight wrappers which simplify access to Example features.
+//
+// Tensorflow Example proto uses associative maps on top of oneof fields.
+// So accessing feature values is not very convenient.
+//
+// For example, to read a first value of integer feature "tag":
+// int id = example.features().feature().at("tag").int64_list().value(0)
+//
+// to add a value:
+// auto features = example->mutable_features();
+// (*features->mutable_feature())["tag"].mutable_int64_list()->add_value(id)
+//
+// For float features you have to use float_list, for string - bytes_list.
+//
+// To do the same with this library:
+// int id = GetFeatureValues<int64>("tag", example).Get(0);
+// GetFeatureValues<int64>("tag", &example)->Add(id);
+//
+// Modification of bytes features is slightly different:
+// auto tag = GetFeatureValues<string>("tag", example);
+// *tag->Add() = "lorem ipsum";
+//
+// To copy multiple values into a feature:
+// AppendFeatureValues({1,2,3}, "tag", &example);
+//
+// GetFeatureValues gives you access to underlying data - RepeatedField object
+// (RepeatedPtrField for byte list). So refer to its documentation of
+// RepeatedField for full list of supported methods.
+//
+// NOTE: It is also important to mention that due to the nature of oneof proto
+// fields setting a feature of one type automatically clears all values stored
+// as another type with the same feature name.
+
+#ifndef TENSORFLOW_EXAMPLE_FEATURE_H_
+#define TENSORFLOW_EXAMPLE_FEATURE_H_
+
+#include <iterator>
+#include <type_traits>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/example/feature.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+namespace internal {
+
+// Returns a reference to a feature corresponding to the name.
+// Note: it will create a new Feature if it is missing in the example.
+::tensorflow::Feature& ExampleFeature(const string& name,
+ ::tensorflow::Example* example);
+
+// Specializations of RepeatedFieldTrait define a type of RepeatedField
+// corresponding to a selected feature type.
+template <typename FeatureType>
+struct RepeatedFieldTrait;
+
+template <>
+struct RepeatedFieldTrait<int64> {
+ using Type = protobuf::RepeatedField<int64>;
+};
+
+template <>
+struct RepeatedFieldTrait<float> {
+ using Type = protobuf::RepeatedField<float>;
+};
+
+template <>
+struct RepeatedFieldTrait<string> {
+ using Type = protobuf::RepeatedPtrField<string>;
+};
+
+// Specializations of FeatureTrait define a type of feature corresponding to a
+// selected value type.
+template <typename ValueType, class Enable = void>
+struct FeatureTrait;
+
+template <typename ValueType>
+struct FeatureTrait<ValueType, typename std::enable_if<
+ std::is_integral<ValueType>::value>::type> {
+ using Type = int64;
+};
+
+template <typename ValueType>
+struct FeatureTrait<
+ ValueType,
+ typename std::enable_if<std::is_floating_point<ValueType>::value>::type> {
+ using Type = float;
+};
+
+template <typename T>
+struct is_string
+ : public std::integral_constant<
+ bool,
+ std::is_same<char*, typename std::decay<T>::type>::value ||
+ std::is_same<const char*, typename std::decay<T>::type>::value> {
+};
+
+template <>
+struct is_string<std::string> : std::true_type {};
+
+template <>
+struct is_string<::tensorflow::StringPiece> : std::true_type {};
+
+template <typename ValueType>
+struct FeatureTrait<
+ ValueType, typename std::enable_if<is_string<ValueType>::value>::type> {
+ using Type = string;
+};
+
+} // namespace internal
+
+// Returns true if feature with the specified name belongs to the example proto.
+// Doesn't check feature type. Note that specialized versions return false if
+// the feature has a wrong type.
+template <typename FeatureType = void>
+bool ExampleHasFeature(const string& name, const Example& example) {
+ return example.features().feature().find(name) !=
+ example.features().feature().end();
+}
+
+// Base declaration of a family of template functions to return a read only
+// repeated field corresponding to a feature with the specified name.
+template <typename FeatureType>
+const typename internal::RepeatedFieldTrait<FeatureType>::Type&
+GetFeatureValues(const string& name, const Example& example);
+
+// Base declaration of a family of template functions to return a mutable
+// repeated field corresponding to a feature with the specified name.
+template <typename FeatureType>
+typename internal::RepeatedFieldTrait<FeatureType>::Type* GetFeatureValues(
+ const string& name, Example* example);
+
+// Copies elements from the range, defined by [first, last) into a feature.
+template <typename IteratorType>
+void AppendFeatureValues(IteratorType first, IteratorType last,
+ const string& name, Example* example) {
+ using FeatureType = typename internal::FeatureTrait<
+ typename std::iterator_traits<IteratorType>::value_type>::Type;
+ std::copy(first, last, protobuf::RepeatedFieldBackInserter(
+ GetFeatureValues<FeatureType>(name, example)));
+}
+
+// Copies all elements from the container into a feature.
+template <typename ContainerType>
+void AppendFeatureValues(const ContainerType& container, const string& name,
+ Example* example) {
+ using IteratorType = typename ContainerType::const_iterator;
+ AppendFeatureValues<IteratorType>(container.begin(), container.end(), name,
+ example);
+}
+
+// Copies all elements from the initializer list into a feature.
+template <typename ValueType>
+void AppendFeatureValues(std::initializer_list<ValueType> container,
+ const string& name, Example* example) {
+ using IteratorType =
+ typename std::initializer_list<ValueType>::const_iterator;
+ AppendFeatureValues<IteratorType>(container.begin(), container.end(), name,
+ example);
+}
+
+template <>
+bool ExampleHasFeature<int64>(const string& name, const Example& example);
+
+template <>
+bool ExampleHasFeature<float>(const string& name, const Example& example);
+
+template <>
+bool ExampleHasFeature<string>(const string& name, const Example& example);
+
+template <>
+const protobuf::RepeatedField<int64>& GetFeatureValues<int64>(
+ const string& name, const Example& example);
+
+template <>
+protobuf::RepeatedField<int64>* GetFeatureValues<int64>(const string& name,
+ Example* example);
+
+template <>
+const protobuf::RepeatedField<float>& GetFeatureValues<float>(
+ const string& name, const Example& example);
+
+template <>
+protobuf::RepeatedField<float>* GetFeatureValues<float>(const string& name,
+ Example* example);
+
+template <>
+const protobuf::RepeatedPtrField<string>& GetFeatureValues<string>(
+ const string& name, const Example& example);
+
+template <>
+protobuf::RepeatedPtrField<string>* GetFeatureValues<string>(const string& name,
+ Example* example);
+
+} // namespace tensorflow
+#endif // TENSORFLOW_EXAMPLE_FEATURE_H_
diff --git a/tensorflow/core/example/feature_util_test.cc b/tensorflow/core/example/feature_util_test.cc
new file mode 100644
index 0000000000..bcee0fb587
--- /dev/null
+++ b/tensorflow/core/example/feature_util_test.cc
@@ -0,0 +1,214 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/example/feature_util.h"
+
+#include <vector>
+
+#include "tensorflow/core/example/example.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+const float kTolerance = 1e-5;
+
+TEST(GetFeatureValuesInt64Test, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_int64_list()
+ ->add_value(42);
+
+ auto tag = GetFeatureValues<int64>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_EQ(42, tag.Get(0));
+}
+
+TEST(GetFeatureValuesInt64Test, WritesASingleValue) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").int64_list().value_size());
+ EXPECT_EQ(42, example.features().feature().at("tag").int64_list().value(0));
+}
+
+TEST(GetFeatureValuesInt64Test, CheckUntypedFieldExistence) {
+ Example example;
+
+ EXPECT_FALSE(ExampleHasFeature("tag", example));
+
+ GetFeatureValues<int64>("tag", &example)->Add(0);
+
+ EXPECT_TRUE(ExampleHasFeature("tag", example));
+}
+
+TEST(GetFeatureValuesInt64Test, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+ ASSERT_FALSE(ExampleHasFeature<int64>("tag", example));
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+
+ EXPECT_TRUE(ExampleHasFeature<int64>("tag", example));
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_EQ(42, tag_ro.Get(0));
+}
+
+TEST(GetFeatureValuesInt64Test, CopyIterableToAField) {
+ Example example;
+ std::vector<int> values{1, 2, 3};
+
+ std::copy(values.begin(), values.end(),
+ protobuf::RepeatedFieldBackInserter(
+ GetFeatureValues<int64>("tag", &example)));
+
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ(1, tag_ro.Get(0));
+ EXPECT_EQ(2, tag_ro.Get(1));
+ EXPECT_EQ(3, tag_ro.Get(2));
+}
+
+TEST(GetFeatureValuesFloatTest, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_float_list()
+ ->add_value(3.14);
+
+ auto tag = GetFeatureValues<float>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_NEAR(3.14, tag.Get(0), kTolerance);
+}
+
+TEST(GetFeatureValuesFloatTest, WritesASingleValue) {
+ Example example;
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").float_list().value_size());
+ EXPECT_NEAR(3.14,
+ example.features().feature().at("tag").float_list().value(0),
+ kTolerance);
+}
+
+TEST(GetFeatureValuesFloatTest, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+ ASSERT_FALSE(ExampleHasFeature<float>("tag", example));
+
+ GetFeatureValues<float>("tag", &example)->Add(3.14);
+
+ EXPECT_TRUE(ExampleHasFeature<float>("tag", example));
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_NEAR(3.14, tag_ro.Get(0), kTolerance);
+}
+
+TEST(GetFeatureValuesStringTest, ReadsASingleValue) {
+ Example example;
+ (*example.mutable_features()->mutable_feature())["tag"]
+ .mutable_bytes_list()
+ ->add_value("FOO");
+
+ auto tag = GetFeatureValues<string>("tag", example);
+
+ ASSERT_EQ(1, tag.size());
+ EXPECT_EQ("FOO", tag.Get(0));
+}
+
+TEST(GetFeatureValuesStringTest, WritesASingleValue) {
+ Example example;
+
+ *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
+
+ ASSERT_EQ(1,
+ example.features().feature().at("tag").bytes_list().value_size());
+ EXPECT_EQ("FOO",
+ example.features().feature().at("tag").bytes_list().value(0));
+}
+
+TEST(GetFeatureValuesBytesTest, CheckTypedFieldExistence) {
+ Example example;
+
+ GetFeatureValues<int64>("tag", &example)->Add(42);
+ ASSERT_FALSE(ExampleHasFeature<string>("tag", example));
+
+ *GetFeatureValues<string>("tag", &example)->Add() = "FOO";
+
+ EXPECT_TRUE(ExampleHasFeature<string>("tag", example));
+ auto tag_ro = GetFeatureValues<string>("tag", example);
+ ASSERT_EQ(1, tag_ro.size());
+ EXPECT_EQ("FOO", tag_ro.Get(0));
+}
+
+TEST(AppendFeatureValuesTest, FloatValuesFromContainer) {
+ Example example;
+
+ std::vector<double> values{1.1, 2.2, 3.3};
+ AppendFeatureValues(values, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
+ EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
+ EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
+}
+
+TEST(AppendFeatureValuesTest, FloatValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({1.1, 2.2, 3.3}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<float>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_NEAR(1.1, tag_ro.Get(0), kTolerance);
+ EXPECT_NEAR(2.2, tag_ro.Get(1), kTolerance);
+ EXPECT_NEAR(3.3, tag_ro.Get(2), kTolerance);
+}
+
+TEST(AppendFeatureValuesTest, Int64ValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({1, 2, 3}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<int64>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ(1, tag_ro.Get(0));
+ EXPECT_EQ(2, tag_ro.Get(1));
+ EXPECT_EQ(3, tag_ro.Get(2));
+}
+
+TEST(AppendFeatureValuesTest, StringValuesUsingInitializerList) {
+ Example example;
+
+ AppendFeatureValues({"FOO", "BAR", "BAZ"}, "tag", &example);
+
+ auto tag_ro = GetFeatureValues<string>("tag", example);
+ ASSERT_EQ(3, tag_ro.size());
+ EXPECT_EQ("FOO", tag_ro.Get(0));
+ EXPECT_EQ("BAR", tag_ro.Get(1));
+ EXPECT_EQ("BAZ", tag_ro.Get(2));
+}
+
+} // namespace
+} // namespace tensorflow