aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/advanced_activations.py
blob: eba10da6f3ce1367f4cb0180d16efdc5913fcddc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Layers that act as activation functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.keras import activations
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export


@tf_export('keras.layers.LeakyReLU')
class LeakyReLU(Layer):
  """Leaky version of a Rectified Linear Unit.

  It allows a small gradient when the unit is not active:
  `f(x) = alpha * x for x < 0`,
  `f(x) = x for x >= 0`.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      alpha: float >= 0. Negative slope coefficient.

  """

  def __init__(self, alpha=0.3, **kwargs):
    super(LeakyReLU, self).__init__(**kwargs)
    self.supports_masking = True
    self.alpha = K.cast_to_floatx(alpha)

  def call(self, inputs):
    return K.relu(inputs, alpha=self.alpha)

  def get_config(self):
    config = {'alpha': float(self.alpha)}
    base_config = super(LeakyReLU, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape


@tf_export('keras.layers.PReLU')
class PReLU(Layer):
  """Parametric Rectified Linear Unit.

  It follows:
  `f(x) = alpha * x for x < 0`,
  `f(x) = x for x >= 0`,
  where `alpha` is a learned array with the same shape as x.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      alpha_initializer: initializer function for the weights.
      alpha_regularizer: regularizer for the weights.
      alpha_constraint: constraint for the weights.
      shared_axes: the axes along which to share learnable
          parameters for the activation function.
          For example, if the incoming feature maps
          are from a 2D convolution
          with output shape `(batch, height, width, channels)`,
          and you wish to share parameters across space
          so that each filter only has one set of parameters,
          set `shared_axes=[1, 2]`.

  """

  def __init__(self,
               alpha_initializer='zeros',
               alpha_regularizer=None,
               alpha_constraint=None,
               shared_axes=None,
               **kwargs):
    super(PReLU, self).__init__(**kwargs)
    self.supports_masking = True
    self.alpha_initializer = initializers.get(alpha_initializer)
    self.alpha_regularizer = regularizers.get(alpha_regularizer)
    self.alpha_constraint = constraints.get(alpha_constraint)
    if shared_axes is None:
      self.shared_axes = None
    elif not isinstance(shared_axes, (list, tuple)):
      self.shared_axes = [shared_axes]
    else:
      self.shared_axes = list(shared_axes)

  @tf_utils.shape_type_conversion
  def build(self, input_shape):
    param_shape = list(input_shape[1:])
    self.param_broadcast = [False] * len(param_shape)
    if self.shared_axes is not None:
      for i in self.shared_axes:
        param_shape[i - 1] = 1
        self.param_broadcast[i - 1] = True
    self.alpha = self.add_weight(
        shape=param_shape,
        name='alpha',
        initializer=self.alpha_initializer,
        regularizer=self.alpha_regularizer,
        constraint=self.alpha_constraint)
    # Set input spec
    axes = {}
    if self.shared_axes:
      for i in range(1, len(input_shape)):
        if i not in self.shared_axes:
          axes[i] = input_shape[i]
    self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)
    self.built = True

  def call(self, inputs, mask=None):
    pos = K.relu(inputs)
    if K.backend() == 'theano':
      neg = (
          K.pattern_broadcast(self.alpha, self.param_broadcast) *
          (inputs - math_ops.abs(inputs)) * 0.5)
    else:
      neg = -self.alpha * K.relu(-inputs)
    return pos + neg

  def get_config(self):
    config = {
        'alpha_initializer': initializers.serialize(self.alpha_initializer),
        'alpha_regularizer': regularizers.serialize(self.alpha_regularizer),
        'alpha_constraint': constraints.serialize(self.alpha_constraint),
        'shared_axes': self.shared_axes
    }
    base_config = super(PReLU, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape


@tf_export('keras.layers.ELU')
class ELU(Layer):
  """Exponential Linear Unit.

  It follows:
  `f(x) =  alpha * (exp(x) - 1.) for x < 0`,
  `f(x) = x for x >= 0`.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      alpha: scale for the negative factor.

  """

  def __init__(self, alpha=1.0, **kwargs):
    super(ELU, self).__init__(**kwargs)
    self.supports_masking = True
    self.alpha = K.cast_to_floatx(alpha)

  def call(self, inputs):
    return K.elu(inputs, self.alpha)

  def get_config(self):
    config = {'alpha': float(self.alpha)}
    base_config = super(ELU, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape


@tf_export('keras.layers.ThresholdedReLU')
class ThresholdedReLU(Layer):
  """Thresholded Rectified Linear Unit.

  It follows:
  `f(x) = x for x > theta`,
  `f(x) = 0 otherwise`.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      theta: float >= 0. Threshold location of activation.

  """

  def __init__(self, theta=1.0, **kwargs):
    super(ThresholdedReLU, self).__init__(**kwargs)
    self.supports_masking = True
    self.theta = K.cast_to_floatx(theta)

  def call(self, inputs, mask=None):
    return inputs * math_ops.cast(
        math_ops.greater(inputs, self.theta), K.floatx())

  def get_config(self):
    config = {'theta': float(self.theta)}
    base_config = super(ThresholdedReLU, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape


@tf_export('keras.layers.Softmax')
class Softmax(Layer):
  """Softmax activation function.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      axis: Integer, axis along which the softmax normalization is applied.
  """

  def __init__(self, axis=-1, **kwargs):
    super(Softmax, self).__init__(**kwargs)
    self.supports_masking = True
    self.axis = axis

  def call(self, inputs):
    return activations.softmax(inputs, axis=self.axis)

  def get_config(self):
    config = {'axis': self.axis}
    base_config = super(Softmax, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape


@tf_export('keras.layers.ReLU')
class ReLU(Layer):
  """Rectified Linear Unit activation function.

  Input shape:
      Arbitrary. Use the keyword argument `input_shape`
      (tuple of integers, does not include the samples axis)
      when using this layer as the first layer in a model.

  Output shape:
      Same shape as the input.

  Arguments:
      max_value: float >= 0. Maximum activation value.
  """

  def __init__(self, max_value=None, **kwargs):
    super(ReLU, self).__init__(**kwargs)
    self.support_masking = True
    self.max_value = K.cast_to_floatx(max_value)
    if self.max_value < 0.:
      raise ValueError('max_value of Relu layer '
                       'cannot be negative value: ' + str(max_value))

  def call(self, inputs):
    return activations.relu(inputs, max_value=self.max_value)

  def get_config(self):
    config = {'max_value': self.max_value}
    base_config = super(ReLU, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  @tf_utils.shape_type_conversion
  def compute_output_shape(self, input_shape):
    return input_shape