blob: a3750851769a31466eebba5cfd5e665f4cbc4f9c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extract parse_example op configuration to a proto."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.example import example_parser_configuration_pb2
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
def extract_example_parser_configuration(parse_example_op, sess):
"""Returns an ExampleParserConfig proto.
Args:
parse_example_op: A ParseExample `Operation`
sess: A tf.Session needed to obtain some configuration values.
Returns:
A ExampleParserConfig proto.
Raises:
ValueError: If attributes are inconsistent.
"""
config = example_parser_configuration_pb2.ExampleParserConfiguration()
num_sparse = parse_example_op.get_attr("Nsparse")
num_dense = parse_example_op.get_attr("Ndense")
total_features = num_dense + num_sparse
sparse_types = parse_example_op.get_attr("sparse_types")
dense_types = parse_example_op.get_attr("Tdense")
dense_shapes = parse_example_op.get_attr("dense_shapes")
if len(sparse_types) != num_sparse:
raise ValueError("len(sparse_types) attribute does not match "
"Nsparse attribute (%d vs %d)" %
(len(sparse_types), num_sparse))
if len(dense_types) != num_dense:
raise ValueError("len(dense_types) attribute does not match "
"Ndense attribute (%d vs %d)" %
(len(dense_types), num_dense))
if len(dense_shapes) != num_dense:
raise ValueError("len(dense_shapes) attribute does not match "
"Ndense attribute (%d vs %d)" %
(len(dense_shapes), num_dense))
# Skip over the serialized input, and the names input.
fetch_list = parse_example_op.inputs[2:]
# Fetch total_features key names and num_dense default values.
if len(fetch_list) != (total_features + num_dense):
raise ValueError("len(fetch_list) does not match total features + num_dense"
"(%d vs %d" % (len(fetch_list),
(total_features + num_dense)))
fetched = sess.run(fetch_list)
if len(fetched) != len(fetch_list):
raise ValueError("len(fetched) does not match len(fetch_list)"
"(%d vs %d" % (len(fetched), len(fetch_list)))
# Fetch indices.
sparse_keys_start = 0
dense_keys_start = sparse_keys_start + num_sparse
dense_def_start = dense_keys_start + num_dense
# Output tensor indices.
sparse_indices_start = 0
sparse_values_start = num_sparse
sparse_shapes_start = sparse_values_start + num_sparse
dense_values_start = sparse_shapes_start + num_sparse
# Dense features.
for i in range(num_dense):
key = fetched[dense_keys_start + i]
feature_config = config.feature_map[key]
# Convert the default value numpy array fetched from the session run
# into a TensorProto.
fixed_config = feature_config.fixed_len_feature
fixed_config.default_value.CopyFrom(
tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
# Convert the shape from the attributes
# into a TensorShapeProto.
fixed_config.shape.CopyFrom(
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
fixed_config.dtype = int(dense_types[i])
# Get the output tensor name.
fixed_config.values_output_tensor_name = parse_example_op.outputs[
dense_values_start + i].name
# Sparse features.
for i in range(num_sparse):
key = fetched[sparse_keys_start + i]
feature_config = config.feature_map[key]
var_len_feature = feature_config.var_len_feature
var_len_feature.dtype = int(sparse_types[i])
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
sparse_indices_start + i].name
var_len_feature.values_output_tensor_name = parse_example_op.outputs[
sparse_values_start + i].name
var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
sparse_shapes_start + i].name
return config
|