diff options
author | Patrik Sundberg <sundberg@google.com> | 2018-08-31 07:28:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-31 07:32:19 -0700 |
commit | d09bd33fc9175c7fdbdded8f98b5c5d3a9f8ad7d (patch) | |
tree | bb869bee40e9833299b2662d2076547c85d5df39 /tensorflow/core/util/example_proto_fast_parsing.cc | |
parent | 1251e8c6cee24d9a295c63fa2362b03deef4396b (diff) |
Add a batch sequence example parsing op, part 2.
PiperOrigin-RevId: 211082479
Diffstat (limited to 'tensorflow/core/util/example_proto_fast_parsing.cc')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 228 |
1 files changed, 152 insertions, 76 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index a38cd1d09f..e52d55e2ff 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -1722,10 +1722,11 @@ Status FastParseSequenceExample( const FastParseExampleConfig& feature_list_config, gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names, thread::ThreadPool* thread_pool, Result* context_result, - Result* feature_list_result) { + Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) { int num_examples = serialized.size(); DCHECK(context_result != nullptr); DCHECK(feature_list_result != nullptr); + DCHECK(dense_feature_lengths != nullptr); std::map<StringPiece, bool> context_is_sparse; std::map<StringPiece, std::pair<DataType, size_t>> context_feature_type_and_lengths; @@ -1740,9 +1741,22 @@ Status FastParseSequenceExample( context_is_sparse[c.feature_name] = true; } for (auto& c : context_config.dense) { + if (context_is_sparse[c.feature_name]) { + return errors::InvalidArgument("Context feature " + c.feature_name + + " cannot be both dense and sparse"); + } TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); context_feature_type_and_lengths[c.feature_name] = - std::make_pair(c.dtype, 0); + std::make_pair(c.dtype, c.default_value.NumElements()); + if (c.default_value.NumElements() > 0) { + if (!c.shape.IsCompatibleWith(c.default_value.shape())) { + return errors::InvalidArgument("Default value for context feature ", + c.feature_name, + " has an incorrect shape: saw ", + c.default_value.shape().DebugString(), + " but expected ", c.shape.DebugString()); + } + } context_is_sparse[c.feature_name] = false; } std::map<StringPiece, bool> sequence_is_sparse; @@ -1755,6 +1769,10 @@ Status FastParseSequenceExample( sequence_is_sparse[c.feature_name] = true; } for (auto& c : feature_list_config.dense) { + if (sequence_is_sparse[c.feature_name]) { + return errors::InvalidArgument("Sequence feature " + c.feature_name + + " cannot be both dense and sparse"); + } TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); sequence_feature_type_and_lengths[c.feature_name] = std::make_pair(c.dtype, 0); @@ -1792,14 +1810,14 @@ Status FastParseSequenceExample( features = sequence_features; config = &sequence_feature_type_and_lengths; } else if (!SkipExtraneousTag(&stream)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } if (features != nullptr) { uint32 length; if (!stream.ReadVarint32(&length)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } auto limit = stream.PushLimit(length); while (!stream.ExpectAtEnd()) { @@ -1807,16 +1825,16 @@ Status FastParseSequenceExample( uint32 length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&length)) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } auto limit = stream.PushLimit(length); if (!stream.ExpectTag(kDelimitedTag(1)) || !ParseString(&stream, &key) || !stream.ExpectTag(kDelimitedTag(2)) || !ParseString(&stream, &value) || !stream.ExpectAtEnd()) { - return errors::InvalidArgument(strings::StrCat( - "Invalid protocol message input, example id: ", example_name)); + return errors::InvalidArgument( + "Invalid protocol message input, example id: ", example_name); } stream.PopLimit(limit); // Only save if this feature was requested. @@ -1851,9 +1869,8 @@ Status FastParseSequenceExample( break; } if (num == -1) { - return errors::InvalidArgument( - strings::StrCat("Error in context feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in context feature ", c.first, + " in example ", example_name); } num_elements += num; } @@ -1876,9 +1893,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } if (feature_length > 2) { auto limit = stream.PushLimit(feature_length); @@ -1898,22 +1915,22 @@ Status FastParseSequenceExample( break; } if (num == -1) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } num_elements += num; stream.PopLimit(limit); } else if (feature_length == 2) { if (!SkipEmptyFeature(&stream, dtype)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } } else if (feature_length != 0) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.first, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.first, " in example ", + example_name); } } } @@ -1936,15 +1953,19 @@ Status FastParseSequenceExample( feature_list_result->sparse_indices.resize(feature_list_config.sparse.size()); feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size()); feature_list_result->dense_values.resize(feature_list_config.dense.size()); + dense_feature_lengths->resize(feature_list_config.dense.size()); + int t = 0; for (const auto& c : context_config.dense) { - TensorShape dense_shape; + TensorShape dense_shape, example_shape; DataType dtype = c.dtype; - size_t expected_max_elements = + const size_t expected_max_elements = context_feature_type_and_lengths[c.feature_name].second; - if (expected_max_elements != dense_shape.num_elements()) { - return errors::InvalidArgument(strings::StrCat( - "Inconsistent number of elements for feature ", c.feature_name)); + if (!c.shape.AsTensorShape(&example_shape) || + expected_max_elements != example_shape.num_elements()) { + return errors::InvalidArgument( + "Inconsistent number of elements for feature ", c.feature_name, ": ", + expected_max_elements, " vs ", dense_shape.num_elements()); } dense_shape.AddDim(num_examples); for (const int dim : c.shape.dim_sizes()) { @@ -1968,18 +1989,58 @@ Status FastParseSequenceExample( out_int64 = context_result->dense_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } t++; // Fill in the values. for (int e = 0; e < num_examples; e++) { size_t num_elements = 0; - const auto& feature = all_context_features[e][c.feature_name]; + const auto feature_iter = all_context_features[e].find(c.feature_name); const string& example_name = example_names.empty() ? kUnknown : example_names[e]; - if (!feature.empty()) { + if (feature_iter == all_context_features[e].end()) { + // Copy the default value, if present. If not, return an error. + if (c.default_value.NumElements() == 0) { + return errors::InvalidArgument( + "Feature: ", c.feature_name, + " (data type: ", DataTypeString(c.dtype), ")", + " is required but could not be found."); + } + const string* in_bytes = nullptr; + const float* in_float = nullptr; + const int64* in_int64 = nullptr; + size_t num = 0; + switch (dtype) { + case DT_STRING: + in_bytes = c.default_value.flat<string>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_bytes++ = *in_bytes++; + } + break; + case DT_FLOAT: + in_float = c.default_value.flat<float>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_float++ = *in_float++; + } + break; + case DT_INT64: + in_int64 = c.default_value.flat<int64>().data(); + num = c.default_value.NumElements(); + for (int p = 0; p < num; p++) { + *out_int64++ = *in_int64++; + } + break; + default: + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); + } + num_elements += num; + } else if (!feature_iter->second.empty()) { + const auto& feature = feature_iter->second; protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(feature.data()), feature.size()); EnableAliasing(&stream); @@ -1998,14 +2059,14 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; } if (num_elements != expected_max_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected number of elements in example ", example_name)); + return errors::InvalidArgument( + "Unexpected number of elements in example ", example_name); } } } @@ -2037,8 +2098,8 @@ Status FastParseSequenceExample( out_int64 = context_result->sparse_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } int64* out_indices = context_result->sparse_indices[t].flat<int64>().data(); auto out_shape = context_result->sparse_shapes[t].vec<int64>(); @@ -2070,8 +2131,8 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; max_num_cols = std::max(max_num_cols, num_added); @@ -2082,30 +2143,35 @@ Status FastParseSequenceExample( } } if (num_elements != expected_num_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected total number of elements in feature ", c.feature_name)); + return errors::InvalidArgument( + "Unexpected total number of elements in feature ", c.feature_name); } out_shape(0) = num_examples; out_shape(1) = max_num_cols; } t = 0; + TensorShape dense_length_shape({num_examples}); for (const auto& c : feature_list_config.dense) { TensorShape dense_shape, row_shape; DataType dtype = c.dtype; - size_t expected_max_elements = + const size_t expected_max_elements = sequence_feature_type_and_lengths[c.feature_name].second; - int64 expected_max_rows = expected_max_elements / row_shape.num_elements(); if (!c.shape.AsTensorShape(&row_shape) || - expected_max_elements != expected_max_rows * row_shape.num_elements()) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected shape error in feature ", c.feature_name)); + expected_max_elements != + (expected_max_elements / row_shape.num_elements()) * + row_shape.num_elements()) { + return errors::InvalidArgument("Unexpected shape error in feature ", + c.feature_name); } + int64 expected_max_rows = expected_max_elements / row_shape.num_elements(); dense_shape.AddDim(num_examples); dense_shape.AddDim(expected_max_rows); for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) { dense_shape.AddDim(dim); } feature_list_result->dense_values[t] = Tensor(dtype, dense_shape); + (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape); + int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data(); string* out_bytes = nullptr; float* out_float = nullptr; @@ -2121,18 +2187,26 @@ Status FastParseSequenceExample( out_int64 = feature_list_result->dense_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } t++; // Fill in the values. for (int e = 0; e < num_examples; e++) { - size_t num_elements = 0; - const auto& feature = all_sequence_features[e][c.feature_name]; + size_t num_elements = 0, num_rows = 0; + const auto feature_iter = all_sequence_features[e].find(c.feature_name); const string& example_name = example_names.empty() ? kUnknown : example_names[e]; - if (!feature.empty()) { + if (feature_iter == all_sequence_features[e].end()) { + // Return an error if this feature was not allowed to be missing. + // Otherwise, we'll pad as needed below. + if (!c.variable_length) { + return errors::InvalidArgument("Missing feature ", c.feature_name, + " in example ", example_name); + } + } else if (!feature_iter->second.empty()) { + const auto& feature = feature_iter->second; protobuf::io::CodedInputStream stream( reinterpret_cast<const uint8*>(feature.data()), feature.size()); EnableAliasing(&stream); @@ -2140,9 +2214,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } auto limit = stream.PushLimit(feature_length); size_t num_added; @@ -2160,10 +2234,11 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; + num_rows++; if (num_added != row_shape.num_elements()) { return errors::InvalidArgument( "Unexpected number of elements in feature ", c.feature_name, @@ -2172,6 +2247,7 @@ Status FastParseSequenceExample( stream.PopLimit(limit); } } + *out_lengths++ = num_rows; // Pad as necessary. int num_to_pad = expected_max_elements - num_elements; switch (dtype) { @@ -2187,8 +2263,8 @@ Status FastParseSequenceExample( out_int64 += num_to_pad; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } } } @@ -2219,8 +2295,8 @@ Status FastParseSequenceExample( out_int64 = feature_list_result->sparse_values[t].flat<int64>().data(); break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in feature ", c.feature_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in feature ", c.feature_name); } int64* out_indices = feature_list_result->sparse_indices[t].flat<int64>().data(); @@ -2244,9 +2320,9 @@ Status FastParseSequenceExample( uint32 feature_length; if (!stream.ExpectTag(kDelimitedTag(1)) || !stream.ReadVarint32(&feature_length)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } if (feature_length > 2) { auto limit = stream.PushLimit(feature_length); @@ -2265,8 +2341,8 @@ Status FastParseSequenceExample( out_int64 += num_added; break; default: - return errors::InvalidArgument(strings::StrCat( - "Unexpected dtype ", dtype, " in example ", example_name)); + return errors::InvalidArgument("Unexpected dtype ", dtype, + " in example ", example_name); } num_elements += num_added; max_num_cols = std::max(max_num_cols, num_added); @@ -2278,14 +2354,14 @@ Status FastParseSequenceExample( stream.PopLimit(limit); } else if (feature_length == 2) { if (!SkipEmptyFeature(&stream, dtype)) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } } else if (feature_length != 0) { - return errors::InvalidArgument( - strings::StrCat("Error in sequence feature ", c.feature_name, - " in example ", example_name)); + return errors::InvalidArgument("Error in sequence feature ", + c.feature_name, " in example ", + example_name); } num_rows++; } @@ -2293,8 +2369,8 @@ Status FastParseSequenceExample( } } if (num_elements != expected_num_elements) { - return errors::InvalidArgument(strings::StrCat( - "Unexpected number of elements in feature ", c.feature_name)); + return errors::InvalidArgument( + "Unexpected number of elements in feature ", c.feature_name); } out_shape(0) = num_examples; out_shape(1) = max_num_rows; |