aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_impl.py')
-rw-r--r--tensorflow/python/ops/nn_impl.py36
1 files changed, 18 insertions, 18 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index be46bf305a..6e7c6efb4c 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -809,8 +809,8 @@ def _sum_rows(x):
def _compute_sampled_logits(weights,
biases,
- inputs,
labels,
+ inputs,
num_sampled,
num_classes,
num_true=1,
@@ -834,11 +834,11 @@ def _compute_sampled_logits(weights,
objects whose concatenation along dimension 0 has shape
`[num_classes, dim]`. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
- inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
- activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
+ activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
@@ -975,8 +975,8 @@ def _compute_sampled_logits(weights,
def nce_loss(weights,
biases,
- inputs,
labels,
+ inputs,
num_sampled,
num_classes,
num_true=1,
@@ -1012,10 +1012,10 @@ def nce_loss(weights,
objects whose concatenation along dimension 0 has shape
[num_classes, dim]. The (possibly-partitioned) class embeddings.
biases: A `Tensor` of shape `[num_classes]`. The class biases.
- inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
- activations of the input network.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
+ activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
@@ -1038,12 +1038,12 @@ def nce_loss(weights,
A `batch_size` 1-D tensor of per-example NCE losses.
"""
logits, labels = _compute_sampled_logits(
- weights,
- biases,
- inputs,
- labels,
- num_sampled,
- num_classes,
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_sampled,
+ num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,
@@ -1114,12 +1114,12 @@ def sampled_softmax_loss(weights,
"""
logits, labels = _compute_sampled_logits(
- weights,
- biases,
- inputs,
- labels,
- num_sampled,
- num_classes,
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_sampled,
+ num_classes=num_classes,
num_true=num_true,
sampled_values=sampled_values,
subtract_log_q=True,