diff options
author | Patrik Sundberg <sundberg@google.com> | 2018-08-04 12:48:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-04 12:52:16 -0700 |
commit | cb80a1ed9fafc9274f031adca4ec5b754ac93f2a (patch) | |
tree | cc5089e88f4d1b51fbb2f513ab09251e88308dec /tensorflow/core/util | |
parent | 3a41e5363530f058cb2b57cf0add09931ec788b2 (diff) |
Add a batch sequence example parsing op, part 1.
PiperOrigin-RevId: 207406637
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 772 | ||||
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.h | 11 |
2 files changed, 781 insertions, 2 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index 3ce7988057..418e97ac24 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -325,9 +325,9 @@ bool ParseExample(protobuf::io::CodedInputStream* stream, while (!stream->ExpectAtEnd()) { if (!stream->ExpectTag(kDelimitedTag(1))) { if (!SkipExtraneousTag(stream)) return false; - continue; + } else { + if (!ParseFeatures(stream, example)) return false; } - if (!ParseFeatures(stream, example)) return false; } return true; } @@ -1455,5 +1455,773 @@ Status FastParseSingleExample(const Config& config, const string& serialized, return Status::OK(); } +// Return the number of bytes elements parsed, or -1 on error. If out is null, +// this method simply counts the number of elements without any copying. +inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream, + string* out) { + int num_elements = 0; + uint32 length; + if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) { + return -1; + } + if (length > 0) { + auto limit = stream->PushLimit(length); + while (!stream->ExpectAtEnd()) { + uint32 bytes_length; + if (!stream->ExpectTag(kDelimitedTag(1)) || + !stream->ReadVarint32(&bytes_length) || + (out != nullptr && !stream->ReadString(out++, bytes_length))) { + return -1; + } + if (out == nullptr) { + stream->Skip(bytes_length); + } + num_elements++; + } + stream->PopLimit(limit); + } + return num_elements; +} + +inline void PadFloatFeature(int num_to_pad, float* out) { + for (int i = 0; i < num_to_pad; i++) { + *out++ = 0.0; + } +} + +inline void PadInt64Feature(int num_to_pad, int64* out) { + for (int i = 0; i < num_to_pad; i++) { + *out++ = 0; + } +} + +// Return the number of float elements parsed, or -1 on error. If out is null, +// this method simply counts the number of elements without any copying. +inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream, + float* out) { + int num_elements = 0; + uint32 length; + if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) { + return -1; + } + if (length > 0) { + auto limit = stream->PushLimit(length); + uint8 peek_tag = PeekTag(stream); + if (peek_tag == kDelimitedTag(1)) { // packed + uint32 packed_length; + if (!stream->ExpectTag(kDelimitedTag(1)) || + !stream->ReadVarint32(&packed_length)) { + return -1; + } + auto packed_limit = stream->PushLimit(packed_length); + while (!stream->ExpectAtEnd()) { + uint32 buffer32; + if (!stream->ReadLittleEndian32(&buffer32)) { + return -1; + } + if (out != nullptr) { + *out++ = bit_cast<float>(buffer32); + } + num_elements++; + } + stream->PopLimit(packed_limit); + } else if (peek_tag == kFixed32Tag(1)) { + while (!stream->ExpectAtEnd()) { + uint32 buffer32; + if (!stream->ExpectTag(kFixed32Tag(1)) || + !stream->ReadLittleEndian32(&buffer32)) { + return -1; + } + if (out != nullptr) { + *out++ = bit_cast<float>(buffer32); + } + num_elements++; + } + } else { + // Unknown tag. + return -1; + } + stream->PopLimit(limit); + } + return num_elements; +} + +// Return the number of int64 elements parsed, or -1 on error. If out is null, +// this method simply counts the number of elements without any copying. +inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream, + int64* out) { + int num_elements = 0; + uint32 length; + if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) { + return -1; + } + if (length > 0) { + auto limit = stream->PushLimit(length); + uint8 peek_tag = PeekTag(stream); + if (peek_tag == kDelimitedTag(1)) { // packed + uint32 packed_length; + if (!stream->ExpectTag(kDelimitedTag(1)) || + !stream->ReadVarint32(&packed_length)) { + return -1; + } + auto packed_limit = stream->PushLimit(packed_length); + while (!stream->ExpectAtEnd()) { + protobuf_uint64 n; // There is no API for int64 + if (!stream->ReadVarint64(&n)) { + return -1; + } + if (out != nullptr) { + *out++ = n; + } + num_elements++; + } + stream->PopLimit(packed_limit); + } else if (peek_tag == kVarintTag(1)) { + while (!stream->ExpectAtEnd()) { + protobuf_uint64 n; // There is no API for int64 + if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) { + return -1; + } + if (out != nullptr) { + *out++ = n; + } + num_elements++; + } + } else { + // Unknown tag. + return -1; + } + stream->PopLimit(limit); + } + return num_elements; +} + +inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) { + uint8 peek_tag = PeekTag(stream); + switch (peek_tag) { + case kDelimitedTag(1): + return DT_STRING; + case kDelimitedTag(2): + return DT_FLOAT; + case kDelimitedTag(3): + return DT_INT64; + default: + return DT_INVALID; + } +} + +inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream, + DataType dtype) { + switch (dtype) { + case DT_STRING: + if (!stream->ExpectTag(kDelimitedTag(1))) { + return false; + } + break; + case DT_FLOAT: + if (!stream->ExpectTag(kDelimitedTag(2))) { + return false; + } + break; + case DT_INT64: + if (!stream->ExpectTag(kDelimitedTag(3))) { + return false; + } + break; + default: + return false; + } + uint32 length; + return stream->ReadVarint32(&length) && length == 0; +} + +// TODO(sundberg): Use the threadpool to parallelize example parsing. +Status FastParseSequenceExample( + const FastParseExampleConfig& context_config, + const FastParseExampleConfig& feature_list_config, + gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, + thread::ThreadPool* thread_pool, Result* context_result, + Result* feature_list_result) { + int num_examples = serialized.size(); + DCHECK(context_result != nullptr); + DCHECK(feature_list_result != nullptr); + std::map<StringPiece, bool> context_is_sparse; + std::map<StringPiece, std::pair<DataType, size_t>> + context_feature_type_and_lengths; + if (!example_names.empty() && example_names.size() != num_examples) { + return errors::InvalidArgument( + "example_names must be empty or have the correct number of elements"); + } + for (auto& c : context_config.sparse) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + context_feature_type_and_lengths[c.feature_name] = + std::make_pair(c.dtype, 0); + context_is_sparse[c.feature_name] = true; + } + for (auto& c : context_config.dense) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + context_feature_type_and_lengths[c.feature_name] = + std::make_pair(c.dtype, 0); + context_is_sparse[c.feature_name] = false; + } + std::map<StringPiece, bool> sequence_is_sparse; + std::map<StringPiece, std::pair<DataType, size_t>> + sequence_feature_type_and_lengths; + for (auto& c : feature_list_config.sparse) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + sequence_feature_type_and_lengths[c.feature_name] = + std::make_pair(c.dtype, 0); + sequence_is_sparse[c.feature_name] = true; + } + for (auto& c : feature_list_config.dense) { + TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); + sequence_feature_type_and_lengths[c.feature_name] = + std::make_pair(c.dtype, 0); + sequence_is_sparse[c.feature_name] = false; + } + + std::vector<std::map<StringPiece, StringPiece>> all_context_features( + num_examples); + std::vector<std::map<StringPiece, StringPiece>> all_sequence_features( + num_examples); + const string kUnknown = "<unknown>"; + for (int d = 0; d < num_examples; d++) { + const string& example = serialized[d]; + const string& example_name = + example_names.empty() ? kUnknown : example_names[d]; + auto* context_features = &all_context_features[d]; + auto* sequence_features = &all_sequence_features[d]; + + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(example.data()), example.size()); + // Not clear what this does. Why not stream.EnableAliasing()? + EnableAliasing(&stream); + + // Extract pointers to all features within this serialized example. + while (!stream.ExpectAtEnd()) { + std::map<StringPiece, StringPiece>* features = nullptr; + const std::map<StringPiece, std::pair<DataType, size_t>>* config = + nullptr; + if (stream.ExpectTag(kDelimitedTag(1))) { + // Context + features = context_features; + config = &context_feature_type_and_lengths; + } else if (stream.ExpectTag(kDelimitedTag(2))) { + // Sequence + features = sequence_features; + config = &sequence_feature_type_and_lengths; + } else if (!SkipExtraneousTag(&stream)) { + return errors::InvalidArgument(strings::StrCat( + "Invalid protocol message input, example id: ", example_name)); + } + if (features != nullptr) { + uint32 length; + if (!stream.ReadVarint32(&length)) { + return errors::InvalidArgument(strings::StrCat( + "Invalid protocol message input, example id: ", example_name)); + } + auto limit = stream.PushLimit(length); + while (!stream.ExpectAtEnd()) { + StringPiece key, value; + uint32 length; + if (!stream.ExpectTag(kDelimitedTag(1)) || + !stream.ReadVarint32(&length)) { + return errors::InvalidArgument(strings::StrCat( + "Invalid protocol message input, example id: ", example_name)); + } + auto limit = stream.PushLimit(length); + if (!stream.ExpectTag(kDelimitedTag(1)) || + !ParseString(&stream, &key) || + !stream.ExpectTag(kDelimitedTag(2)) || + !ParseString(&stream, &value) || !stream.ExpectAtEnd()) { + return errors::InvalidArgument(strings::StrCat( + "Invalid protocol message input, example id: ", example_name)); + } + stream.PopLimit(limit); + // Only save if this feature was requested. + if (config->count(key) > 0) { + (*features)[key] = value; + } + } + stream.PopLimit(limit); + } + } + + for (const auto& c : *context_features) { + size_t num_elements = 0; + if (!c.second.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(c.second.data()), c.second.size()); + EnableAliasing(&stream); + DataType dtype = context_feature_type_and_lengths[c.first].first; + int64 num; + switch (dtype) { + case DT_STRING: + num = ParseBytesFeature(&stream, nullptr); + break; + case DT_FLOAT: + num = ParseFloatFeature(&stream, nullptr); + break; + case DT_INT64: + num = ParseInt64Feature(&stream, nullptr); + break; + default: + num = -1; + break; + } + if (num == -1) { + return errors::InvalidArgument( + strings::StrCat("Error in context feature ", c.first, + " in example ", example_name)); + } + num_elements += num; + } + if (context_is_sparse[c.first]) { + context_feature_type_and_lengths[c.first].second += num_elements; + } else { + size_t current_max = context_feature_type_and_lengths[c.first].second; + context_feature_type_and_lengths[c.first].second = + std::max(current_max, num_elements); + } + } + for (const auto& c : *sequence_features) { + size_t num_elements = 0; + if (!c.second.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(c.second.data()), c.second.size()); + EnableAliasing(&stream); + DataType dtype = sequence_feature_type_and_lengths[c.first].first; + while (!stream.ExpectAtEnd()) { + uint32 feature_length; + if (!stream.ExpectTag(kDelimitedTag(1)) || + !stream.ReadVarint32(&feature_length)) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.first, + " in example ", example_name)); + } + if (feature_length > 2) { + auto limit = stream.PushLimit(feature_length); + int64 num; + switch (dtype) { + case DT_STRING: + num = ParseBytesFeature(&stream, nullptr); + break; + case DT_FLOAT: + num = ParseFloatFeature(&stream, nullptr); + break; + case DT_INT64: + num = ParseInt64Feature(&stream, nullptr); + break; + default: + num = -1; + break; + } + if (num == -1) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.first, + " in example ", example_name)); + } + num_elements += num; + stream.PopLimit(limit); + } else if (feature_length == 2) { + if (!SkipEmptyFeature(&stream, dtype)) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.first, + " in example ", example_name)); + } + } else if (feature_length != 0) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.first, + " in example ", example_name)); + } + } + } + if (sequence_is_sparse[c.first]) { + sequence_feature_type_and_lengths[c.first].second += num_elements; + } else { + size_t current_max = sequence_feature_type_and_lengths[c.first].second; + sequence_feature_type_and_lengths[c.first].second = + std::max(current_max, num_elements); + } + } + } + + // Allocate memory. + context_result->sparse_values.resize(context_config.sparse.size()); + context_result->sparse_indices.resize(context_config.sparse.size()); + context_result->sparse_shapes.resize(context_config.sparse.size()); + context_result->dense_values.resize(context_config.dense.size()); + feature_list_result->sparse_values.resize(feature_list_config.sparse.size()); + feature_list_result->sparse_indices.resize(feature_list_config.sparse.size()); + feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size()); + feature_list_result->dense_values.resize(feature_list_config.dense.size()); + int t = 0; + for (const auto& c : context_config.dense) { + TensorShape dense_shape; + DataType dtype = c.dtype; + size_t expected_max_elements = + context_feature_type_and_lengths[c.feature_name].second; + if (expected_max_elements != dense_shape.num_elements()) { + return errors::InvalidArgument(strings::StrCat( + "Inconsistent number of elements for feature ", c.feature_name)); + } + dense_shape.AddDim(num_examples); + for (const int dim : c.shape.dim_sizes()) { + dense_shape.AddDim(dim); + } + context_result->dense_values[t] = Tensor(dtype, dense_shape); + + // TODO(sundberg): Refactor to reduce code duplication, and add bounds + // checking for the outputs. + string* out_bytes = nullptr; + float* out_float = nullptr; + int64* out_int64 = nullptr; + switch (dtype) { + case DT_STRING: + out_bytes = context_result->dense_values[t].flat<string>().data(); + break; + case DT_FLOAT: + out_float = context_result->dense_values[t].flat<float>().data(); + break; + case DT_INT64: + out_int64 = context_result->dense_values[t].flat<int64>().data(); + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + } + t++; + + // Fill in the values. + for (int e = 0; e < num_examples; e++) { + size_t num_elements = 0; + const auto& feature = all_context_features[e][c.feature_name]; + const string& example_name = + example_names.empty() ? kUnknown : example_names[e]; + if (!feature.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(feature.data()), feature.size()); + EnableAliasing(&stream); + size_t num_added; + switch (dtype) { + case DT_STRING: + num_added = ParseBytesFeature(&stream, out_bytes); + out_bytes += num_added; + break; + case DT_FLOAT: + num_added = ParseFloatFeature(&stream, out_float); + out_float += num_added; + break; + case DT_INT64: + num_added = ParseInt64Feature(&stream, out_int64); + out_int64 += num_added; + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in example ", example_name)); + } + num_elements += num_added; + } + if (num_elements != expected_max_elements) { + return errors::InvalidArgument(strings::StrCat( + "Unexpected number of elements in example ", example_name)); + } + } + } + t = 0; + for (const auto& c : context_config.sparse) { + TensorShape indices_shape, values_shape; + DataType dtype = c.dtype; + size_t expected_num_elements = + context_feature_type_and_lengths[c.feature_name].second; + indices_shape.AddDim(expected_num_elements); + indices_shape.AddDim(2); + values_shape.AddDim(expected_num_elements); + context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape); + context_result->sparse_values[t] = Tensor(dtype, values_shape); + context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2})); + // TODO(sundberg): Refactor to reduce code duplication, and add bounds + // checking for the outputs. + string* out_bytes = nullptr; + float* out_float = nullptr; + int64* out_int64 = nullptr; + switch (dtype) { + case DT_STRING: + out_bytes = context_result->sparse_values[t].flat<string>().data(); + break; + case DT_FLOAT: + out_float = context_result->sparse_values[t].flat<float>().data(); + break; + case DT_INT64: + out_int64 = context_result->sparse_values[t].flat<int64>().data(); + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + } + int64* out_indices = context_result->sparse_indices[t].flat<int64>().data(); + auto out_shape = context_result->sparse_shapes[t].vec<int64>(); + t++; + + // Fill in the values. + size_t num_elements = 0; + size_t max_num_cols = 0; + for (int e = 0; e < num_examples; e++) { + const auto& feature = all_context_features[e][c.feature_name]; + const string& example_name = + example_names.empty() ? kUnknown : example_names[e]; + if (!feature.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(feature.data()), feature.size()); + EnableAliasing(&stream); + size_t num_added; + switch (dtype) { + case DT_STRING: + num_added = ParseBytesFeature(&stream, out_bytes); + out_bytes += num_added; + break; + case DT_FLOAT: + num_added = ParseFloatFeature(&stream, out_float); + out_float += num_added; + break; + case DT_INT64: + num_added = ParseInt64Feature(&stream, out_int64); + out_int64 += num_added; + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in example ", example_name)); + } + num_elements += num_added; + max_num_cols = std::max(max_num_cols, num_added); + for (int i = 0; i < num_added; i++) { + *out_indices++ = e; + *out_indices++ = i; + } + } + } + if (num_elements != expected_num_elements) { + return errors::InvalidArgument(strings::StrCat( + "Unexpected total number of elements in feature ", c.feature_name)); + } + out_shape(0) = num_examples; + out_shape(1) = max_num_cols; + } + t = 0; + for (const auto& c : feature_list_config.dense) { + TensorShape dense_shape, row_shape; + DataType dtype = c.dtype; + size_t expected_max_elements = + sequence_feature_type_and_lengths[c.feature_name].second; + int64 expected_max_rows = expected_max_elements / row_shape.num_elements(); + if (!c.shape.AsTensorShape(&row_shape) || + expected_max_elements != expected_max_rows * row_shape.num_elements()) { + return errors::InvalidArgument(strings::StrCat( + "Unexpected shape error in feature ", c.feature_name)); + } + dense_shape.AddDim(num_examples); + dense_shape.AddDim(expected_max_rows); + for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) { + dense_shape.AddDim(dim); + } + feature_list_result->dense_values[t] = Tensor(dtype, dense_shape); + + string* out_bytes = nullptr; + float* out_float = nullptr; + int64* out_int64 = nullptr; + switch (dtype) { + case DT_STRING: + out_bytes = feature_list_result->dense_values[t].flat<string>().data(); + break; + case DT_FLOAT: + out_float = feature_list_result->dense_values[t].flat<float>().data(); + break; + case DT_INT64: + out_int64 = feature_list_result->dense_values[t].flat<int64>().data(); + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + } + t++; + + // Fill in the values. + for (int e = 0; e < num_examples; e++) { + size_t num_elements = 0; + const auto& feature = all_sequence_features[e][c.feature_name]; + const string& example_name = + example_names.empty() ? kUnknown : example_names[e]; + if (!feature.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(feature.data()), feature.size()); + EnableAliasing(&stream); + while (!stream.ExpectAtEnd()) { + uint32 feature_length; + if (!stream.ExpectTag(kDelimitedTag(1)) || + !stream.ReadVarint32(&feature_length)) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.feature_name, + " in example ", example_name)); + } + auto limit = stream.PushLimit(feature_length); + size_t num_added; + switch (dtype) { + case DT_STRING: + num_added = ParseBytesFeature(&stream, out_bytes); + out_bytes += num_added; + break; + case DT_FLOAT: + num_added = ParseFloatFeature(&stream, out_float); + out_float += num_added; + break; + case DT_INT64: + num_added = ParseInt64Feature(&stream, out_int64); + out_int64 += num_added; + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in example ", example_name)); + } + num_elements += num_added; + if (num_added != row_shape.num_elements()) { + return errors::InvalidArgument( + "Unexpected number of elements in feature ", c.feature_name, + ", example ", example_name); + } + stream.PopLimit(limit); + } + } + // Pad as necessary. + int num_to_pad = expected_max_elements - num_elements; + switch (dtype) { + case DT_STRING: + out_bytes += num_to_pad; + break; + case DT_FLOAT: + PadFloatFeature(num_to_pad, out_float); + out_float += num_to_pad; + break; + case DT_INT64: + PadInt64Feature(num_to_pad, out_int64); + out_int64 += num_to_pad; + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in example ", example_name)); + } + } + } + t = 0; + for (const auto& c : feature_list_config.sparse) { + TensorShape indices_shape, values_shape; + DataType dtype = c.dtype; + size_t expected_num_elements = + sequence_feature_type_and_lengths[c.feature_name].second; + indices_shape.AddDim(expected_num_elements); + indices_shape.AddDim(3); + values_shape.AddDim(expected_num_elements); + feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape); + feature_list_result->sparse_values[t] = Tensor(dtype, values_shape); + feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3})); + + string* out_bytes = nullptr; + float* out_float = nullptr; + int64* out_int64 = nullptr; + switch (dtype) { + case DT_STRING: + out_bytes = feature_list_result->sparse_values[t].flat<string>().data(); + break; + case DT_FLOAT: + out_float = feature_list_result->sparse_values[t].flat<float>().data(); + break; + case DT_INT64: + out_int64 = feature_list_result->sparse_values[t].flat<int64>().data(); + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + } + int64* out_indices = + feature_list_result->sparse_indices[t].flat<int64>().data(); + auto out_shape = feature_list_result->sparse_shapes[t].vec<int64>(); + t++; + + // Fill in the values. + size_t num_elements = 0; + size_t max_num_rows = 0; + size_t max_num_cols = 0; + for (int e = 0; e < num_examples; e++) { + const auto& feature = all_sequence_features[e][c.feature_name]; + const string& example_name = + example_names.empty() ? kUnknown : example_names[e]; + if (!feature.empty()) { + protobuf::io::CodedInputStream stream( + reinterpret_cast<const uint8*>(feature.data()), feature.size()); + EnableAliasing(&stream); + size_t num_rows = 0; + while (!stream.ExpectAtEnd()) { + uint32 feature_length; + if (!stream.ExpectTag(kDelimitedTag(1)) || + !stream.ReadVarint32(&feature_length)) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.feature_name, + " in example ", example_name)); + } + if (feature_length > 2) { + auto limit = stream.PushLimit(feature_length); + size_t num_added; + switch (dtype) { + case DT_STRING: + num_added = ParseBytesFeature(&stream, out_bytes); + out_bytes += num_added; + break; + case DT_FLOAT: + num_added = ParseFloatFeature(&stream, out_float); + out_float += num_added; + break; + case DT_INT64: + num_added = ParseInt64Feature(&stream, out_int64); + out_int64 += num_added; + break; + default: + return errors::InvalidArgument(strings::StrCat( + "Unexpected dtype ", dtype, " in example ", example_name)); + } + num_elements += num_added; + max_num_cols = std::max(max_num_cols, num_added); + for (int i = 0; i < num_added; i++) { + *out_indices++ = e; + *out_indices++ = num_rows; + *out_indices++ = i; + } + stream.PopLimit(limit); + } else if (feature_length == 2) { + if (!SkipEmptyFeature(&stream, dtype)) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.feature_name, + " in example ", example_name)); + } + } else if (feature_length != 0) { + return errors::InvalidArgument( + strings::StrCat("Error in sequence feature ", c.feature_name, + " in example ", example_name)); + } + num_rows++; + } + max_num_rows = std::max(max_num_rows, num_rows); + } + } + if (num_elements != expected_num_elements) { + return errors::InvalidArgument(strings::StrCat( + "Unexpected number of elements in feature ", c.feature_name)); + } + out_shape(0) = num_examples; + out_shape(1) = max_num_rows; + out_shape(2) = max_num_cols; + } + + return Status::OK(); +} + } // namespace example } // namespace tensorflow diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index 1b08f02267..024a4518ee 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -85,6 +85,17 @@ typedef FastParseExampleConfig FastParseSingleExampleConfig; Status FastParseSingleExample(const FastParseSingleExampleConfig& config, const string& serialized, Result* result); +// Parses a batch of serialized SequenceExample protos and converts them into +// result according to given config. +// Given example names have to either be empty or the same size as serialized. +// example_names are used only for error messages. +Status FastParseSequenceExample( + const example::FastParseExampleConfig& context_config, + const example::FastParseExampleConfig& feature_list_config, + gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, + thread::ThreadPool* thread_pool, example::Result* context_result, + example::Result* feature_list_result); + // This function parses serialized Example and populates given example. // It uses the same specialized parser as FastParseExample which is efficient. // But then constructs Example which is relatively slow. |