aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/slim
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-24 14:45:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-24 14:50:10 -0700
commit8d1a4fa09cb40ee98ecddc99f207f17b05176897 (patch)
tree2f2d2595848258cac9358909dad9af0ecbeb0d75 /tensorflow/contrib/slim
parentdfc7b26b0dc0dd54038a1be3b31b05bd39c1e79f (diff)
Add a MultiHandler that can conditionally apply handling logic based on presence of input Tensors.
PiperOrigin-RevId: 173314020
Diffstat (limited to 'tensorflow/contrib/slim')
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py34
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py51
2 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index 7a56df9e97..0544404e9e 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -243,6 +243,40 @@ class LookupTensor(Tensor):
return self._table.lookup(unmapped_tensor)
+class BackupHandler(ItemHandler):
+ """An ItemHandler that tries two ItemHandlers in order."""
+
+ def __init__(self, handler, backup):
+ """Initializes the BackupHandler handler.
+
+ If the first Handler's tensors_to_item returns a Tensor with no elements,
+ the second Handler is used.
+
+ Args:
+ handler: The primary ItemHandler.
+ backup: The backup ItemHandler.
+
+ Raises:
+ ValueError: if either is not an ItemHandler.
+ """
+ if not isinstance(handler, ItemHandler):
+ raise ValueError('Primary handler is of type %s instead of ItemHandler'
+ % type(handler))
+ if not isinstance(backup, ItemHandler):
+ raise ValueError('Backup handler is of type %s instead of ItemHandler'
+ % type(backup))
+ self._handler = handler
+ self._backup = backup
+ super(BackupHandler, self).__init__(handler.keys + backup.keys)
+
+ def tensors_to_item(self, keys_to_tensors):
+ item = self._handler.tensors_to_item(keys_to_tensors)
+ return control_flow_ops.cond(
+ pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0),
+ true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
+ false_fn=lambda: item)
+
+
class SparseTensor(ItemHandler):
"""An ItemHandler for SparseTensors."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index 9c5a14d006..d783d4fef4 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -843,5 +843,56 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose([2, 0, 1], obtained_class_ids)
+ def testDecodeExampleWithBackupHandlerLookup(self):
+
+ example1 = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ 'image/object/class/label':
+ self._EncodedInt64Feature(np.array([42, 10, 900]))
+ }))
+ example2 = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ }))
+ example3 = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/label':
+ self._EncodedInt64Feature(np.array([42, 10, 901]))
+ }))
+ # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
+ table = lookup_ops.index_table_from_tensor(
+ constant_op.constant(['dog', 'guinea pig', 'cat']))
+ keys_to_features = {
+ 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string),
+ 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64),
+ }
+ backup_handler = tfexample_decoder.BackupHandler(
+ handler=tfexample_decoder.Tensor('image/object/class/label'),
+ backup=tfexample_decoder.LookupTensor('image/object/class/text', table))
+ items_to_handlers = {
+ 'labels': backup_handler,
+ }
+ decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
+ items_to_handlers)
+ obtained_class_ids_each_example = []
+ with self.test_session() as sess:
+ sess.run(lookup_ops.tables_initializer())
+ for example in [example1, example2, example3]:
+ serialized_example = array_ops.reshape(
+ example.SerializeToString(), shape=[])
+ obtained_class_ids_each_example.append(
+ decoder.decode(serialized_example)[0].eval())
+
+ self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0])
+ self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1])
+ self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2])
+
+
if __name__ == '__main__':
test.main()