aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/example_proto_fast_parsing.cc
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-02-07 16:45:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-07 17:09:35 -0800
commit780bc6b4d98665125c43685b20eeba6ad2804c0c (patch)
tree4acac8d596888cae078e520e65d836ff1a2c28d3 /tensorflow/core/util/example_proto_fast_parsing.cc
parente6bfaf47374b44bb688023904eac98576baf4cd4 (diff)
Add support for variable major dimension in dense features in example parser c++ op.
Full python support (including more comprehensive documentation) coming soon. Change: 146852707
Diffstat (limited to 'tensorflow/core/util/example_proto_fast_parsing.cc')
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc317
1 files changed, 264 insertions, 53 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index e14f50551e..facb092dbc 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -424,6 +424,7 @@ Status FastParseSerializedExample(
const size_t example_index, const Config& config,
const PresizedCuckooMap<std::pair<size_t, Type>>& config_index,
SeededHasher hasher, std::vector<Tensor>* output_dense,
+ std::vector<SparseBuffer>* output_varlen_dense,
std::vector<SparseBuffer>* output_sparse) {
DCHECK(output_dense != nullptr);
DCHECK(output_sparse != nullptr);
@@ -463,9 +464,9 @@ Status FastParseSerializedExample(
}
auto example_error = [&](StringPiece suffix) {
- return errors::InvalidArgument("Name: ", example_name, ", Key: ",
- feature_name, ", Index: ", example_index,
- ". ", suffix);
+ return errors::InvalidArgument("Name: ", example_name,
+ ", Key: ", feature_name,
+ ", Index: ", example_index, ". ", suffix);
};
auto parse_error = [&] {
@@ -494,54 +495,117 @@ Status FastParseSerializedExample(
dense_feature_last_example[d] = example_index;
if (example_dtype != config.dense[d].dtype) {
- return example_error(
- strings::StrCat("Data types don't match. Data type: ",
- DataTypeString(example_dtype), "Expected type: ",
- DataTypeString(config.dense[d].dtype)));
+ return example_error(strings::StrCat(
+ "Data types don't match. Data type: ",
+ DataTypeString(example_dtype),
+ "Expected type: ", DataTypeString(config.dense[d].dtype)));
}
- Tensor& out = (*output_dense)[d];
+ if (!config.dense[d].variable_length) {
+ Tensor& out = (*output_dense)[d];
+
+ const std::size_t num_elements = config.dense[d].elements_per_stride;
+ const std::size_t offset = example_index * num_elements;
+
+ auto shape_error = [&](size_t size, StringPiece type_str) {
+ return example_error(strings::StrCat(
+ "Number of ", type_str,
+ " values != expected. "
+ "Values size: ",
+ size,
+ " but output shape: ", config.dense[d].shape.DebugString()));
+ };
+
+ switch (config.dense[d].dtype) {
+ case DT_INT64: {
+ auto out_p = out.flat<int64>().data() + offset;
+ LimitedArraySlice<int64> slice(out_p, num_elements);
+ if (!feature.ParseInt64List(&slice)) return parse_error();
+ if (slice.EndDistance() != 0) {
+ return shape_error(num_elements - slice.EndDistance(), "int64");
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ auto out_p = out.flat<float>().data() + offset;
+ LimitedArraySlice<float> slice(out_p, num_elements);
+ if (!feature.ParseFloatList(&slice)) return parse_error();
+ if (slice.EndDistance() != 0) {
+ return shape_error(num_elements - slice.EndDistance(), "float");
+ }
+ break;
+ }
+ case DT_STRING: {
+ auto out_p = out.flat<string>().data() + offset;
+ LimitedArraySlice<string> slice(out_p, num_elements);
+ if (!feature.ParseBytesList(&slice)) return parse_error();
+ if (slice.EndDistance() != 0) {
+ return shape_error(num_elements - slice.EndDistance(), "bytes");
+ }
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ } else { // if variable length
+ SparseBuffer& out = (*output_varlen_dense)[d];
- const std::size_t num_elements = config.dense[d].elements_per_stride;
- const std::size_t offset = example_index * num_elements;
+ const std::size_t num_elements = config.dense[d].elements_per_stride;
- auto shape_error = [&](size_t size, StringPiece type_str) {
- return example_error(strings::StrCat(
- "Number of ", type_str,
- " values != expected. "
- "Values size: ",
- size, " but output shape: ", config.dense[d].shape.DebugString()));
- };
+ if (example_dtype != DT_INVALID &&
+ example_dtype != config.dense[d].dtype) {
+ return example_error(strings::StrCat(
+ "Data types don't match. ",
+ "Expected type: ", DataTypeString(config.dense[d].dtype)));
+ }
- switch (config.dense[d].dtype) {
- case DT_INT64: {
- auto out_p = out.flat<int64>().data() + offset;
- LimitedArraySlice<int64> slice(out_p, num_elements);
- if (!feature.ParseInt64List(&slice)) return parse_error();
- if (slice.EndDistance() != 0) {
- return shape_error(num_elements - slice.EndDistance(), "int64");
+ auto shape_error = [&](size_t size, StringPiece type_str) {
+ return example_error(strings::StrCat(
+ "Number of ", type_str,
+ " values is not a multiple of stride length. Saw ", size,
+ " values but output shape is: ",
+ config.dense[d].shape.DebugString()));
+ };
+
+ switch (config.dense[d].dtype) {
+ case DT_INT64: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseInt64List(&out.int64_list)) {
+ return parse_error();
+ }
+ if (out.int64_list.size() % num_elements != 0) {
+ return shape_error(out.int64_list.size(), "int64");
+ }
+ }
+ out.example_end_indices.push_back(out.int64_list.size());
+ break;
}
- break;
- }
- case DT_FLOAT: {
- auto out_p = out.flat<float>().data() + offset;
- LimitedArraySlice<float> slice(out_p, num_elements);
- if (!feature.ParseFloatList(&slice)) return parse_error();
- if (slice.EndDistance() != 0) {
- return shape_error(num_elements - slice.EndDistance(), "float");
+ case DT_FLOAT: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseFloatList(&out.float_list)) {
+ return parse_error();
+ }
+ if (out.float_list.size() % num_elements != 0) {
+ return shape_error(out.float_list.size(), "float");
+ }
+ }
+ out.example_end_indices.push_back(out.float_list.size());
+ break;
}
- break;
- }
- case DT_STRING: {
- auto out_p = out.flat<string>().data() + offset;
- LimitedArraySlice<string> slice(out_p, num_elements);
- if (!feature.ParseBytesList(&slice)) return parse_error();
- if (slice.EndDistance() != 0) {
- return shape_error(num_elements - slice.EndDistance(), "bytes");
+ case DT_STRING: {
+ if (example_dtype != DT_INVALID) {
+ if (!feature.ParseBytesList(&out.bytes_list)) {
+ return parse_error();
+ }
+ if (out.bytes_list.size() % num_elements != 0) {
+ return shape_error(out.bytes_list.size(), "bytes");
+ }
+ }
+ out.example_end_indices.push_back(out.bytes_list.size());
+ break;
}
- break;
+ default:
+ CHECK(false) << "Should not happen.";
}
- default:
- CHECK(false) << "Should not happen.";
}
} else {
// If feature was already visited, skip.
@@ -563,9 +627,9 @@ Status FastParseSerializedExample(
SparseBuffer& out = (*output_sparse)[d];
if (example_dtype != DT_INVALID &&
example_dtype != config.sparse[d].dtype) {
- return example_error(
- strings::StrCat("Data types don't match. ", "Expected type: ",
- DataTypeString(config.sparse[d].dtype)));
+ return example_error(strings::StrCat(
+ "Data types don't match. ",
+ "Expected type: ", DataTypeString(config.sparse[d].dtype)));
}
switch (config.sparse[d].dtype) {
@@ -602,8 +666,9 @@ Status FastParseSerializedExample(
}
}
- // Handle missing dense features.
+ // Handle missing dense features for fixed strides.
for (size_t d = 0; d < config.dense.size(); ++d) {
+ if (config.dense[d].variable_length) continue;
if (dense_feature_last_example[d] == example_index) continue;
if (config.dense[d].default_value.NumElements() == 0) {
return errors::InvalidArgument(
@@ -637,6 +702,16 @@ Status FastParseSerializedExample(
}
}
+ // Handle missing varlen dense features.
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ if (!config.dense[d].variable_length) continue;
+ if (dense_feature_last_example[d] == example_index) continue;
+ SparseBuffer& out = (*output_varlen_dense)[d];
+ size_t prev_example_end_index =
+ out.example_end_indices.empty() ? 0 : out.example_end_indices.back();
+ out.example_end_indices.push_back(prev_example_end_index);
+ }
+
// Handle missing sparse features.
for (size_t d = 0; d < config.sparse.size(); ++d) {
if (sparse_feature_last_example[d] == example_index) continue;
@@ -661,6 +736,65 @@ Status CheckConfigDataType(DataType dtype) {
}
}
+template <typename T>
+const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer);
+
+template <>
+const SmallVector<int64>& GetListFromBuffer<int64>(const SparseBuffer& buffer) {
+ return buffer.int64_list;
+}
+template <>
+const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) {
+ return buffer.float_list;
+}
+template <>
+const SmallVector<string>& GetListFromBuffer<string>(
+ const SparseBuffer& buffer) {
+ return buffer.bytes_list;
+}
+
+template <typename T>
+void CopyOrMoveBlock(const T* b, const T* e, T* t) {
+ std::copy(b, e, t);
+}
+template <>
+void CopyOrMoveBlock(const string* b, const string* e, string* t) {
+ std::move(b, e, t);
+}
+
+template <typename T>
+void FillAndCopyVarLen(
+ const int d, const size_t num_elements,
+ const size_t num_elements_per_minibatch, const size_t data_stride_size,
+ const Config& config,
+ const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers,
+ Tensor* values) {
+ const Tensor& default_value = config.dense[d].default_value;
+
+ // Copy-fill the tensors (creating the zero/fill-padding)
+ std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements,
+ default_value.flat<T>()(0));
+
+ // Iterate over minibatch elements
+ for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) {
+ const SparseBuffer& buffer = varlen_dense_buffers[i][d];
+ const size_t offset = i * num_elements_per_minibatch;
+ const size_t stride_size = config.dense[d].elements_per_stride;
+
+ // Copy values over.
+ auto& list = GetListFromBuffer<T>(buffer);
+ auto list_ptr = list.begin();
+ auto data = values->flat<T>().data() + offset;
+ DCHECK(list.size() % stride_size == 0);
+ const size_t num_entries = list.size() / stride_size;
+ for (size_t j = 0; j < num_entries; ++j) {
+ CopyOrMoveBlock(list_ptr, list_ptr + stride_size, data);
+ list_ptr += stride_size;
+ data += data_stride_size;
+ }
+ }
+}
+
} // namespace
Status FastParseExample(const Config& config,
@@ -701,14 +835,17 @@ Status FastParseExample(const Config& config,
"Could not avoid collision. This should not happen.");
}
- // Allocate dense output (sparse have to be buffered).
+ // Allocate dense output for fixed length dense values
+ // (variable-length dense and sparse have to be buffered).
+ std::vector<Tensor> fixed_dense_values(config.dense.size());
for (size_t d = 0; d < config.dense.size(); ++d) {
+ if (config.dense[d].variable_length) continue;
TensorShape out_shape;
out_shape.AddDim(serialized.size());
for (const int64 dim : config.dense[d].shape.dim_sizes()) {
out_shape.AddDim(dim);
}
- result->dense_values.emplace_back(config.dense[d].dtype, out_shape);
+ fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape);
}
// This parameter affects performance in a big and data-dependent way.
@@ -750,17 +887,19 @@ Status FastParseExample(const Config& config,
// Do minibatches in parallel.
std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches);
+ std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches);
std::vector<Status> status_of_minibatch(num_minibatches);
auto ProcessMiniBatch = [&](size_t minibatch) {
sparse_buffers[minibatch].resize(config.sparse.size());
+ varlen_dense_buffers[minibatch].resize(config.dense.size());
size_t start = first_example_of_minibatch(minibatch);
size_t end = first_example_of_minibatch(minibatch + 1);
for (size_t e = start; e < end; ++e) {
status_of_minibatch[minibatch] = FastParseSerializedExample(
serialized[e],
(example_names.size() > 0 ? example_names[e] : "<unknown>"), e,
- config, config_index, hasher, &result->dense_values,
- &sparse_buffers[minibatch]);
+ config, config_index, hasher, &fixed_dense_values,
+ &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch]);
if (!status_of_minibatch[minibatch].ok()) break;
}
};
@@ -771,8 +910,12 @@ Status FastParseExample(const Config& config,
TF_RETURN_IF_ERROR(status);
}
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ result->dense_values.push_back(std::move(fixed_dense_values[d]));
+ }
+
// Merge SparseBuffers from all minibatches for every config.sparse.
- auto MergeMinibatches = [&](size_t d) {
+ auto MergeSparseMinibatches = [&](size_t d) {
// Loop over minibatches
size_t total_num_features = 0;
size_t max_num_features = 0;
@@ -849,8 +992,76 @@ Status FastParseExample(const Config& config,
}
};
+ // Merge SparseBuffers from all minibatches for every config.dense having
+ // variable_length.
+ auto MergeDenseVarLenMinibatches = [&](size_t d) {
+ if (!config.dense[d].variable_length) return;
+
+ // Loop over minibatches
+ size_t max_num_features = 0;
+ for (auto& dense_values_tmp : varlen_dense_buffers) {
+ std::vector<size_t>& end_indices =
+ dense_values_tmp[d].example_end_indices;
+ max_num_features = std::max(max_num_features, end_indices[0]);
+ for (size_t i = 1; i < end_indices.size(); ++i) {
+ size_t example_size = end_indices[i] - end_indices[i - 1];
+ max_num_features = std::max(max_num_features, example_size);
+ }
+ }
+
+ const size_t stride_size = config.dense[d].elements_per_stride;
+ const size_t max_num_elements = max_num_features / stride_size;
+ TensorShape values_shape;
+ DCHECK(max_num_features % config.dense[d].elements_per_stride == 0);
+ const size_t batch_size = serialized.size();
+ values_shape.AddDim(batch_size);
+ values_shape.AddDim(max_num_elements);
+ for (int i = 1; i < config.dense[d].shape.dims(); ++i) {
+ values_shape.AddDim(config.dense[d].shape.dim_size(i));
+ }
+ Tensor values(config.dense[d].dtype, values_shape);
+ result->dense_values[d] = values;
+ const size_t num_elements = values.NumElements();
+
+ // Nothing to write, exit early.
+ if (num_elements == 0) return;
+
+ const size_t num_elements_per_minibatch = num_elements / batch_size;
+ const size_t data_stride_size =
+ (max_num_elements == 0)
+ ? 0
+ : (num_elements_per_minibatch / max_num_elements);
+
+ switch (config.dense[d].dtype) {
+ case DT_INT64: {
+ FillAndCopyVarLen<int64>(d, num_elements, num_elements_per_minibatch,
+ data_stride_size, config, varlen_dense_buffers,
+ &values);
+ break;
+ }
+ case DT_FLOAT: {
+ FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch,
+ data_stride_size, config, varlen_dense_buffers,
+ &values);
+ break;
+ }
+ case DT_STRING: {
+ FillAndCopyVarLen<string>(d, num_elements, num_elements_per_minibatch,
+ data_stride_size, config,
+ varlen_dense_buffers, &values);
+ break;
+ }
+ default:
+ CHECK(false) << "Should not happen.";
+ }
+ };
+
+ for (size_t d = 0; d < config.dense.size(); ++d) {
+ MergeDenseVarLenMinibatches(d);
+ }
+
for (size_t d = 0; d < config.sparse.size(); ++d) {
- MergeMinibatches(d);
+ MergeSparseMinibatches(d);
}
return Status::OK();