aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Patrik Sundberg <sundberg@google.com>2018-08-04 12:48:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-04 12:52:16 -0700
commitcb80a1ed9fafc9274f031adca4ec5b754ac93f2a (patch)
treecc5089e88f4d1b51fbb2f513ab09251e88308dec /tensorflow/core/util
parent3a41e5363530f058cb2b57cf0add09931ec788b2 (diff)
Add a batch sequence example parsing op, part 1.
PiperOrigin-RevId: 207406637
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc772
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h11
2 files changed, 781 insertions, 2 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 3ce7988057..418e97ac24 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -325,9 +325,9 @@ bool ParseExample(protobuf::io::CodedInputStream* stream,
while (!stream->ExpectAtEnd()) {
if (!stream->ExpectTag(kDelimitedTag(1))) {
if (!SkipExtraneousTag(stream)) return false;
- continue;
+ } else {
+ if (!ParseFeatures(stream, example)) return false;
}
- if (!ParseFeatures(stream, example)) return false;
}
return true;
}
@@ -1455,5 +1455,773 @@ Status FastParseSingleExample(const Config& config, const string& serialized,
return Status::OK();
}
+// Return the number of bytes elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream,
+ string* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ while (!stream->ExpectAtEnd()) {
+ uint32 bytes_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&bytes_length) ||
+ (out != nullptr && !stream->ReadString(out++, bytes_length))) {
+ return -1;
+ }
+ if (out == nullptr) {
+ stream->Skip(bytes_length);
+ }
+ num_elements++;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+inline void PadFloatFeature(int num_to_pad, float* out) {
+ for (int i = 0; i < num_to_pad; i++) {
+ *out++ = 0.0;
+ }
+}
+
+inline void PadInt64Feature(int num_to_pad, int64* out) {
+ for (int i = 0; i < num_to_pad; i++) {
+ *out++ = 0;
+ }
+}
+
+// Return the number of float elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream,
+ float* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ uint8 peek_tag = PeekTag(stream);
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ uint32 packed_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&packed_length)) {
+ return -1;
+ }
+ auto packed_limit = stream->PushLimit(packed_length);
+ while (!stream->ExpectAtEnd()) {
+ uint32 buffer32;
+ if (!stream->ReadLittleEndian32(&buffer32)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = bit_cast<float>(buffer32);
+ }
+ num_elements++;
+ }
+ stream->PopLimit(packed_limit);
+ } else if (peek_tag == kFixed32Tag(1)) {
+ while (!stream->ExpectAtEnd()) {
+ uint32 buffer32;
+ if (!stream->ExpectTag(kFixed32Tag(1)) ||
+ !stream->ReadLittleEndian32(&buffer32)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = bit_cast<float>(buffer32);
+ }
+ num_elements++;
+ }
+ } else {
+ // Unknown tag.
+ return -1;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+// Return the number of int64 elements parsed, or -1 on error. If out is null,
+// this method simply counts the number of elements without any copying.
+inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream,
+ int64* out) {
+ int num_elements = 0;
+ uint32 length;
+ if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) {
+ return -1;
+ }
+ if (length > 0) {
+ auto limit = stream->PushLimit(length);
+ uint8 peek_tag = PeekTag(stream);
+ if (peek_tag == kDelimitedTag(1)) { // packed
+ uint32 packed_length;
+ if (!stream->ExpectTag(kDelimitedTag(1)) ||
+ !stream->ReadVarint32(&packed_length)) {
+ return -1;
+ }
+ auto packed_limit = stream->PushLimit(packed_length);
+ while (!stream->ExpectAtEnd()) {
+ protobuf_uint64 n; // There is no API for int64
+ if (!stream->ReadVarint64(&n)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = n;
+ }
+ num_elements++;
+ }
+ stream->PopLimit(packed_limit);
+ } else if (peek_tag == kVarintTag(1)) {
+ while (!stream->ExpectAtEnd()) {
+ protobuf_uint64 n; // There is no API for int64
+ if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) {
+ return -1;
+ }
+ if (out != nullptr) {
+ *out++ = n;
+ }
+ num_elements++;
+ }
+ } else {
+ // Unknown tag.
+ return -1;
+ }
+ stream->PopLimit(limit);
+ }
+ return num_elements;
+}
+
+inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) {
+ uint8 peek_tag = PeekTag(stream);
+ switch (peek_tag) {
+ case kDelimitedTag(1):
+ return DT_STRING;
+ case kDelimitedTag(2):
+ return DT_FLOAT;
+ case kDelimitedTag(3):
+ return DT_INT64;
+ default:
+ return DT_INVALID;
+ }
+}
+
+inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream,
+ DataType dtype) {
+ switch (dtype) {
+ case DT_STRING:
+ if (!stream->ExpectTag(kDelimitedTag(1))) {
+ return false;
+ }
+ break;
+ case DT_FLOAT:
+ if (!stream->ExpectTag(kDelimitedTag(2))) {
+ return false;
+ }
+ break;
+ case DT_INT64:
+ if (!stream->ExpectTag(kDelimitedTag(3))) {
+ return false;
+ }
+ break;
+ default:
+ return false;
+ }
+ uint32 length;
+ return stream->ReadVarint32(&length) && length == 0;
+}
+
+// TODO(sundberg): Use the threadpool to parallelize example parsing.
+Status FastParseSequenceExample(
+ const FastParseExampleConfig& context_config,
+ const FastParseExampleConfig& feature_list_config,
+ gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, Result* context_result,
+ Result* feature_list_result) {
+ int num_examples = serialized.size();
+ DCHECK(context_result != nullptr);
+ DCHECK(feature_list_result != nullptr);
+ std::map<StringPiece, bool> context_is_sparse;
+ std::map<StringPiece, std::pair<DataType, size_t>>
+ context_feature_type_and_lengths;
+ if (!example_names.empty() && example_names.size() != num_examples) {
+ return errors::InvalidArgument(
+ "example_names must be empty or have the correct number of elements");
+ }
+ for (auto& c : context_config.sparse) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ context_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ context_is_sparse[c.feature_name] = true;
+ }
+ for (auto& c : context_config.dense) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ context_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ context_is_sparse[c.feature_name] = false;
+ }
+ std::map<StringPiece, bool> sequence_is_sparse;
+ std::map<StringPiece, std::pair<DataType, size_t>>
+ sequence_feature_type_and_lengths;
+ for (auto& c : feature_list_config.sparse) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ sequence_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ sequence_is_sparse[c.feature_name] = true;
+ }
+ for (auto& c : feature_list_config.dense) {
+ TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
+ sequence_feature_type_and_lengths[c.feature_name] =
+ std::make_pair(c.dtype, 0);
+ sequence_is_sparse[c.feature_name] = false;
+ }
+
+ std::vector<std::map<StringPiece, StringPiece>> all_context_features(
+ num_examples);
+ std::vector<std::map<StringPiece, StringPiece>> all_sequence_features(
+ num_examples);
+ const string kUnknown = "<unknown>";
+ for (int d = 0; d < num_examples; d++) {
+ const string& example = serialized[d];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[d];
+ auto* context_features = &all_context_features[d];
+ auto* sequence_features = &all_sequence_features[d];
+
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(example.data()), example.size());
+ // Not clear what this does. Why not stream.EnableAliasing()?
+ EnableAliasing(&stream);
+
+ // Extract pointers to all features within this serialized example.
+ while (!stream.ExpectAtEnd()) {
+ std::map<StringPiece, StringPiece>* features = nullptr;
+ const std::map<StringPiece, std::pair<DataType, size_t>>* config =
+ nullptr;
+ if (stream.ExpectTag(kDelimitedTag(1))) {
+ // Context
+ features = context_features;
+ config = &context_feature_type_and_lengths;
+ } else if (stream.ExpectTag(kDelimitedTag(2))) {
+ // Sequence
+ features = sequence_features;
+ config = &sequence_feature_type_and_lengths;
+ } else if (!SkipExtraneousTag(&stream)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ if (features != nullptr) {
+ uint32 length;
+ if (!stream.ReadVarint32(&length)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ auto limit = stream.PushLimit(length);
+ while (!stream.ExpectAtEnd()) {
+ StringPiece key, value;
+ uint32 length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&length)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ auto limit = stream.PushLimit(length);
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !ParseString(&stream, &key) ||
+ !stream.ExpectTag(kDelimitedTag(2)) ||
+ !ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid protocol message input, example id: ", example_name));
+ }
+ stream.PopLimit(limit);
+ // Only save if this feature was requested.
+ if (config->count(key) > 0) {
+ (*features)[key] = value;
+ }
+ }
+ stream.PopLimit(limit);
+ }
+ }
+
+ for (const auto& c : *context_features) {
+ size_t num_elements = 0;
+ if (!c.second.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
+ EnableAliasing(&stream);
+ DataType dtype = context_feature_type_and_lengths[c.first].first;
+ int64 num;
+ switch (dtype) {
+ case DT_STRING:
+ num = ParseBytesFeature(&stream, nullptr);
+ break;
+ case DT_FLOAT:
+ num = ParseFloatFeature(&stream, nullptr);
+ break;
+ case DT_INT64:
+ num = ParseInt64Feature(&stream, nullptr);
+ break;
+ default:
+ num = -1;
+ break;
+ }
+ if (num == -1) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in context feature ", c.first,
+ " in example ", example_name));
+ }
+ num_elements += num;
+ }
+ if (context_is_sparse[c.first]) {
+ context_feature_type_and_lengths[c.first].second += num_elements;
+ } else {
+ size_t current_max = context_feature_type_and_lengths[c.first].second;
+ context_feature_type_and_lengths[c.first].second =
+ std::max(current_max, num_elements);
+ }
+ }
+ for (const auto& c : *sequence_features) {
+ size_t num_elements = 0;
+ if (!c.second.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(c.second.data()), c.second.size());
+ EnableAliasing(&stream);
+ DataType dtype = sequence_feature_type_and_lengths[c.first].first;
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ if (feature_length > 2) {
+ auto limit = stream.PushLimit(feature_length);
+ int64 num;
+ switch (dtype) {
+ case DT_STRING:
+ num = ParseBytesFeature(&stream, nullptr);
+ break;
+ case DT_FLOAT:
+ num = ParseFloatFeature(&stream, nullptr);
+ break;
+ case DT_INT64:
+ num = ParseInt64Feature(&stream, nullptr);
+ break;
+ default:
+ num = -1;
+ break;
+ }
+ if (num == -1) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ num_elements += num;
+ stream.PopLimit(limit);
+ } else if (feature_length == 2) {
+ if (!SkipEmptyFeature(&stream, dtype)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ } else if (feature_length != 0) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.first,
+ " in example ", example_name));
+ }
+ }
+ }
+ if (sequence_is_sparse[c.first]) {
+ sequence_feature_type_and_lengths[c.first].second += num_elements;
+ } else {
+ size_t current_max = sequence_feature_type_and_lengths[c.first].second;
+ sequence_feature_type_and_lengths[c.first].second =
+ std::max(current_max, num_elements);
+ }
+ }
+ }
+
+ // Allocate memory.
+ context_result->sparse_values.resize(context_config.sparse.size());
+ context_result->sparse_indices.resize(context_config.sparse.size());
+ context_result->sparse_shapes.resize(context_config.sparse.size());
+ context_result->dense_values.resize(context_config.dense.size());
+ feature_list_result->sparse_values.resize(feature_list_config.sparse.size());
+ feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
+ feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
+ feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ int t = 0;
+ for (const auto& c : context_config.dense) {
+ TensorShape dense_shape;
+ DataType dtype = c.dtype;
+ size_t expected_max_elements =
+ context_feature_type_and_lengths[c.feature_name].second;
+ if (expected_max_elements != dense_shape.num_elements()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Inconsistent number of elements for feature ", c.feature_name));
+ }
+ dense_shape.AddDim(num_examples);
+ for (const int dim : c.shape.dim_sizes()) {
+ dense_shape.AddDim(dim);
+ }
+ context_result->dense_values[t] = Tensor(dtype, dense_shape);
+
+ // TODO(sundberg): Refactor to reduce code duplication, and add bounds
+ // checking for the outputs.
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = context_result->dense_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = context_result->dense_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = context_result->dense_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ t++;
+
+ // Fill in the values.
+ for (int e = 0; e < num_examples; e++) {
+ size_t num_elements = 0;
+ const auto& feature = all_context_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ }
+ if (num_elements != expected_max_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected number of elements in example ", example_name));
+ }
+ }
+ }
+ t = 0;
+ for (const auto& c : context_config.sparse) {
+ TensorShape indices_shape, values_shape;
+ DataType dtype = c.dtype;
+ size_t expected_num_elements =
+ context_feature_type_and_lengths[c.feature_name].second;
+ indices_shape.AddDim(expected_num_elements);
+ indices_shape.AddDim(2);
+ values_shape.AddDim(expected_num_elements);
+ context_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
+ context_result->sparse_values[t] = Tensor(dtype, values_shape);
+ context_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({2}));
+ // TODO(sundberg): Refactor to reduce code duplication, and add bounds
+ // checking for the outputs.
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = context_result->sparse_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = context_result->sparse_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = context_result->sparse_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
+ auto out_shape = context_result->sparse_shapes[t].vec<int64>();
+ t++;
+
+ // Fill in the values.
+ size_t num_elements = 0;
+ size_t max_num_cols = 0;
+ for (int e = 0; e < num_examples; e++) {
+ const auto& feature = all_context_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ max_num_cols = std::max(max_num_cols, num_added);
+ for (int i = 0; i < num_added; i++) {
+ *out_indices++ = e;
+ *out_indices++ = i;
+ }
+ }
+ }
+ if (num_elements != expected_num_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected total number of elements in feature ", c.feature_name));
+ }
+ out_shape(0) = num_examples;
+ out_shape(1) = max_num_cols;
+ }
+ t = 0;
+ for (const auto& c : feature_list_config.dense) {
+ TensorShape dense_shape, row_shape;
+ DataType dtype = c.dtype;
+ size_t expected_max_elements =
+ sequence_feature_type_and_lengths[c.feature_name].second;
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
+ if (!c.shape.AsTensorShape(&row_shape) ||
+ expected_max_elements != expected_max_rows * row_shape.num_elements()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected shape error in feature ", c.feature_name));
+ }
+ dense_shape.AddDim(num_examples);
+ dense_shape.AddDim(expected_max_rows);
+ for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
+ dense_shape.AddDim(dim);
+ }
+ feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = feature_list_result->dense_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = feature_list_result->dense_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ t++;
+
+ // Fill in the values.
+ for (int e = 0; e < num_examples; e++) {
+ size_t num_elements = 0;
+ const auto& feature = all_sequence_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ auto limit = stream.PushLimit(feature_length);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ if (num_added != row_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name,
+ ", example ", example_name);
+ }
+ stream.PopLimit(limit);
+ }
+ }
+ // Pad as necessary.
+ int num_to_pad = expected_max_elements - num_elements;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes += num_to_pad;
+ break;
+ case DT_FLOAT:
+ PadFloatFeature(num_to_pad, out_float);
+ out_float += num_to_pad;
+ break;
+ case DT_INT64:
+ PadInt64Feature(num_to_pad, out_int64);
+ out_int64 += num_to_pad;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ }
+ }
+ t = 0;
+ for (const auto& c : feature_list_config.sparse) {
+ TensorShape indices_shape, values_shape;
+ DataType dtype = c.dtype;
+ size_t expected_num_elements =
+ sequence_feature_type_and_lengths[c.feature_name].second;
+ indices_shape.AddDim(expected_num_elements);
+ indices_shape.AddDim(3);
+ values_shape.AddDim(expected_num_elements);
+ feature_list_result->sparse_indices[t] = Tensor(DT_INT64, indices_shape);
+ feature_list_result->sparse_values[t] = Tensor(dtype, values_shape);
+ feature_list_result->sparse_shapes[t] = Tensor(DT_INT64, TensorShape({3}));
+
+ string* out_bytes = nullptr;
+ float* out_float = nullptr;
+ int64* out_int64 = nullptr;
+ switch (dtype) {
+ case DT_STRING:
+ out_bytes = feature_list_result->sparse_values[t].flat<string>().data();
+ break;
+ case DT_FLOAT:
+ out_float = feature_list_result->sparse_values[t].flat<float>().data();
+ break;
+ case DT_INT64:
+ out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ }
+ int64* out_indices =
+ feature_list_result->sparse_indices[t].flat<int64>().data();
+ auto out_shape = feature_list_result->sparse_shapes[t].vec<int64>();
+ t++;
+
+ // Fill in the values.
+ size_t num_elements = 0;
+ size_t max_num_rows = 0;
+ size_t max_num_cols = 0;
+ for (int e = 0; e < num_examples; e++) {
+ const auto& feature = all_sequence_features[e][c.feature_name];
+ const string& example_name =
+ example_names.empty() ? kUnknown : example_names[e];
+ if (!feature.empty()) {
+ protobuf::io::CodedInputStream stream(
+ reinterpret_cast<const uint8*>(feature.data()), feature.size());
+ EnableAliasing(&stream);
+ size_t num_rows = 0;
+ while (!stream.ExpectAtEnd()) {
+ uint32 feature_length;
+ if (!stream.ExpectTag(kDelimitedTag(1)) ||
+ !stream.ReadVarint32(&feature_length)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ if (feature_length > 2) {
+ auto limit = stream.PushLimit(feature_length);
+ size_t num_added;
+ switch (dtype) {
+ case DT_STRING:
+ num_added = ParseBytesFeature(&stream, out_bytes);
+ out_bytes += num_added;
+ break;
+ case DT_FLOAT:
+ num_added = ParseFloatFeature(&stream, out_float);
+ out_float += num_added;
+ break;
+ case DT_INT64:
+ num_added = ParseInt64Feature(&stream, out_int64);
+ out_int64 += num_added;
+ break;
+ default:
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected dtype ", dtype, " in example ", example_name));
+ }
+ num_elements += num_added;
+ max_num_cols = std::max(max_num_cols, num_added);
+ for (int i = 0; i < num_added; i++) {
+ *out_indices++ = e;
+ *out_indices++ = num_rows;
+ *out_indices++ = i;
+ }
+ stream.PopLimit(limit);
+ } else if (feature_length == 2) {
+ if (!SkipEmptyFeature(&stream, dtype)) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ } else if (feature_length != 0) {
+ return errors::InvalidArgument(
+ strings::StrCat("Error in sequence feature ", c.feature_name,
+ " in example ", example_name));
+ }
+ num_rows++;
+ }
+ max_num_rows = std::max(max_num_rows, num_rows);
+ }
+ }
+ if (num_elements != expected_num_elements) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Unexpected number of elements in feature ", c.feature_name));
+ }
+ out_shape(0) = num_examples;
+ out_shape(1) = max_num_rows;
+ out_shape(2) = max_num_cols;
+ }
+
+ return Status::OK();
+}
+
} // namespace example
} // namespace tensorflow
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index 1b08f02267..024a4518ee 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -85,6 +85,17 @@ typedef FastParseExampleConfig FastParseSingleExampleConfig;
Status FastParseSingleExample(const FastParseSingleExampleConfig& config,
const string& serialized, Result* result);
+// Parses a batch of serialized SequenceExample protos and converts them into
+// result according to given config.
+// Given example names have to either be empty or the same size as serialized.
+// example_names are used only for error messages.
+Status FastParseSequenceExample(
+ const example::FastParseExampleConfig& context_config,
+ const example::FastParseExampleConfig& feature_list_config,
+ gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
+ thread::ThreadPool* thread_pool, example::Result* context_result,
+ example::Result* feature_list_result);
+
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
// But then constructs Example which is relatively slow.