aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-16 17:59:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-16 18:03:12 -0700
commit5c5dc8d5641b7c915f681109921dfb2b3e082a9b (patch)
tree96009b437c51afe7d3f43e103a51649b3f9a825c
parent684f88fa7e61721c3264dc70abeed2b3e6fa7717 (diff)
Adding an ItemHandler that does lookups. This allows decoding of tf.Examples where IDs are not materialized (e.g. 'image/object/class/text' present but 'image/object/class/label' not).
PiperOrigin-RevId: 172406978
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py36
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py31
2 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index 094568389c..7a56df9e97 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -207,6 +207,42 @@ class Tensor(ItemHandler):
return tensor
+class LookupTensor(Tensor):
+ """An ItemHandler that returns a parsed Tensor, the result of a lookup."""
+
+ def __init__(self,
+ tensor_key,
+ table,
+ shape_keys=None,
+ shape=None,
+ default_value=''):
+ """Initializes the LookupTensor handler.
+
+ See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup.
+
+ Args:
+ tensor_key: the name of the `TFExample` feature to read the tensor from.
+ table: A tf.lookup table.
+ shape_keys: Optional name or list of names of the TF-Example feature in
+ which the tensor shape is stored. If a list, then each corresponds to
+ one dimension of the shape.
+ shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
+ reshaped accordingly.
+ default_value: The value used when the `tensor_key` is not found in a
+ particular `TFExample`.
+
+ Raises:
+ ValueError: if both `shape_keys` and `shape` are specified.
+ """
+ self._table = table
+ super(LookupTensor, self).__init__(tensor_key, shape_keys, shape,
+ default_value)
+
+ def tensors_to_item(self, keys_to_tensors):
+ unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors)
+ return self._table.lookup(unmapped_tensor)
+
+
class SparseTensor(ItemHandler):
"""An ItemHandler for SparseTensors."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 60d1eba07f..9c5a14d006 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import image_ops
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
@@ -811,6 +812,36 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image)
self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image)
+ def testDecodeExampleWithLookup(self):
+
+ example = example_pb2.Example(features=feature_pb2.Features(feature={
+ 'image/object/class/text': self._BytesFeature(
+ np.array(['cat', 'dog', 'guinea pig'])),
+ }))
+ serialized_example = example.SerializeToString()
+ # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
+ table = lookup_ops.index_table_from_tensor(
+ constant_op.constant(['dog', 'guinea pig', 'cat']))
+
+ with self.test_session() as sess:
+ sess.run(lookup_ops.tables_initializer())
+
+ serialized_example = array_ops.reshape(serialized_example, shape=[])
+
+ keys_to_features = {
+ 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string),
+ }
+
+ items_to_handlers = {
+ 'labels':
+ tfexample_decoder.LookupTensor('image/object/class/text', table),
+ }
+
+ decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
+ items_to_handlers)
+ obtained_class_ids = decoder.decode(serialized_example)[0].eval()
+
+ self.assertAllClose([2, 0, 1], obtained_class_ids)
if __name__ == '__main__':
test.main()