diff options
author | 2018-08-20 12:37:46 -0700 | |
---|---|---|
committer | 2018-08-20 12:42:36 -0700 | |
commit | 600caf99897e82cd0db8665acca5e7630ec1a292 (patch) | |
tree | 42827783d04dec896fff08fa28d3eb7d640f2934 | |
parent | 5e1cd6a15f6b65ccb5660714274368b71486c7f6 (diff) |
Clarify meaning of `num_sampled` in `nce_loss`.
Fixes #17949
PiperOrigin-RevId: 209465778
-rw-r--r-- | tensorflow/python/ops/nn_impl.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 51f812b395..2a1919e66f 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1210,7 +1210,9 @@ def nce_loss(weights, num_true]`. The target classes. inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. - num_sampled: An `int`. The number of classes to randomly sample per batch. + num_sampled: An `int`. The number of negative classes to randomly sample + per batch. This single sample of negative classes is evaluated for each + element in the batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, |