aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-12 12:44:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-12 12:49:07 -0700
commitc9d03a568a221e96c47ee7d5be703984d61b95a4 (patch)
treec1aed41de4cbc5180cd819c08e7d7fef992121af /tensorflow/contrib/nn
parent9aa0dcbf282f119ac5f53bbb71af40f432bc3be9 (diff)
Add tf.contrib.nn.rank_sampled_softmax_loss, a variant of tf.nn.sampled_softmax_loss that has been shown to improve rank loss. Paper: https://arxiv.org/abs/1707.03073
PiperOrigin-RevId: 161702455
Diffstat (limited to 'tensorflow/contrib/nn')
-rw-r--r--tensorflow/contrib/nn/BUILD21
-rw-r--r--tensorflow/contrib/nn/python/ops/sampling_ops.py243
-rw-r--r--tensorflow/contrib/nn/python/ops/sampling_ops_test.py322
3 files changed, 586 insertions, 0 deletions
diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD
index dbac049d83..af33496e5d 100644
--- a/tensorflow/contrib/nn/BUILD
+++ b/tensorflow/contrib/nn/BUILD
@@ -7,6 +7,8 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
+load("//tensorflow:tensorflow.bzl", "py_test")
+
py_library(
name = "nn_py",
srcs = [
@@ -14,15 +16,34 @@ py_library(
"python/__init__.py",
"python/ops/__init__.py",
"python/ops/cross_entropy.py",
+ "python/ops/sampling_ops.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:util",
],
)
+py_test(
+ name = "sampling_ops_test",
+ size = "small",
+ srcs = ["python/ops/sampling_ops_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":nn_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops.py b/tensorflow/contrib/nn/python/ops/sampling_ops.py
new file mode 100644
index 0000000000..7a9eed511b
--- /dev/null
+++ b/tensorflow/contrib/nn/python/ops/sampling_ops.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops related to candidate sampling."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+
+
+def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
+ resampling_temperature, partition_strategy):
+ """A helper function for rank_sampled_softmax_loss.
+
+ This computes, for each i in `sampled_values`,
+
+ log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
+
+ where w_i, b_i are the weight and bias of the i-th class, repsectively,
+ and j ranges over the rows of `inputs`. For efficiency, we rearrange the
+ computation to
+
+ log(sum_j exp(w_i * (x_j / resampling_temperature))) +
+ b_i / resampling_temperature.
+
+ This translates to the following batched computation using tensorflow ops:
+
+ reduce_logsumexp(matmul(embeddings,
+ transpose(inputs / resampling_temperature))) +
+ biases / resampling_temperature
+
+ The computation of the first term is colocated with the embeddings using
+ `transform_fn` in `embedding_ops._embedding_lookup_and_transform`. The second
+ term, not the bottleneck, is computed at the worker.
+
+ Args:
+ weights: From `rank_sampled_softmax_loss`.
+ biases: From `rank_sampled_softmax_loss`.
+ inputs: From `rank_sampled_softmax_loss`.
+ sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
+ num_resampled: An `int`. This many values are selected from
+ `sampled_values` using the adaptive resampling algorithm. The caller
+ must ensure that `num_resampled` is less than the size of
+ `sampled_values`.
+ resampling_temperature: A scalar `Tensor` with the temperature parameter
+ for the adaptive resampling algorithm.
+ partition_strategy: From `rank_sampled_softmax_loss`.
+
+ Returns:
+ A tuple of (`resampled_candidates`, `true_expected_count`,
+ `resampled_expected_count`), similar to `sampled_values` but sampled
+ down to `num_resampled` values.
+ """
+ # This code supports passing a Tensor for num_resampled, but since it is only
+ # called with an int, that's what we specify in the arg list. If this
+ # function is ever externalized, we should change the doc to support Tensor.
+
+ sampled, true_expected_count, sampled_expected_count = sampled_values
+
+ sampled = math_ops.cast(array_ops.stop_gradient(sampled), dtypes.int64)
+ true_expected_count = array_ops.stop_gradient(true_expected_count)
+ sampled_expected_count = array_ops.stop_gradient(sampled_expected_count)
+
+ reweighted_inputs = inputs / resampling_temperature
+
+ def logsumexp_logit(embeddings):
+ return math_ops.reduce_logsumexp(
+ math_ops.matmul(embeddings, reweighted_inputs, transpose_b=True),
+ axis=1,
+ keep_dims=False)
+
+ # Calling this protected form of embedding_lookup allows co-locating
+ # the logsumexp computation with the partitioned weights, which yields
+ # a large speedup in practice.
+ sampled_logits = embedding_ops._embedding_lookup_and_transform( # pylint: disable=protected-access
+ weights, sampled, partition_strategy, transform_fn=logsumexp_logit)
+ sampled_b = array_ops.reshape(
+ embedding_ops.embedding_lookup(biases, sampled, partition_strategy), [-1])
+ sampled_logits += sampled_b / resampling_temperature
+
+ _, resampled_indices = nn.top_k(sampled_logits, k=num_resampled, sorted=False)
+ resampled = array_ops.gather(sampled, indices=resampled_indices)
+ resampled_expected_count = array_ops.gather(
+ sampled_expected_count, indices=resampled_indices)
+
+ return resampled, true_expected_count, resampled_expected_count
+
+
+# TODO(ccolby): Before checkin, Add reference to TAPAS paper when in arxiv.org.
+def rank_sampled_softmax_loss(weights,
+ biases,
+ labels,
+ inputs,
+ num_sampled,
+ num_resampled,
+ num_classes,
+ num_true,
+ sampled_values,
+ resampling_temperature,
+ remove_accidental_hits,
+ partition_strategy,
+ name=None):
+ """Computes softmax loss using rank-based adaptive resampling.
+
+ This has been shown to improve rank loss after training compared to
+ @{tf.nn.sampled_softmax_loss}. For a description of the algorithm and some
+ experimental results, please see: [TAPAS: Two-pass Approximate Adaptive
+ Sampling for Softmax](https://arxiv.org/abs/1707.03073).
+
+ Sampling follows two phases:
+ * In the first phase, `num_sampled` classes are selected using
+ @{tf.nn.learned_unigram_candidate_sampler} or supplied `sampled_values`.
+ The logits are calculated on those sampled classes. This phases is
+ similar to @{tf.nn.sampled_softmax_loss}.
+ * In the second phase, the `num_resampled` classes with highest predicted
+ probability are kept. Probabilities are
+ `LogSumExp(logits / resampling_temperature)`, where the sum is over
+ `inputs`.
+
+ The `resampling_temperature` parameter controls the "adaptiveness" of the
+ resampling. At lower temperatures, resampling is more adaptive because it
+ picks more candidates close to the predicted classes. A common strategy is
+ to decrease the temperature as training proceeds.
+
+ See @{tf.nn.sampled_softmax_loss} for more documentation on sampling and
+ for typical default values for some of the parameters.
+
+ This operation is for training only. It is generally an underestimate of
+ the full softmax loss.
+
+ A common use case is to use this method for training, and calculate the full
+ softmax loss for evaluation or inference. In this case, you must set
+ `partition_strategy="div"` for the two losses to be consistent, as in the
+ following example:
+
+ ```python
+ if mode == "train":
+ loss = rank_sampled_softmax_loss(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ ...,
+ partition_strategy="div")
+ elif mode == "eval":
+ logits = tf.matmul(inputs, tf.transpose(weights))
+ logits = tf.nn.bias_add(logits, biases)
+ labels_one_hot = tf.one_hot(labels, n_classes)
+ loss = tf.nn.softmax_cross_entropy_with_logits(
+ labels=labels_one_hot,
+ logits=logits)
+ ```
+
+ Args:
+ weights: A `Tensor` or `PartitionedVariable` of shape `[num_classes, dim]`,
+ or a list of `Tensor` objects whose concatenation along dimension 0
+ has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
+ biases: A `Tensor` or `PartitionedVariable` of shape `[num_classes]`.
+ The (possibly-sharded) class biases.
+ labels: A `Tensor` of type `int64` and shape `[batch_size,
+ num_true]`. The target classes. Note that this format differs from
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
+ activations of the input network.
+ num_sampled: An `int`. The number of classes to randomly sample per batch.
+ num_resampled: An `int`. The number of classes to select from the
+ `num_sampled` classes using the adaptive resampling algorithm. Must be
+ less than `num_sampled`.
+ num_classes: An `int`. The number of possible classes.
+ num_true: An `int`. The number of target classes per training example.
+ sampled_values: A tuple of (`sampled_candidates`, `true_expected_count`,
+ `sampled_expected_count`) returned by a `*_candidate_sampler` function.
+ If None, default to `nn.learned_unigram_candidate_sampler`.
+ resampling_temperature: A scalar `Tensor` with the temperature parameter
+ for the adaptive resampling algorithm.
+ remove_accidental_hits: A `bool`. Whether to remove "accidental hits"
+ where a sampled class equals one of the target classes.
+ partition_strategy: A string specifying the partitioning strategy, relevant
+ if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
+ See @{tf.nn.embedding_lookup} for more details.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `batch_size` 1-D tensor of per-example sampled softmax losses.
+
+ Raises:
+ ValueError: If `num_sampled <= num_resampled`.
+ """
+ if num_sampled > num_classes:
+ raise ValueError("num_sampled ({}) cannot be greater than num_classes ({})".
+ format(num_sampled, num_classes))
+ if num_sampled <= num_resampled:
+ raise ValueError("num_resampled ({}) must be less than num_sampled ({})".
+ format(num_resampled, num_sampled))
+ if partition_strategy not in ("div", "mod"):
+ raise ValueError(
+ "unsupported partition_strategy ({})".format(partition_strategy))
+ with ops.name_scope(name, "rank_sampled_softmax_loss", [
+ weights, biases, labels, inputs, sampled_values, resampling_temperature
+ ]) as name:
+ if not sampled_values:
+ sampled_values = nn.learned_unigram_candidate_sampler(
+ true_classes=labels,
+ num_true=num_true,
+ num_sampled=num_sampled,
+ unique=True,
+ range_max=num_classes)
+ # From sampled_values, select the top num_resampled values using the
+ # adaptive rank resampling strategy.
+ resampled_values = _rank_resample(weights, biases, inputs, sampled_values,
+ num_resampled, resampling_temperature,
+ partition_strategy)
+ return nn.sampled_softmax_loss(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_resampled,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=resampled_values,
+ remove_accidental_hits=remove_accidental_hits,
+ partition_strategy=partition_strategy,
+ name=name)
diff --git a/tensorflow/contrib/nn/python/ops/sampling_ops_test.py b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py
new file mode 100644
index 0000000000..1d4fe1321b
--- /dev/null
+++ b/tensorflow/contrib/nn/python/ops/sampling_ops_test.py
@@ -0,0 +1,322 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sampling_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.nn.python.ops import sampling_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import test
+
+
+class RankSampledSoftmaxLossTest(test.TestCase):
+
+ def setUp(self):
+ self._sampled = [3, 4, 5, 6, 7]
+ self._num_sampled = len(self._sampled)
+ # Because values of all matrices increase with indices, logits increase with
+ # class id. So, for the above sampled classes, adaptive sampling will select
+ # these resampled classes.
+ self._resampled = [5, 6, 7]
+ self._num_resampled = len(self._resampled)
+ self._num_classes = 10
+ self._num_true = 2
+ self._sampled_values = (self._sampled, [[0.5], [0.5]],
+ [0.5, 0.5, 0.5, 0.5, 0.5])
+ self._resampled_values = (self._resampled, [[0.5], [0.5]], [0.5, 0.5, 0.5])
+ self._remove_accidental_hits = False
+ self._embed_dim = 5
+ self._batch_size = 2
+
+ def _weights(self):
+ return constant_op.constant([
+ [0.0, 0.1, 0.2, 0.3, 0.4],
+ [1.0, 1.1, 1.2, 1.3, 1.4],
+ [2.0, 2.1, 2.2, 2.3, 2.4],
+ [3.0, 3.1, 3.2, 3.3, 3.4],
+ [4.0, 4.1, 4.2, 4.3, 4.4],
+ [5.0, 5.1, 5.2, 5.3, 5.4],
+ [6.0, 6.1, 6.2, 6.3, 6.4],
+ [7.0, 7.1, 7.2, 7.3, 7.4],
+ [8.0, 8.1, 8.2, 8.3, 8.4],
+ [9.0, 9.1, 9.2, 9.3, 9.4],
+ ])
+
+ def _div_sharded_weights(self):
+ return [
+ constant_op.constant([
+ [0.0, 0.1, 0.2, 0.3, 0.4],
+ [1.0, 1.1, 1.2, 1.3, 1.4],
+ ]),
+ constant_op.constant([
+ [2.0, 2.1, 2.2, 2.3, 2.4],
+ [3.0, 3.1, 3.2, 3.3, 3.4],
+ ]),
+ constant_op.constant([
+ [4.0, 4.1, 4.2, 4.3, 4.4],
+ [5.0, 5.1, 5.2, 5.3, 5.4],
+ ]),
+ constant_op.constant([
+ [6.0, 6.1, 6.2, 6.3, 6.4],
+ [7.0, 7.1, 7.2, 7.3, 7.4],
+ ]),
+ constant_op.constant([
+ [8.0, 8.1, 8.2, 8.3, 8.4],
+ [9.0, 9.1, 9.2, 9.3, 9.4],
+ ]),
+ ]
+
+ def _mod_sharded_weights(self):
+ return [
+ constant_op.constant([
+ [0.0, 0.1, 0.2, 0.3, 0.4],
+ [5.0, 5.1, 5.2, 5.3, 5.4],
+ ]),
+ constant_op.constant([
+ [1.0, 1.1, 1.2, 1.3, 1.4],
+ [6.0, 6.1, 6.2, 6.3, 6.4],
+ ]),
+ constant_op.constant([
+ [2.0, 2.1, 2.2, 2.3, 2.4],
+ [7.0, 7.1, 7.2, 7.3, 7.4],
+ ]),
+ constant_op.constant([
+ [3.0, 3.1, 3.2, 3.3, 3.4],
+ [8.0, 8.1, 8.2, 8.3, 8.4],
+ ]),
+ constant_op.constant([
+ [4.0, 4.1, 4.2, 4.3, 4.4],
+ [9.0, 9.1, 9.2, 9.3, 9.4],
+ ]),
+ ]
+
+ def _biases(self):
+ return constant_op.constant(
+ [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
+
+ def _div_sharded_biases(self):
+ return [
+ constant_op.constant([0.0, 0.1]),
+ constant_op.constant([0.2, 0.3]),
+ constant_op.constant([0.4, 0.5]),
+ constant_op.constant([0.6, 0.7]),
+ constant_op.constant([0.8, 0.9]),
+ ]
+
+ def _mod_sharded_biases(self):
+ return [
+ constant_op.constant([0.0, 0.5]),
+ constant_op.constant([0.1, 0.6]),
+ constant_op.constant([0.2, 0.7]),
+ constant_op.constant([0.3, 0.8]),
+ constant_op.constant([0.4, 0.9]),
+ ]
+
+ def _labels(self):
+ return constant_op.constant(
+ [[0, 1], [1, 2]],
+ shape=(self._batch_size, self._num_true),
+ name='labels',
+ dtype=dtypes.int64)
+
+ def _inputs(self):
+ return constant_op.constant(
+ [
+ [0., 1., 2., 3., 4.],
+ [10., 11., 12., 13., 14.],
+ ],
+ shape=(self._batch_size, self._embed_dim),
+ name='inputs')
+
+ def testInvalidNumSampled0(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'num_resampled \(3\) must be less than num_sampled \(3\)'):
+ sampling_ops.rank_sampled_softmax_loss(
+ weights=self._weights(),
+ biases=self._biases(),
+ labels=self._labels(),
+ inputs=self._inputs(),
+ num_sampled=3,
+ num_resampled=3,
+ num_classes=self._num_classes,
+ num_true=self._num_true,
+ sampled_values=None,
+ resampling_temperature=1.,
+ remove_accidental_hits=True,
+ partition_strategy='div')
+
+ def testInvalidNumSampled1(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ ValueError,
+ r'num_resampled \(3\) must be less than num_sampled \(2\)'):
+ sampling_ops.rank_sampled_softmax_loss(
+ weights=self._weights(),
+ biases=self._biases(),
+ labels=self._labels(),
+ inputs=self._inputs(),
+ num_sampled=2,
+ num_resampled=3,
+ num_classes=self._num_classes,
+ num_true=self._num_true,
+ sampled_values=None,
+ resampling_temperature=1.,
+ remove_accidental_hits=True,
+ partition_strategy='div')
+
+ def testMissingPartitionStrategy(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(ValueError,
+ r'unsupported partition_strategy \(None\)'):
+ sampling_ops.rank_sampled_softmax_loss(
+ weights=self._weights(),
+ biases=self._biases(),
+ labels=self._labels(),
+ inputs=self._inputs(),
+ num_sampled=2,
+ num_resampled=1,
+ num_classes=self._num_classes,
+ num_true=self._num_true,
+ sampled_values=None,
+ resampling_temperature=1.,
+ remove_accidental_hits=True,
+ partition_strategy=None)
+
+ def _testCompareWithNN(self, weights, biases, partition_strategy):
+ with ops.Graph().as_default():
+ loss = sampling_ops.rank_sampled_softmax_loss(
+ weights=weights(),
+ biases=biases(),
+ labels=self._labels(),
+ inputs=self._inputs(),
+ num_sampled=self._num_sampled,
+ num_resampled=self._num_resampled,
+ num_classes=self._num_classes,
+ num_true=self._num_true,
+ sampled_values=self._sampled_values,
+ resampling_temperature=1.,
+ remove_accidental_hits=self._remove_accidental_hits,
+ partition_strategy=partition_strategy)
+ loss_nn = nn.sampled_softmax_loss(
+ weights=weights(),
+ biases=biases(),
+ labels=self._labels(),
+ inputs=self._inputs(),
+ num_sampled=self._num_resampled,
+ num_classes=self._num_classes,
+ num_true=self._num_true,
+ sampled_values=self._resampled_values,
+ remove_accidental_hits=self._remove_accidental_hits,
+ partition_strategy=partition_strategy)
+ with self.test_session() as sess:
+ loss_val = sess.run(loss)
+ loss_nn_val = sess.run(loss_nn)
+
+ self.assertAllClose(loss_val, loss_nn_val)
+
+ def testCompareWithNNUnsharded(self):
+ self._testCompareWithNN(self._weights, self._biases, 'div')
+
+ def testCompareWithNNShardWeightsDiv(self):
+ self._testCompareWithNN(self._div_sharded_weights, self._biases, 'div')
+
+ def testCompareWithNNShardWeightsAndBiasesDiv(self):
+ self._testCompareWithNN(self._div_sharded_weights, self._div_sharded_biases,
+ 'div')
+
+ def testCompareWithNNShardWeightsMod(self):
+ self._testCompareWithNN(self._mod_sharded_weights, self._biases, 'mod')
+
+ def testCompareWithNNShardWeightsAndBiasesMod(self):
+ self._testCompareWithNN(self._mod_sharded_weights, self._mod_sharded_biases,
+ 'mod')
+
+ def _testCompareWithNNTemperature(self, temperature, resampled):
+ weights = [[1., 2.], [3., 4.]] # two sampled classes
+ inputs = [[6., -5. / 2.], [-11., 21. / 2.]]
+ # Let w0, w1 = weights of sampled classes (biases set to 0 for simplicity)
+ # Let x0, x1 = inputs
+ # logits:
+ # w0.x0 = 1
+ # w0.x1 = 10
+ # w1.x0 = 8
+ # w1.x1 = 9
+ # Resampling 1 class with temperature = t will pick the larger of:
+ # exp(1/t) + exp(10/t) ==> w0, for values of t < 2.12
+ # exp(8/t) + exp(9/t) ==> w1, for values of t > 2.13
+ num_sampled = 2
+ num_resampled = 1
+ num_classes = 2
+ num_true = 1
+ sampled_values = [0, 1], [[1.], [1.]], [1., 1.]
+ resampled_values = [resampled], [[1.], [1.]], [1.]
+ remove_accidental_hits = False
+ with ops.Graph().as_default():
+ weights = constant_op.constant(weights)
+ biases = constant_op.constant([0., 0.])
+ labels = constant_op.constant([[0], [1]], dtype=dtypes.int64)
+ inputs = constant_op.constant(inputs)
+ loss = sampling_ops.rank_sampled_softmax_loss(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_sampled,
+ num_resampled=num_resampled,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=sampled_values,
+ resampling_temperature=constant_op.constant(temperature),
+ remove_accidental_hits=remove_accidental_hits,
+ partition_strategy='div')
+ loss_nn = nn.sampled_softmax_loss(
+ weights=weights,
+ biases=biases,
+ labels=labels,
+ inputs=inputs,
+ num_sampled=num_resampled,
+ num_classes=num_classes,
+ num_true=num_true,
+ sampled_values=resampled_values,
+ remove_accidental_hits=remove_accidental_hits,
+ partition_strategy='div')
+ with self.test_session() as sess:
+ loss_val = sess.run(loss)
+ loss_nn_val = sess.run(loss_nn)
+
+ self.assertAllClose(loss_val, loss_nn_val)
+
+ def testCompareWithNNTemperatureLo1(self):
+ self._testCompareWithNNTemperature(1., 0)
+
+ def testCompareWithNNTemperatureLo2(self):
+ self._testCompareWithNNTemperature(2.12, 0)
+
+ def testCompareWithNNTemperatureHi1(self):
+ self._testCompareWithNNTemperature(2.13, 1)
+
+ def testCompareWithNNTemperatureHi2(self):
+ self._testCompareWithNNTemperature(3., 1)
+
+
+if __name__ == '__main__':
+ test.main()