aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/slim
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-30 14:00:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 14:15:18 -0700
commit87ab69541f71e83a490fc2e1563bf1094665d2ab (patch)
treead4b3c62e5d454a14d913b6f70ed2af9b666aaaa /tensorflow/contrib/slim
parentd340f4700ec1eceb2011d1d5620633c60a12da48 (diff)
Make dtype in Image class actually modifiable.
Changing dtype to any other type other than default will cause a crash because decode_jpeg or decode_image will promise to return uint8 all the time while decode_raw will actually vary its return type. This mismatch of types causes tf.case to fail and makes dtype parameter unusable. PiperOrigin-RevId: 210975290
Diffstat (limited to 'tensorflow/contrib/slim')
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py11
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py316
2 files changed, 197 insertions, 130 deletions
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index d877831fce..a6ce45c203 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -416,12 +416,17 @@ class Image(ItemHandler):
def decode_image():
"""Decodes a image based on the headers."""
- return image_ops.decode_image(image_buffer, channels=self._channels)
+ return math_ops.cast(
+ image_ops.decode_image(image_buffer, channels=self._channels),
+ self._dtype)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
- return image_ops.decode_jpeg(
- image_buffer, channels=self._channels, dct_method=self._dct_method)
+ return math_ops.cast(
+ image_ops.decode_jpeg(
+ image_buffer,
+ channels=self._channels,
+ dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""Checks if an image is jpeg."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index d783d4fef4..826242c9d7 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -37,12 +37,12 @@ from tensorflow.python.platform import test
class TFExampleDecoderTest(test.TestCase):
def _EncodedFloatFeature(self, ndarray):
- return feature_pb2.Feature(float_list=feature_pb2.FloatList(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=ndarray.flatten().tolist()))
def _EncodedInt64Feature(self, ndarray):
- return feature_pb2.Feature(int64_list=feature_pb2.Int64List(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
with self.test_session():
@@ -74,12 +74,14 @@ class TFExampleDecoderTest(test.TestCase):
if image_format in ['raw', 'RAW']:
return constant_op.constant(image.tostring(), dtype=dtypes.string)
- def GenerateImage(self, image_format, image_shape):
+ def GenerateImage(self, image_format, image_shape, image_dtype=np.uint8):
"""Generates an image and an example containing the encoded image.
Args:
image_format: the encoding format of the image.
image_shape: the shape of the image to generate.
+ image_dtype: the dtype of values in the image. Only 'raw' image can have
+ type different than uint8.
Returns:
image: the generated image.
@@ -87,14 +89,18 @@ class TFExampleDecoderTest(test.TestCase):
serialized image and a feature key 'image/format' set to the image
encoding format ['jpeg', 'JPEG', 'png', 'PNG', 'raw'].
"""
+ assert image_format in ['raw', 'RAW'] or image_dtype == np.uint8
num_pixels = image_shape[0] * image_shape[1] * image_shape[2]
image = np.linspace(
- 0, num_pixels - 1, num=num_pixels).reshape(image_shape).astype(np.uint8)
+ 0, num_pixels - 1,
+ num=num_pixels).reshape(image_shape).astype(image_dtype)
tf_encoded = self._Encoder(image, image_format)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': self._EncodedBytesFeature(tf_encoded),
- 'image/format': self._StringFeature(image_format)
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded': self._EncodedBytesFeature(tf_encoded),
+ 'image/format': self._StringFeature(image_format)
+ }))
return image, example.SerializeToString()
@@ -168,8 +174,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_decoded_image = self.DecodeExample(
serialized_example,
- tfexample_decoder.Image(
- shape=None, channels=channels),
+ tfexample_decoder.Image(shape=None, channels=channels),
image_format='jpeg')
self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
@@ -225,27 +230,38 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
+ def testDecodeExampleWithRawEncodingFloatDtype(self):
image_shape = (2, 3, 3)
- unused_image, serialized_example = self.GenerateImage(
+ image, serialized_example = self.GenerateImage(
+ image_format='raw', image_shape=image_shape, image_dtype=np.float32)
+
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32),
+ image_format='raw')
+
+ self.assertAllClose(image, decoded_image, atol=0)
+
+ def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
+ image_shape = (2, 3, 3)
+ # Image has type uint8 but decoding at uint16 should not cause problems.
+ image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
- # decode_raw support uint16 now so ValueError will be thrown instead.
- with self.assertRaisesRegexp(
- ValueError,
- 'true_fn and false_fn must have the same type: uint16, uint8'):
- unused_decoded_image = self.RunDecodeExample(
- serialized_example,
- tfexample_decoder.Image(dtype=dtypes.uint16),
- image_format='jpeg')
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(dtype=dtypes.uint16),
+ image_format='jpeg')
+ self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithStringTensor(self):
tensor_shape = (2, 3, 1)
np_array = np.array([[['ab'], ['cd'], ['ef']],
[['ghi'], ['jkl'], ['mnop']]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._BytesFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._BytesFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -259,7 +275,9 @@ class TFExampleDecoderTest(test.TestCase):
default_value=constant_op.constant(
'', shape=tensor_shape, dtype=dtypes.string))
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -271,9 +289,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFloatTensor(self):
np_array = np.random.rand(2, 3, 1).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -282,7 +301,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -291,9 +312,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithInt64Tensor(self):
np_array = np.random.randint(1, 10, size=(2, 3, 1))
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -302,7 +324,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -311,9 +335,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensor(self):
np_array = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -322,7 +347,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -332,9 +359,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFixLenTensorWithShape(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -342,12 +370,10 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
- parsing_ops.FixedLenFeature(
- np_array.shape, dtype=dtypes.int64),
+ parsing_ops.FixedLenFeature(np_array.shape, dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -357,9 +383,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensorToDense(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -369,8 +396,7 @@ class TFExampleDecoderTest(test.TestCase):
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -382,12 +408,18 @@ class TFExampleDecoderTest(test.TestCase):
np_image = np.random.rand(2, 3, 1).astype('f')
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/shape': self._EncodedInt64Feature(np.array(np_labels.shape)),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/shape':
+ self._EncodedInt64Feature(np.array(np_labels.shape)),
+ }))
serialized_example = example.SerializeToString()
@@ -401,11 +433,9 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
- tfexample_decoder.Tensor(
- 'labels', shape_keys='labels/shape'),
+ tfexample_decoder.Tensor('labels', shape_keys='labels/shape'),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -419,14 +449,22 @@ class TFExampleDecoderTest(test.TestCase):
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
height, width, depth = np_labels.shape
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/height': self._EncodedInt64Feature(np.array([height])),
- 'labels/width': self._EncodedInt64Feature(np.array([width])),
- 'labels/depth': self._EncodedInt64Feature(np.array([depth])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/height':
+ self._EncodedInt64Feature(np.array([height])),
+ 'labels/width':
+ self._EncodedInt64Feature(np.array([width])),
+ 'labels/depth':
+ self._EncodedInt64Feature(np.array([depth])),
+ }))
serialized_example = example.SerializeToString()
@@ -442,8 +480,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
tfexample_decoder.Tensor(
'labels',
@@ -459,10 +496,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithSparseTensor(self):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -472,7 +511,9 @@ class TFExampleDecoderTest(test.TestCase):
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
}
- items_to_handlers = {'labels': tfexample_decoder.SparseTensor(),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.SparseTensor(),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -485,11 +526,13 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- 'shape': self._EncodedInt64Feature(np_shape),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ 'shape': self._EncodedInt64Feature(np_shape),
+ }))
serialized_example = example.SerializeToString()
@@ -515,10 +558,12 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -544,10 +589,12 @@ class TFExampleDecoderTest(test.TestCase):
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
np_dense = np.array([0.0, 0.1, 0.2, 0.0, 0.0, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -559,8 +606,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'labels':
- tfexample_decoder.SparseTensor(
- shape=np_shape, densify=True),
+ tfexample_decoder.SparseTensor(shape=np_shape, densify=True),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -572,9 +618,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -603,9 +650,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -701,12 +749,14 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -740,26 +790,32 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
- 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
}
items_to_handlers = {
@@ -784,11 +840,16 @@ class TFExampleDecoderTest(test.TestCase):
with self.test_session():
tf_string = tf_encoded.eval()
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
- value=[tf_string, tf_string])),
- 'image/format': self._StringFeature(image_format),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded':
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[tf_string, tf_string])),
+ 'image/format':
+ self._StringFeature(image_format),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -797,8 +858,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features={
'image/encoded':
- parsing_ops.FixedLenFeature(
- (2,), dtypes.string),
+ parsing_ops.FixedLenFeature((2,), dtypes.string),
'image/format':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value=image_format),
@@ -814,10 +874,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithLookup(self):
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/class/text': self._BytesFeature(
- np.array(['cat', 'dog', 'guinea pig'])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ }))
serialized_example = example.SerializeToString()
# 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
table = lookup_ops.index_table_from_tensor(