From c98ffffcb4e0cc668c0ff7b73d51677a7eb7dcf4 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 9 Oct 2018 16:19:46 -0700 Subject: Part 2/3 of the update of tf.keras to the Keras 2.2.4 API. PiperOrigin-RevId: 216442569 --- tensorflow/python/keras/layers/convolutional.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'tensorflow/python/keras/layers/convolutional.py') diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index 8f5872385c..58024677ee 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -1951,6 +1951,7 @@ class UpSampling2D(Layer): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + interpolation: A string, one of `nearest` or `bilinear`. Input shape: 4D tensor with shape: @@ -1967,10 +1968,18 @@ class UpSampling2D(Layer): `(batch, channels, upsampled_rows, upsampled_cols)` """ - def __init__(self, size=(2, 2), data_format=None, **kwargs): + def __init__(self, + size=(2, 2), + data_format=None, + interpolation='nearest', + **kwargs): super(UpSampling2D, self).__init__(**kwargs) self.data_format = conv_utils.normalize_data_format(data_format) self.size = conv_utils.normalize_tuple(size, 2, 'size') + if interpolation not in {'nearest', 'bilinear'}: + raise ValueError('`interpolation` argument should be one of `"nearest"` ' + 'or `"bilinear"`.') + self.interpolation = interpolation self.input_spec = InputSpec(ndim=4) def compute_output_shape(self, input_shape): @@ -1992,7 +2001,8 @@ class UpSampling2D(Layer): def call(self, inputs): return backend.resize_images( - inputs, self.size[0], self.size[1], self.data_format) + inputs, self.size[0], self.size[1], self.data_format, + interpolation=self.interpolation) def get_config(self): config = {'size': self.size, 'data_format': self.data_format} -- cgit v1.2.3