diff options
author | Ben Lee <blee@google.com> | 2016-06-28 17:32:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-06-28 18:48:04 -0700 |
commit | ae41b85dc12cbd5edd046e3ed0d8d3ef0f63f885 (patch) | |
tree | ccd9c4e130419c4b6906a3650551c534ee8e9266 /tensorflow/core/example | |
parent | 41e0d7b3e51c7dd78cf0518191cb9750ed8131c4 (diff) |
Python tensorflow.Example parser configuration extractor
- Proto definition for configuration
- Utility for converting from proto
- Visibility change
Change: 126145305
Diffstat (limited to 'tensorflow/core/example')
4 files changed, 161 insertions, 0 deletions
diff --git a/tensorflow/core/example/example_parser_configuration.cc b/tensorflow/core/example/example_parser_configuration.cc index d730e80cdc..55b2b03c83 100644 --- a/tensorflow/core/example/example_parser_configuration.cc +++ b/tensorflow/core/example/example_parser_configuration.cc @@ -157,4 +157,42 @@ Status ExtractExampleParserConfiguration( return Status::OK(); } +Status ExampleParserConfigurationProtoToFeatureVectors( + const ExampleParserConfiguration& config_proto, + std::vector<FixedLenFeature>* fixed_len_features, + std::vector<VarLenFeature>* var_len_features) { + const auto& feature_map = config_proto.feature_map(); + for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) { + string key = it->first; + const auto& config = it->second; + if (config.has_fixed_len_feature()) { + const auto& fixed_config = config.fixed_len_feature(); + FixedLenFeature f; + f.key = key; + f.dtype = fixed_config.dtype(); + f.shape = TensorShape(fixed_config.shape()); + Tensor default_value(f.dtype, f.shape); + if (!default_value.FromProto(fixed_config.default_value())) { + return errors::InvalidArgument( + "Invalid default_value in config proto ", + fixed_config.default_value().DebugString()); + } + f.default_value = default_value; + f.values_output_tensor_name = fixed_config.values_output_tensor_name(); + fixed_len_features->push_back(f); + } else { + const auto& var_len_config = config.var_len_feature(); + VarLenFeature v; + v.key = key; + v.dtype = var_len_config.dtype(); + v.values_output_tensor_name = var_len_config.values_output_tensor_name(); + v.indices_output_tensor_name = + var_len_config.indices_output_tensor_name(); + v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name(); + var_len_features->push_back(v); + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h index 01afff6096..69955ec4cb 100644 --- a/tensorflow/core/example/example_parser_configuration.h +++ b/tensorflow/core/example/example_parser_configuration.h @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/example_parser_configuration.pb.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -42,6 +43,14 @@ Status ExtractExampleParserConfiguration( std::vector<FixedLenFeature>* fixed_len_features, std::vector<VarLenFeature>* var_len_features); +// Given a config proto, ostensibly extracted via python, +// fill a vector of C++ structs suitable for calling +// the tensorflow.Example -> Tensor conversion code. +Status ExampleParserConfigurationProtoToFeatureVectors( + const ExampleParserConfiguration& config_proto, + std::vector<FixedLenFeature>* fixed_len_features, + std::vector<VarLenFeature>* var_len_features); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_ diff --git a/tensorflow/core/example/example_parser_configuration.proto b/tensorflow/core/example/example_parser_configuration.proto new file mode 100644 index 0000000000..852151dc93 --- /dev/null +++ b/tensorflow/core/example/example_parser_configuration.proto @@ -0,0 +1,39 @@ +// Protocol messages for describing the configuration of the ExampleParserOp. + +syntax = "proto3"; + +// option cc_enable_arenas = true; +option java_outer_classname = "ExampleParserConfigurationProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.example"; +option java_generate_equals_and_hash = true; +package tensorflow; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/types.proto"; + +message VarLenFeatureProto { + tensorflow.DataType dtype = 1; + string values_output_tensor_name = 2; + string indices_output_tensor_name = 3; + string shapes_output_tensor_name = 4; +}; + +message FixedLenFeatureProto { + tensorflow.DataType dtype = 1; + tensorflow.TensorShapeProto shape = 2; + tensorflow.TensorProto default_value = 3; + string values_output_tensor_name = 4; +}; + +message FeatureConfiguration { + oneof config { + FixedLenFeatureProto fixed_len_feature = 1; + VarLenFeatureProto var_len_feature = 2; + } +}; + +message ExampleParserConfiguration { + map<string, FeatureConfiguration> feature_map = 1; +}; diff --git a/tensorflow/core/example/example_parser_configuration_test.cc b/tensorflow/core/example/example_parser_configuration_test.cc index 8de410b1b8..0fa772bd6b 100644 --- a/tensorflow/core/example/example_parser_configuration_test.cc +++ b/tensorflow/core/example/example_parser_configuration_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/example/example_parser_configuration.h" #include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -144,5 +145,79 @@ TEST_F(ExtractExampleParserConfigurationTest, Basic) { dense_vec[2].values_output_tensor_name); } +static const char kExampleParseConfigurationProto[] = R"( feature_map { + key: "x" + value { + fixed_len_feature { + dtype: DT_FLOAT + shape { + dim { + size: 1 + } + } + default_value { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 33.0 + } + values_output_tensor_name: "ParseExample/ParseExample:3" + } + } +} +feature_map { + key: "y" + value { + var_len_feature { + dtype: DT_STRING + values_output_tensor_name: "ParseExample/ParseExample:1" + indices_output_tensor_name: "ParseExample/ParseExample:0" + shapes_output_tensor_name: "ParseExample/ParseExample:2" + } + } +} +)"; + +class ExampleParserConfigurationProtoToFeatureVectorsTest + : public ::testing::Test { + protected: + void SetUp() override { + CHECK(protobuf::TextFormat::ParseFromString(kExampleParseConfigurationProto, + &config_proto_)); + } + ExampleParserConfiguration config_proto_; +}; + +TEST_F(ExampleParserConfigurationProtoToFeatureVectorsTest, Basic) { + std::vector<FixedLenFeature> fixed_len_features; + std::vector<VarLenFeature> var_len_features; + ExampleParserConfigurationProtoToFeatureVectors( + config_proto_, &fixed_len_features, &var_len_features); + ASSERT_EQ(1, fixed_len_features.size()); + ASSERT_EQ(1, var_len_features.size()); + + const FixedLenFeature& f = fixed_len_features[0]; + ASSERT_EQ(DT_FLOAT, f.dtype); + ASSERT_EQ("x", f.key); + ASSERT_EQ("ParseExample/ParseExample:3", f.values_output_tensor_name); + + TensorShape expected_shape({1}); + ASSERT_EQ(expected_shape.dims(), f.shape.dims()); + ASSERT_EQ(1, f.shape.dim_size(0)); + + Tensor expected_default(DT_FLOAT, TensorShape({1})); + test::FillIota<float>(&expected_default, 33.0); + test::ExpectTensorEqual<float>(expected_default, f.default_value); + + const VarLenFeature& v = var_len_features[0]; + ASSERT_EQ(DT_STRING, v.dtype); + ASSERT_EQ("ParseExample/ParseExample:0", v.indices_output_tensor_name); + ASSERT_EQ("ParseExample/ParseExample:1", v.values_output_tensor_name); + ASSERT_EQ("ParseExample/ParseExample:2", v.shapes_output_tensor_name); +} + } // namespace } // namespace tensorflow |