diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-24 14:45:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-24 14:50:10 -0700 |
commit | 8d1a4fa09cb40ee98ecddc99f207f17b05176897 (patch) | |
tree | 2f2d2595848258cac9358909dad9af0ecbeb0d75 /tensorflow/contrib/slim | |
parent | dfc7b26b0dc0dd54038a1be3b31b05bd39c1e79f (diff) |
Add a MultiHandler that can conditionally apply handling logic based on presence of input Tensors.
PiperOrigin-RevId: 173314020
Diffstat (limited to 'tensorflow/contrib/slim')
-rw-r--r-- | tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py | 34 | ||||
-rw-r--r-- | tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py | 51 |
2 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 7a56df9e97..0544404e9e 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -243,6 +243,40 @@ class LookupTensor(Tensor): return self._table.lookup(unmapped_tensor) +class BackupHandler(ItemHandler): + """An ItemHandler that tries two ItemHandlers in order.""" + + def __init__(self, handler, backup): + """Initializes the BackupHandler handler. + + If the first Handler's tensors_to_item returns a Tensor with no elements, + the second Handler is used. + + Args: + handler: The primary ItemHandler. + backup: The backup ItemHandler. + + Raises: + ValueError: if either is not an ItemHandler. + """ + if not isinstance(handler, ItemHandler): + raise ValueError('Primary handler is of type %s instead of ItemHandler' + % type(handler)) + if not isinstance(backup, ItemHandler): + raise ValueError('Backup handler is of type %s instead of ItemHandler' + % type(backup)) + self._handler = handler + self._backup = backup + super(BackupHandler, self).__init__(handler.keys + backup.keys) + + def tensors_to_item(self, keys_to_tensors): + item = self._handler.tensors_to_item(keys_to_tensors) + return control_flow_ops.cond( + pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0), + true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors), + false_fn=lambda: item) + + class SparseTensor(ItemHandler): """An ItemHandler for SparseTensors.""" diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 9c5a14d006..d783d4fef4 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -843,5 +843,56 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllClose([2, 0, 1], obtained_class_ids) + def testDecodeExampleWithBackupHandlerLookup(self): + + example1 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 900])) + })) + example2 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + })) + example3 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 901])) + })) + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64), + } + backup_handler = tfexample_decoder.BackupHandler( + handler=tfexample_decoder.Tensor('image/object/class/label'), + backup=tfexample_decoder.LookupTensor('image/object/class/text', table)) + items_to_handlers = { + 'labels': backup_handler, + } + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids_each_example = [] + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + for example in [example1, example2, example3]: + serialized_example = array_ops.reshape( + example.SerializeToString(), shape=[]) + obtained_class_ids_each_example.append( + decoder.decode(serialized_example)[0].eval()) + + self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0]) + self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1]) + self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2]) + + if __name__ == '__main__': test.main() |