aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py289
1 files changed, 161 insertions, 128 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bdaac65904..c4de2c7f00 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -23,6 +23,7 @@ import numbers
import numpy as np
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import ops
@@ -32,13 +33,13 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
+from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_argument_lookup
-from tensorflow.python.util import deprecation
# Aliases for some automatically-generated names.
local_response_normalization = gen_nn_ops.lrn
@@ -1645,17 +1646,18 @@ def _softmax(logits, compute_op, dim=-1, name=None):
return output
-def softmax(logits, dim=-1, name=None):
+@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
+def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
This function performs the equivalent of
- softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), dim)
+ softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
- dim: The dimension softmax would be performed on. The default is -1 which
+ axis: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
@@ -1663,23 +1665,27 @@ def softmax(logits, dim=-1, name=None):
A `Tensor`. Has the same type and shape as `logits`.
Raises:
- InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
+ InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
dimension of `logits`.
"""
- return _softmax(logits, gen_nn_ops._softmax, dim, name)
+ axis = deprecated_argument_lookup("axis", axis, "dim", dim)
+ if axis is None:
+ axis = -1
+ return _softmax(logits, gen_nn_ops._softmax, axis, name)
-def log_softmax(logits, dim=-1, name=None):
+@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
+def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
For each batch `i` and class `j` we have
- logsoftmax = logits - log(reduce_sum(exp(logits), dim))
+ logsoftmax = logits - log(reduce_sum(exp(logits), axis))
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
- dim: The dimension softmax would be performed on. The default is -1 which
+ axis: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
@@ -1687,10 +1693,13 @@ def log_softmax(logits, dim=-1, name=None):
A `Tensor`. Has the same type as `logits`. Same shape as `logits`.
Raises:
- InvalidArgumentError: if `logits` is empty or `dim` is beyond the last
+ InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
dimension of `logits`.
"""
- return _softmax(logits, gen_nn_ops._log_softmax, dim, name)
+ axis = deprecated_argument_lookup("axis", axis, "dim", dim)
+ if axis is None:
+ axis = -1
+ return _softmax(logits, gen_nn_ops._log_softmax, axis, name)
def _ensure_xent_args(name, sentinel, labels, logits):
@@ -1702,9 +1711,9 @@ def _ensure_xent_args(name, sentinel, labels, logits):
raise ValueError("Both labels and logits must be provided.")
-def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=invalid-name
- labels=None, logits=None,
- dim=-1, name=None):
+def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ dim=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@@ -1728,10 +1737,6 @@ def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=inva
`[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
or `float64`).
- Backpropagation will happen into both `logits` and `labels`. To disallow
- backpropagation into `labels`, pass label tensors through a `stop_gradients`
- before feeding it to this function.
-
**Note that to avoid confusion, it is required to pass only named arguments to
this function.**
@@ -1753,123 +1758,57 @@ def softmax_cross_entropy_with_logits_v2(_sentinel=None, # pylint: disable=inva
# could break users who call this with bad labels, but disregard the bad
# results.
- with ops.name_scope(
- name, "softmax_cross_entropy_with_logits", [logits, labels]) as name:
- logits = ops.convert_to_tensor(logits, name="logits")
- labels = ops.convert_to_tensor(labels, name="labels")
- precise_logits = math_ops.cast(logits, dtypes.float32) if (
- logits.dtype == dtypes.float16) else logits
- # labels and logits must be of the same type
- labels = math_ops.cast(labels, precise_logits.dtype)
- input_rank = array_ops.rank(precise_logits)
- # For shape inference.
- shape = logits.get_shape()
-
- # Move the dim to the end if dim is not the last dimension.
- if dim is not -1:
- def _move_dim_to_end(tensor, dim_index, rank):
- return array_ops.transpose(tensor,
- array_ops.concat([
- math_ops.range(dim_index),
- math_ops.range(dim_index + 1, rank),
- [dim_index]
- ], 0))
-
- precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
- labels = _move_dim_to_end(labels, dim, input_rank)
-
- input_shape = array_ops.shape(precise_logits)
-
- # Make precise_logits and labels into matrices.
- precise_logits = _flatten_outer_dims(precise_logits)
- labels = _flatten_outer_dims(labels)
-
- # Do the actual op computation.
- # The second output tensor contains the gradients. We use it in
- # _CrossEntropyGrad() in nn_grad but not here.
- cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
- precise_logits, labels, name=name)
-
- # The output cost shape should be the input minus dim.
- output_shape = array_ops.slice(input_shape, [0],
- [math_ops.subtract(input_rank, 1)])
- cost = array_ops.reshape(cost, output_shape)
-
- # Make shape inference work since reshape and transpose may erase its static
- # shape.
- if context.in_graph_mode() and shape is not None and shape.dims is not None:
- shape = shape.as_list()
- del shape[dim]
- cost.set_shape(shape)
-
- if logits.dtype == dtypes.float16:
- return math_ops.cast(cost, dtypes.float16)
- else:
- return cost
-
-
-_XENT_DEPRECATION = """
-Future major versions of TensorFlow will allow gradients to flow
-into the labels input on backprop by default.
-
-See tf.nn.softmax_cross_entropy_with_logits_v2.
-"""
-
-
-@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
-def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
- labels=None, logits=None,
- dim=-1, name=None):
- """Computes softmax cross entropy between `logits` and `labels`.
-
- Measures the probability error in discrete classification tasks in which the
- classes are mutually exclusive (each entry is in exactly one class). For
- example, each CIFAR-10 image is labeled with one and only one label: an image
- can be a dog or a truck, but not both.
-
- **NOTE:** While the classes are mutually exclusive, their probabilities
- need not be. All that is required is that each row of `labels` is
- a valid probability distribution. If they are not, the computation of the
- gradient will be incorrect.
+ logits = ops.convert_to_tensor(logits)
+ labels = ops.convert_to_tensor(labels)
+ precise_logits = math_ops.cast(logits, dtypes.float32) if (
+ logits.dtype == dtypes.float16) else logits
+ # labels and logits must be of the same type
+ labels = math_ops.cast(labels, precise_logits.dtype)
+ input_rank = array_ops.rank(precise_logits)
+ # For shape inference.
+ shape = logits.get_shape()
- If using exclusive `labels` (wherein one and only
- one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
+ # Move the dim to the end if dim is not the last dimension.
+ if dim is not -1:
+ def _move_dim_to_end(tensor, dim_index, rank):
+ return array_ops.transpose(tensor,
+ array_ops.concat([
+ math_ops.range(dim_index),
+ math_ops.range(dim_index + 1, rank),
+ [dim_index]
+ ], 0))
- **WARNING:** This op expects unscaled logits, since it performs a `softmax`
- on `logits` internally for efficiency. Do not call this op with the
- output of `softmax`, as it will produce incorrect results.
+ precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
+ labels = _move_dim_to_end(labels, dim, input_rank)
- `logits` and `labels` must have the same shape, e.g.
- `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
- or `float64`).
+ input_shape = array_ops.shape(precise_logits)
- Backpropagation will happen only into `logits`. To calculate a cross entropy
- loss that allows backpropagation into both `logits` and `labels`, see
- @{tf.nn.softmax_cross_entropy_with_logits_v2}.
+ # Make precise_logits and labels into matrices.
+ precise_logits = _flatten_outer_dims(precise_logits)
+ labels = _flatten_outer_dims(labels)
- **Note that to avoid confusion, it is required to pass only named arguments to
- this function.**
+ # Do the actual op computation.
+ # The second output tensor contains the gradients. We use it in
+ # _CrossEntropyGrad() in nn_grad but not here.
+ cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+ precise_logits, labels, name=name)
- Args:
- _sentinel: Used to prevent positional parameters. Internal, do not use.
- labels: Each row `labels[i]` must be a valid probability distribution.
- logits: Unscaled log probabilities.
- dim: The class dimension. Defaulted to -1 which is the last dimension.
- name: A name for the operation (optional).
+ # The output cost shape should be the input minus dim.
+ output_shape = array_ops.slice(input_shape, [0],
+ [math_ops.subtract(input_rank, 1)])
+ cost = array_ops.reshape(cost, output_shape)
- Returns:
- A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
- softmax cross entropy loss.
- """
- _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
- labels, logits)
-
- with ops.name_scope(
- name, "softmax_cross_entropy_with_logits_sg", [logits, labels]) as name:
- labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
+ # Make shape inference work since reshape and transpose may erase its static
+ # shape.
+ if context.in_graph_mode() and shape is not None and shape.dims is not None:
+ shape = shape.as_list()
+ del shape[dim]
+ cost.set_shape(shape)
- return softmax_cross_entropy_with_logits_v2(
- labels=labels, logits=logits, dim=dim, name=name)
+ if logits.dtype == dtypes.float16:
+ return math_ops.cast(cost, dtypes.float16)
+ else:
+ return cost
def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
@@ -2305,6 +2244,100 @@ def conv1d(value, filters, stride, padding,
return array_ops.squeeze(result, [spatial_start_dim])
+def conv1d_transpose(value,
+ filter,
+ output_shape,
+ stride,
+ padding="SAME",
+ data_format="NWC",
+ name=None):
+ """The transpose of `conv1d`.
+
+ This operation is sometimes called "deconvolution" after [Deconvolutional
+ Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
+ actually the transpose (gradient) of `conv1d` rather than an actual
+ deconvolution.
+
+ Args:
+ value: A 3-D `Tensor` of type `float` and shape
+ `[batch, in_width, in_channels]` for `NWC` data format or
+ `[batch, in_channels, in_width]` for `NCW` data format.
+ filter: A 3-D `Tensor` with the same type as `value` and shape
+ `[filter_width, output_channels, in_channels]`. `filter`'s
+ `in_channels` dimension must match that of `value`.
+ output_shape: A 1-D `Tensor` representing the output shape of the
+ deconvolution op.
+ stride: An `integer`. The number of entries by which
+ the filter is moved right at each step.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ See the @{tf.nn.convolution$comment here}
+ data_format: A string. 'NHWC' and 'NCHW' are supported.
+ name: Optional name for the returned tensor.
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+
+ Raises:
+ ValueError: If input/output depth does not match `filter`'s shape, or if
+ padding is other than `'VALID'` or `'SAME'`.
+ """
+ with ops.name_scope(name, "conv1d_transpose",
+ [value, filter, output_shape]) as name:
+ output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
+ if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(3)):
+ raise ValueError("output_shape must have shape (3,), got {}"
+ .format(output_shape_.get_shape()))
+
+ # The format could be either NWC or NCW, map to NHWC or NCHW
+ if data_format is None or data_format == "NWC":
+ data_format_2d = "NHWC"
+ axis = 2
+ elif data_format == "NCW":
+ data_format_2d = "NCHW"
+ axis = 1
+ else:
+ raise ValueError("data_format must be \"NWC\" or \"NCW\".")
+
+ if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[2]):
+ raise ValueError("input channels does not match filter's input channels, "
+ "{} != {}".format(value.get_shape()[axis],
+ filter.get_shape()[2]))
+
+ if isinstance(output_shape, (list, np.ndarray)):
+ # output_shape's shape should be == [3] if reached this point.
+ if not filter.get_shape()[1].is_compatible_with(output_shape[axis]):
+ raise ValueError(
+ "output_shape does not match filter's output channels, "
+ "{} != {}".format(output_shape[axis], filter.get_shape()[1]))
+
+ if padding != "VALID" and padding != "SAME":
+ raise ValueError("padding must be either VALID or SAME:"
+ " {}".format(padding))
+
+ # Reshape the input tensor to [batch, 1, in_width, in_channels]
+ if data_format_2d == "NHWC":
+ output_shape_ = array_ops.concat([output_shape_[:1], [1],
+ output_shape_[1:]], axis=0)
+ spatial_start_dim = 1
+ strides = [1, 1, stride, 1]
+ else:
+ output_shape_ = array_ops.concat([output_shape_[:2], [1],
+ output_shape_[2:]], axis=0)
+ spatial_start_dim = 2
+ strides = [1, 1, 1, stride]
+ value = array_ops.expand_dims(value, spatial_start_dim)
+ filter = array_ops.expand_dims(filter, 0)
+
+ result = gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_,
+ filter=filter,
+ out_backprop=value,
+ strides=strides,
+ padding=padding,
+ data_format=data_format_2d,
+ name=name)
+ return array_ops.squeeze(result, [spatial_start_dim])
+
+
@ops.RegisterStatistics("Dilation2D", "flops")
def _calc_dilation2d_flops(graph, node):
"""Calculates the compute resources needed for Dilation2D."""