diff options
Diffstat (limited to 'tensorflow/python/keras/layers/convolutional.py')
-rw-r--r-- | tensorflow/python/keras/layers/convolutional.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 8f5872385c..58024677ee 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -1951,6 +1951,7 @@ class UpSampling2D(Layer): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + interpolation: A string, one of `nearest` or `bilinear`. Input shape: 4D tensor with shape: @@ -1967,10 +1968,18 @@ class UpSampling2D(Layer): `(batch, channels, upsampled_rows, upsampled_cols)` """ - def __init__(self, size=(2, 2), data_format=None, **kwargs): + def __init__(self, + size=(2, 2), + data_format=None, + interpolation='nearest', + **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') + if interpolation not in {'nearest', 'bilinear'}: + raise ValueError('`interpolation` argument should be one of `"nearest"` ' + 'or `"bilinear"`.') + self.interpolation = interpolation self.input_spec = InputSpec(ndim=4) def compute_output_shape(self, input_shape): @@ -1992,7 +2001,8 @@ class UpSampling2D(Layer): def call(self, inputs): return backend.resize_images( - inputs, self.size[0], self.size[1], self.data_format) + inputs, self.size[0], self.size[1], self.data_format, + interpolation=self.interpolation) def get_config(self): config = {'size': self.size, 'data_format': self.data_format} |