aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Patrik Sundberg <sundberg@google.com>2018-08-31 07:28:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 07:32:19 -0700
commitd09bd33fc9175c7fdbdded8f98b5c5d3a9f8ad7d (patch)
treebb869bee40e9833299b2662d2076547c85d5df39 /tensorflow/core/util
parent1251e8c6cee24d9a295c63fa2362b03deef4396b (diff)
Add a batch sequence example parsing op, part 2.
PiperOrigin-RevId: 211082479
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc228
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h3
-rw-r--r--tensorflow/core/util/example_proto_helper.cc53
-rw-r--r--tensorflow/core/util/example_proto_helper.h61
4 files changed, 268 insertions, 77 deletions
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index a38cd1d09f..e52d55e2ff 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -1722,10 +1722,11 @@ Status FastParseSequenceExample(
const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* context_result,
- Result* feature_list_result) {
+ Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size();
DCHECK(context_result != nullptr);
DCHECK(feature_list_result != nullptr);
+ DCHECK(dense_feature_lengths != nullptr);
std::map<StringPiece, bool> context_is_sparse;
std::map<StringPiece, std::pair<DataType, size_t>>
context_feature_type_and_lengths;
@@ -1740,9 +1741,22 @@ Status FastParseSequenceExample(
context_is_sparse[c.feature_name] = true;
}
for (auto& c : context_config.dense) {
+ if (context_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Context feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
context_feature_type_and_lengths[c.feature_name] =
- std::make_pair(c.dtype, 0);
+ std::make_pair(c.dtype, c.default_value.NumElements());
+ if (c.default_value.NumElements() > 0) {
+ if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
+ return errors::InvalidArgument("Default value for context feature ",
+ c.feature_name,
+ " has an incorrect shape: saw ",
+ c.default_value.shape().DebugString(),
+ " but expected ", c.shape.DebugString());
+ }
+ }
context_is_sparse[c.feature_name] = false;
}
std::map<StringPiece, bool> sequence_is_sparse;
@@ -1755,6 +1769,10 @@ Status FastParseSequenceExample(
sequence_is_sparse[c.feature_name] = true;
}
for (auto& c : feature_list_config.dense) {
+ if (sequence_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Sequence feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
sequence_feature_type_and_lengths[c.feature_name] =
std::make_pair(c.dtype, 0);
@@ -1792,14 +1810,14 @@ Status FastParseSequenceExample(
features = sequence_features;
config = &sequence_feature_type_and_lengths;
} else if (!SkipExtraneousTag(&stream)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
if (features != nullptr) {
uint32 length;
if (!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
while (!stream.ExpectAtEnd()) {
@@ -1807,16 +1825,16 @@ Status FastParseSequenceExample(
uint32 length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!ParseString(&stream, &key) ||
!stream.ExpectTag(kDelimitedTag(2)) ||
!ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
stream.PopLimit(limit);
// Only save if this feature was requested.
@@ -1851,9 +1869,8 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in context feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in context feature ", c.first,
+ " in example ", example_name);
}
num_elements += num;
}
@@ -1876,9 +1893,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -1898,22 +1915,22 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
num_elements += num;
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
}
}
@@ -1936,15 +1953,19 @@ Status FastParseSequenceExample(
feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ dense_feature_lengths->resize(feature_list_config.dense.size());
+
int t = 0;
for (const auto& c : context_config.dense) {
- TensorShape dense_shape;
+ TensorShape dense_shape, example_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
context_feature_type_and_lengths[c.feature_name].second;
- if (expected_max_elements != dense_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Inconsistent number of elements for feature ", c.feature_name));
+ if (!c.shape.AsTensorShape(&example_shape) ||
+ expected_max_elements != example_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Inconsistent number of elements for feature ", c.feature_name, ": ",
+ expected_max_elements, " vs ", dense_shape.num_elements());
}
dense_shape.AddDim(num_examples);
for (const int dim : c.shape.dim_sizes()) {
@@ -1968,18 +1989,58 @@ Status FastParseSequenceExample(
out_int64 = context_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0;
- const auto& feature = all_context_features[e][c.feature_name];
+ const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_context_features[e].end()) {
+ // Copy the default value, if present. If not, return an error.
+ if (c.default_value.NumElements() == 0) {
+ return errors::InvalidArgument(
+ "Feature: ", c.feature_name,
+ " (data type: ", DataTypeString(c.dtype), ")",
+ " is required but could not be found.");
+ }
+ const string* in_bytes = nullptr;
+ const float* in_float = nullptr;
+ const int64* in_int64 = nullptr;
+ size_t num = 0;
+ switch (dtype) {
+ case DT_STRING:
+ in_bytes = c.default_value.flat<string>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_bytes++ = *in_bytes++;
+ }
+ break;
+ case DT_FLOAT:
+ in_float = c.default_value.flat<float>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_float++ = *in_float++;
+ }
+ break;
+ case DT_INT64:
+ in_int64 = c.default_value.flat<int64>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_int64++ = *in_int64++;
+ }
+ break;
+ default:
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
+ }
+ num_elements += num;
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -1998,14 +2059,14 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
}
if (num_elements != expected_max_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in example ", example_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in example ", example_name);
}
}
}
@@ -2037,8 +2098,8 @@ Status FastParseSequenceExample(
out_int64 = context_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
auto out_shape = context_result->sparse_shapes[t].vec<int64>();
@@ -2070,8 +2131,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2082,30 +2143,35 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected total number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected total number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_cols;
}
t = 0;
+ TensorShape dense_length_shape({num_examples});
for (const auto& c : feature_list_config.dense) {
TensorShape dense_shape, row_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
sequence_feature_type_and_lengths[c.feature_name].second;
- int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
if (!c.shape.AsTensorShape(&row_shape) ||
- expected_max_elements != expected_max_rows * row_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected shape error in feature ", c.feature_name));
+ expected_max_elements !=
+ (expected_max_elements / row_shape.num_elements()) *
+ row_shape.num_elements()) {
+ return errors::InvalidArgument("Unexpected shape error in feature ",
+ c.feature_name);
}
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
dense_shape.AddDim(num_examples);
dense_shape.AddDim(expected_max_rows);
for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
dense_shape.AddDim(dim);
}
feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+ (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
+ int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr;
float* out_float = nullptr;
@@ -2121,18 +2187,26 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
- size_t num_elements = 0;
- const auto& feature = all_sequence_features[e][c.feature_name];
+ size_t num_elements = 0, num_rows = 0;
+ const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_sequence_features[e].end()) {
+ // Return an error if this feature was not allowed to be missing.
+ // Otherwise, we'll pad as needed below.
+ if (!c.variable_length) {
+ return errors::InvalidArgument("Missing feature ", c.feature_name,
+ " in example ", example_name);
+ }
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -2140,9 +2214,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
auto limit = stream.PushLimit(feature_length);
size_t num_added;
@@ -2160,10 +2234,11 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
+ num_rows++;
if (num_added != row_shape.num_elements()) {
return errors::InvalidArgument(
"Unexpected number of elements in feature ", c.feature_name,
@@ -2172,6 +2247,7 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
}
}
+ *out_lengths++ = num_rows;
// Pad as necessary.
int num_to_pad = expected_max_elements - num_elements;
switch (dtype) {
@@ -2187,8 +2263,8 @@ Status FastParseSequenceExample(
out_int64 += num_to_pad;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
}
}
@@ -2219,8 +2295,8 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices =
feature_list_result->sparse_indices[t].flat<int64>().data();
@@ -2244,9 +2320,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -2265,8 +2341,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2278,14 +2354,14 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
num_rows++;
}
@@ -2293,8 +2369,8 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_rows;
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index db5b5ff929..055d9c2c30 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -118,7 +118,8 @@ Status FastParseSequenceExample(
const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result,
- example::Result* feature_list_result);
+ example::Result* feature_list_result,
+ std::vector<Tensor>* dense_feature_lengths);
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc
index e156a3bc8f..41fb20c00a 100644
--- a/tensorflow/core/util/example_proto_helper.cc
+++ b/tensorflow/core/util/example_proto_helper.cc
@@ -443,6 +443,59 @@ Status ParseSingleExampleAttrs::FinishInit() {
return Status::OK();
}
+Status ParseSequenceExampleAttrs::FinishInit() {
+ if (num_context_sparse != context_sparse_keys.size() ||
+ num_context_sparse != context_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_context_sparse (", num_context_sparse,
+ ") must match the size of context_sparse_keys (",
+ context_sparse_keys.size(), ") and context_sparse_types (",
+ context_sparse_types.size(), ")");
+ }
+ if (num_context_dense != context_dense_keys.size() ||
+ num_context_dense != context_dense_types.size() ||
+ num_context_dense != context_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_context_dense (", num_context_dense,
+ ") must match the size of context_dense_keys (",
+ context_dense_keys.size(), "), context_dense_types (",
+ context_dense_types.size(), ") and context_dense_shapes (",
+ context_dense_shapes.size(), ")");
+ }
+ if (num_feature_list_sparse != feature_list_sparse_keys.size() ||
+ num_feature_list_sparse != feature_list_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_sparse (", num_feature_list_sparse,
+ ") must match the size of feature_list_sparse_keys (",
+ feature_list_sparse_keys.size(), ") and feature_list_sparse_types (",
+ feature_list_sparse_types.size(), ")");
+ }
+ if (num_feature_list_dense != feature_list_dense_keys.size() ||
+ num_feature_list_dense != feature_list_dense_types.size() ||
+ num_feature_list_dense != feature_list_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_dense (", num_feature_list_dense,
+ ") must match the size of feature_list_dense_keys (",
+ feature_list_dense_keys.size(), "), feature_list_dense_types (",
+ feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
+ feature_list_dense_shapes.size(), ")");
+ }
+ for (const DataType& type : context_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : context_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+
+ return Status::OK();
+}
+
Status ParseSingleSequenceExampleAttrs::FinishInit() {
if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h
index e511704962..c183ee4d96 100644
--- a/tensorflow/core/util/example_proto_helper.h
+++ b/tensorflow/core/util/example_proto_helper.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
@@ -271,6 +272,66 @@ class ParseSingleExampleAttrs {
Status FinishInit(); // for context-independent parts of Init.
};
+// Parses the attributes passed to ParseSequenceExample.
+// REQUIRES: Init must be called after construction.
+class ParseSequenceExampleAttrs {
+ public:
+ template <typename ContextType>
+ Status Init(ContextType* ctx) {
+ std::vector<string> feature_list_dense_missing_assumed_empty_tmp;
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_missing_assumed_empty",
+ &feature_list_dense_missing_assumed_empty_tmp));
+ for (const string& feature : feature_list_dense_missing_assumed_empty_tmp) {
+ feature_list_dense_missing_assumed_empty.insert(feature);
+ }
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_keys", &context_sparse_keys));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("context_dense_keys", &context_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_keys", &feature_list_sparse_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_types", &context_sparse_types));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
+ return FinishInit();
+ }
+
+ std::unordered_set<string> feature_list_dense_missing_assumed_empty;
+ int64 num_context_sparse;
+ int64 num_context_dense;
+ int64 num_feature_list_sparse;
+ int64 num_feature_list_dense;
+ std::vector<string> context_sparse_keys;
+ std::vector<string> context_dense_keys;
+ std::vector<string> feature_list_sparse_keys;
+ std::vector<string> feature_list_dense_keys;
+ std::vector<DataType> context_sparse_types;
+ std::vector<DataType> context_dense_types;
+ std::vector<TensorShape> context_dense_shapes;
+ std::vector<DataType> feature_list_sparse_types;
+ std::vector<DataType> feature_list_dense_types;
+ std::vector<TensorShape> feature_list_dense_shapes;
+
+ private:
+ Status FinishInit(); // for context-independent parts of Init.
+};
+
// Parses the attributes passed to ParseSingleSequenceExample.
// REQUIRES: Init must be called after construction.
class ParseSingleSequenceExampleAttrs {