From 600caf99897e82cd0db8665acca5e7630ec1a292 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Mon, 20 Aug 2018 12:37:46 -0700 Subject: Clarify meaning of `num_sampled` in `nce_loss`. Fixes #17949 PiperOrigin-RevId: 209465778 --- tensorflow/python/ops/nn_impl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 51f812b395..2a1919e66f 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1210,7 +1210,9 @@ def nce_loss(weights, num_true]`. The target classes. inputs: A `Tensor` of shape `[batch_size, dim]`. The forward activations of the input network. - num_sampled: An `int`. The number of classes to randomly sample per batch. + num_sampled: An `int`. The number of negative classes to randomly sample + per batch. This single sample of negative classes is evaluated for each + element in the batch. num_classes: An `int`. The number of possible classes. num_true: An `int`. The number of target classes per training example. sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, -- cgit v1.2.3