From 5c5dc8d5641b7c915f681109921dfb2b3e082a9b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 16 Oct 2017 17:59:11 -0700 Subject: Adding an ItemHandler that does lookups. This allows decoding of tf.Examples where IDs are not materialized (e.g. 'image/object/class/text' present but 'image/object/class/label' not). PiperOrigin-RevId: 172406978 --- .../slim/python/slim/data/tfexample_decoder.py | 36 ++++++++++++++++++++++ .../python/slim/data/tfexample_decoder_test.py | 31 +++++++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 094568389c..7a56df9e97 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -207,6 +207,42 @@ class Tensor(ItemHandler): return tensor +class LookupTensor(Tensor): + """An ItemHandler that returns a parsed Tensor, the result of a lookup.""" + + def __init__(self, + tensor_key, + table, + shape_keys=None, + shape=None, + default_value=''): + """Initializes the LookupTensor handler. + + See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup. + + Args: + tensor_key: the name of the `TFExample` feature to read the tensor from. + table: A tf.lookup table. + shape_keys: Optional name or list of names of the TF-Example feature in + which the tensor shape is stored. If a list, then each corresponds to + one dimension of the shape. + shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is + reshaped accordingly. + default_value: The value used when the `tensor_key` is not found in a + particular `TFExample`. + + Raises: + ValueError: if both `shape_keys` and `shape` are specified. + """ + self._table = table + super(LookupTensor, self).__init__(tensor_key, shape_keys, shape, + default_value) + + def tensors_to_item(self, keys_to_tensors): + unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors) + return self._table.lookup(unmapped_tensor) + + class SparseTensor(ItemHandler): """An ItemHandler for SparseTensors.""" diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 60d1eba07f..9c5a14d006 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import image_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -811,6 +812,36 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image) self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image) + def testDecodeExampleWithLookup(self): + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/class/text': self._BytesFeature( + np.array(['cat', 'dog', 'guinea pig'])), + })) + serialized_example = example.SerializeToString() + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + } + + items_to_handlers = { + 'labels': + tfexample_decoder.LookupTensor('image/object/class/text', table), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids = decoder.decode(serialized_example)[0].eval() + + self.assertAllClose([2, 0, 1], obtained_class_ids) if __name__ == '__main__': test.main() -- cgit v1.2.3