diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-28 12:18:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-28 13:32:10 -0700 |
commit | 7993cb5745657fbb8903a87d217441338303ac25 (patch) | |
tree | 8dc68b0ed77cddb880d1d0efe8575e4ded2fbfa7 /tensorflow/contrib/losses/python | |
parent | e861edc29a8a2babcfde95672e28cec5c31912e3 (diff) |
Use consistent naming for `labels` and `weights` args.
Fix asserts to pass expected before actual values.
Change: 137544983
Diffstat (limited to 'tensorflow/contrib/losses/python')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 368 | ||||
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops_test.py | 679 |
3 files changed, 680 insertions, 373 deletions
diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py index b76ba547a1..54c21684bd 100644 --- a/tensorflow/contrib/losses/python/losses/__init__.py +++ b/tensorflow/contrib/losses/python/losses/__init__.py @@ -30,10 +30,10 @@ log_loss penalty be twice as severe as the sum_of_squares_loss, we would implement this as: # Explicitely set the weight. - tf.contrib.losses.log(predictions, targets, weight=2.0) + tf.contrib.losses.log(predictions, labels, weight=2.0) # Uses default weight of 1.0 - tf.contrib.losses.sum_of_squares(predictions, targets) + tf.contrib.losses.sum_of_squares(predictions, labels) # All the losses are collected into the `GraphKeys.LOSSES` collection. losses = tf.get_collection(tf.GraphKeys.LOSSES) @@ -62,7 +62,7 @@ Finally, in certain cases, we may want to specify a different loss for every single measurable value. For example, if we are performing per-pixel depth prediction, or per-pixel denoising, a single batch sample has P values where P is the number of pixels in the image. For many losses, the number of measurable -values matches the number of elements in the predictions and targets tensors. +values matches the number of elements in the predictions and labels tensors. For others, such as softmax_cross_entropy and cosine_distance, the loss functions reduces the dimensions of the inputs to produces a tensor of losses for each measurable value. For example, softmax_cross_entropy takes as diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 023efd125d..7610f9275f 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -21,6 +21,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -45,17 +46,46 @@ __all__ = ["absolute_difference", "sparse_softmax_cross_entropy"] -def _scale_losses(losses, weight): +# TODO(b/32171727): Remove when deprecated `targets` is removed. +def _labels(labels, targets): + if labels is None: + labels = targets + elif targets is not None: + raise ValueError("Can not specify both `labels` and `targets`.") + if labels is None: + raise ValueError("Must provide 1 of `labels` and `targets`.") + return labels + + +# TODO(b/32171727): Remove when deprecated `weight` is removed. +_WEIGHT_SENTINEL = object() + + +# TODO(b/32171727): Remove when deprecated `weight` is removed. Also, restore +# weights=1.0 as default in all calling fns. +def _weights(weights, weight): + if weights is _WEIGHT_SENTINEL: + weights = weight + elif weight is not _WEIGHT_SENTINEL: + raise ValueError("Can not specify both `weights` and `weight`.") + if weights is None: + raise ValueError("`weights` cannot be None.") + if weights is _WEIGHT_SENTINEL: + weights = 1.0 + return weights + + +def _scale_losses(losses, weights): """Computes the scaled loss. Args: losses: A `Tensor` of size [batch_size, d1, ... dN]. - weight: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN]. + weights: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN]. The `losses` are reduced (tf.reduce_sum) until its dimension matches - that of `weight` at which point the reduced `losses` are element-wise - multiplied by `weight` and a final reduce_sum is computed on the result. + that of `weights` at which point the reduced `losses` are element-wise + multiplied by `weights` and a final reduce_sum is computed on the result. Conceptually, this operation is equivalent to broadcasting (tiling) - `weight` to be the same size as `losses`, performing an element-wise + `weights` to be the same size as `losses`, performing an element-wise multiplication, and summing the result. Returns: @@ -63,11 +93,11 @@ def _scale_losses(losses, weight): `losses`. """ # First, compute the sum of the losses over all elements: - start_index = max(0, weight.get_shape().ndims) + start_index = max(0, weights.get_shape().ndims) reduction_indices = list(range(start_index, losses.get_shape().ndims)) reduced_losses = math_ops.reduce_sum(losses, reduction_indices=reduction_indices) - reduced_losses = math_ops.mul(reduced_losses, weight) + reduced_losses = math_ops.mul(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) @@ -111,38 +141,42 @@ def _safe_mean(losses, num_present): return _safe_div(total_loss, num_present) -def compute_weighted_loss(losses, weight=1.0): +@deprecated_args( + "2016-11-25", "`weight` is being deprecated, use `weights`.", "weight") +def compute_weighted_loss( + losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL): """Computes the weighted loss. Args: losses: A tensor of size [batch_size, d1, ... dN]. - weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` that returns the weighted loss. Raises: - ValueError: If the weight is None or the shape is not compatible with the - losses shape or if the number of dimensions (rank) of either losses or - weight is missing. + ValueError: If `weights` is `None` or the shape is not compatible with + `losses`, or if the number of dimensions (rank) of either `losses` or + `weights` is missing. """ - if weight is None: - raise ValueError("`weight` cannot be None") + weights = _weights(weights, weight) + losses = ops.convert_to_tensor(losses) input_dtype = losses.dtype losses = math_ops.to_float(losses) - weight = math_ops.to_float(ops.convert_to_tensor(weight)) + weights = math_ops.to_float(ops.convert_to_tensor(weights)) if losses.get_shape().ndims is None: raise ValueError("losses.get_shape().ndims cannot be None") - if weight.get_shape().ndims is None: + weights_shape = weights.get_shape() + if weights_shape.ndims is None: raise ValueError("weight.get_shape().ndims cannot be None") - weight_shape = weight.get_shape() - if weight_shape.ndims > 1 and weight_shape.dims[-1].is_compatible_with(1): - weight = array_ops.squeeze(weight, [-1]) + if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1): + weights = array_ops.squeeze(weights, [-1]) - total_loss = _scale_losses(losses, weight) - num_present = _num_present(losses, weight) + total_loss = _scale_losses(losses, weights) + num_present = _num_present(losses, weights) mean_loss = _safe_mean(total_loss, num_present) # convert the result back to the input type mean_loss = math_ops.cast(mean_loss, input_dtype) @@ -150,19 +184,19 @@ def compute_weighted_loss(losses, weight=1.0): return mean_loss -def _num_present(losses, weight, per_batch=False): - """Computes the number of elements in the loss function induced by `weight`. +def _num_present(losses, weights, per_batch=False): + """Computes the number of elements in the loss function induced by `weights`. - A given weight tensor induces different numbers of usable elements in the - `losses` tensor. The `weight` tensor is broadcast across `losses` for all + A given weights tensor induces different numbers of usable elements in the + `losses` tensor. The `weights` tensor is broadcast across `losses` for all possible dimensions. For example, if `losses` is a tensor of dimension - [4, 5, 6, 3] and weight is a tensor of size [4, 5], then weight is, in effect, - tiled to match the size of `losses`. Following this effective tile, the total - number of present elements is the number of non-zero weights. + [4, 5, 6, 3] and `weights` is a tensor of size [4, 5], then `weights` is, in + effect, tiled to match the size of `losses`. Following this effective tile, + the total number of present elements is the number of non-zero weights. Args: losses: A tensor of size [batch_size, d1, ... dN]. - weight: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. per_batch: Whether to return the number of elements per batch or as a sum total. @@ -171,28 +205,28 @@ def _num_present(losses, weight, per_batch=False): `per_batch` is True, the value is returned as a tensor of size [batch_size]. Otherwise, a single scalar tensor is returned. """ - # If the weight is a scalar, its easy to compute: - if weight.get_shape().ndims == 0: + # If weights is a scalar, its easy to compute: + if weights.get_shape().ndims == 0: batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses), [0], [1]), []) num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)), math_ops.to_float(batch_size)) - num_per_batch = math_ops.select(math_ops.equal(weight, 0), + num_per_batch = math_ops.select(math_ops.equal(weights, 0), 0.0, num_per_batch) num_per_batch = math_ops.mul(array_ops.ones( array_ops.reshape(batch_size, [1])), num_per_batch) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) # First, count the number of nonzero weights: - if weight.get_shape().ndims >= 1: - reduction_indices = list(range(1, weight.get_shape().ndims)) + if weights.get_shape().ndims >= 1: + reduction_indices = list(range(1, weights.get_shape().ndims)) num_nonzero_per_batch = math_ops.reduce_sum( - math_ops.to_float(math_ops.not_equal(weight, 0)), + math_ops.to_float(math_ops.not_equal(weights, 0)), reduction_indices=reduction_indices) # Next, determine the number of elements that weight would broadcast to: broadcast_dims = array_ops.slice(array_ops.shape(losses), - [weight.get_shape().ndims], [-1]) + [weights.get_shape().ndims], [-1]) num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast) @@ -258,7 +292,14 @@ def get_total_loss(add_regularization_losses=True, name="total_loss"): return math_ops.add_n(losses, name=name) -def absolute_difference(predictions, targets, weight=1.0, scope=None): +@deprecated_args( + "2016-11-25", + "`targets` is being deprecated, use `labels`." + " `weight` is being deprecated, use `weights`.", + "targets", "weight") +def absolute_difference( + predictions, labels=None, weights=_WEIGHT_SENTINEL, scope=None, + targets=None, weight=_WEIGHT_SENTINEL): """Adds an Absolute Difference loss to the training procedure. `weight` acts as a coefficient for the loss. If a scalar is provided, then the @@ -271,31 +312,36 @@ def absolute_difference(predictions, targets, weight=1.0, scope=None): Args: predictions: The predicted outputs. - targets: The ground truth output tensor, same dimensions as 'predictions'. - weight: Coefficients for the loss a scalar, a tensor of shape + labels: The ground truth output tensor, same dimensions as 'predictions'. + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + targets: Deprecated alias for `labels`. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If the shape of `predictions` doesn't match that of `targets` or + ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weight` is invalid. """ + labels = _labels(labels, targets) + weights = _weights(weights, weight) with ops.name_scope(scope, "absolute_difference", - [predictions, targets]) as scope: - predictions.get_shape().assert_is_compatible_with(targets.get_shape()) - if weight is None: - raise ValueError("`weight` cannot be None") + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) - targets = math_ops.to_float(targets) - losses = math_ops.abs(math_ops.sub(predictions, targets)) - return compute_weighted_loss(losses, weight) + labels = math_ops.to_float(labels) + losses = math_ops.abs(math_ops.sub(predictions, labels)) + return compute_weighted_loss(losses, weights) -def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0, - label_smoothing=0, scope=None): +@deprecated_args( + "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") +def sigmoid_cross_entropy( + logits, multi_class_labels, weights=_WEIGHT_SENTINEL, label_smoothing=0, + scope=None, weight=_WEIGHT_SENTINEL): """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits. `weight` acts as a coefficient for the loss. If a scalar is provided, @@ -311,20 +357,23 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0, Args: logits: [batch_size, num_classes] logits outputs of the network . multi_class_labels: [batch_size, num_classes] target labels in (0, 1). - weight: Coefficients for the loss. The tensor must be a scalar, a tensor of + weights: Coefficients for the loss. The tensor must be a scalar, a tensor of shape [batch_size] or shape [batch_size, num_classes]. label_smoothing: If greater than 0 then smooth the labels. scope: The scope for the operations performed in computing the loss. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If the shape of `predictions` doesn't match that of `targets` or - if the shape of `weight` is invalid or if `weight` is None. + ValueError: If the shape of `logits` doesn't match that of + `multi_class_labels` or if the shape of `weight` is invalid, or if + `weight` is None. """ + weights = _weights(weights, weight) with ops.name_scope(scope, "sigmoid_cross_entropy_loss", - [logits, multi_class_labels]): + [logits, multi_class_labels, weights]): logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) @@ -335,11 +384,14 @@ def sigmoid_cross_entropy(logits, multi_class_labels, weight=1.0, losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels, name="xentropy") - return compute_weighted_loss(losses, weight) + return compute_weighted_loss(losses, weights) -def softmax_cross_entropy(logits, onehot_labels, weight=1.0, - label_smoothing=0, scope=None): +@deprecated_args( + "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") +def softmax_cross_entropy( + logits, onehot_labels, weights=_WEIGHT_SENTINEL, label_smoothing=0, + scope=None, weight=_WEIGHT_SENTINEL): """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits. `weight` acts as a coefficient for the loss. If a scalar is provided, @@ -354,10 +406,11 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0, Args: logits: [batch_size, num_classes] logits outputs of the network . onehot_labels: [batch_size, num_classes] target one_hot_encoded labels. - weight: Coefficients for the loss. The tensor must be a scalar or a tensor + weights: Coefficients for the loss. The tensor must be a scalar or a tensor of shape [batch_size]. label_smoothing: If greater than 0 then smooth the labels. scope: the scope for the operations performed in computing the loss. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. @@ -366,8 +419,9 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0, ValueError: If the shape of `logits` doesn't match that of `onehot_labels` or if the shape of `weight` is invalid or if `weight` is None. """ + weights = _weights(weights, weight) with ops.name_scope(scope, "softmax_cross_entropy_loss", - [logits, onehot_labels]): + [logits, onehot_labels, weights]): logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) onehot_labels = math_ops.cast(onehot_labels, logits.dtype) @@ -381,11 +435,15 @@ def softmax_cross_entropy(logits, onehot_labels, weight=1.0, losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels, name="xentropy") - return compute_weighted_loss(losses, weight) + return compute_weighted_loss(losses, weights) -def sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None): - """Cross-entropy loss using tf.nn.sparse_softmax_cross_entropy_with_logits. +@deprecated_args( + "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") +def sparse_softmax_cross_entropy( + logits, labels, weights=_WEIGHT_SENTINEL, scope=None, + weight=_WEIGHT_SENTINEL): + """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`. `weight` acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If `weight` is a @@ -396,9 +454,10 @@ def sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None): logits: [batch_size, num_classes] logits outputs of the network . labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or `int64` in the range `[0, num_classes)`. - weight: Coefficients for the loss. The tensor must be a scalar or a tensor + weights: Coefficients for the loss. The tensor must be a scalar or a tensor of shape [batch_size] or [batch_size, 1]. scope: the scope for the operations performed in computing the loss. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. @@ -407,17 +466,25 @@ def sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None): ValueError: If the shapes of logits, labels, and weight are incompatible, or if `weight` is None. """ + weights = _weights(weights, weight) with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", - [logits, labels]): + [logits, labels, weights]): labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]]) - weight = array_ops.squeeze(weight) + weights = array_ops.squeeze(weights) losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name="xentropy") - return compute_weighted_loss(losses, weight) + return compute_weighted_loss(losses, weights) -def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None): +@deprecated_args( + "2016-11-25", + "`targets` is being deprecated, use `labels`." + " `weight` is being deprecated, use `weights`.", + "targets", "weight") +def log_loss( + predictions, labels=None, weights=_WEIGHT_SENTINEL, epsilon=1e-7, + scope=None, targets=None, weight=_WEIGHT_SENTINEL): """Adds a Log Loss term to the training procedure. `weight` acts as a coefficient for the loss. If a scalar is provided, then the @@ -430,60 +497,72 @@ def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None): Args: predictions: The predicted outputs. - targets: The ground truth output tensor, same dimensions as 'predictions'. - weight: Coefficients for the loss a scalar, a tensor of shape + labels: The ground truth output tensor, same dimensions as 'predictions'. + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. epsilon: A small increment to add to avoid taking a log of zero. scope: The scope for the operations performed in computing the loss. + targets: Deprecated alias for `labels`. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If the shape of `predictions` doesn't match that of `targets` or + ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weight` is invalid. """ + labels = _labels(labels, targets) + weights = _weights(weights, weight) with ops.name_scope(scope, "log_loss", - [predictions, targets]) as scope: - predictions.get_shape().assert_is_compatible_with(targets.get_shape()) - if weight is None: - raise ValueError("`weight` cannot be None") + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) - targets = math_ops.to_float(targets) + labels = math_ops.to_float(labels) losses = -math_ops.mul( - targets, + labels, math_ops.log(predictions + epsilon)) - math_ops.mul( - (1 - targets), math_ops.log(1 - predictions + epsilon)) - return compute_weighted_loss(losses, weight) + (1 - labels), math_ops.log(1 - predictions + epsilon)) + return compute_weighted_loss(losses, weights) -def hinge_loss(logits, target, scope=None): +@deprecated_args( + "2016-11-25", "`target` is being deprecated, use `labels`.", "target") +def hinge_loss(logits, labels=None, scope=None, target=None): """Method that returns the loss tensor for hinge loss. Args: logits: The logits, a float tensor. - target: The ground truth output tensor. Its shape should match the shape of + labels: The ground truth output tensor. Its shape should match the shape of logits. The values of the tensor are expected to be 0.0 or 1.0. scope: The scope for the operations performed in computing the loss. + target: Deprecated alias for `labels`. Returns: A `Tensor` of same shape as logits and target representing the loss values across the batch. Raises: - ValueError: If the shapes of `logits` and `target` don't match. + ValueError: If the shapes of `logits` and `labels` don't match. """ - with ops.name_scope(scope, "hinge_loss", [logits, target]) as scope: - logits.get_shape().assert_is_compatible_with(target.get_shape()) + labels = _labels(labels, target) + with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope: + logits.get_shape().assert_is_compatible_with(labels.get_shape()) # We first need to convert binary labels to -1/1 labels (as floats). - target = math_ops.to_float(target) - all_ones = array_ops.ones_like(target) - labels = math_ops.sub(2 * target, all_ones) - losses = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) - return losses - - -def mean_squared_error(predictions, targets, weight=1.0, scope=None): + labels = math_ops.to_float(labels) + all_ones = array_ops.ones_like(labels) + labels = math_ops.sub(2 * labels, all_ones) + return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) + + +@deprecated_args( + "2016-11-25", + "`targets` is being deprecated, use `labels`." + " `weight` is being deprecated, use `weights`.", + "targets", "weight") +def mean_squared_error( + predictions, labels=None, weights=_WEIGHT_SENTINEL, scope=None, + targets=None, weight=_WEIGHT_SENTINEL): """Adds a Sum-of-Squares loss to the training procedure. `weight` acts as a coefficient for the loss. If a scalar is provided, then the @@ -496,38 +575,47 @@ def mean_squared_error(predictions, targets, weight=1.0, scope=None): Args: predictions: The predicted outputs. - targets: The ground truth output tensor, same dimensions as 'predictions'. - weight: Coefficients for the loss a scalar, a tensor of shape + labels: The ground truth output tensor, same dimensions as 'predictions'. + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + targets: Deprecated alias for `labels`. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If the shape of `predictions` doesn't match that of `targets` or + ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weight` is invalid. """ + labels = _labels(labels, targets) + weights = _weights(weights, weight) with ops.name_scope(scope, "mean_squared_error", - [predictions, targets]) as scope: - predictions.get_shape().assert_is_compatible_with(targets.get_shape()) - if weight is None: - raise ValueError("`weight` cannot be None") + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) - targets = math_ops.to_float(targets) - losses = math_ops.square(math_ops.sub(predictions, targets)) - return compute_weighted_loss(losses, weight) - - -def mean_pairwise_squared_error(predictions, targets, weight=1.0, scope=None): + labels = math_ops.to_float(labels) + losses = math_ops.square(math_ops.sub(predictions, labels)) + return compute_weighted_loss(losses, weights) + + +@deprecated_args( + "2016-11-25", + "`targets` is being deprecated, use `labels`." + " `weight` is being deprecated, use `weights`.", + "targets", "weight") +def mean_pairwise_squared_error( + predictions, labels=None, weights=_WEIGHT_SENTINEL, scope=None, + targets=None, weight=_WEIGHT_SENTINEL): """Adds a pairwise-errors-squared loss to the training procedure. Unlike `mean_squared_error`, which is a measure of the differences between - corresponding elements of `predictions` and `targets`, + corresponding elements of `predictions` and `labels`, `mean_pairwise_squared_error` is a measure of the differences between pairs of - corresponding elements of `predictions` and `targets`. + corresponding elements of `predictions` and `labels`. - For example, if `targets`=[a, b, c] and `predictions`=[x, y, z], there are + For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are three pairs of differences are summed to compute the loss: loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3 @@ -545,42 +633,44 @@ def mean_pairwise_squared_error(predictions, targets, weight=1.0, scope=None): Args: predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN] where N+1 is the total number of dimensions in `predictions`. - targets: The ground truth output tensor, whose shape must match the shape of + labels: The ground truth output tensor, whose shape must match the shape of the `predictions` tensor. - weight: Coefficients for the loss a scalar, a tensor of shape [batch_size] + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + targets: Deprecated alias for `labels`. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If the shape of `predictions` doesn't match that of `targets` or + ValueError: If the shape of `predictions` doesn't match that of `labels` or if the shape of `weight` is invalid. """ + labels = _labels(labels, targets) + weights = _weights(weights, weight) with ops.name_scope(scope, "mean_pairwise_squared_error", - [predictions, targets]) as scope: - predictions.get_shape().assert_is_compatible_with(targets.get_shape()) - if weight is None: - raise ValueError("`weight` cannot be None") + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) - targets = math_ops.to_float(targets) - weight = math_ops.to_float(ops.convert_to_tensor(weight)) + labels = math_ops.to_float(labels) + weights = math_ops.to_float(ops.convert_to_tensor(weights)) - diffs = math_ops.sub(predictions, targets) + diffs = math_ops.sub(predictions, labels) # Need to verify here since the function doesn't use compute_weighted_loss if diffs.get_shape().ndims is None: raise ValueError("diffs.get_shape().ndims cannot be None") - if weight.get_shape().ndims is None: - raise ValueError("weight.get_shape().ndims cannot be None") + if weights.get_shape().ndims is None: + raise ValueError("weights.get_shape().ndims cannot be None") reduction_indices = list(range(1, diffs.get_shape().ndims)) sum_squares_diff_per_batch = math_ops.reduce_sum( math_ops.square(diffs), reduction_indices=reduction_indices) - num_present_per_batch = _num_present(diffs, weight, per_batch=True) + num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch) @@ -589,7 +679,7 @@ def mean_pairwise_squared_error(predictions, targets, weight=1.0, scope=None): term2 = 2.0 * _safe_div(math_ops.square(sum_diff), math_ops.square(num_present_per_batch)) - loss = _scale_losses(term1 - term2, weight) + loss = _scale_losses(term1 - term2, weights) mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0, loss, @@ -599,37 +689,47 @@ def mean_pairwise_squared_error(predictions, targets, weight=1.0, scope=None): return mean_loss -def cosine_distance(predictions, targets, dim, weight=1.0, scope=None): +@deprecated_args( + "2016-11-25", + "`targets` is being deprecated, use `labels`." + " `weight` is being deprecated, use `weights`.", + "targets", "weight") +def cosine_distance( + predictions, labels=None, dim=None, weights=_WEIGHT_SENTINEL, scope=None, + targets=None, weight=_WEIGHT_SENTINEL): """Adds a cosine-distance loss to the training procedure. - Note that the function assumes that the predictions and targets are already + Note that the function assumes that `predictions` and `labels` are already unit-normalized. Args: predictions: An arbitrary matrix. - targets: A `Tensor` whose shape matches 'predictions' + labels: A `Tensor` whose shape matches 'predictions' dim: The dimension along which the cosine distance is computed. - weight: Coefficients for the loss a scalar, a tensor of shape + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + targets: Deprecated alias for `labels`. + weight: Deprecated alias for `weights`. Returns: A scalar `Tensor` representing the loss value. Raises: - ValueError: If predictions.shape doesn't match targets.shape, if the ignore - mask is provided and its shape doesn't match targets.shape or if - the ignore mask is not boolean valued. + ValueError: If `predictions` shape doesn't match `labels` shape, or + `weights` is `None`. """ + labels = _labels(labels, targets) + weights = _weights(weights, weight) + if dim is None: + raise ValueError("`dim` cannot be None.") with ops.name_scope(scope, "cosine_distance_loss", - [predictions, targets]) as scope: - predictions.get_shape().assert_is_compatible_with(targets.get_shape()) - if weight is None: - raise ValueError("`weight` cannot be None") + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) - targets = math_ops.to_float(targets) + labels = math_ops.to_float(labels) - radial_diffs = math_ops.mul(predictions, targets) + radial_diffs = math_ops.mul(predictions, labels) losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) - return compute_weighted_loss(losses, weight) + return compute_weighted_loss(losses, weights) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py index 363caf4f3d..d96a63da6c 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py @@ -28,12 +28,16 @@ class AbsoluteDifferenceLossTest(tf.test.TestCase): def setUp(self): self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) - self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): with self.assertRaises(ValueError): tf.contrib.losses.absolute_difference( + self._predictions, self._predictions, weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.absolute_difference( self._predictions, self._predictions, weight=None) def testAllCorrectNoLossWeight(self): @@ -44,56 +48,79 @@ class AbsoluteDifferenceLossTest(tf.test.TestCase): def testNonZeroLoss(self): loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets) + self._predictions, self._labels) with self.test_session(): self.assertAlmostEqual(5.5, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): - self.assertAlmostEqual(5.5 * weight, loss.eval(), 3) + self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) + + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + w = 2.3 + weighted_loss_tensors = ( + tf.contrib.losses.absolute_difference( + self._predictions, self._labels, w), + tf.contrib.losses.absolute_difference( + self._predictions, self._labels, weights=w), + tf.contrib.losses.absolute_difference( + self._predictions, self._labels, weight=w), + tf.contrib.losses.absolute_difference( + self._predictions, labels=self._labels, weights=w), + tf.contrib.losses.absolute_difference( + self._predictions, labels=self._labels, weight=w), + tf.contrib.losses.absolute_difference( + self._predictions, targets=self._labels, weights=w), + tf.contrib.losses.absolute_difference( + self._predictions, targets=self._labels, weight=w), + ) + with self.test_session(): + for weighted_loss_tensor in weighted_loss_tensors: + self.assertAlmostEqual(5.5 * w, weighted_loss_tensor.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, tf.constant(weight)) + self._predictions, self._labels, tf.constant(weights)) with self.test_session(): - self.assertAlmostEqual(5.5 * weight, loss.eval(), 3) + self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): - weight = tf.constant([1.2, 0.0], shape=[2,]) + weights = tf.constant([1.2, 0.0], shape=[2,]) loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(5.6, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeights(self): - weight = tf.constant([1.2, 0.0], shape=[2, 1]) + weights = tf.constant([1.2, 0.0], shape=[2, 1]) loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(5.6, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeights(self): - weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) + weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(16.6, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZero(self): - weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) + weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(6.0, loss.eval(), 3) def testLossWithSampleSpecificWeightsAllZero(self): - weight = tf.zeros((2, 3)) + weights = tf.zeros((2, 3)) loss = tf.contrib.losses.absolute_difference( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -109,6 +136,9 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): [0, 0, 1]]) with self.test_session(): with self.assertRaises(ValueError): + tf.contrib.losses.softmax_cross_entropy(logits, labels, weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): tf.contrib.losses.softmax_cross_entropy(logits, labels, weight=None) def testAllCorrect(self): @@ -120,7 +150,7 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): [0, 1, 0], [0, 0, 1]]) loss = tf.contrib.losses.softmax_cross_entropy(logits, labels) - self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') + self.assertEquals('softmax_cross_entropy_loss/value', loss.op.name) self.assertAlmostEqual(loss.eval(), 0.0, 3) def testAllWrong(self): @@ -143,10 +173,31 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = 2.3 + weights = 2.3 with self.test_session(): - loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight) - self.assertAlmostEqual(loss.eval(), weight * 10.0, 3) + loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weights) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + w = 2.3 + with self.test_session(): + loss_tensors = ( + tf.contrib.losses.softmax_cross_entropy(logits, labels, w), + tf.contrib.losses.softmax_cross_entropy(logits, labels, weights=w), + tf.contrib.losses.softmax_cross_entropy(logits, labels, weight=w), + tf.contrib.losses.softmax_cross_entropy( + logits, onehot_labels=labels, weights=w), + tf.contrib.losses.softmax_cross_entropy( + logits, onehot_labels=labels, weight=w)) + for loss_tensor in loss_tensors: + self.assertAlmostEqual(w * 10.0, loss_tensor.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): logits = tf.constant([[10.0, 0.0, 0.0], @@ -155,11 +206,11 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = 2.3 + weights = 2.3 with self.test_session(): loss = tf.contrib.losses.softmax_cross_entropy( - logits, labels, tf.constant(weight)) - self.assertAlmostEqual(loss.eval(), weight * 10.0, 3) + logits, labels, tf.constant(weights)) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): logits = tf.constant([[10.0, 0.0, 0.0], @@ -168,10 +219,10 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = tf.constant([1.2, 3.4, 5.6], shape=[3]) + weights = tf.constant([1.2, 3.4, 5.6], shape=[3]) with self.test_session(): - loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight) - self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3) + loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) def testAllWrongAllWeightsMissing(self): logits = tf.constant([[10.0, 0.0, 0.0], @@ -180,10 +231,10 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = tf.constant([0, 0, 0], shape=[3]) + weights = tf.constant([0, 0, 0], shape=[3]) with self.test_session(): - loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight) - self.assertAlmostEqual(loss.eval(), 0.0, 3) + loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weights) + self.assertAlmostEqual(0.0, loss.eval(), 3) def testSomeWeightsMissing(self): logits = tf.constant([[10.0, 0.0, 0.0], @@ -192,10 +243,10 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = tf.constant([1.2, 0, 0], shape=[3]) + weights = tf.constant([1.2, 0, 0], shape=[3]) with self.test_session(): - loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weight) - self.assertAlmostEqual(loss.eval(), 12.0, 3) + loss = tf.contrib.losses.softmax_cross_entropy(logits, labels, weights) + self.assertAlmostEqual(12.0, loss.eval(), 3) def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self): with self.test_session(): @@ -205,13 +256,17 @@ class SoftmaxCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - weight = tf.constant([[3, 4, 5], - [2, 6, 0], - [8, 0, 1]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) with self.assertRaises(ValueError): tf.contrib.losses.softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.softmax_cross_entropy( + logits, labels, weight=weights).eval() def testSoftmaxLabelSmoothing(self): with self.test_session(): @@ -245,6 +300,10 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): with self.test_session(): with self.assertRaises(ValueError): tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.sparse_softmax_cross_entropy( logits, labels, weight=None) def testAllCorrectInt32Labels(self): @@ -315,66 +374,88 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = 2.3 + weights = 2.3 with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight) - self.assertAlmostEqual(loss.eval(), weight * 10.0, 3) + logits, labels, weights) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + w = 2.3 + with self.test_session(): + loss_tensors = ( + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, w), + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weights=w), + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=w), + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels=labels, weights=w), + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels=labels, weight=w)) + for loss_tensor in loss_tensors: + self.assertAlmostEqual(w * 10.0, loss_tensor.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): logits = tf.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = 2.3 + weights = 2.3 with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, tf.constant(weight)) - self.assertAlmostEqual(loss.eval(), weight * 10.0, 3) + logits, labels, tf.constant(weights)) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): logits = tf.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = tf.constant([1.2, 3.4, 5.6], shape=[3]) + weights = tf.constant([1.2, 3.4, 5.6], shape=[3]) with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight) - self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3) + logits, labels, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) def testNonZeroLossWithColumnWeights(self): logits = tf.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = tf.constant([[1.2], [3.4], [5.6]]) + weights = tf.constant([[1.2], [3.4], [5.6]]) with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight) - self.assertAlmostEqual(loss.eval(), (1.2 + 3.4 + 5.6) * 10.0 / 3.0, 3) + logits, labels, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) def testAllWrongAllWeightsMissing(self): logits = tf.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = tf.constant([0, 0, 0], shape=[3]) + weights = tf.constant([0, 0, 0], shape=[3]) with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight) - self.assertAlmostEqual(loss.eval(), 0.0, 3) + logits, labels, weights) + self.assertAlmostEqual(0.0, loss.eval(), 3) def testSomeWeightsMissing(self): logits = tf.constant([[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]) labels = tf.constant([[2], [0], [1]]) - weight = tf.constant([1.2, 0, 0], shape=[3]) + weights = tf.constant([1.2, 0, 0], shape=[3]) with self.test_session(): loss = tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight) - self.assertAlmostEqual(loss.eval(), 12.0, 3) + logits, labels, weights) + self.assertAlmostEqual(12.0, loss.eval(), 3) def testMeasurementSpecificWeightsRaisesException(self): with self.test_session(): @@ -382,13 +463,17 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) labels = tf.constant([[0], [1], [2]]) - weight = tf.constant([[3, 4, 5], - [2, 6, 0], - [8, 0, 1]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) with self.assertRaises(ValueError): tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=weights).eval() def testInconsistentWeightSizeRaisesException(self): """The weight tensor has incorrect number of elements.""" @@ -397,11 +482,15 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) labels = tf.constant([[0], [1], [2]]) - weight = tf.constant([1.2, 3.4, 5.6, 7.8]) + weights = tf.constant([1.2, 3.4, 5.6, 7.8]) with self.assertRaises(ValueError): tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=weights).eval() def testInconsistentLabelSizeRaisesException(self): """The label tensor has incorrect number of elements.""" @@ -410,11 +499,15 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [-100.0, 100.0, -100.0], [-100.0, -100.0, 100.0]]) labels = tf.constant([[0], [1], [2], [3]]) - weight = tf.constant([1.2, 3.4, 5.6]) + weights = tf.constant([1.2, 3.4, 5.6]) with self.assertRaises(ValueError): tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=weights).eval() def testInconsistentWeightShapeRaisesException(self): """The weight tensor has incorrect shape.""" @@ -424,11 +517,15 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [-100.0, -100.0, 100.0, -100.0], [-100.0, -100.0, -100.0, 100.0]]) labels = tf.constant([[0], [1], [2], [3]]) - weight = tf.constant([[1.2, 3.4], [5.6, 7.8]]) + weights = tf.constant([[1.2, 3.4], [5.6, 7.8]]) with self.assertRaises(ValueError): tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=weights).eval() def testInconsistentLabelShapeRaisesException(self): """The label tensor has incorrect shape.""" @@ -438,11 +535,15 @@ class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): [-100.0, -100.0, 100.0, -100.0], [-100.0, -100.0, -100.0, 100.0]]) labels = tf.constant([[0, 1], [2, 3]]) - weight = tf.constant([1.2, 3.4, 5.6, 7.8]) + weights = tf.constant([1.2, 3.4, 5.6, 7.8]) with self.assertRaises(tf.errors.InvalidArgumentError): tf.contrib.losses.sparse_softmax_cross_entropy( - logits, labels, weight=weight).eval() + logits, labels, weights=weights).eval() + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(tf.errors.InvalidArgumentError): + tf.contrib.losses.sparse_softmax_cross_entropy( + logits, labels, weight=weights).eval() class SigmoidCrossEntropyLossTest(tf.test.TestCase): @@ -457,35 +558,35 @@ class SigmoidCrossEntropyLossTest(tf.test.TestCase): [0, 0, 1]]) loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels) self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') - self.assertAlmostEqual(loss.eval(), 0.0, 3) + self.assertAlmostEqual(0.0, loss.eval(), 3) def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self): logits = tf.placeholder(tf.float32, shape=(None, 1)) labels = tf.placeholder(tf.float32, shape=(None, 1)) - weight = tf.ones_like(logits, dtype=tf.float32) + weights = tf.ones_like(logits, dtype=tf.float32) - loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels, weight) + loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels, weights) with self.test_session() as sess: loss = sess.run(loss, feed_dict={ logits: np.ones((32, 1)), labels: np.ones((32, 1)), }) - self.assertAlmostEqual(loss, 0.313, 3) + self.assertAlmostEqual(0.313, loss, 3) def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self): logits = tf.placeholder(tf.float32, shape=(None, 2)) labels = tf.placeholder(tf.float32, shape=(None, 2)) - weight = tf.ones_like(logits, dtype=tf.float32) + weights = tf.ones_like(logits, dtype=tf.float32) - loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels, weight) + loss = tf.contrib.losses.sigmoid_cross_entropy(logits, labels, weights) with self.test_session() as sess: loss = sess.run(loss, feed_dict={ logits: np.ones((32, 2)), labels: np.ones((32, 2)), }) - self.assertAlmostEqual(loss, 0.313, 3) + self.assertAlmostEqual(0.313, loss, 3) def testAllWrongSigmoid(self): with self.test_session(): @@ -507,13 +608,13 @@ class SigmoidCrossEntropyLossTest(tf.test.TestCase): labels = tf.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - weight = tf.constant([[3, 4, 5], - [2, 6, 0], - [8, 0, 1]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) loss = tf.contrib.losses.sigmoid_cross_entropy( - logits, labels, weight=weight) + logits, labels, weights) self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') - self.assertAlmostEqual(loss.eval(), 1700.0 / 7.0, 3) + self.assertAlmostEqual(1700.0 / 7.0, loss.eval(), 3) def testMultiCorrectSigmoid(self): logits = tf.constant([[100.0, -100.0, 100.0], @@ -569,170 +670,193 @@ class LogLossTest(tf.test.TestCase): def setUp(self): predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3)) - targets = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3)) + labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3)) self._np_predictions = predictions - self._np_targets = targets + self._np_labels = labels epsilon = 1e-7 self._expected_losses = np.multiply( - targets, np.log(predictions + epsilon)) + np.multiply( - 1 - targets, np.log(1 - predictions + epsilon)) + labels, np.log(predictions + epsilon)) + np.multiply( + 1 - labels, np.log(1 - predictions + epsilon)) self._predictions = tf.constant(predictions) - self._targets = tf.constant(targets) + self._labels = tf.constant(labels) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): with self.assertRaises(ValueError): - tf.contrib.losses.log_loss(self._targets, self._targets, weight=None) + tf.contrib.losses.log_loss(self._labels, self._labels, weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.log_loss(self._labels, self._labels, weight=None) def testAllCorrectNoLossWeight(self): - loss = tf.contrib.losses.log_loss(self._targets, self._targets) + loss = tf.contrib.losses.log_loss(self._labels, self._labels) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testAllCorrectNoLossWeightWithPlaceholder(self): - tf_predictions = tf.placeholder(tf.float32, shape=self._np_targets.shape) - loss = tf.contrib.losses.log_loss(tf_predictions, self._targets) + tf_predictions = tf.placeholder(tf.float32, shape=self._np_labels.shape) + loss = tf.contrib.losses.log_loss(tf_predictions, self._labels) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(feed_dict={ - tf_predictions: self._np_targets}), 3) + tf_predictions: self._np_labels}), 3) def testNonZeroLoss(self): - loss = tf.contrib.losses.log_loss(self._predictions, self._targets) + loss = tf.contrib.losses.log_loss(self._predictions, self._labels) with self.test_session(): self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): - self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0, + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss.eval(), 3) + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + w = 2.3 + loss_tensors = ( + tf.contrib.losses.log_loss(self._predictions, self._labels, w), + tf.contrib.losses.log_loss(self._predictions, self._labels, weights=w), + tf.contrib.losses.log_loss(self._predictions, self._labels, weight=w), + tf.contrib.losses.log_loss( + self._predictions, labels=self._labels, weights=w), + tf.contrib.losses.log_loss( + self._predictions, labels=self._labels, weight=w), + tf.contrib.losses.log_loss( + self._predictions, targets=self._labels, weights=w), + tf.contrib.losses.log_loss( + self._predictions, targets=self._labels, weight=w)) + with self.test_session(): + for loss_tensor in loss_tensors: + self.assertAlmostEqual( + w * -np.sum(self._expected_losses) / 6.0, loss_tensor.eval(), 3) + def testNonZeroLossWithScalarTensorWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, tf.constant(weight)) + self._predictions, self._labels, tf.constant(weights)) with self.test_session(): - self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0, + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self): tf_predictions = tf.placeholder(tf.float32, shape=self._np_predictions.shape) - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.log_loss( - tf_predictions, self._targets, tf.constant(weight)) + tf_predictions, self._labels, tf.constant(weights)) with self.test_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) - self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0, + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss, 3) def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self): tf_predictions = tf.placeholder(tf.float32, shape=[None, None]) - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.log_loss( - tf_predictions, self._targets, tf.constant(weight)) + tf_predictions, self._labels, tf.constant(weights)) with self.test_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) - self.assertAlmostEqual(weight * -np.sum(self._expected_losses) / 6.0, + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, loss, 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): - weight = tf.constant([1.2, 3.4], shape=[2]) + weights = tf.constant([1.2, 3.4], shape=[2]) expected_losses = np.multiply( self._expected_losses, np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))) loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self): - weight = tf.constant([1.2, 0], shape=[2]) + weights = tf.constant([1.2, 0], shape=[2]) expected_losses = np.multiply( self._expected_losses, np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3))) loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self): - weight = tf.constant([1.2, 0], shape=[2, 1]) + weights = tf.constant([1.2, 0], shape=[2, 1]) expected_losses = np.multiply( self._expected_losses, np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3))) loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3) def testWeightsWithSameNumDimsButWrongShapeThrowsException(self): - weight = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) + weights = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) with self.test_session(): with self.assertRaises(ValueError): - tf.contrib.losses.log_loss(self._predictions, self._targets, weight) + tf.contrib.losses.log_loss(self._predictions, self._labels, weights) def testNonZeroLossWithMeasurementSpecificWeights(self): - weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) - expected_losses = np.multiply(self._expected_losses, weight) + weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) loss = tf.contrib.losses.log_loss( self._predictions, - self._targets, - weight=tf.constant(weight, shape=(2, 3))) + self._labels, + tf.constant(weights, shape=(2, 3))) with self.test_session(): self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3) def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self): - weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) - expected_losses = np.multiply(self._expected_losses, weight) + weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) tf_predictions = tf.placeholder(tf.float32, shape=[2, 3]) loss = tf.contrib.losses.log_loss( tf_predictions, - self._targets, - weight=tf.constant(weight, shape=(2, 3))) + self._labels, + tf.constant(weights, shape=(2, 3))) with self.test_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3) def testNonZeroLossWithSampleSpecificWeightsMostZero(self): - weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) - expected_losses = np.multiply(self._expected_losses, weight) + weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) loss = tf.contrib.losses.log_loss( self._predictions, - self._targets, - weight=tf.constant(weight, shape=(2, 3))) + self._labels, + tf.constant(weights, shape=(2, 3))) with self.test_session(): self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self): - weight = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) - expected_losses = np.multiply(self._expected_losses, weight) + weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) tf_predictions = tf.placeholder(tf.float32, shape=[2, 3]) - tf_weight = tf.constant(weight, shape=(2, 3)) - loss = tf.contrib.losses.log_loss(tf_predictions, self._targets, tf_weight) + tf_weights = tf.constant(weights, shape=(2, 3)) + loss = tf.contrib.losses.log_loss(tf_predictions, self._labels, tf_weights) with self.test_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) self.assertAlmostEqual(-np.sum(expected_losses), loss, 3) def testLossWithSampleSpecificWeightsAllZero(self): - tf_weight = tf.zeros(shape=(2, 3)) + tf_weights = tf.zeros(shape=(2, 3)) loss = tf.contrib.losses.log_loss( - self._predictions, self._targets, tf_weight) + self._predictions, self._labels, tf_weights) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -742,22 +866,22 @@ class HingeLossTest(tf.test.TestCase): def testIncompatibleShapes(self): with self.test_session(): logits = tf.constant([[-1.0], [2.1]]) - target = tf.constant([0.0, 1.0]) + labels = tf.constant([0.0, 1.0]) with self.assertRaises(ValueError): - _ = tf.contrib.losses.hinge_loss(logits, target).eval() + _ = tf.contrib.losses.hinge_loss(logits, labels).eval() def testAllOutsideMargin(self): with self.test_session(): logits = tf.constant([1.2, -1.4, -1.0, 2.1]) - target = tf.constant([1.0, 0.0, 0.0, 1.0]) - loss = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + loss = tf.contrib.losses.hinge_loss(logits, labels) self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3) def testSomeInsideMargin(self): with self.test_session(): logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]]) - target = tf.constant([[0.0], [0.0], [1.0], [1.0]]) - loss = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([[0.0], [0.0], [1.0], [1.0]]) + loss = tf.contrib.losses.hinge_loss(logits, labels) # Examples 1 and 4 are on the correct side of the hyperplane but within # the margin so they incur some (small) loss. self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3) @@ -765,8 +889,8 @@ class HingeLossTest(tf.test.TestCase): def testSomeMisclassified(self): with self.test_session(): logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) - target = tf.constant([[[1.0], [0.0], [0.0], [1.0]]]) - loss = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([[[1.0], [0.0], [0.0], [1.0]]]) + loss = tf.contrib.losses.hinge_loss(logits, labels) # Examples 2 and 4 are on the wrong side of the hyperplane so they incur # some (fairly large) loss. self.assertAllClose( @@ -777,12 +901,15 @@ class MeanSquaredErrorTest(tf.test.TestCase): def setUp(self): self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) - self._targets = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): with self.assertRaises(ValueError): tf.contrib.losses.mean_squared_error( + self._predictions, self._predictions, weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + tf.contrib.losses.mean_squared_error( self._predictions, self._predictions, weight=None) def testAllCorrectNoLossWeight(self): @@ -793,56 +920,78 @@ class MeanSquaredErrorTest(tf.test.TestCase): def testNonZeroLoss(self): loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets) + self._predictions, self._labels) with self.test_session(): self.assertAlmostEqual(49.5, loss.eval(), 3) def testNonZeroLossWithPythonScalarWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): - self.assertAlmostEqual(49.5 * weight, loss.eval(), 3) + self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, tf.constant(weight)) + self._predictions, self._labels, tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) + + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + w = 2.3 + loss_tensors = ( + tf.contrib.losses.mean_squared_error( + self._predictions, self._labels, tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, self._labels, weights=tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, self._labels, weight=tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, labels=self._labels, weights=tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, labels=self._labels, weight=tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, targets=self._labels, weights=tf.constant(w)), + tf.contrib.losses.mean_squared_error( + self._predictions, targets=self._labels, weight=tf.constant(w))) with self.test_session(): - self.assertAlmostEqual(49.5 * weight, loss.eval(), 3) + for loss_tensor in loss_tensors: + self.assertAlmostEqual(49.5 * w, loss_tensor.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): - weight = tf.constant([1.2, 3.4], shape=[2,]) + weights = tf.constant([1.2, 3.4], shape=[2,]) loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) def testNonZeroLossWithTwoDimBatchSpecificWeights(self): - weight = tf.constant([1.2, 3.4], shape=[2, 1]) + weights = tf.constant([1.2, 3.4], shape=[2, 1]) loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeights(self): - weight = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) + weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(587 / 5.0, loss.eval(), 3) def testNonZeroLossWithSampleSpecificWeightsMostZero(self): - weight = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) + weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(18.0, loss.eval(), 3) def testLossWithSampleSpecificWeightsAllZero(self): - weight = tf.zeros((2, 3)) + weights = tf.zeros((2, 3)) loss = tf.contrib.losses.mean_squared_error( - self._predictions, self._targets, weight) + self._predictions, self._labels, weights) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -852,10 +1001,10 @@ class MeanPairwiseSquaresErrorTest(tf.test.TestCase): def setUp(self): self._predictions = np.array([[4, 8, 12], [8, 1, 3]]) - self._targets = np.array([[1, 9, 2], - [-5, -5, 7]]) + self._labels = np.array([[1, 9, 2], + [-5, -5, 7]]) - batch_size, dims = self._targets.shape + batch_size, dims = self._labels.shape # Compute the expected loss 'manually'. total = np.zeros((batch_size, 1)) @@ -863,7 +1012,7 @@ class MeanPairwiseSquaresErrorTest(tf.test.TestCase): for i in range(dims): for j in range(dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() - y = self._targets[b, i].item() - self._targets[b, j].item() + y = self._labels[b, i].item() - self._labels[b, j].item() tmp = (x-y) * (x-y) total[b] += tmp @@ -873,21 +1022,27 @@ class MeanPairwiseSquaresErrorTest(tf.test.TestCase): with self.test_session(): with self.assertRaises(ValueError): tf.contrib.losses.mean_pairwise_squared_error( - predictions=tf.constant(self._targets), - targets=tf.constant(self._targets), + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), + weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), weight=None) def testAllCorrectNoLossWeight(self): loss = tf.contrib.losses.mean_pairwise_squared_error( - predictions=tf.constant(self._targets), - targets=tf.constant(self._targets)) + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels)) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) def testNonZeroLoss(self): loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets)) + labels=tf.constant(self._labels)) with self.test_session(): self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3) @@ -918,93 +1073,121 @@ class MeanPairwiseSquaresErrorTest(tf.test.TestCase): self.assertFalse(np.isnan(np_grad).any()) def testNonZeroLossWithPythonScalarWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=weight) + labels=tf.constant(self._labels), + weights=weights) with self.test_session(): - self.assertAlmostEqual(weight * np.sum(self._expected_losses), + self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss.eval(), 3) def testNonZeroLossWithScalarTensorWeight(self): - weight = 2.3 + weights = 2.3 loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=tf.constant(weight)) + labels=tf.constant(self._labels), + weights=tf.constant(weights)) with self.test_session(): - self.assertAlmostEqual(weight * np.sum(self._expected_losses), + self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss.eval(), 3) def testNonZeroLossWithScalarZeroWeight(self): - weight = 0 + weights = 0 loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=tf.constant(weight)) + labels=tf.constant(self._labels), + weights=tf.constant(weights)) with self.test_session(): self.assertAlmostEqual(0, loss.eval(), 3) def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self): - weight = 2.3 + weights = 2.3 tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape) - tf_targets = tf.placeholder(tf.float32, shape=self._targets.shape) + tf_labels = tf.placeholder(tf.float32, shape=self._labels.shape) loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf_predictions, - targets=tf_targets, - weight=tf.constant(weight)) + labels=tf_labels, + weights=tf.constant(weights)) with self.test_session() as sess: loss = sess.run(loss, feed_dict={ tf_predictions: self._predictions, - tf_targets: self._targets, + tf_labels: self._labels, }) - self.assertAlmostEqual(weight * np.sum(self._expected_losses), loss, 3) + self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss, 3) + + # TODO(b/32171727): Remove when deprecated args are removed. + def testDeprecatedArgs(self): + w = 2.3 + tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape) + tf_labels = tf.placeholder(tf.float32, shape=self._labels.shape) + loss_tensors = ( + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, tf_labels, tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, tf_labels, weights=tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, tf_labels, weight=tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, labels=tf_labels, weights=tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, labels=tf_labels, weight=tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, targets=tf_labels, weights=tf.constant(w)), + tf.contrib.losses.mean_pairwise_squared_error( + tf_predictions, targets=tf_labels, weight=tf.constant(w))) + with self.test_session() as sess: + for loss_tensor in loss_tensors: + loss = sess.run(loss_tensor, feed_dict={ + tf_predictions: self._predictions, + tf_labels: self._labels, + }) + self.assertAlmostEqual(w * np.sum(self._expected_losses), loss, 3) def testNonZeroLossWithOneDimBatchSpecificWeights(self): - weight = np.asarray([2.0, 1.0]).reshape((2, 1)) - expected_losses = np.multiply(weight, self._expected_losses) + weights = np.asarray([2.0, 1.0]).reshape((2, 1)) + expected_losses = np.multiply(weights, self._expected_losses) loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=tf.constant(weight, shape=[2])) + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) with self.test_session(): self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3) def testZeroLossWithOneDimBatchZeroWeights(self): - weight = np.asarray([0.0, 0.0]).reshape((2, 1)) + weights = np.asarray([0.0, 0.0]).reshape((2, 1)) loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=tf.constant(weight, shape=[2])) + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) with self.test_session(): self.assertAlmostEqual(0, loss.eval(), 3) def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self): - weight = np.asarray([1.2, 3.4]).reshape((2, 1)) - expected_losses = np.multiply(weight, self._expected_losses) + weights = np.asarray([1.2, 3.4]).reshape((2, 1)) + expected_losses = np.multiply(weights, self._expected_losses) tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape) - tf_targets = tf.placeholder(tf.int32, shape=self._targets.shape) + tf_labels = tf.placeholder(tf.int32, shape=self._labels.shape) loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf_predictions, - targets=tf_targets, - weight=tf.constant(weight, shape=[2])) + labels=tf_labels, + weights=tf.constant(weights, shape=[2])) with self.test_session() as sess: loss = sess.run(loss, feed_dict={ tf_predictions: self._predictions, - tf_targets: self._targets, + tf_labels: self._labels, }) self.assertAlmostEqual(np.sum(expected_losses), loss, 3) def testLossWithAllZeroBatchSpecificWeights(self): - weight = np.zeros((2, 1)) + weights = np.zeros((2, 1)) loss = tf.contrib.losses.mean_pairwise_squared_error( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), - weight=tf.constant(weight, shape=[2])) + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) with self.test_session(): self.assertAlmostEqual(0.0, loss.eval(), 3) @@ -1019,26 +1202,33 @@ class CosineDistanceLossTest(tf.test.TestCase): [0, 0, -1], # Batch 3 [1, 0, 0]]).reshape((3, 2, 3)) - self._targets = np.asarray([[1, 0, 0], - [0, 0, 1], - [0, 1, 0], - [1, 0, 0], - [0, 0, 1], - [0, 1, 0]]).reshape((3, 2, 3)) + self._labels = np.asarray([[1, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + [0, 1, 0]]).reshape((3, 2, 3)) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): with self.assertRaises(ValueError): tf.contrib.losses.cosine_distance( - predictions=tf.constant(self._targets), - targets=tf.constant(self._targets), + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), + dim=2, + weights=None) + # TODO(b/32171727): Remove when deprecated `weight` is removed. + with self.assertRaises(ValueError): + tf.contrib.losses.cosine_distance( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), dim=2, weight=None) def testAllCorrectNoWeights(self): loss = tf.contrib.losses.cosine_distance( - predictions=tf.constant(self._targets), - targets=tf.constant(self._targets), + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), dim=2) with self.test_session(): self.assertAlmostEqual(0, loss.eval(), 5) @@ -1046,7 +1236,7 @@ class CosineDistanceLossTest(tf.test.TestCase): def testPartiallyCorrectWithIntegerValues(self): loss = tf.contrib.losses.cosine_distance( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2) with self.test_session(): self.assertAlmostEqual(1, loss.eval(), 5) @@ -1056,14 +1246,14 @@ class CosineDistanceLossTest(tf.test.TestCase): '0.819031913261206 0.567041924552012 0.087465312324590;' '-0.665139432070255 -0.739487441769973 -0.103671883216994;' '0.707106781186548 -0.707106781186548 0')) - targets = np.matrix(( + labels = np.matrix(( '0.819031913261206 0.567041924552012 0.087465312324590;' '0.665139432070255 0.739487441769973 0.103671883216994;' '0.707106781186548 0.707106781186548 0')) tf_preds = tf.constant(predictions, shape=(3, 1, 3), dtype=tf.float32) - tf_targets = tf.constant(targets, shape=(3, 1, 3), dtype=tf.float32) - loss = tf.contrib.losses.cosine_distance(tf_preds, tf_targets, dim=2) + tf_labels = tf.constant(labels, shape=(3, 1, 3), dtype=tf.float32) + loss = tf.contrib.losses.cosine_distance(tf_preds, tf_labels, dim=2) with self.test_session(): self.assertAlmostEqual(1.0, loss.eval(), 5) @@ -1071,18 +1261,18 @@ class CosineDistanceLossTest(tf.test.TestCase): def testSampleSpecificWeights(self): loss = tf.contrib.losses.cosine_distance( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.constant([1, 0, 0])) + weights=tf.constant([1, 0, 0])) with self.test_session(): self.assertEqual(1.0, loss.eval()) def testMeasurementSpecificWeights(self): loss = tf.contrib.losses.cosine_distance( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) with self.test_session(): self.assertEqual(3.0 / 4.0, loss.eval()) @@ -1092,17 +1282,17 @@ class CosineDistanceLossTest(tf.test.TestCase): with self.assertRaises(ValueError): tf.contrib.losses.cosine_distance( predictions=tf_predictions, - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) def testMeasurementSpecificWeightsWithPlaceholderWithShape(self): - tf_predictions = tf.placeholder(tf.float32, shape=self._targets.shape) + tf_predictions = tf.placeholder(tf.float32, shape=self._labels.shape) loss = tf.contrib.losses.cosine_distance( predictions=tf_predictions, - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) with self.test_session() as sess: loss = sess.run(loss, feed_dict={tf_predictions: self._predictions}) self.assertEqual(3.0 / 4.0, loss) @@ -1110,28 +1300,45 @@ class CosineDistanceLossTest(tf.test.TestCase): def testZeroLossWhenAllSampleSpecificWeightsAreZero(self): loss = tf.contrib.losses.cosine_distance( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.zeros((3,))) + weights=tf.zeros((3,))) with self.test_session(): self.assertEqual(0, loss.eval()) def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self): loss = tf.contrib.losses.cosine_distance( predictions=tf.constant(self._predictions), - targets=tf.constant(self._targets), + labels=tf.constant(self._labels), dim=2, - weight=tf.zeros((3, 2))) + weights=tf.zeros((3, 2))) with self.test_session(): self.assertEqual(0, loss.eval()) class ComputeWeightedLossTest(tf.test.TestCase): + # TODO(b/32171727): Remove when deprecated `weight` is removed. + def testDeprecatedArgs(self): + losses = (1.2, 0.4, -1.0, -1.1) + expected_loss = -0.5 / 4 + weights = 2.0 + expected_weighted_loss = weights * expected_loss + with self.test_session(): + loss_tensor = tf.contrib.losses.compute_weighted_loss(losses) + self.assertAllClose(expected_loss, loss_tensor.eval(), atol=1e-3) + weighted_loss_tensors = ( + tf.contrib.losses.compute_weighted_loss(losses, weights), + tf.contrib.losses.compute_weighted_loss(losses, weights=weights), + tf.contrib.losses.compute_weighted_loss(losses, weight=weights)) + for weighted_loss_tensor in weighted_loss_tensors: + self.assertAllClose( + expected_weighted_loss, weighted_loss_tensor.eval(), atol=1e-3) + def testHingeLoss(self): logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - target = tf.constant([1.0, 0.0, 0.0, 1.0]) - losses = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + losses = tf.contrib.losses.hinge_loss(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) loss = tf.contrib.losses.compute_weighted_loss(losses) self.assertTrue(tf.contrib.losses.get_losses()) @@ -1144,8 +1351,8 @@ class AddLossTest(tf.test.TestCase): def testAddExternalLoss(self): logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - target = tf.constant([1.0, 0.0, 0.0, 1.0]) - losses = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + losses = tf.contrib.losses.hinge_loss(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) tf.contrib.losses.add_loss(tf.reduce_mean(losses)) self.assertTrue(tf.contrib.losses.get_losses()) @@ -1156,8 +1363,8 @@ class AddLossTest(tf.test.TestCase): def testNoneLossCollection(self): logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - target = tf.constant([1.0, 0.0, 0.0, 1.0]) - losses = tf.contrib.losses.hinge_loss(logits, target) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + losses = tf.contrib.losses.hinge_loss(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) tf.contrib.losses.add_loss(tf.reduce_mean(losses), loss_collection=None) self.assertFalse(tf.contrib.losses.get_losses()) @@ -1166,15 +1373,15 @@ class AddLossTest(tf.test.TestCase): def testNoCollectLosses(self): logits = tf.constant([1.2, 0.4, -1.0, -1.1]) - target = tf.constant([1.0, 0.0, 0.0, 1.0]) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) self.assertFalse(tf.contrib.losses.get_losses()) with tf.contrib.framework.arg_scope([tf.contrib.losses.add_loss], loss_collection=None): - tf.contrib.losses.absolute_difference(logits, target) - tf.contrib.losses.log_loss(logits, target) - tf.contrib.losses.mean_squared_error(logits, target) - tf.contrib.losses.sigmoid_cross_entropy(logits, target) - tf.contrib.losses.softmax_cross_entropy(logits, target) + tf.contrib.losses.absolute_difference(logits, labels) + tf.contrib.losses.log_loss(logits, labels) + tf.contrib.losses.mean_squared_error(logits, labels) + tf.contrib.losses.sigmoid_cross_entropy(logits, labels) + tf.contrib.losses.softmax_cross_entropy(logits, labels) self.assertFalse(tf.contrib.losses.get_losses()) if __name__ == '__main__': |