diff options
author | Patrik Sundberg <sundberg@google.com> | 2018-08-31 07:28:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 07:32:19 -0700 |
commit | d09bd33fc9175c7fdbdded8f98b5c5d3a9f8ad7d (patch) | |
tree | bb869bee40e9833299b2662d2076547c85d5df39 /tensorflow/core/util | |
parent | 1251e8c6cee24d9a295c63fa2362b03deef4396b (diff) |
Add a batch sequence example parsing op, part 2.
PiperOrigin-RevId: 211082479
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 228 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.h | 3 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_helper.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_helper.h | 61 |
4 files changed, 268 insertions, 77 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index a38cd1d09f..e52d55e2ff 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -1722,10 +1722,11 @@ Status FastParseSequenceExample( const FastParseExampleConfig& feature_list_config, gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, thread::ThreadPool* thread_pool, Result* context_result, - Result* feature_list_result) { + Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) { int num_examples = serialized.size(); DCHECK(context_result != nullptr); DCHECK(feature_list_result != nullptr); + DCHECK(dense_feature_lengths != nullptr); std::map<StringPiece, bool> context_is_sparse; std::map<StringPiece, std::pair<DataType, size_t>> context_feature_type_and_lengths; @@ -1740,9 +1741,22 @@ Status FastParseSequenceExample( context_is_sparse[c.feature_name] = true; } for (auto& c : context_config.dense) { + if (context_is_sparse[c.feature_name]) { + return errors::InvalidArgument("Context feature " + c.feature_name + + " cannot be both dense and sparse"); + } TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); context_feature_type_and_lengths[c.feature_name] = - std::make_pair(c.dtype, 0); + std::make_pair(c.dtype, c.default_value.NumElements()); + if (c.default_value.NumElements() > 0) { + if (!c.shape.IsCompatibleWith(c.default_value.shape())) { + return errors::InvalidArgument("Default value for context feature ", + c.feature_name, + " has an incorrect shape: saw ", + c.default_value.shape().DebugString(), + " but expected ", c.shape.DebugString()); + } + } context_is_sparse[c.feature_name] = false; } std::map<StringPiece, bool> sequence_is_sparse; @@ -1755,6 +1769,10 @@ Status FastParseSequenceExample( sequence_is_sparse[c.feature_name] = true; } for (auto& c : feature_list_config.dense) { + if (sequence_is_sparse[c.feature_name]) { + return errors::InvalidArgument("Sequence feature " + c.feature_name + + " cannot be both dense and sparse"); + } TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); sequence_feature_type_and_lengths[c.feature_name] = std::make_pair(c.dtype, 0); @@ -1792,14 +1810,14 @@ Status FastParseSequenceExample( features = sequence_features; config = &sequence_feature_type_and_lengths; } else if (!SkipExtraneousTag(&stream)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } if (features != nullptr) { uint32 length; if (!stream.ReadVarint32(&length)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } auto limit = stream.PushLimit(length); while (!stream.ExpectAtEnd()) { @@ -1807,16 +1825,16 @@ Status FastParseSequenceExample( uint32 length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&length)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } auto limit = stream.PushLimit(length); if (!stream.ExpectTag(kDelimitedTag(1)) || !ParseString(&stream, &key) || !stream.ExpectTag(kDelimitedTag(2)) || !ParseString(&stream, &value) || !stream.ExpectAtEnd()) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } stream.PopLimit(limit); // Only save if this feature was requested. @@ -1851,9 +1869,8 @@ Status FastParseSequenceExample( break; } if (num == -1) { - return errors::InvalidArgument( - strings::StrCat("Error in context feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in context feature ", c.first, + " in example ", example_name); } num_elements += num; } @@ -1876,9 +1893,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } if (feature_length > 2) { auto limit = stream.PushLimit(feature_length); @@ -1898,22 +1915,22 @@ Status FastParseSequenceExample( break; } if (num == -1) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } num_elements += num; stream.PopLimit(limit); } else if (feature_length == 2) { if (!SkipEmptyFeature(&stream, dtype)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } } else if (feature_length != 0) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } } } @@ -1936,15 +1953,19 @@ Status FastParseSequenceExample( feature_list_result->sparse_indices.resize(feature_list_config.sparse.size()); feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size()); feature_list_result->dense_values.resize(feature_list_config.dense.size()); + dense_feature_lengths->resize(feature_list_config.dense.size()); + int t = 0; for (const auto& c : context_config.dense) { - TensorShape dense_shape; + TensorShape dense_shape, example_shape; DataType dtype = c.dtype; - size_t expected_max_elements = + const size_t expected_max_elements = context_feature_type_and_lengths[c.feature_name].second; - if (expected_max_elements != dense_shape.num_elements()) { - return errors::InvalidArgument(strings::StrCat( - "Inconsistent number of elements for feature ", c.feature_name)); + if (!c.shape.AsTensorShape(&example_shape) || + expected_max_elements != example_shape.num_elements()) { + return errors::InvalidArgument( + "Inconsistent number of elements for feature ", c.feature_name, ": ", + expected_max_elements, " vs ", dense_shape.num_elements()); } dense_shape.AddDim(num_examples); for (const int dim : c.shape.dim_sizes()) { @@ -1968,18 +1989,58 @@ Status FastParseSequenceExample( out_int64 = context_result->dense_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } t++; // Fill in the values. for (int e = 0; e < num_examples; e++) { size_t num_elements = 0; - const auto& feature = all_context_features[e][c.feature_name]; + const auto feature_iter = all_context_features[e].find(c.feature_name); const string& example_name = example_names.empty() ? kUnknown : example_names[e]; - if (!feature.empty()) { + if (feature_iter == all_context_features[e].end()) { + // Copy the default value, if present. If not, return an error. + if (c.default_value.NumElements() == 0) { + return errors::InvalidArgument( + "Feature: ", c.feature_name, + " (data type: ", DataTypeString(c.dtype), ")", + " is required but could not be found."); + } + const string* in_bytes = nullptr; + const float* in_float = nullptr; + const int64* in_int64 = nullptr; + size_t num = 0; + switch (dtype) { + case DT_STRING: + in_bytes = c.default_value.flat<string>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_bytes++ = *in_bytes++; + } + break; + case DT_FLOAT: + in_float = c.default_value.flat<float>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_float++ = *in_float++; + } + break; + case DT_INT64: + in_int64 = c.default_value.flat<int64>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_int64++ = *in_int64++; + } + break; + default: + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); + } + num_elements += num; + } else if (!feature_iter->second.empty()) { + const auto& feature = feature_iter->second; protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(feature.data()), feature.size()); EnableAliasing(&stream); @@ -1998,14 +2059,14 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; } if (num_elements != expected_max_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected number of elements in example ", example_name)); + return errors::InvalidArgument( + "Unexpected number of elements in example ", example_name); } } } @@ -2037,8 +2098,8 @@ Status FastParseSequenceExample( out_int64 = context_result->sparse_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } int64* out_indices = context_result->sparse_indices[t].flat<int64>().data(); auto out_shape = context_result->sparse_shapes[t].vec<int64>(); @@ -2070,8 +2131,8 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; max_num_cols = std::max(max_num_cols, num_added); @@ -2082,30 +2143,35 @@ Status FastParseSequenceExample( } } if (num_elements != expected_num_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected total number of elements in feature ", c.feature_name)); + return errors::InvalidArgument( + "Unexpected total number of elements in feature ", c.feature_name); } out_shape(0) = num_examples; out_shape(1) = max_num_cols; } t = 0; + TensorShape dense_length_shape({num_examples}); for (const auto& c : feature_list_config.dense) { TensorShape dense_shape, row_shape; DataType dtype = c.dtype; - size_t expected_max_elements = + const size_t expected_max_elements = sequence_feature_type_and_lengths[c.feature_name].second; - int64 expected_max_rows = expected_max_elements / row_shape.num_elements(); if (!c.shape.AsTensorShape(&row_shape) || - expected_max_elements != expected_max_rows * row_shape.num_elements()) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected shape error in feature ", c.feature_name)); + expected_max_elements != + (expected_max_elements / row_shape.num_elements()) * + row_shape.num_elements()) { + return errors::InvalidArgument("Unexpected shape error in feature ", + c.feature_name); } + int64 expected_max_rows = expected_max_elements / row_shape.num_elements(); dense_shape.AddDim(num_examples); dense_shape.AddDim(expected_max_rows); for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) { dense_shape.AddDim(dim); } feature_list_result->dense_values[t] = Tensor(dtype, dense_shape); + (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape); + int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data(); string* out_bytes = nullptr; float* out_float = nullptr; @@ -2121,18 +2187,26 @@ Status FastParseSequenceExample( out_int64 = feature_list_result->dense_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } t++; // Fill in the values. for (int e = 0; e < num_examples; e++) { - size_t num_elements = 0; - const auto& feature = all_sequence_features[e][c.feature_name]; + size_t num_elements = 0, num_rows = 0; + const auto feature_iter = all_sequence_features[e].find(c.feature_name); const string& example_name = example_names.empty() ? kUnknown : example_names[e]; - if (!feature.empty()) { + if (feature_iter == all_sequence_features[e].end()) { + // Return an error if this feature was not allowed to be missing. + // Otherwise, we'll pad as needed below. + if (!c.variable_length) { + return errors::InvalidArgument("Missing feature ", c.feature_name, + " in example ", example_name); + } + } else if (!feature_iter->second.empty()) { + const auto& feature = feature_iter->second; protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(feature.data()), feature.size()); EnableAliasing(&stream); @@ -2140,9 +2214,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } auto limit = stream.PushLimit(feature_length); size_t num_added; @@ -2160,10 +2234,11 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; + num_rows++; if (num_added != row_shape.num_elements()) { return errors::InvalidArgument( "Unexpected number of elements in feature ", c.feature_name, @@ -2172,6 +2247,7 @@ Status FastParseSequenceExample( stream.PopLimit(limit); } } + *out_lengths++ = num_rows; // Pad as necessary. int num_to_pad = expected_max_elements - num_elements; switch (dtype) { @@ -2187,8 +2263,8 @@ Status FastParseSequenceExample( out_int64 += num_to_pad; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } } } @@ -2219,8 +2295,8 @@ Status FastParseSequenceExample( out_int64 = feature_list_result->sparse_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } int64* out_indices = feature_list_result->sparse_indices[t].flat<int64>().data(); @@ -2244,9 +2320,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } if (feature_length > 2) { auto limit = stream.PushLimit(feature_length); @@ -2265,8 +2341,8 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; max_num_cols = std::max(max_num_cols, num_added); @@ -2278,14 +2354,14 @@ Status FastParseSequenceExample( stream.PopLimit(limit); } else if (feature_length == 2) { if (!SkipEmptyFeature(&stream, dtype)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } } else if (feature_length != 0) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } num_rows++; } @@ -2293,8 +2369,8 @@ Status FastParseSequenceExample( } } if (num_elements != expected_num_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected number of elements in feature ", c.feature_name)); + return errors::InvalidArgument( + "Unexpected number of elements in feature ", c.feature_name); } out_shape(0) = num_examples; out_shape(1) = max_num_rows; diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index db5b5ff929..055d9c2c30 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -118,7 +118,8 @@ Status FastParseSequenceExample( const example::FastParseExampleConfig& feature_list_config, gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, thread::ThreadPool* thread_pool, example::Result* context_result, - example::Result* feature_list_result); + example::Result* feature_list_result, + std::vector<Tensor>* dense_feature_lengths); // This function parses serialized Example and populates given example. // It uses the same specialized parser as FastParseExample which is efficient. diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc index e156a3bc8f..41fb20c00a 100644 --- a/tensorflow/core/util/example_proto_helper.cc +++ b/tensorflow/core/util/example_proto_helper.cc @@ -443,6 +443,59 @@ Status ParseSingleExampleAttrs::FinishInit() { return Status::OK(); } +Status ParseSequenceExampleAttrs::FinishInit() { + if (num_context_sparse != context_sparse_keys.size() || + num_context_sparse != context_sparse_types.size()) { + return errors::InvalidArgument( + "num_context_sparse (", num_context_sparse, + ") must match the size of context_sparse_keys (", + context_sparse_keys.size(), ") and context_sparse_types (", + context_sparse_types.size(), ")"); + } + if (num_context_dense != context_dense_keys.size() || + num_context_dense != context_dense_types.size() || + num_context_dense != context_dense_shapes.size()) { + return errors::InvalidArgument( + "num_context_dense (", num_context_dense, + ") must match the size of context_dense_keys (", + context_dense_keys.size(), "), context_dense_types (", + context_dense_types.size(), ") and context_dense_shapes (", + context_dense_shapes.size(), ")"); + } + if (num_feature_list_sparse != feature_list_sparse_keys.size() || + num_feature_list_sparse != feature_list_sparse_types.size()) { + return errors::InvalidArgument( + "num_feature_list_sparse (", num_feature_list_sparse, + ") must match the size of feature_list_sparse_keys (", + feature_list_sparse_keys.size(), ") and feature_list_sparse_types (", + feature_list_sparse_types.size(), ")"); + } + if (num_feature_list_dense != feature_list_dense_keys.size() || + num_feature_list_dense != feature_list_dense_types.size() || + num_feature_list_dense != feature_list_dense_shapes.size()) { + return errors::InvalidArgument( + "num_feature_list_dense (", num_feature_list_dense, + ") must match the size of feature_list_dense_keys (", + feature_list_dense_keys.size(), "), feature_list_dense_types (", + feature_list_dense_types.size(), ") and feature_list_dense_shapes (", + feature_list_dense_shapes.size(), ")"); + } + for (const DataType& type : context_dense_types) { + TF_RETURN_IF_ERROR(CheckValidType(type)); + } + for (const DataType& type : context_sparse_types) { + TF_RETURN_IF_ERROR(CheckValidType(type)); + } + for (const DataType& type : feature_list_dense_types) { + TF_RETURN_IF_ERROR(CheckValidType(type)); + } + for (const DataType& type : feature_list_sparse_types) { + TF_RETURN_IF_ERROR(CheckValidType(type)); + } + + return Status::OK(); +} + Status ParseSingleSequenceExampleAttrs::FinishInit() { if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) { return errors::InvalidArgument( diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h index e511704962..c183ee4d96 100644 --- a/tensorflow/core/util/example_proto_helper.h +++ b/tensorflow/core/util/example_proto_helper.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" @@ -271,6 +272,66 @@ class ParseSingleExampleAttrs { Status FinishInit(); // for context-independent parts of Init. }; +// Parses the attributes passed to ParseSequenceExample. +// REQUIRES: Init must be called after construction. +class ParseSequenceExampleAttrs { + public: + template <typename ContextType> + Status Init(ContextType* ctx) { + std::vector<string> feature_list_dense_missing_assumed_empty_tmp; + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_missing_assumed_empty", + &feature_list_dense_missing_assumed_empty_tmp)); + for (const string& feature : feature_list_dense_missing_assumed_empty_tmp) { + feature_list_dense_missing_assumed_empty.insert(feature); + } + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_sparse_keys", &context_sparse_keys)); + TF_RETURN_IF_ERROR(ctx->GetAttr("context_dense_keys", &context_dense_keys)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_sparse_keys", &feature_list_sparse_keys)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_sparse_types", &context_sparse_types)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse)); + TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("context_dense_shapes", &context_dense_shapes)); + TF_RETURN_IF_ERROR( + ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes)); + return FinishInit(); + } + + std::unordered_set<string> feature_list_dense_missing_assumed_empty; + int64 num_context_sparse; + int64 num_context_dense; + int64 num_feature_list_sparse; + int64 num_feature_list_dense; + std::vector<string> context_sparse_keys; + std::vector<string> context_dense_keys; + std::vector<string> feature_list_sparse_keys; + std::vector<string> feature_list_dense_keys; + std::vector<DataType> context_sparse_types; + std::vector<DataType> context_dense_types; + std::vector<TensorShape> context_dense_shapes; + std::vector<DataType> feature_list_sparse_types; + std::vector<DataType> feature_list_dense_types; + std::vector<TensorShape> feature_list_dense_shapes; + + private: + Status FinishInit(); // for context-independent parts of Init. +}; + // Parses the attributes passed to ParseSingleSequenceExample. // REQUIRES: Init must be called after construction. class ParseSingleSequenceExampleAttrs { |