aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py64
1 files changed, 40 insertions, 24 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 13f52fbae7..7509ef9c59 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2338,7 +2338,8 @@ def permute_dimensions(x, pattern):
@tf_export('keras.backend.resize_images')
-def resize_images(x, height_factor, width_factor, data_format):
+def resize_images(x, height_factor, width_factor, data_format,
+ interpolation='nearest'):
"""Resizes the images contained in a 4D tensor.
Arguments:
@@ -2346,40 +2347,55 @@ def resize_images(x, height_factor, width_factor, data_format):
height_factor: Positive integer.
width_factor: Positive integer.
data_format: One of `"channels_first"`, `"channels_last"`.
+ interpolation: A string, one of `nearest` or `bilinear`.
Returns:
A tensor.
Raises:
- ValueError: if `data_format` is neither
- `channels_last` or `channels_first`.
+ ValueError: in case of incorrect value for
+ `data_format` or `interpolation`.
"""
if data_format == 'channels_first':
- original_shape = int_shape(x)
- new_shape = array_ops.shape(x)[2:]
- new_shape *= constant_op.constant(
- np.array([height_factor, width_factor]).astype('int32'))
+ rows, cols = 2, 3
+ elif data_format == 'channels_last':
+ rows, cols = 1, 2
+ else:
+ raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
+
+ original_shape = int_shape(x)
+ new_shape = array_ops.shape(x)[rows:cols + 1]
+ new_shape *= constant_op.constant(
+ np.array([height_factor, width_factor], dtype='int32'))
+
+ if data_format == 'channels_first':
x = permute_dimensions(x, [0, 2, 3, 1])
+ if interpolation == 'nearest':
x = image_ops.resize_nearest_neighbor(x, new_shape)
+ elif interpolation == 'bilinear':
+ x = image_ops.resize_bilinear(x, new_shape)
+ else:
+ raise ValueError('interpolation should be one '
+ 'of "nearest" or "bilinear".')
+ if data_format == 'channels_first':
x = permute_dimensions(x, [0, 3, 1, 2])
- x.set_shape((None, None, original_shape[2] * height_factor
- if original_shape[2] is not None else None,
- original_shape[3] * width_factor
- if original_shape[3] is not None else None))
- return x
- elif data_format == 'channels_last':
- original_shape = int_shape(x)
- new_shape = array_ops.shape(x)[1:3]
- new_shape *= constant_op.constant(
- np.array([height_factor, width_factor]).astype('int32'))
- x = image_ops.resize_nearest_neighbor(x, new_shape)
- x.set_shape((None, original_shape[1] * height_factor
- if original_shape[1] is not None else None,
- original_shape[2] * width_factor
- if original_shape[2] is not None else None, None))
- return x
+
+ if original_shape[rows] is None:
+ new_height = None
else:
- raise ValueError('Invalid data_format: ' + str(data_format))
+ new_height = original_shape[rows] * height_factor
+
+ if original_shape[cols] is None:
+ new_width = None
+ else:
+ new_width = original_shape[cols] * width_factor
+
+ if data_format == 'channels_first':
+ output_shape = (None, None, new_height, new_width)
+ else:
+ output_shape = (None, new_height, new_width, None)
+ x.set_shape(output_shape)
+ return x
@tf_export('keras.backend.resize_volumes')