aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-26 16:53:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 16:59:01 -0800
commitaee7f95a027accc94f1f9130f0cfaecd9399bc1d (patch)
tree6b8484915bf631f18b2fa0561a73549d9bf19fad /tensorflow/contrib/losses
parente95537708f070a98607393a8f60bc61f1611a77b (diff)
Add C0301 line-too-long error to pylint sanity check.
PiperOrigin-RevId: 183467186
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py130
1 files changed, 68 insertions, 62 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 7c523ad492..8c3a8afe7a 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -30,20 +30,13 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_args
-__all__ = ["absolute_difference",
- "add_loss",
- "cosine_distance",
- "compute_weighted_loss",
- "get_losses",
- "get_regularization_losses",
- "get_total_loss",
- "hinge_loss",
- "log_loss",
- "mean_pairwise_squared_error",
- "mean_squared_error",
- "sigmoid_cross_entropy",
- "softmax_cross_entropy",
- "sparse_softmax_cross_entropy"]
+__all__ = [
+ "absolute_difference", "add_loss", "cosine_distance",
+ "compute_weighted_loss", "get_losses", "get_regularization_losses",
+ "get_total_loss", "hinge_loss", "log_loss", "mean_pairwise_squared_error",
+ "mean_squared_error", "sigmoid_cross_entropy", "softmax_cross_entropy",
+ "sparse_softmax_cross_entropy"
+]
def _scale_losses(losses, weights):
@@ -66,8 +59,8 @@ def _scale_losses(losses, weights):
# First, compute the sum of the losses over all elements:
start_index = max(0, weights.get_shape().ndims)
reduction_indices = list(range(start_index, losses.get_shape().ndims))
- reduced_losses = math_ops.reduce_sum(losses,
- reduction_indices=reduction_indices)
+ reduced_losses = math_ops.reduce_sum(
+ losses, reduction_indices=reduction_indices)
reduced_losses = math_ops.multiply(reduced_losses, weights)
return math_ops.reduce_sum(reduced_losses)
@@ -90,9 +83,10 @@ def _safe_div(numerator, denominator, name="value"):
"""
return array_ops.where(
math_ops.greater(denominator, 0),
- math_ops.div(numerator, array_ops.where(
- math_ops.equal(denominator, 0),
- array_ops.ones_like(denominator), denominator)),
+ math_ops.div(numerator,
+ array_ops.where(
+ math_ops.equal(denominator, 0),
+ array_ops.ones_like(denominator), denominator)),
array_ops.zeros_like(numerator),
name=name)
@@ -176,14 +170,15 @@ def _num_present(losses, weights, per_batch=False):
"""
# If weights is a scalar, its easy to compute:
if weights.get_shape().ndims == 0:
- batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
- [0], [1]), [])
- num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
- math_ops.to_float(batch_size))
- num_per_batch = array_ops.where(math_ops.equal(weights, 0),
- 0.0, num_per_batch)
- num_per_batch = math_ops.multiply(array_ops.ones(
- array_ops.reshape(batch_size, [1])), num_per_batch)
+ batch_size = array_ops.reshape(
+ array_ops.slice(array_ops.shape(losses), [0], [1]), [])
+ num_per_batch = math_ops.div(
+ math_ops.to_float(array_ops.size(losses)),
+ math_ops.to_float(batch_size))
+ num_per_batch = array_ops.where(
+ math_ops.equal(weights, 0), 0.0, num_per_batch)
+ num_per_batch = math_ops.multiply(
+ array_ops.ones(array_ops.reshape(batch_size, [1])), num_per_batch)
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
# First, count the number of nonzero weights:
@@ -194,8 +189,8 @@ def _num_present(losses, weights, per_batch=False):
reduction_indices=reduction_indices)
# Next, determine the number of elements that weights would broadcast to:
- broadcast_dims = array_ops.slice(array_ops.shape(losses),
- [weights.get_shape().ndims], [-1])
+ broadcast_dims = array_ops.slice(
+ array_ops.shape(losses), [weights.get_shape().ndims], [-1])
num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast)
@@ -303,8 +298,11 @@ def absolute_difference(predictions, labels=None, weights=1.0, scope=None):
@deprecated("2016-12-30",
"Use tf.losses.sigmoid_cross_entropy instead. Note that the order "
"of the predictions and labels arguments has been changed.")
-def sigmoid_cross_entropy(
- logits, multi_class_labels, weights=1.0, label_smoothing=0, scope=None):
+def sigmoid_cross_entropy(logits,
+ multi_class_labels,
+ weights=1.0,
+ label_smoothing=0,
+ scope=None):
"""Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
`weights` acts as a coefficient for the loss. If a scalar is provided,
@@ -340,20 +338,22 @@ def sigmoid_cross_entropy(
multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
if label_smoothing > 0:
- multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
- 0.5 * label_smoothing)
+ multi_class_labels = (
+ multi_class_labels * (1 - label_smoothing) + 0.5 * label_smoothing)
- losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
- logits=logits,
- name="xentropy")
+ losses = nn.sigmoid_cross_entropy_with_logits(
+ labels=multi_class_labels, logits=logits, name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@deprecated("2016-12-30",
"Use tf.losses.softmax_cross_entropy instead. Note that the order "
"of the logits and labels arguments has been changed.")
-def softmax_cross_entropy(
- logits, onehot_labels, weights=1.0, label_smoothing=0, scope=None):
+def softmax_cross_entropy(logits,
+ onehot_labels,
+ weights=1.0,
+ label_smoothing=0,
+ scope=None):
"""Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
`weights` acts as a coefficient for the loss. If a scalar is provided,
@@ -393,9 +393,8 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
- losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
- logits=logits,
- name="xentropy")
+ losses = nn.softmax_cross_entropy_with_logits(
+ labels=onehot_labels, logits=logits, name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@@ -429,9 +428,8 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
[logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
- losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
- logits=logits,
- name="xentropy")
+ losses = nn.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=logits, name="xentropy")
return compute_weighted_loss(losses, weights, scope=scope)
@@ -470,8 +468,7 @@ def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None):
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
losses = -math_ops.multiply(
- labels,
- math_ops.log(predictions + epsilon)) - math_ops.multiply(
+ labels, math_ops.log(predictions + epsilon)) - math_ops.multiply(
(1 - labels), math_ops.log(1 - predictions + epsilon))
return compute_weighted_loss(losses, weights, scope=scope)
@@ -490,7 +487,8 @@ def hinge_loss(logits, labels=None, scope=None):
scope: The scope for the operations performed in computing the loss.
Returns:
- An unweighted `Tensor` of same shape as `logits` and `labels` representing the
+ An unweighted `Tensor` of same shape as `logits` and `labels` representing
+ the
loss values across the batch.
Raises:
@@ -544,8 +542,10 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None):
@deprecated("2016-12-30",
"Use tf.losses.mean_pairwise_squared_error instead. Note that the "
"order of the predictions and labels arguments has been changed.")
-def mean_pairwise_squared_error(
- predictions, labels=None, weights=1.0, scope=None):
+def mean_pairwise_squared_error(predictions,
+ labels=None,
+ weights=1.0,
+ scope=None):
"""Adds a pairwise-errors-squared loss to the training procedure.
Unlike `mean_squared_error`, which is a measure of the differences between
@@ -602,31 +602,34 @@ def mean_pairwise_squared_error(
reduction_indices = list(range(1, diffs.get_shape().ndims))
sum_squares_diff_per_batch = math_ops.reduce_sum(
- math_ops.square(diffs),
- reduction_indices=reduction_indices)
+ math_ops.square(diffs), reduction_indices=reduction_indices)
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
- term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
- num_present_per_batch)
+ term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
- term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
- math_ops.square(num_present_per_batch))
+ term2 = 2.0 * _safe_div(
+ math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
loss = _scale_losses(term1 - term2, weights)
- mean_loss = array_ops.where(math_ops.reduce_sum(num_present_per_batch) > 0,
- loss,
- array_ops.zeros_like(loss),
- name="value")
+ mean_loss = array_ops.where(
+ math_ops.reduce_sum(num_present_per_batch) > 0,
+ loss,
+ array_ops.zeros_like(loss),
+ name="value")
add_loss(mean_loss)
return mean_loss
@deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.")
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
-def cosine_distance(
- predictions, labels=None, axis=None, weights=1.0, scope=None, dim=None):
+def cosine_distance(predictions,
+ labels=None,
+ axis=None,
+ weights=1.0,
+ scope=None,
+ dim=None):
"""Adds a cosine-distance loss to the training procedure.
Note that the function assumes that `predictions` and `labels` are already
@@ -662,5 +665,8 @@ def cosine_distance(
labels = math_ops.to_float(labels)
radial_diffs = math_ops.multiply(predictions, labels)
- losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[axis,])
+ losses = 1 - math_ops.reduce_sum(
+ radial_diffs, reduction_indices=[
+ axis,
+ ])
return compute_weighted_loss(losses, weights, scope=scope)