aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/losses
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-28 19:35:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-28 19:43:36 -0800
commitb95017e730b09fd033cc29c3b501205ee04c2b64 (patch)
tree02413fe1230cf563d6a846b1748d85a04f91036d /tensorflow/contrib/losses
parenta51863d3ee28fcda605ec60baeacab309c86daed (diff)
Name ops and vars in head and in various places in contrib/.
Change: 140431877
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py60
1 files changed, 31 insertions, 29 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 7610f9275f..8078d9f51a 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -144,12 +144,13 @@ def _safe_mean(losses, num_present):
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`.", "weight")
def compute_weighted_loss(
- losses, weights=_WEIGHT_SENTINEL, weight=_WEIGHT_SENTINEL):
+ losses, weights=_WEIGHT_SENTINEL, scope=None, weight=_WEIGHT_SENTINEL):
"""Computes the weighted loss.
Args:
losses: A tensor of size [batch_size, d1, ... dN].
weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ scope: the scope for the operations performed in computing the loss.
weight: Deprecated alias for `weights`.
Returns:
@@ -161,27 +162,28 @@ def compute_weighted_loss(
`weights` is missing.
"""
weights = _weights(weights, weight)
- losses = ops.convert_to_tensor(losses)
- input_dtype = losses.dtype
- losses = math_ops.to_float(losses)
- weights = math_ops.to_float(ops.convert_to_tensor(weights))
+ with ops.name_scope(scope, "weighted_loss", [losses, weights]):
+ losses = ops.convert_to_tensor(losses)
+ input_dtype = losses.dtype
+ losses = math_ops.to_float(losses)
+ weights = math_ops.to_float(ops.convert_to_tensor(weights))
- if losses.get_shape().ndims is None:
- raise ValueError("losses.get_shape().ndims cannot be None")
- weights_shape = weights.get_shape()
- if weights_shape.ndims is None:
- raise ValueError("weight.get_shape().ndims cannot be None")
+ if losses.get_shape().ndims is None:
+ raise ValueError("losses.get_shape().ndims cannot be None")
+ weights_shape = weights.get_shape()
+ if weights_shape.ndims is None:
+ raise ValueError("weight.get_shape().ndims cannot be None")
- if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
- weights = array_ops.squeeze(weights, [-1])
+ if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
+ weights = array_ops.squeeze(weights, [-1])
- total_loss = _scale_losses(losses, weights)
- num_present = _num_present(losses, weights)
- mean_loss = _safe_mean(total_loss, num_present)
- # convert the result back to the input type
- mean_loss = math_ops.cast(mean_loss, input_dtype)
- add_loss(mean_loss)
- return mean_loss
+ total_loss = _scale_losses(losses, weights)
+ num_present = _num_present(losses, weights)
+ mean_loss = _safe_mean(total_loss, num_present)
+ # convert the result back to the input type
+ mean_loss = math_ops.cast(mean_loss, input_dtype)
+ add_loss(mean_loss)
+ return mean_loss
def _num_present(losses, weights, per_batch=False):
@@ -334,7 +336,7 @@ def absolute_difference(
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
losses = math_ops.abs(math_ops.sub(predictions, labels))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -373,7 +375,7 @@ def sigmoid_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
- [logits, multi_class_labels, weights]):
+ [logits, multi_class_labels, weights]) as scope:
logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
@@ -384,7 +386,7 @@ def sigmoid_cross_entropy(
losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -421,7 +423,7 @@ def softmax_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "softmax_cross_entropy_loss",
- [logits, onehot_labels, weights]):
+ [logits, onehot_labels, weights]) as scope:
logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
@@ -435,7 +437,7 @@ def softmax_cross_entropy(
losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -468,13 +470,13 @@ def sparse_softmax_cross_entropy(
"""
weights = _weights(weights, weight)
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
- [logits, labels, weights]):
+ [logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
weights = array_ops.squeeze(weights)
losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
name="xentropy")
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -523,7 +525,7 @@ def log_loss(
labels,
math_ops.log(predictions + epsilon)) - math_ops.mul(
(1 - labels), math_ops.log(1 - predictions + epsilon))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -597,7 +599,7 @@ def mean_squared_error(
predictions = math_ops.to_float(predictions)
labels = math_ops.to_float(labels)
losses = math_ops.square(math_ops.sub(predictions, labels))
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)
@deprecated_args(
@@ -732,4 +734,4 @@ def cosine_distance(
radial_diffs = math_ops.mul(predictions, labels)
losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
- return compute_weighted_loss(losses, weights)
+ return compute_weighted_loss(losses, weights, scope=scope)