diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-02-07 16:45:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-07 17:09:35 -0800 |
commit | 780bc6b4d98665125c43685b20eeba6ad2804c0c (patch) | |
tree | 4acac8d596888cae078e520e65d836ff1a2c28d3 /tensorflow/core/util/example_proto_fast_parsing.cc | |
parent | e6bfaf47374b44bb688023904eac98576baf4cd4 (diff) |
Add support for variable major dimension in dense features in example parser c++ op.
Full python support (including more comprehensive documentation) coming soon.
Change: 146852707
Diffstat (limited to 'tensorflow/core/util/example_proto_fast_parsing.cc')
-rw-r--r-- | tensorflow/core/util/example_proto_fast_parsing.cc | 317 |
1 files changed, 264 insertions, 53 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index e14f50551e..facb092dbc 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -424,6 +424,7 @@ Status FastParseSerializedExample( const size_t example_index, const Config& config, const PresizedCuckooMap<std::pair<size_t, Type>>& config_index, SeededHasher hasher, std::vector<Tensor>* output_dense, + std::vector<SparseBuffer>* output_varlen_dense, std::vector<SparseBuffer>* output_sparse) { DCHECK(output_dense != nullptr); DCHECK(output_sparse != nullptr); @@ -463,9 +464,9 @@ Status FastParseSerializedExample( } auto example_error = [&](StringPiece suffix) { - return errors::InvalidArgument("Name: ", example_name, ", Key: ", - feature_name, ", Index: ", example_index, - ". ", suffix); + return errors::InvalidArgument("Name: ", example_name, + ", Key: ", feature_name, + ", Index: ", example_index, ". ", suffix); }; auto parse_error = [&] { @@ -494,54 +495,117 @@ Status FastParseSerializedExample( dense_feature_last_example[d] = example_index; if (example_dtype != config.dense[d].dtype) { - return example_error( - strings::StrCat("Data types don't match. Data type: ", - DataTypeString(example_dtype), "Expected type: ", - DataTypeString(config.dense[d].dtype))); + return example_error(strings::StrCat( + "Data types don't match. Data type: ", + DataTypeString(example_dtype), + "Expected type: ", DataTypeString(config.dense[d].dtype))); } - Tensor& out = (*output_dense)[d]; + if (!config.dense[d].variable_length) { + Tensor& out = (*output_dense)[d]; + + const std::size_t num_elements = config.dense[d].elements_per_stride; + const std::size_t offset = example_index * num_elements; + + auto shape_error = [&](size_t size, StringPiece type_str) { + return example_error(strings::StrCat( + "Number of ", type_str, + " values != expected. " + "Values size: ", + size, + " but output shape: ", config.dense[d].shape.DebugString())); + }; + + switch (config.dense[d].dtype) { + case DT_INT64: { + auto out_p = out.flat<int64>().data() + offset; + LimitedArraySlice<int64> slice(out_p, num_elements); + if (!feature.ParseInt64List(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "int64"); + } + break; + } + case DT_FLOAT: { + auto out_p = out.flat<float>().data() + offset; + LimitedArraySlice<float> slice(out_p, num_elements); + if (!feature.ParseFloatList(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "float"); + } + break; + } + case DT_STRING: { + auto out_p = out.flat<string>().data() + offset; + LimitedArraySlice<string> slice(out_p, num_elements); + if (!feature.ParseBytesList(&slice)) return parse_error(); + if (slice.EndDistance() != 0) { + return shape_error(num_elements - slice.EndDistance(), "bytes"); + } + break; + } + default: + CHECK(false) << "Should not happen."; + } + } else { // if variable length + SparseBuffer& out = (*output_varlen_dense)[d]; - const std::size_t num_elements = config.dense[d].elements_per_stride; - const std::size_t offset = example_index * num_elements; + const std::size_t num_elements = config.dense[d].elements_per_stride; - auto shape_error = [&](size_t size, StringPiece type_str) { - return example_error(strings::StrCat( - "Number of ", type_str, - " values != expected. " - "Values size: ", - size, " but output shape: ", config.dense[d].shape.DebugString())); - }; + if (example_dtype != DT_INVALID && + example_dtype != config.dense[d].dtype) { + return example_error(strings::StrCat( + "Data types don't match. ", + "Expected type: ", DataTypeString(config.dense[d].dtype))); + } - switch (config.dense[d].dtype) { - case DT_INT64: { - auto out_p = out.flat<int64>().data() + offset; - LimitedArraySlice<int64> slice(out_p, num_elements); - if (!feature.ParseInt64List(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "int64"); + auto shape_error = [&](size_t size, StringPiece type_str) { + return example_error(strings::StrCat( + "Number of ", type_str, + " values is not a multiple of stride length. Saw ", size, + " values but output shape is: ", + config.dense[d].shape.DebugString())); + }; + + switch (config.dense[d].dtype) { + case DT_INT64: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseInt64List(&out.int64_list)) { + return parse_error(); + } + if (out.int64_list.size() % num_elements != 0) { + return shape_error(out.int64_list.size(), "int64"); + } + } + out.example_end_indices.push_back(out.int64_list.size()); + break; } - break; - } - case DT_FLOAT: { - auto out_p = out.flat<float>().data() + offset; - LimitedArraySlice<float> slice(out_p, num_elements); - if (!feature.ParseFloatList(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "float"); + case DT_FLOAT: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseFloatList(&out.float_list)) { + return parse_error(); + } + if (out.float_list.size() % num_elements != 0) { + return shape_error(out.float_list.size(), "float"); + } + } + out.example_end_indices.push_back(out.float_list.size()); + break; } - break; - } - case DT_STRING: { - auto out_p = out.flat<string>().data() + offset; - LimitedArraySlice<string> slice(out_p, num_elements); - if (!feature.ParseBytesList(&slice)) return parse_error(); - if (slice.EndDistance() != 0) { - return shape_error(num_elements - slice.EndDistance(), "bytes"); + case DT_STRING: { + if (example_dtype != DT_INVALID) { + if (!feature.ParseBytesList(&out.bytes_list)) { + return parse_error(); + } + if (out.bytes_list.size() % num_elements != 0) { + return shape_error(out.bytes_list.size(), "bytes"); + } + } + out.example_end_indices.push_back(out.bytes_list.size()); + break; } - break; + default: + CHECK(false) << "Should not happen."; } - default: - CHECK(false) << "Should not happen."; } } else { // If feature was already visited, skip. @@ -563,9 +627,9 @@ Status FastParseSerializedExample( SparseBuffer& out = (*output_sparse)[d]; if (example_dtype != DT_INVALID && example_dtype != config.sparse[d].dtype) { - return example_error( - strings::StrCat("Data types don't match. ", "Expected type: ", - DataTypeString(config.sparse[d].dtype))); + return example_error(strings::StrCat( + "Data types don't match. ", + "Expected type: ", DataTypeString(config.sparse[d].dtype))); } switch (config.sparse[d].dtype) { @@ -602,8 +666,9 @@ Status FastParseSerializedExample( } } - // Handle missing dense features. + // Handle missing dense features for fixed strides. for (size_t d = 0; d < config.dense.size(); ++d) { + if (config.dense[d].variable_length) continue; if (dense_feature_last_example[d] == example_index) continue; if (config.dense[d].default_value.NumElements() == 0) { return errors::InvalidArgument( @@ -637,6 +702,16 @@ Status FastParseSerializedExample( } } + // Handle missing varlen dense features. + for (size_t d = 0; d < config.dense.size(); ++d) { + if (!config.dense[d].variable_length) continue; + if (dense_feature_last_example[d] == example_index) continue; + SparseBuffer& out = (*output_varlen_dense)[d]; + size_t prev_example_end_index = + out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); + out.example_end_indices.push_back(prev_example_end_index); + } + // Handle missing sparse features. for (size_t d = 0; d < config.sparse.size(); ++d) { if (sparse_feature_last_example[d] == example_index) continue; @@ -661,6 +736,65 @@ Status CheckConfigDataType(DataType dtype) { } } +template <typename T> +const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer); + +template <> +const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) { + return buffer.int64_list; +} +template <> +const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) { + return buffer.float_list; +} +template <> +const SmallVector<string>& GetListFromBuffer<string>( + const SparseBuffer& buffer) { + return buffer.bytes_list; +} + +template <typename T> +void CopyOrMoveBlock(const T* b, const T* e, T* t) { + std::copy(b, e, t); +} +template <> +void CopyOrMoveBlock(const string* b, const string* e, string* t) { + std::move(b, e, t); +} + +template <typename T> +void FillAndCopyVarLen( + const int d, const size_t num_elements, + const size_t num_elements_per_minibatch, const size_t data_stride_size, + const Config& config, + const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers, + Tensor* values) { + const Tensor& default_value = config.dense[d].default_value; + + // Copy-fill the tensors (creating the zero/fill-padding) + std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements, + default_value.flat<T>()(0)); + + // Iterate over minibatch elements + for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) { + const SparseBuffer& buffer = varlen_dense_buffers[i][d]; + const size_t offset = i * num_elements_per_minibatch; + const size_t stride_size = config.dense[d].elements_per_stride; + + // Copy values over. + auto& list = GetListFromBuffer<T>(buffer); + auto list_ptr = list.begin(); + auto data = values->flat<T>().data() + offset; + DCHECK(list.size() % stride_size == 0); + const size_t num_entries = list.size() / stride_size; + for (size_t j = 0; j < num_entries; ++j) { + CopyOrMoveBlock(list_ptr, list_ptr + stride_size, data); + list_ptr += stride_size; + data += data_stride_size; + } + } +} + } // namespace Status FastParseExample(const Config& config, @@ -701,14 +835,17 @@ Status FastParseExample(const Config& config, "Could not avoid collision. This should not happen."); } - // Allocate dense output (sparse have to be buffered). + // Allocate dense output for fixed length dense values + // (variable-length dense and sparse have to be buffered). + std::vector<Tensor> fixed_dense_values(config.dense.size()); for (size_t d = 0; d < config.dense.size(); ++d) { + if (config.dense[d].variable_length) continue; TensorShape out_shape; out_shape.AddDim(serialized.size()); for (const int64 dim : config.dense[d].shape.dim_sizes()) { out_shape.AddDim(dim); } - result->dense_values.emplace_back(config.dense[d].dtype, out_shape); + fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape); } // This parameter affects performance in a big and data-dependent way. @@ -750,17 +887,19 @@ Status FastParseExample(const Config& config, // Do minibatches in parallel. std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches); + std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches); std::vector<Status> status_of_minibatch(num_minibatches); auto ProcessMiniBatch = [&](size_t minibatch) { sparse_buffers[minibatch].resize(config.sparse.size()); + varlen_dense_buffers[minibatch].resize(config.dense.size()); size_t start = first_example_of_minibatch(minibatch); size_t end = first_example_of_minibatch(minibatch + 1); for (size_t e = start; e < end; ++e) { status_of_minibatch[minibatch] = FastParseSerializedExample( serialized[e], (example_names.size() > 0 ? example_names[e] : "<unknown>"), e, - config, config_index, hasher, &result->dense_values, - &sparse_buffers[minibatch]); + config, config_index, hasher, &fixed_dense_values, + &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]); if (!status_of_minibatch[minibatch].ok()) break; } }; @@ -771,8 +910,12 @@ Status FastParseExample(const Config& config, TF_RETURN_IF_ERROR(status); } + for (size_t d = 0; d < config.dense.size(); ++d) { + result->dense_values.push_back(std::move(fixed_dense_values[d])); + } + // Merge SparseBuffers from all minibatches for every config.sparse. - auto MergeMinibatches = [&](size_t d) { + auto MergeSparseMinibatches = [&](size_t d) { // Loop over minibatches size_t total_num_features = 0; size_t max_num_features = 0; @@ -849,8 +992,76 @@ Status FastParseExample(const Config& config, } }; + // Merge SparseBuffers from all minibatches for every config.dense having + // variable_length. + auto MergeDenseVarLenMinibatches = [&](size_t d) { + if (!config.dense[d].variable_length) return; + + // Loop over minibatches + size_t max_num_features = 0; + for (auto& dense_values_tmp : varlen_dense_buffers) { + std::vector<size_t>& end_indices = + dense_values_tmp[d].example_end_indices; + max_num_features = std::max(max_num_features, end_indices[0]); + for (size_t i = 1; i < end_indices.size(); ++i) { + size_t example_size = end_indices[i] - end_indices[i - 1]; + max_num_features = std::max(max_num_features, example_size); + } + } + + const size_t stride_size = config.dense[d].elements_per_stride; + const size_t max_num_elements = max_num_features / stride_size; + TensorShape values_shape; + DCHECK(max_num_features % config.dense[d].elements_per_stride == 0); + const size_t batch_size = serialized.size(); + values_shape.AddDim(batch_size); + values_shape.AddDim(max_num_elements); + for (int i = 1; i < config.dense[d].shape.dims(); ++i) { + values_shape.AddDim(config.dense[d].shape.dim_size(i)); + } + Tensor values(config.dense[d].dtype, values_shape); + result->dense_values[d] = values; + const size_t num_elements = values.NumElements(); + + // Nothing to write, exit early. + if (num_elements == 0) return; + + const size_t num_elements_per_minibatch = num_elements / batch_size; + const size_t data_stride_size = + (max_num_elements == 0) + ? 0 + : (num_elements_per_minibatch / max_num_elements); + + switch (config.dense[d].dtype) { + case DT_INT64: { + FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, varlen_dense_buffers, + &values); + break; + } + case DT_FLOAT: { + FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, varlen_dense_buffers, + &values); + break; + } + case DT_STRING: { + FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch, + data_stride_size, config, + varlen_dense_buffers, &values); + break; + } + default: + CHECK(false) << "Should not happen."; + } + }; + + for (size_t d = 0; d < config.dense.size(); ++d) { + MergeDenseVarLenMinibatches(d); + } + for (size_t d = 0; d < config.sparse.size(); ++d) { - MergeMinibatches(d); + MergeSparseMinibatches(d); } return Status::OK(); |