diff options
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/layers.py')
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 221 |
1 files changed, 220 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index f2a904b521..ed4b723ca7 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -49,15 +49,20 @@ from tensorflow.python.training import moving_averages # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. __all__ = ['avg_pool2d', + 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', + 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', + 'conv3d_transpose', 'convolution', 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', + 'convolution3d', + 'convolution3d_transpose', 'dropout', 'elu', 'flatten', @@ -66,6 +71,7 @@ __all__ = ['avg_pool2d', 'linear', 'pool', 'max_pool2d', + 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', @@ -82,6 +88,8 @@ __all__ = ['avg_pool2d', DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' +DATA_FORMAT_NCDHW = 'NCDHW' +DATA_FORMAT_NDHWC = 'NDHWC' @add_arg_scope @@ -132,6 +140,54 @@ def avg_pool2d(inputs, return utils.collect_named_outputs(outputs_collections, sc, outputs) +@add_arg_scope +def avg_pool3d(inputs, + kernel_size, + stride=2, + padding='VALID', + data_format=DATA_FORMAT_NDHWC, + outputs_collections=None, + scope=None): + """Adds a 3D average pooling op. + + It is assumed that the pooling is done per image but not in batch or channels. + + Args: + inputs: A 5-D tensor of shape `[batch_size, depth, height, width, channels]` if + `data_format` is `NDHWC`, and `[batch_size, channels, depth, height, width]` if + `data_format` is `NCDHW`. + kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the + pooling kernel over which the op is computed. Can be an int if both + values are the same. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. + Can be an int if both strides are the same. Note that presently + both strides must have the same value. + padding: The padding method, either 'VALID' or 'SAME'. + data_format: A string. `NDHWC` (default) and `NCDHW` are supported. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + + Returns: + A `Tensor` representing the results of the pooling operation. + + Raises: + ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`. + """ + if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC): + raise ValueError('data_format has to be either NCDHW or NDHWC.') + with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + df = ('channels_first' if data_format and data_format.startswith('NC') + else 'channels_last') + layer = pooling_layers.AveragePooling3D(pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) + outputs = layer.apply(inputs) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + def _fused_batch_norm( inputs, decay=0.999, @@ -985,6 +1041,7 @@ def convolution(inputs, sc.original_name_scope, outputs) convolution2d = convolution +convolution3d = convolution @add_arg_scope @@ -1204,6 +1261,116 @@ def convolution2d_transpose( @add_arg_scope +def convolution3d_transpose( + inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=DATA_FORMAT_NDHWC, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds a convolution3d_transpose with an optional batch normalization layer. + + The function creates a variable called `weights`, representing the + kernel, that is convolved with the input. If `batch_norm_params` is `None`, a + second variable called 'biases' is added to the result of the operation. + Args: + inputs: A 5-D `Tensor` of type `float` and shape + `[batch, depth, height, width, in_channels]` for `NDHWC` data format or + `[batch, in_channels, depth, height, width]` for `NCDHW` data format. + num_outputs: Integer, the number of output filters. + kernel_size: A list of length 3 holding the [kernel_depth, kernel_height, kernel_width] of + of the filters. Can be an int if both values are the same. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. + Can be an int if both strides are the same. Note that presently + both strides must have the same value. + padding: One of 'VALID' or 'SAME'. + data_format: A string. `NDHWC` (default) and `NCDHW` are supported. + activation_fn: Activation function. The default value is a ReLU function. + Explicitly set it to None to skip it and maintain a linear activation. + normalizer_fn: Normalization function to use instead of `biases`. If + `normalizer_fn` is provided then `biases_initializer` and + `biases_regularizer` are ignored and `biases` are not created nor added. + default set to None for no normalizer function + normalizer_params: Normalization function parameters. + weights_initializer: An initializer for the weights. + weights_regularizer: Optional regularizer for the weights. + biases_initializer: An initializer for the biases. If None skip biases. + biases_regularizer: Optional regularizer for the biases. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional list of collections for all the variables or + a dictionary containing a different list of collection per variable. + outputs_collections: Collection to add the outputs. + trainable: Whether or not the variables should be trainable or not. + scope: Optional scope for variable_scope. + Returns: + A tensor representing the output of the operation. + Raises: + ValueError: If 'kernel_size' is not a list of length 3. + ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`. + ValueError: If `C` dimension of `inputs` is None. + """ + layer_variable_getter = _build_variable_getter( + {'bias': 'biases', 'kernel': 'weights'}) + + with variable_scope.variable_scope( + scope, 'Conv3d_transpose', [inputs], reuse=reuse, + custom_getter=layer_variable_getter) as sc: + if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC): + raise ValueError('data_format has to be either NCDHW or NDHWC.') + + inputs = ops.convert_to_tensor(inputs) + + df = ('channels_first' if data_format and data_format.startswith('NC') + else 'channels_last') + layer = convolutional_layers.Convolution3DTranspose( + filters=num_outputs, + kernel_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) + outputs = layer.apply(inputs) + + # Add variables to collections. + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') + if layer.bias: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') + + if normalizer_fn is not None: + normalizer_params = normalizer_params or {} + outputs = normalizer_fn(outputs, **normalizer_params) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) + + +@add_arg_scope def dropout(inputs, keep_prob=0.5, noise_shape=None, @@ -1467,7 +1634,8 @@ def fully_connected(inputs, ValueError: If x has rank less than 2 or if its last dimension is not set. """ if not isinstance(num_outputs, six.integer_types): - raise ValueError('num_outputs should be int or long, got %s.', num_outputs) + raise ValueError( + 'num_outputs should be int or long, got %s.' % (num_outputs,)) layer_variable_getter = _build_variable_getter({'bias': 'biases', 'kernel': 'weights'}) @@ -1690,6 +1858,55 @@ def max_pool2d(inputs, @add_arg_scope +def max_pool3d(inputs, + kernel_size, + stride=2, + padding='VALID', + data_format=DATA_FORMAT_NDHWC, + outputs_collections=None, + scope=None): + """Adds a 3D Max Pooling op. + + It is assumed that the pooling is done per image but not in batch or channels. + + Args: + inputs: A 5-D tensor of shape `[batch_size, depth, height, width, channels]` if + `data_format` is `NDHWC`, and `[batch_size, channels, depth, height, width]` if + `data_format` is `NCDHW`. + kernel_size: A list of length 3: [kernel_depth, kernel_height, kernel_width] of the + pooling kernel over which the op is computed. Can be an int if both + values are the same. + stride: A list of length 3: [stride_depth, stride_height, stride_width]. + Can be an int if both strides are the same. Note that presently + both strides must have the same value. + padding: The padding method, either 'VALID' or 'SAME'. + data_format: A string. `NDHWC` (default) and `NCDHW` are supported. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + + Returns: + A `Tensor` representing the results of the pooling operation. + + Raises: + ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`. + ValueError: If 'kernel_size' is not a 3-D list + """ + if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC): + raise ValueError('data_format has to be either NCDHW or NDHWC.') + with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + df = ('channels_first' if data_format and data_format.startswith('NC') + else 'channels_last') + layer = pooling_layers.MaxPooling3D(pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) + outputs = layer.apply(inputs) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + +@add_arg_scope def pool(inputs, kernel_size, pooling_type, @@ -2346,6 +2563,8 @@ linear = functools.partial(fully_connected, activation_fn=None) # Simple alias. conv2d = convolution2d +conv3d = convolution3d conv2d_transpose = convolution2d_transpose +conv3d_transpose = convolution3d_transpose conv2d_in_plane = convolution2d_in_plane separable_conv2d = separable_convolution2d |