aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/layers.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py221
1 files changed, 220 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index f2a904b521..ed4b723ca7 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -49,15 +49,20 @@ from tensorflow.python.training import moving_averages
# TODO(b/28426988): Replace legacy_* fns migrated from slim.
# TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
__all__ = ['avg_pool2d',
+ 'avg_pool3d',
'batch_norm',
'bias_add',
'conv2d',
+ 'conv3d',
'conv2d_in_plane',
'conv2d_transpose',
+ 'conv3d_transpose',
'convolution',
'convolution2d',
'convolution2d_in_plane',
'convolution2d_transpose',
+ 'convolution3d',
+ 'convolution3d_transpose',
'dropout',
'elu',
'flatten',
@@ -66,6 +71,7 @@ __all__ = ['avg_pool2d',
'linear',
'pool',
'max_pool2d',
+ 'max_pool3d',
'one_hot_encoding',
'relu',
'relu6',
@@ -82,6 +88,8 @@ __all__ = ['avg_pool2d',
DATA_FORMAT_NCHW = 'NCHW'
DATA_FORMAT_NHWC = 'NHWC'
+DATA_FORMAT_NCDHW = 'NCDHW'
+DATA_FORMAT_NDHWC = 'NDHWC'
@add_arg_scope
@@ -132,6 +140,54 @@ def avg_pool2d(inputs,
return utils.collect_named_outputs(outputs_collections, sc, outputs)
+@add_arg_scope
+def avg_pool3d(inputs,
+ kernel_size,
+ stride=2,
+ padding='VALID',
+ data_format=DATA_FORMAT_NDHWC,
+ outputs_collections=None,
+ scope=None):
+ """Adds a 3D average pooling op.
+
+ It is assumed that the pooling is done per image but not in batch or channels.
+
+ Args:
+ inputs: A 5-D tensor of shape `[batch_size, depth, height, width, channels]` if
+ `data_format` is `NDHWC`, and `[batch_size, channels, depth, height, width]` if
+ `data_format` is `NCDHW`.
+ kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the
+ pooling kernel over which the op is computed. Can be an int if both
+ values are the same.
+ stride: A list of length 3: [stride_depth, stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: The padding method, either 'VALID' or 'SAME'.
+ data_format: A string. `NDHWC` (default) and `NCDHW` are supported.
+ outputs_collections: The collections to which the outputs are added.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ A `Tensor` representing the results of the pooling operation.
+
+ Raises:
+ ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
+ """
+ if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
+ raise ValueError('data_format has to be either NCDHW or NDHWC.')
+ with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc:
+ inputs = ops.convert_to_tensor(inputs)
+ df = ('channels_first' if data_format and data_format.startswith('NC')
+ else 'channels_last')
+ layer = pooling_layers.AveragePooling3D(pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
+ outputs = layer.apply(inputs)
+ return utils.collect_named_outputs(outputs_collections, sc, outputs)
+
+
def _fused_batch_norm(
inputs,
decay=0.999,
@@ -985,6 +1041,7 @@ def convolution(inputs,
sc.original_name_scope, outputs)
convolution2d = convolution
+convolution3d = convolution
@add_arg_scope
@@ -1204,6 +1261,116 @@ def convolution2d_transpose(
@add_arg_scope
+def convolution3d_transpose(
+ inputs,
+ num_outputs,
+ kernel_size,
+ stride=1,
+ padding='SAME',
+ data_format=DATA_FORMAT_NDHWC,
+ activation_fn=nn.relu,
+ normalizer_fn=None,
+ normalizer_params=None,
+ weights_initializer=initializers.xavier_initializer(),
+ weights_regularizer=None,
+ biases_initializer=init_ops.zeros_initializer(),
+ biases_regularizer=None,
+ reuse=None,
+ variables_collections=None,
+ outputs_collections=None,
+ trainable=True,
+ scope=None):
+ """Adds a convolution3d_transpose with an optional batch normalization layer.
+
+ The function creates a variable called `weights`, representing the
+ kernel, that is convolved with the input. If `batch_norm_params` is `None`, a
+ second variable called 'biases' is added to the result of the operation.
+ Args:
+ inputs: A 5-D `Tensor` of type `float` and shape
+ `[batch, depth, height, width, in_channels]` for `NDHWC` data format or
+ `[batch, in_channels, depth, height, width]` for `NCDHW` data format.
+ num_outputs: Integer, the number of output filters.
+ kernel_size: A list of length 3 holding the [kernel_depth, kernel_height, kernel_width] of
+ of the filters. Can be an int if both values are the same.
+ stride: A list of length 3: [stride_depth, stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: One of 'VALID' or 'SAME'.
+ data_format: A string. `NDHWC` (default) and `NCDHW` are supported.
+ activation_fn: Activation function. The default value is a ReLU function.
+ Explicitly set it to None to skip it and maintain a linear activation.
+ normalizer_fn: Normalization function to use instead of `biases`. If
+ `normalizer_fn` is provided then `biases_initializer` and
+ `biases_regularizer` are ignored and `biases` are not created nor added.
+ default set to None for no normalizer function
+ normalizer_params: Normalization function parameters.
+ weights_initializer: An initializer for the weights.
+ weights_regularizer: Optional regularizer for the weights.
+ biases_initializer: An initializer for the biases. If None skip biases.
+ biases_regularizer: Optional regularizer for the biases.
+ reuse: Whether or not the layer and its variables should be reused. To be
+ able to reuse the layer scope must be given.
+ variables_collections: Optional list of collections for all the variables or
+ a dictionary containing a different list of collection per variable.
+ outputs_collections: Collection to add the outputs.
+ trainable: Whether or not the variables should be trainable or not.
+ scope: Optional scope for variable_scope.
+ Returns:
+ A tensor representing the output of the operation.
+ Raises:
+ ValueError: If 'kernel_size' is not a list of length 3.
+ ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
+ ValueError: If `C` dimension of `inputs` is None.
+ """
+ layer_variable_getter = _build_variable_getter(
+ {'bias': 'biases', 'kernel': 'weights'})
+
+ with variable_scope.variable_scope(
+ scope, 'Conv3d_transpose', [inputs], reuse=reuse,
+ custom_getter=layer_variable_getter) as sc:
+ if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
+ raise ValueError('data_format has to be either NCDHW or NDHWC.')
+
+ inputs = ops.convert_to_tensor(inputs)
+
+ df = ('channels_first' if data_format and data_format.startswith('NC')
+ else 'channels_last')
+ layer = convolutional_layers.Convolution3DTranspose(
+ filters=num_outputs,
+ kernel_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ activation=None,
+ use_bias=not normalizer_fn and biases_initializer,
+ kernel_initializer=weights_initializer,
+ bias_initializer=biases_initializer,
+ kernel_regularizer=weights_regularizer,
+ bias_regularizer=biases_regularizer,
+ activity_regularizer=None,
+ trainable=trainable,
+ name=sc.name,
+ dtype=inputs.dtype.base_dtype,
+ _scope=sc,
+ _reuse=reuse)
+ outputs = layer.apply(inputs)
+
+ # Add variables to collections.
+ _add_variable_to_collections(layer.kernel, variables_collections, 'weights')
+ if layer.bias:
+ _add_variable_to_collections(layer.bias, variables_collections, 'biases')
+
+ if normalizer_fn is not None:
+ normalizer_params = normalizer_params or {}
+ outputs = normalizer_fn(outputs, **normalizer_params)
+
+ if activation_fn is not None:
+ outputs = activation_fn(outputs)
+ return utils.collect_named_outputs(outputs_collections,
+ sc.original_name_scope, outputs)
+
+
+@add_arg_scope
def dropout(inputs,
keep_prob=0.5,
noise_shape=None,
@@ -1467,7 +1634,8 @@ def fully_connected(inputs,
ValueError: If x has rank less than 2 or if its last dimension is not set.
"""
if not isinstance(num_outputs, six.integer_types):
- raise ValueError('num_outputs should be int or long, got %s.', num_outputs)
+ raise ValueError(
+ 'num_outputs should be int or long, got %s.' % (num_outputs,))
layer_variable_getter = _build_variable_getter({'bias': 'biases',
'kernel': 'weights'})
@@ -1690,6 +1858,55 @@ def max_pool2d(inputs,
@add_arg_scope
+def max_pool3d(inputs,
+ kernel_size,
+ stride=2,
+ padding='VALID',
+ data_format=DATA_FORMAT_NDHWC,
+ outputs_collections=None,
+ scope=None):
+ """Adds a 3D Max Pooling op.
+
+ It is assumed that the pooling is done per image but not in batch or channels.
+
+ Args:
+ inputs: A 5-D tensor of shape `[batch_size, depth, height, width, channels]` if
+ `data_format` is `NDHWC`, and `[batch_size, channels, depth, height, width]` if
+ `data_format` is `NCDHW`.
+ kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the
+ pooling kernel over which the op is computed. Can be an int if both
+ values are the same.
+ stride: A list of length 3: [stride_depth, stride_height, stride_width].
+ Can be an int if both strides are the same. Note that presently
+ both strides must have the same value.
+ padding: The padding method, either 'VALID' or 'SAME'.
+ data_format: A string. `NDHWC` (default) and `NCDHW` are supported.
+ outputs_collections: The collections to which the outputs are added.
+ scope: Optional scope for name_scope.
+
+ Returns:
+ A `Tensor` representing the results of the pooling operation.
+
+ Raises:
+ ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`.
+ ValueError: If 'kernel_size' is not a 3-D list
+ """
+ if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC):
+ raise ValueError('data_format has to be either NCDHW or NDHWC.')
+ with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc:
+ inputs = ops.convert_to_tensor(inputs)
+ df = ('channels_first' if data_format and data_format.startswith('NC')
+ else 'channels_last')
+ layer = pooling_layers.MaxPooling3D(pool_size=kernel_size,
+ strides=stride,
+ padding=padding,
+ data_format=df,
+ _scope=sc)
+ outputs = layer.apply(inputs)
+ return utils.collect_named_outputs(outputs_collections, sc, outputs)
+
+
+@add_arg_scope
def pool(inputs,
kernel_size,
pooling_type,
@@ -2346,6 +2563,8 @@ linear = functools.partial(fully_connected, activation_fn=None)
# Simple alias.
conv2d = convolution2d
+conv3d = convolution3d
conv2d_transpose = convolution2d_transpose
+conv3d_transpose = convolution3d_transpose
conv2d_in_plane = convolution2d_in_plane
separable_conv2d = separable_convolution2d