aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2017-01-04 21:25:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-04 21:46:08 -0800
commit333dc32ff79af21484695157f3d141dc776f7c02 (patch)
treeb379bcaa56bfa54d12ea839fb7e62ab163490743 /tensorflow/python
parentd9541696b068cfcc1fab66b03d0b8d605b64f14d (diff)
Change arg order for {softmax,sparse_softmax,sigmoid}_cross_entropy_with_logits to be (labels, predictions), and force use of named args to avoid accidents.
Change: 143629623
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/sparse_xent_op_test.py19
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py5
-rw-r--r--tensorflow/python/ops/gradient_checker_test.py2
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py9
-rw-r--r--tensorflow/python/ops/nn_impl.py35
-rw-r--r--tensorflow/python/ops/nn_ops.py51
-rw-r--r--tensorflow/python/ops/nn_xent_test.py34
-rw-r--r--tensorflow/python/training/saver_test.py11
8 files changed, 107 insertions, 59 deletions
diff --git a/tensorflow/python/kernel_tests/sparse_xent_op_test.py b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
index ef94af54fe..d2a815a0d7 100644
--- a/tensorflow/python/kernel_tests/sparse_xent_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_xent_op_test.py
@@ -141,25 +141,26 @@ class SparseXentTest(test.TestCase):
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(ValueError, ".*Rank mismatch:*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
- [[0., 1.], [2., 3.], [2., 3.]], [[0, 2]])
+ labels=[[0, 2]], logits=[[0., 1.], [2., 3.], [2., 3.]])
def testScalar(self):
with self.test_session(use_gpu=True):
with self.assertRaisesRegexp(ValueError, ".*Logits cannot be scalars*"):
nn_ops.sparse_softmax_cross_entropy_with_logits(
- constant_op.constant(1.0), constant_op.constant(0))
+ labels=constant_op.constant(0), logits=constant_op.constant(1.0))
def testLabelsPlaceholderScalar(self):
with self.test_session(use_gpu=True):
labels = array_ops.placeholder(np.int32)
- y = nn_ops.sparse_softmax_cross_entropy_with_logits([[7.]], labels)
+ y = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=[[7.]])
with self.assertRaisesOpError("labels must be 1-D"):
y.eval(feed_dict={labels: 0})
def testVector(self):
with self.test_session(use_gpu=True):
loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
- constant_op.constant([1.0]), constant_op.constant(0))
+ labels=constant_op.constant(0), logits=constant_op.constant([1.0]))
self.assertAllClose(0.0, loss.eval())
def testFloat(self):
@@ -191,7 +192,8 @@ class SparseXentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
- x = nn_ops.sparse_softmax_cross_entropy_with_logits(f, l, name="xent")
+ x = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=l, logits=f, name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
@@ -201,7 +203,8 @@ class SparseXentTest(test.TestCase):
# manually reshape loss
np_loss = np.reshape(np_loss, np.array(labels).shape)
with self.test_session(use_gpu=True) as sess:
- loss = nn_ops.sparse_softmax_cross_entropy_with_logits(features, labels)
+ loss = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=labels, logits=features)
backprop = loss.op.inputs[0].op.outputs[1]
tf_loss, tf_backprop = sess.run([loss, backprop])
self.assertAllCloseAccordingToType(np_loss, tf_loss)
@@ -225,7 +228,7 @@ class SparseXentTest(test.TestCase):
labels = array_ops.placeholder(dtypes.int32, shape=[None, 1])
logits = array_ops.placeholder(dtypes.float32, shape=[None, 3])
ce = nn_ops.sparse_softmax_cross_entropy_with_logits(
- logits, array_ops.squeeze(labels))
+ labels=array_ops.squeeze(labels), logits=logits)
labels_v2 = np.zeros((1, 1), dtype=np.int32)
logits_v2 = np.random.randn(1, 3)
sess.run([ce], feed_dict={labels: labels_v2, logits: logits_v2})
@@ -243,7 +246,7 @@ def _sparse_vs_dense_xent_benchmark_dense(labels, logits):
array_ops.stack([length]), 1.0, 0.0)
target = array_ops.reshape(target, array_ops.stack([-1, num_entries]))
crossent = nn_ops.softmax_cross_entropy_with_logits(
- logits, target, name="SequenceLoss/CrossEntropy")
+ labels=target, logits=logits, name="SequenceLoss/CrossEntropy")
crossent_sum = math_ops.reduce_sum(crossent)
grads = gradients_impl.gradients([crossent_sum], [logits])[0]
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index ac56f567ce..e1e0566124 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -57,7 +57,7 @@ class XentTest(test.TestCase):
np_loss, _ = self._npXent(np_features, np_labels, dim=dim)
with self.test_session(use_gpu=use_gpu) as sess:
loss = nn_ops.softmax_cross_entropy_with_logits(
- np_features, np_labels, dim=dim)
+ labels=np_labels, logits=np_features, dim=dim)
tf_loss = sess.run(loss)
print("np_loss:", np_loss)
print("tf_loss:", tf_loss)
@@ -166,7 +166,8 @@ class XentTest(test.TestCase):
shape=[3, 4],
dtype=dtypes.float64,
name="f")
- x = nn_ops.softmax_cross_entropy_with_logits(f, l, name="xent")
+ x = nn_ops.softmax_cross_entropy_with_logits(labels=l, logits=f,
+ name="xent")
err = gradient_checker.compute_gradient_error(f, [3, 4], x, [3])
print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
diff --git a/tensorflow/python/ops/gradient_checker_test.py b/tensorflow/python/ops/gradient_checker_test.py
index 8182352658..3ea8f3798c 100644
--- a/tensorflow/python/ops/gradient_checker_test.py
+++ b/tensorflow/python/ops/gradient_checker_test.py
@@ -267,7 +267,7 @@ class MiniMNISTTest(test.TestCase):
dtype=dtypes.float64,
name="labels")
cost = nn_ops.softmax_cross_entropy_with_logits(
- logits, labels, name="cost")
+ labels=labels, logits=logits, name="cost")
# Test the gradients.
err = gradient_checker.compute_gradient_error(
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 749e46cb77..c23d046d70 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -559,7 +559,8 @@ def sigmoid_cross_entropy(
multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
0.5 * label_smoothing)
- losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
+ losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
+ logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope, loss_collection)
@@ -613,7 +614,8 @@ def softmax_cross_entropy(
smooth_negatives = label_smoothing / num_classes
onehot_labels = onehot_labels * smooth_positives + smooth_negatives
- losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
+ losses = nn.softmax_cross_entropy_with_logits(labels=onehot_labels,
+ logits=logits,
name="xentropy")
return compute_weighted_loss(losses, weights, scope, loss_collection)
@@ -653,7 +655,8 @@ def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
[logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
- losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
+ losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
+ logits=logits,
name="xentropy")
# Reshape losses to [batch_size, 1] to be consistent with weights.
losses = array_ops.reshape(losses, shape=[array_ops.shape(losses)[0], 1])
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 5aba723017..60499c36ec 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -96,7 +96,9 @@ def log_poisson_loss(targets, log_input, compute_full_loss=False, name=None):
return result
-def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
+def sigmoid_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ name=None):
"""Computes sigmoid cross entropy given `logits`.
Measures the probability error in discrete classification tasks in which each
@@ -104,7 +106,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
perform multilabel classification where a picture can contain both an elephant
and a dog at the same time.
- For brevity, let `x = logits`, `z = targets`. The logistic loss is
+ For brevity, let `x = logits`, `z = labels`. The logistic loss is
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
@@ -124,11 +126,12 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
max(x, 0) - x * z + log(1 + exp(-abs(x)))
- `logits` and `targets` must have the same type and shape.
+ `logits` and `labels` must have the same type and shape.
Args:
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
+ labels: A `Tensor` of the same type and shape as `logits`.
logits: A `Tensor` of type `float32` or `float64`.
- targets: A `Tensor` of the same type and shape as `logits`.
name: A name for the operation (optional).
Returns:
@@ -136,16 +139,21 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
logistic losses.
Raises:
- ValueError: If `logits` and `targets` do not have the same shape.
+ ValueError: If `logits` and `labels` do not have the same shape.
"""
- with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
+ # pylint: disable=protected-access
+ nn_ops._ensure_xent_args("sigmoid_cross_entropy_with_logits",
+ _sentinel, labels, logits)
+ # pylint: enable=protected-access
+
+ with ops.name_scope(name, "logistic_loss", [logits, labels]) as name:
logits = ops.convert_to_tensor(logits, name="logits")
- targets = ops.convert_to_tensor(targets, name="targets")
+ labels = ops.convert_to_tensor(labels, name="labels")
try:
- targets.get_shape().merge_with(logits.get_shape())
+ labels.get_shape().merge_with(logits.get_shape())
except ValueError:
- raise ValueError("logits and targets must have the same shape (%s vs %s)"
- % (logits.get_shape(), targets.get_shape()))
+ raise ValueError("logits and labels must have the same shape (%s vs %s)"
+ % (logits.get_shape(), labels.get_shape()))
# The logistic loss formula from above is
# x - x * z + log(1 + exp(-x))
@@ -159,7 +167,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
cond = (logits >= zeros)
relu_logits = array_ops.where(cond, logits, zeros)
neg_abs_logits = array_ops.where(cond, -logits, logits)
- return math_ops.add(relu_logits - logits * targets,
+ return math_ops.add(relu_logits - logits * labels,
math_ops.log1p(math_ops.exp(neg_abs_logits)),
name=name)
@@ -1095,7 +1103,7 @@ def nce_loss(weights,
partition_strategy=partition_strategy,
name=name)
sampled_losses = sigmoid_cross_entropy_with_logits(
- logits, labels, name="sampled_losses")
+ labels=labels, logits=logits, name="sampled_losses")
# sampled_losses is batch_size x {true_loss, sampled_losses...}
# We sum out true and sampled losses.
return _sum_rows(sampled_losses)
@@ -1170,6 +1178,7 @@ def sampled_softmax_loss(weights,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
- sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
+ sampled_losses = nn_ops.softmax_cross_entropy_with_logits(labels=labels,
+ logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index ee199efccf..f5075f0675 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -14,7 +14,6 @@
# ==============================================================================
"""Wrappers for primitive Neural Net (NN) Operations."""
-# pylint: disable=invalid-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -1047,7 +1046,7 @@ def conv2d_transpose(value,
raise ValueError("data_format has to be either NCHW or NHWC.")
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
- axis = 3 if data_format=="NHWC" else 1
+ axis = 3 if data_format == "NHWC" else 1
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[3]):
raise ValueError("input channels does not match filter's input channels, "
"{} != {}".format(value.get_shape()[3], filter.get_shape(
@@ -1528,7 +1527,18 @@ def log_softmax(logits, dim=-1, name=None):
return _softmax(logits, gen_nn_ops._log_softmax, dim, name)
-def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
+def _ensure_xent_args(name, sentinel, labels, logits):
+ # Make sure that all arguments were passed as named arguments.
+ if sentinel is not None:
+ raise ValueError("Only call `%s` with "
+ "named arguments (labels=..., logits=..., ...)" % name)
+ if labels is None or logits is None:
+ raise ValueError("Both labels and logits must be provided.")
+
+
+def softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ dim=-1, name=None):
"""Computes softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@@ -1551,9 +1561,13 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
`logits` and `labels` must have the same shape `[batch_size, num_classes]`
and the same dtype (either `float16`, `float32`, or `float64`).
+ **Note that to avoid confusion, it is required to pass only named arguments to
+ this function.**
+
Args:
- logits: Unscaled log probabilities.
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
labels: Each row `labels[i]` must be a valid probability distribution.
+ logits: Unscaled log probabilities.
dim: The class dimension. Defaulted to -1 which is the last dimension.
name: A name for the operation (optional).
@@ -1561,6 +1575,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
softmax cross entropy loss.
"""
+ _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
+ labels, logits)
+
# TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
# could break users who call this with bad labels, but disregard the bad
# results.
@@ -1569,7 +1586,7 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
labels = ops.convert_to_tensor(labels)
precise_logits = math_ops.cast(logits, dtypes.float32) if (
logits.dtype == dtypes.float16) else logits
- # Labels and logits must be of the same type
+ # labels and logits must be of the same type
labels = math_ops.cast(labels, precise_logits.dtype)
input_rank = array_ops.rank(precise_logits)
# For shape inference.
@@ -1618,7 +1635,9 @@ def softmax_cross_entropy_with_logits(logits, labels, dim=-1, name=None):
return cost
-def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
+def sparse_softmax_cross_entropy_with_logits(_sentinel=None, # pylint: disable=invalid-name
+ labels=None, logits=None,
+ name=None):
"""Computes sparse softmax cross entropy between `logits` and `labels`.
Measures the probability error in discrete classification tasks in which the
@@ -1640,14 +1659,18 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
A common use case is to have logits of shape `[batch_size, num_classes]` and
labels of shape `[batch_size]`. But higher dimensions are supported.
+ **Note that to avoid confusion, it is required to pass only named arguments to
+ this function.**
+
Args:
- logits: Unscaled log probabilities of rank `r` and shape
- `[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
labels: `Tensor` of shape `[d_0, d_1, ..., d_{r-2}]` and dtype `int32` or
`int64`. Each entry in `labels` must be an index in `[0, num_classes)`.
Other values will raise an exception when this op is run on CPU, and
return `NaN` for corresponding corresponding loss and gradient rows
on GPU.
+ logits: Unscaled log probabilities of rank `r` and shape
+ `[d_0, d_1, ..., d_{r-2}, num_classes]` and dtype `float32` or `float64`.
name: A name for the operation (optional).
Returns:
@@ -1658,6 +1681,9 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
of the labels is not equal to the rank of the labels minus one.
"""
+ _ensure_xent_args("sparse_softmax_cross_entropy_with_logits", _sentinel,
+ labels, logits)
+
# TODO(pcmurray) Raise an error when the label is not an index in
# [0, num_classes). Note: This could break users who call this with bad
# labels, but disregard the bad results.
@@ -1679,8 +1705,8 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
if logits.get_shape().ndims is not None and (
labels_static_shape.ndims is not None and
labels_static_shape.ndims != logits.get_shape().ndims - 1):
- raise ValueError("Rank mismatch: Rank of labels (received %s) should equal "
- "rank of logits minus 1 (received %s)." %
+ raise ValueError("Rank mismatch: Rank of labels (received %s) should "
+ "equal rank of logits minus 1 (received %s)." %
(labels_static_shape.ndims, logits.get_shape().ndims))
# Check if no reshapes are required.
if logits.get_shape().ndims == 2:
@@ -1857,8 +1883,7 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name
return bias_add_v1(mm, biases, name=name)
-# pylint: disable=invalid-name
-def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
+def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name
"""Computes dropout.
With probability `keep_prob`, outputs the input element scaled up by
@@ -2082,5 +2107,3 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
rates=rates,
padding=padding,
name=name))
-
-# pylint: enable=invalid-name
diff --git a/tensorflow/python/ops/nn_xent_test.py b/tensorflow/python/ops/nn_xent_test.py
index 3d0157fb89..90f4b40770 100644
--- a/tensorflow/python/ops/nn_xent_test.py
+++ b/tensorflow/python/ops/nn_xent_test.py
@@ -57,7 +57,7 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
with self.test_session():
logits, targets, _ = self._Inputs()
loss = nn_impl.sigmoid_cross_entropy_with_logits(
- logits, targets, name="mylogistic")
+ labels=targets, logits=logits, name="mylogistic")
self.assertEqual("mylogistic", loss.op.name)
def testLogisticOutput(self):
@@ -65,7 +65,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float16]:
with self.test_session(use_gpu=use_gpu):
logits, targets, losses = self._Inputs(dtype=dtype)
- loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
+ loss = nn_impl.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@@ -75,7 +76,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
for dtype in [dtypes.float32, dtypes.float16]:
with self.test_session(use_gpu=use_gpu):
logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
- loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
+ loss = nn_impl.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@@ -84,7 +86,8 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
sizes = [4, 2]
with self.test_session():
logits, targets, _ = self._Inputs(sizes=sizes)
- loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
+ loss = nn_impl.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
@@ -93,13 +96,15 @@ class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
with self.test_session():
logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
- loss = nn_impl.sigmoid_cross_entropy_with_logits(logits, targets)
+ loss = nn_impl.sigmoid_cross_entropy_with_logits(
+ labels=targets, logits=logits)
grads = gradients_impl.gradients(loss, logits)[0].eval()
self.assertAllClose(grads, [0.5, -0.5])
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
- nn_impl.sigmoid_cross_entropy_with_logits([[2, 1]], [1, 2, 3])
+ nn_impl.sigmoid_cross_entropy_with_logits(labels=[1, 2, 3],
+ logits=[[2, 1]])
class WeightedCrossEntropyTest(test.TestCase):
@@ -128,15 +133,15 @@ class WeightedCrossEntropyTest(test.TestCase):
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs()
loss = nn_impl.weighted_cross_entropy_with_logits(
- targets, logits, pos_weight, name="mybce")
+ targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
self.assertEqual("mybce", loss.op.name)
def testOutput(self):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
- loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
- pos_weight)
+ loss = nn_impl.weighted_cross_entropy_with_logits(
+ targets=targets, logits=logits, pos_weight=pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@@ -146,8 +151,8 @@ class WeightedCrossEntropyTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(
dtype=dtypes.float32, sizes=[2, 2, 2])
- loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
- pos_weight)
+ loss = nn_impl.weighted_cross_entropy_with_logits(
+ targets=targets, logits=logits, pos_weight=pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
self.assertAllClose(np_loss, tf_loss, atol=0.001)
@@ -156,15 +161,16 @@ class WeightedCrossEntropyTest(test.TestCase):
sizes = [4, 2]
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
- loss = nn_impl.weighted_cross_entropy_with_logits(targets, logits,
- pos_weight)
+ loss = nn_impl.weighted_cross_entropy_with_logits(
+ targets=targets, logits=logits, pos_weight=pos_weight)
err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
- nn_impl.weighted_cross_entropy_with_logits([1, 2, 3], [[2, 1]], 2.0)
+ nn_impl.weighted_cross_entropy_with_logits(
+ targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
if __name__ == "__main__":
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 68e6bb5e63..2bde726b45 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1614,7 +1614,7 @@ class MetaGraphTest(test.TestCase):
concated, array_ops.stack([batch_size, 10]), 1.0, 0.0)
logits = ops_lib.get_collection("logits")[0]
cross_entropy = nn_ops.softmax_cross_entropy_with_logits(
- logits, onehot_labels, name="xentropy")
+ labels=onehot_labels, logits=logits, name="xentropy")
loss = math_ops.reduce_mean(cross_entropy, name="xentropy_mean")
summary.scalar("loss", loss)
@@ -1698,7 +1698,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
- cost = nn_ops.softmax_cross_entropy_with_logits(logit, label, name="cost")
+ cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
+ logits=logit, name="cost")
adam.AdamOptimizer().minimize(cost, name="optimize")
saver = saver_module.Saver()
sess.run(variables.global_variables_initializer())
@@ -1726,7 +1727,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
- cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
+ cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
+ logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph()
@@ -1758,7 +1760,8 @@ class MetaGraphTest(test.TestCase):
bias = variables.Variable(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
- cost = nn_ops.softmax_cross_entropy_with_logits(logit, label)
+ cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
+ logits=logit)
adam.AdamOptimizer().minimize(cost, name="optimize")
meta_graph_def = saver_module.export_meta_graph(clear_devices=True)
graph_io.write_graph(meta_graph_def, "/tmp", "meta_graph.pbtxt")