diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 128 |
1 files changed, 118 insertions, 10 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index bdaac65904..ec7b9372ca 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.gen_nn_ops import * from tensorflow.python.util import deprecation + # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn @@ -1645,52 +1646,62 @@ def _softmax(logits, compute_op, dim=-1, name=None): return output -def softmax(logits, dim=-1, name=None): +@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") +def softmax(logits, axis=None, name=None, dim=None): """Computes softmax activations. This function performs the equivalent of - softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), dim) + softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis) Args: logits: A non-empty `Tensor`. Must be one of the following types: `half`, `float32`, `float64`. - dim: The dimension softmax would be performed on. The default is -1 which + axis: The dimension softmax would be performed on. The default is -1 which indicates the last dimension. name: A name for the operation (optional). + dim: Deprecated alias for `axis`. Returns: A `Tensor`. Has the same type and shape as `logits`. Raises: - InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + InvalidArgumentError: if `logits` is empty or `axis` is beyond the last dimension of `logits`. """ - return _softmax(logits, gen_nn_ops._softmax, dim, name) + axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + axis = -1 + return _softmax(logits, gen_nn_ops._softmax, axis, name) -def log_softmax(logits, dim=-1, name=None): +@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim") +def log_softmax(logits, axis=None, name=None, dim=None): """Computes log softmax activations. For each batch `i` and class `j` we have - logsoftmax = logits - log(reduce_sum(exp(logits), dim)) + logsoftmax = logits - log(reduce_sum(exp(logits), axis)) Args: logits: A non-empty `Tensor`. Must be one of the following types: `half`, `float32`, `float64`. - dim: The dimension softmax would be performed on. The default is -1 which + axis: The dimension softmax would be performed on. The default is -1 which indicates the last dimension. name: A name for the operation (optional). + dim: Deprecated alias for `axis`. Returns: A `Tensor`. Has the same type as `logits`. Same shape as `logits`. Raises: - InvalidArgumentError: if `logits` is empty or `dim` is beyond the last + InvalidArgumentError: if `logits` is empty or `axis` is beyond the last dimension of `logits`. """ - return _softmax(logits, gen_nn_ops._log_softmax, dim, name) + axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim) + if axis is None: + axis = -1 + return _softmax(logits, gen_nn_ops._log_softmax, axis, name) def _ensure_xent_args(name, sentinel, labels, logits): @@ -2305,6 +2316,103 @@ def conv1d(value, filters, stride, padding, return array_ops.squeeze(result, [spatial_start_dim]) +def conv1d_transpose( + value, + filter, # pylint: disable=redefined-builtin + output_shape, + stride, + padding="SAME", + data_format="NWC", + name=None): + """The transpose of `conv1d`. + + This operation is sometimes called "deconvolution" after [Deconvolutional + Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is + actually the transpose (gradient) of `conv1d` rather than an actual + deconvolution. + + Args: + value: A 3-D `Tensor` of type `float` and shape + `[batch, in_width, in_channels]` for `NWC` data format or + `[batch, in_channels, in_width]` for `NCW` data format. + filter: A 3-D `Tensor` with the same type as `value` and shape + `[filter_width, output_channels, in_channels]`. `filter`'s + `in_channels` dimension must match that of `value`. + output_shape: A 1-D `Tensor` representing the output shape of the + deconvolution op. + stride: An `integer`. The number of entries by which + the filter is moved right at each step. + padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. + See the @{tf.nn.convolution$comment here} + data_format: A string. 'NHWC' and 'NCHW' are supported. + name: Optional name for the returned tensor. + + Returns: + A `Tensor` with the same type as `value`. + + Raises: + ValueError: If input/output depth does not match `filter`'s shape, or if + padding is other than `'VALID'` or `'SAME'`. + """ + with ops.name_scope(name, "conv1d_transpose", + [value, filter, output_shape]) as name: + output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") + if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(3)): + raise ValueError("output_shape must have shape (3,), got {}".format( + output_shape_.get_shape())) + + # The format could be either NWC or NCW, map to NHWC or NCHW + if data_format is None or data_format == "NWC": + data_format_2d = "NHWC" + axis = 2 + elif data_format == "NCW": + data_format_2d = "NCHW" + axis = 1 + else: + raise ValueError("data_format must be \"NWC\" or \"NCW\".") + + if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[2]): + raise ValueError("input channels does not match filter's input channels, " + "{} != {}".format(value.get_shape()[axis], + filter.get_shape()[2])) + + if isinstance(output_shape, (list, np.ndarray)): + # output_shape's shape should be == [3] if reached this point. + if not filter.get_shape()[1].is_compatible_with(output_shape[axis]): + raise ValueError( + "output_shape does not match filter's output channels, " + "{} != {}".format(output_shape[axis], + filter.get_shape()[1])) + + if padding != "VALID" and padding != "SAME": + raise ValueError("padding must be either VALID or SAME:" + " {}".format(padding)) + + # Reshape the input tensor to [batch, 1, in_width, in_channels] + if data_format_2d == "NHWC": + output_shape_ = array_ops.concat( + [output_shape_[:1], [1], output_shape_[1:]], axis=0) + spatial_start_dim = 1 + strides = [1, 1, stride, 1] + else: + output_shape_ = array_ops.concat( + [output_shape_[:2], [1], output_shape_[2:]], axis=0) + spatial_start_dim = 2 + strides = [1, 1, 1, stride] + value = array_ops.expand_dims(value, spatial_start_dim) + filter = array_ops.expand_dims(filter, 0) + + result = gen_nn_ops.conv2d_backprop_input( + input_sizes=output_shape_, + filter=filter, + out_backprop=value, + strides=strides, + padding=padding, + data_format=data_format_2d, + name=name) + return array_ops.squeeze(result, [spatial_start_dim]) + + @ops.RegisterStatistics("Dilation2D", "flops") def _calc_dilation2d_flops(graph, node): """Calculates the compute resources needed for Dilation2D.""" |