aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/example_parser_configuration.py
blob: a3750851769a31466eebba5cfd5e665f4cbc4f9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extract parse_example op configuration to a proto."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.core.example import example_parser_configuration_pb2
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util


def extract_example_parser_configuration(parse_example_op, sess):
  """Returns an ExampleParserConfig proto.

  Args:
    parse_example_op: A ParseExample `Operation`
    sess: A tf.Session needed to obtain some configuration values.
  Returns:
    A ExampleParserConfig proto.

  Raises:
    ValueError: If attributes are inconsistent.
  """
  config = example_parser_configuration_pb2.ExampleParserConfiguration()

  num_sparse = parse_example_op.get_attr("Nsparse")
  num_dense = parse_example_op.get_attr("Ndense")
  total_features = num_dense + num_sparse

  sparse_types = parse_example_op.get_attr("sparse_types")
  dense_types = parse_example_op.get_attr("Tdense")
  dense_shapes = parse_example_op.get_attr("dense_shapes")

  if len(sparse_types) != num_sparse:
    raise ValueError("len(sparse_types) attribute does not match "
                     "Nsparse attribute (%d vs %d)" %
                     (len(sparse_types), num_sparse))

  if len(dense_types) != num_dense:
    raise ValueError("len(dense_types) attribute does not match "
                     "Ndense attribute (%d vs %d)" %
                     (len(dense_types), num_dense))

  if len(dense_shapes) != num_dense:
    raise ValueError("len(dense_shapes) attribute does not match "
                     "Ndense attribute (%d vs %d)" %
                     (len(dense_shapes), num_dense))

  # Skip over the serialized input, and the names input.
  fetch_list = parse_example_op.inputs[2:]

  # Fetch total_features key names and num_dense default values.
  if len(fetch_list) != (total_features + num_dense):
    raise ValueError("len(fetch_list) does not match total features + num_dense"
                     "(%d vs %d" % (len(fetch_list),
                                    (total_features + num_dense)))

  fetched = sess.run(fetch_list)

  if len(fetched) != len(fetch_list):
    raise ValueError("len(fetched) does not match len(fetch_list)"
                     "(%d vs %d" % (len(fetched), len(fetch_list)))

  # Fetch indices.
  sparse_keys_start = 0
  dense_keys_start = sparse_keys_start + num_sparse
  dense_def_start = dense_keys_start + num_dense

  # Output tensor indices.
  sparse_indices_start = 0
  sparse_values_start = num_sparse
  sparse_shapes_start = sparse_values_start + num_sparse
  dense_values_start = sparse_shapes_start + num_sparse

  # Dense features.
  for i in range(num_dense):
    key = fetched[dense_keys_start + i]
    feature_config = config.feature_map[key]
    # Convert the default value numpy array fetched from the session run
    # into a TensorProto.
    fixed_config = feature_config.fixed_len_feature

    fixed_config.default_value.CopyFrom(
        tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
    # Convert the shape from the attributes
    # into a TensorShapeProto.
    fixed_config.shape.CopyFrom(
        tensor_shape.TensorShape(dense_shapes[i]).as_proto())

    fixed_config.dtype = int(dense_types[i])
    # Get the output tensor name.
    fixed_config.values_output_tensor_name = parse_example_op.outputs[
        dense_values_start + i].name

  # Sparse features.
  for i in range(num_sparse):
    key = fetched[sparse_keys_start + i]
    feature_config = config.feature_map[key]
    var_len_feature = feature_config.var_len_feature
    var_len_feature.dtype = int(sparse_types[i])
    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
        sparse_indices_start + i].name
    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
        sparse_values_start + i].name
    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
        sparse_shapes_start + i].name

  return config