aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/text
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-29 11:13:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 11:21:09 -0700
commitecdadc84bd5b2eb09a76979c9632d66f7fd08d2c (patch)
tree9cdab8fe81d19771b82cd5b9a7e2637bed471675 /tensorflow/contrib/text
parente87b2be0e68412e6a72e5b7184968e1e0b1f9178 (diff)
Add a `corpus_size` argument to tf.contrib.text.skip_gram_sample_with_text_vocab op. This argument mirrors `corpus_size` argument used by tf.contrib.text.skip_gram_sample op.
It allows vocab subsampling to work properly when vocabulary files have been preprocessed to eliminate all infrequent tokens (where frequency < vocab_min_count) to save the memory for those unused tokens. This preprocessing will help to train skip-gram models on a large vocabulary where memory usage of the internal lookup table could become a performance bottleneck. If `corpus_size` is needed but not supplied, then it will be calculated from `vocab_freq_file`, which is the same as the existing implementation. PiperOrigin-RevId: 166873744
Diffstat (limited to 'tensorflow/contrib/text')
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops.py25
-rw-r--r--tensorflow/contrib/text/python/ops/skip_gram_ops_test.py41
2 files changed, 61 insertions, 5 deletions
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops.py b/tensorflow/contrib/text/python/ops/skip_gram_ops.py
index 410ee517e0..7ed45031a3 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops.py
@@ -216,6 +216,7 @@ def skip_gram_sample_with_text_vocab(input_tensor,
vocab_delimiter=",",
vocab_min_count=0,
vocab_subsampling=None,
+ corpus_size=None,
min_skips=1,
max_skips=5,
start=0,
@@ -267,6 +268,18 @@ def skip_gram_sample_with_text_vocab(input_tensor,
frequently will be randomly down-sampled. Reasonable starting values may
be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
more details.
+ corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
+ total number of tokens in the corpus (e.g., sum of all the frequency
+ counts of `vocab_freq_file`). Used with `vocab_subsampling` for
+ down-sampling frequently occurring tokens. If this is specified,
+ `vocab_freq_file` and `vocab_subsampling` must also be specified.
+ If `corpus_size` is needed but not supplied, then it will be calculated
+ from `vocab_freq_file`. You might want to supply your own value if you
+ have already eliminated infrequent tokens from your vocabulary files
+ (where frequency < vocab_min_count) to save memory in the internal token
+ lookup table. Otherwise, the unused tokens' variables will waste memory.
+ The user-supplied `corpus_size` value must be greater than or equal to the
+ sum of all the frequency counts of `vocab_freq_file`.
min_skips: `int` or scalar `Tensor` specifying the minimum window size to
randomly use for each token. Must be >= 0 and <= `max_skips`. If
`min_skips` and `max_skips` are both 0, the only label outputted will be
@@ -316,7 +329,7 @@ def skip_gram_sample_with_text_vocab(input_tensor,
# Iterates through the vocab file and calculates the number of vocab terms as
# well as the total corpus size (by summing the frequency counts of all the
# vocab terms).
- corpus_size = 0.0
+ calculated_corpus_size = 0.0
vocab_size = 0
with gfile.GFile(vocab_freq_file, mode="r") as f:
reader = csv.reader(f, delimiter=vocab_delimiter)
@@ -334,7 +347,15 @@ def skip_gram_sample_with_text_vocab(input_tensor,
format(freq, row))
# Note: tokens whose frequencies are below vocab_min_count will still
# contribute to the total corpus size used for vocab subsampling.
- corpus_size += freq
+ calculated_corpus_size += freq
+
+ if not corpus_size:
+ corpus_size = calculated_corpus_size
+ elif calculated_corpus_size - corpus_size > 1e-6:
+ raise ValueError(
+ "`corpus_size`={} must be greater than or equal to the sum of all the "
+ "frequency counts ({}) of `vocab_freq_file` ({}).".format(
+ corpus_size, calculated_corpus_size, vocab_freq_file))
vocab_freq_table = lookup.HashTable(
lookup.TextFileInitializer(
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
index d989942f73..84e36146d5 100644
--- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
+++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py
@@ -470,7 +470,7 @@ class SkipGramOpsTest(test.TestCase):
self.assertAllEqual(expected_labels, labels.eval())
def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count,
- vocab_freq_dtype):
+ vocab_freq_dtype, corpus_size=None):
# The outputs are non-deterministic, so set random seed to help ensure that
# the outputs remain constant for testing.
random_seed.set_random_seed(42)
@@ -499,6 +499,7 @@ class SkipGramOpsTest(test.TestCase):
vocab_freq_dtype=vocab_freq_dtype,
vocab_min_count=vocab_min_count,
vocab_subsampling=0.05,
+ corpus_size=corpus_size,
min_skips=1,
max_skips=1,
seed=123)
@@ -523,10 +524,27 @@ class SkipGramOpsTest(test.TestCase):
# the: 30
# to: 20
# universe: 2
+ #
+ # corpus_size for the above vocab is 40+8+30+20+2 = 100.
+ text_vocab_freq_file = self._make_text_vocab_freq_file()
self._text_vocab_subsample_vocab_helper(
- vocab_freq_file=self._make_text_vocab_freq_file(),
+ vocab_freq_file=text_vocab_freq_file,
vocab_min_count=3,
vocab_freq_dtype=dtypes.int64)
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=text_vocab_freq_file,
+ vocab_min_count=3,
+ vocab_freq_dtype=dtypes.int64,
+ corpus_size=100)
+
+ # The user-supplied corpus_size should not be less than the sum of all
+ # the frequency counts of vocab_freq_file, which is 100.
+ with self.assertRaises(ValueError):
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=text_vocab_freq_file,
+ vocab_min_count=3,
+ vocab_freq_dtype=dtypes.int64,
+ corpus_size=99)
def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self):
"""Tests skip-gram sampling with text vocab and subsampling with floats."""
@@ -536,10 +554,27 @@ class SkipGramOpsTest(test.TestCase):
# the: 0.3
# to: 0.2
# universe: 0.02
+ #
+ # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1.
+ text_vocab_float_file = self._make_text_vocab_float_file()
self._text_vocab_subsample_vocab_helper(
- vocab_freq_file=self._make_text_vocab_float_file(),
+ vocab_freq_file=text_vocab_float_file,
vocab_min_count=0.03,
vocab_freq_dtype=dtypes.float32)
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=text_vocab_float_file,
+ vocab_min_count=0.03,
+ vocab_freq_dtype=dtypes.float32,
+ corpus_size=1.0)
+
+ # The user-supplied corpus_size should not be less than the sum of all
+ # the frequency counts of vocab_freq_file, which is 1.
+ with self.assertRaises(ValueError):
+ self._text_vocab_subsample_vocab_helper(
+ vocab_freq_file=text_vocab_float_file,
+ vocab_min_count=0.03,
+ vocab_freq_dtype=dtypes.float32,
+ corpus_size=0.99)
def test_skip_gram_sample_with_text_vocab_errors(self):
"""Tests various errors raised by skip_gram_sample_with_text_vocab()."""