aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/convolutional.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/convolutional.py')
-rw-r--r--tensorflow/python/keras/layers/convolutional.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index 8f5872385c..58024677ee 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -1951,6 +1951,7 @@ class UpSampling2D(Layer):
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
+ interpolation: A string, one of `nearest` or `bilinear`.
Input shape:
4D tensor with shape:
@@ -1967,10 +1968,18 @@ class UpSampling2D(Layer):
`(batch, channels, upsampled_rows, upsampled_cols)`
"""
- def __init__(self, size=(2, 2), data_format=None, **kwargs):
+ def __init__(self,
+ size=(2, 2),
+ data_format=None,
+ interpolation='nearest',
+ **kwargs):
super(UpSampling2D, self).__init__(**kwargs)
self.data_format = conv_utils.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
+ if interpolation not in {'nearest', 'bilinear'}:
+ raise ValueError('`interpolation` argument should be one of `"nearest"` '
+ 'or `"bilinear"`.')
+ self.interpolation = interpolation
self.input_spec = InputSpec(ndim=4)
def compute_output_shape(self, input_shape):
@@ -1992,7 +2001,8 @@ class UpSampling2D(Layer):
def call(self, inputs):
return backend.resize_images(
- inputs, self.size[0], self.size[1], self.data_format)
+ inputs, self.size[0], self.size[1], self.data_format,
+ interpolation=self.interpolation)
def get_config(self):
config = {'size': self.size, 'data_format': self.data_format}