aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2018-08-20 12:37:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 12:42:36 -0700
commit600caf99897e82cd0db8665acca5e7630ec1a292 (patch)
tree42827783d04dec896fff08fa28d3eb7d640f2934
parent5e1cd6a15f6b65ccb5660714274368b71486c7f6 (diff)
Clarify meaning of `num_sampled` in `nce_loss`.
Fixes #17949 PiperOrigin-RevId: 209465778
-rw-r--r--tensorflow/python/ops/nn_impl.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 51f812b395..2a1919e66f 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -1210,7 +1210,9 @@ def nce_loss(weights,
num_true]`. The target classes.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
- num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_sampled: An `int`. The number of negative classes to randomly sample
+ per batch. This single sample of negative classes is evaluated for each
+ element in the batch.
num_classes: An `int`. The number of possible classes.
num_true: An `int`. The number of target classes per training example.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,