aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nn
diff options
context:
space:
mode:
authorGravatar Dandelion Man? <dandelion@google.com>2017-12-15 18:15:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 18:19:09 -0800
commit90e42f3ac8c43474633136af4242dca04b6a1e09 (patch)
tree64dbb44252c89c847bee86db07cea5aa94072e7c /tensorflow/contrib/nn
parent713d45278491d792c525344de6038a61ebcb2136 (diff)
Automated g4 rollback of changelist 179260538
PiperOrigin-RevId: 179263865
Diffstat (limited to 'tensorflow/contrib/nn')
-rw-r--r--tensorflow/contrib/nn/__init__.py1
-rw-r--r--tensorflow/contrib/nn/python/ops/sampling_ops.py100
2 files changed, 101 insertions, 0 deletions
diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py
index 0bc133a00e..96d60e1498 100644
--- a/tensorflow/contrib/nn/__init__.py
+++ b/tensorflow/contrib/nn/__init__.py
@@ -21,6 +21,7 @@
@@deprecated_flipped_sigmoid_cross_entropy_with_logits
@@nth_element
@@rank_sampled_softmax_loss
+@@sampled_sparse_softmax_loss
@@scaled_softplus
"""
diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops.py b/tensorflow/contrib/nn/python/ops/sampling_ops.py
index 98749cff7e..63fc487dca 100644
--- a/tensorflow/contrib/nn/python/ops/sampling_ops.py
+++ b/tensorflow/contrib/nn/python/ops/sampling_ops.py
@@ -24,6 +24,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_impl
+from tensorflow.python.ops import nn_ops
def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
@@ -240,3 +242,101 @@ def rank_sampled_softmax_loss(weights,
remove_accidental_hits=remove_accidental_hits,
partition_strategy=partition_strategy,
name=name)
+
+
+def sampled_sparse_softmax_loss(weights,
+ biases,
+ labels,
+ inputs,
+ num_sampled,
+ num_classes,
+ sampled_values=None,
+ remove_accidental_hits=True,
+ partition_strategy="mod",
+ name="sampled_sparse_softmax_loss"):
+ """Computes and returns the sampled sparse softmax training loss.
+
+ This is a faster way to train a softmax classifier over a huge number of
+ classes.
+
+ This operation is for training only. It is generally an underestimate of
+ the full softmax loss.
+
+ A common use case is to use this method for training, and calculate the full
+ softmax loss for evaluation or inference. In this case, you must set
+ `partition_strategy="div"` for the two losses to be consistent, as in the
+ following example:
+
+ ```python
+ if mode == "train":
+ loss = tf.nn.sampled_sparse_softmax_loss(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ ...,
+ partition_strategy="div")
+ elif mode == "eval":
+ logits = tf.matmul(inputs, tf.transpose(weights))
+ logits = tf.nn.bias_add(logits, biases)
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=tf.squeeze(labels),
+ logits=logits)
+ ```
+
+ See our [Candidate Sampling Algorithms Reference]
+ (https://www.tensorflow.org/extras/candidate_sampling.pdf)
+
+ Also see Section 3 of [Jean et al., 2014](http://arxiv.org/abs/1412.2007)
+ ([pdf](http://arxiv.org/pdf/1412.2007.pdf)) for the math.
+
+ Args:
+ weights: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
+ objects whose concatenation along dimension 0 has shape
+ [num_classes, dim]. The (possibly-sharded) class embeddings.
+ biases: A `Tensor` of shape `[num_classes]`. The class biases.
+ labels: A `Tensor` of type `int64` and shape `[batch_size, 1]`.
+ The index of the single target class for each row of logits. Note that
+ this format differs from the `labels` argument of
+ `nn.sparse_softmax_cross_entropy_with_logits`.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
+ activations of the input network.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_classes: An `int`. The number of possible classes.
+ sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
+ (if None, we default to `log_uniform_candidate_sampler`)
+ remove_accidental_hits: A `bool`. whether to remove "accidental hits"
+ where a sampled class equals one of the target classes. Default is
+ True.
+ partition_strategy: A string specifying the partitioning strategy, relevant
+ if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
+ Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `batch_size` 1-D tensor of per-example sampled softmax losses.
+
+ """
+ logits, _ = nn_impl._compute_sampled_logits(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_sampled,
+ num_classes=num_classes,
+ num_true=1,
+ sampled_values=sampled_values,
+ subtract_log_q=True,
+ remove_accidental_hits=remove_accidental_hits,
+ partition_strategy=partition_strategy,
+ name=name)
+
+ # There is only one true label. _compute_sampled_logits puts the true logit
+ # at index 0.
+ labels = array_ops.zeros([array_ops.shape(logits)[0], 1], dtype=dtypes.int64)
+
+ sampled_losses = nn_ops.sparse_softmax_cross_entropy_with_logits(
+ labels=array_ops.squeeze(labels), logits=logits)
+ # sampled_losses is a [batch_size] tensor.
+ return sampled_losses