diff options
author | Andrew Selle <aselle@google.com> | 2016-12-20 14:12:07 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-20 14:26:56 -0800 |
commit | cb4acf5e47574deccf0c578d6d1d18d74f6117af (patch) | |
tree | 3904556c008dced949672b7d54e1018159917317 /tensorflow/contrib/losses | |
parent | d2b92f24e072d0bb39584a9e6bee049edb49f905 (diff) |
Rename usages of tf.mul, tf.neg, tf.sub that are used internally
Change: 142595367
Diffstat (limited to 'tensorflow/contrib/losses')
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 23 |
1 files changed, 12 insertions, 11 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 49eed50ed1..15956554f7 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -67,7 +67,7 @@ def _scale_losses(losses, weights): reduction_indices = list(range(start_index, losses.get_shape().ndims)) reduced_losses = math_ops.reduce_sum(losses, reduction_indices=reduction_indices) - reduced_losses = math_ops.mul(reduced_losses, weights) + reduced_losses = math_ops.multiply(reduced_losses, weights) return math_ops.reduce_sum(reduced_losses) @@ -181,7 +181,7 @@ def _num_present(losses, weights, per_batch=False): math_ops.to_float(batch_size)) num_per_batch = array_ops.where(math_ops.equal(weights, 0), 0.0, num_per_batch) - num_per_batch = math_ops.mul(array_ops.ones( + num_per_batch = math_ops.multiply(array_ops.ones( array_ops.reshape(batch_size, [1])), num_per_batch) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) @@ -197,7 +197,7 @@ def _num_present(losses, weights, per_batch=False): [weights.get_shape().ndims], [-1]) num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) - num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast) + num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) @@ -295,7 +295,7 @@ def absolute_difference(predictions, labels=None, weights=1.0, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = math_ops.abs(math_ops.sub(predictions, labels)) + losses = math_ops.abs(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope=scope) @@ -458,9 +458,9 @@ def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = -math_ops.mul( + losses = -math_ops.multiply( labels, - math_ops.log(predictions + epsilon)) - math_ops.mul( + math_ops.log(predictions + epsilon)) - math_ops.multiply( (1 - labels), math_ops.log(1 - predictions + epsilon)) return compute_weighted_loss(losses, weights, scope=scope) @@ -487,8 +487,9 @@ def hinge_loss(logits, labels=None, scope=None): # We first need to convert binary labels to -1/1 labels (as floats). labels = math_ops.to_float(labels) all_ones = array_ops.ones_like(labels) - labels = math_ops.sub(2 * labels, all_ones) - return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) + labels = math_ops.subtract(2 * labels, all_ones) + return nn_ops.relu( + math_ops.subtract(all_ones, math_ops.multiply(labels, logits))) @deprecated("2016-12-30", "Use tf.losses.mean_squared_error instead.") @@ -522,7 +523,7 @@ def mean_squared_error(predictions, labels=None, weights=1.0, scope=None): predictions.get_shape().assert_is_compatible_with(labels.get_shape()) predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - losses = math_ops.square(math_ops.sub(predictions, labels)) + losses = math_ops.square(math_ops.subtract(predictions, labels)) return compute_weighted_loss(losses, weights, scope=scope) @@ -574,7 +575,7 @@ def mean_pairwise_squared_error( labels = math_ops.to_float(labels) weights = math_ops.to_float(ops.convert_to_tensor(weights)) - diffs = math_ops.sub(predictions, labels) + diffs = math_ops.subtract(predictions, labels) # Need to verify here since the function doesn't use compute_weighted_loss if diffs.get_shape().ndims is None: @@ -638,6 +639,6 @@ def cosine_distance( predictions = math_ops.to_float(predictions) labels = math_ops.to_float(labels) - radial_diffs = math_ops.mul(predictions, labels) + radial_diffs = math_ops.multiply(predictions, labels) losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) return compute_weighted_loss(losses, weights, scope=scope) |