diff options
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 64 |
1 files changed, 40 insertions, 24 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 13f52fbae7..7509ef9c59 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -2338,7 +2338,8 @@ def permute_dimensions(x, pattern): @tf_export('keras.backend.resize_images') -def resize_images(x, height_factor, width_factor, data_format): +def resize_images(x, height_factor, width_factor, data_format, + interpolation='nearest'): """Resizes the images contained in a 4D tensor. Arguments: @@ -2346,40 +2347,55 @@ def resize_images(x, height_factor, width_factor, data_format): height_factor: Positive integer. width_factor: Positive integer. data_format: One of `"channels_first"`, `"channels_last"`. + interpolation: A string, one of `nearest` or `bilinear`. Returns: A tensor. Raises: - ValueError: if `data_format` is neither - `channels_last` or `channels_first`. + ValueError: in case of incorrect value for + `data_format` or `interpolation`. """ if data_format == 'channels_first': - original_shape = int_shape(x) - new_shape = array_ops.shape(x)[2:] - new_shape *= constant_op.constant( - np.array([height_factor, width_factor]).astype('int32')) + rows, cols = 2, 3 + elif data_format == 'channels_last': + rows, cols = 1, 2 + else: + raise ValueError('Invalid `data_format` argument: %s' % (data_format,)) + + original_shape = int_shape(x) + new_shape = array_ops.shape(x)[rows:cols + 1] + new_shape *= constant_op.constant( + np.array([height_factor, width_factor], dtype='int32')) + + if data_format == 'channels_first': x = permute_dimensions(x, [0, 2, 3, 1]) + if interpolation == 'nearest': x = image_ops.resize_nearest_neighbor(x, new_shape) + elif interpolation == 'bilinear': + x = image_ops.resize_bilinear(x, new_shape) + else: + raise ValueError('interpolation should be one ' + 'of "nearest" or "bilinear".') + if data_format == 'channels_first': x = permute_dimensions(x, [0, 3, 1, 2]) - x.set_shape((None, None, original_shape[2] * height_factor - if original_shape[2] is not None else None, - original_shape[3] * width_factor - if original_shape[3] is not None else None)) - return x - elif data_format == 'channels_last': - original_shape = int_shape(x) - new_shape = array_ops.shape(x)[1:3] - new_shape *= constant_op.constant( - np.array([height_factor, width_factor]).astype('int32')) - x = image_ops.resize_nearest_neighbor(x, new_shape) - x.set_shape((None, original_shape[1] * height_factor - if original_shape[1] is not None else None, - original_shape[2] * width_factor - if original_shape[2] is not None else None, None)) - return x + + if original_shape[rows] is None: + new_height = None else: - raise ValueError('Invalid data_format: ' + str(data_format)) + new_height = original_shape[rows] * height_factor + + if original_shape[cols] is None: + new_width = None + else: + new_width = original_shape[cols] * width_factor + + if data_format == 'channels_first': + output_shape = (None, None, new_height, new_width) + else: + output_shape = (None, new_height, new_width, None) + x.set_shape(output_shape) + return x @tf_export('keras.backend.resize_volumes') |