diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-29 11:13:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-29 11:21:09 -0700 |
commit | ecdadc84bd5b2eb09a76979c9632d66f7fd08d2c (patch) | |
tree | 9cdab8fe81d19771b82cd5b9a7e2637bed471675 | |
parent | e87b2be0e68412e6a72e5b7184968e1e0b1f9178 (diff) |
Add a `corpus_size` argument to tf.contrib.text.skip_gram_sample_with_text_vocab op. This argument mirrors `corpus_size` argument used by tf.contrib.text.skip_gram_sample op.
It allows vocab subsampling to work properly when vocabulary files have been preprocessed to eliminate all infrequent tokens (where frequency < vocab_min_count) to save the memory for those unused tokens. This preprocessing will help to train skip-gram models on a large vocabulary where memory usage of the internal lookup table could become a performance bottleneck.
If `corpus_size` is needed but not supplied, then it will be calculated from `vocab_freq_file`, which is the same as the existing implementation.
PiperOrigin-RevId: 166873744
-rw-r--r-- | tensorflow/contrib/text/python/ops/skip_gram_ops.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/text/python/ops/skip_gram_ops_test.py | 41 |
2 files changed, 61 insertions, 5 deletions
diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops.py b/tensorflow/contrib/text/python/ops/skip_gram_ops.py index 410ee517e0..7ed45031a3 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops.py @@ -216,6 +216,7 @@ def skip_gram_sample_with_text_vocab(input_tensor, vocab_delimiter=",", vocab_min_count=0, vocab_subsampling=None, + corpus_size=None, min_skips=1, max_skips=5, start=0, @@ -267,6 +268,18 @@ def skip_gram_sample_with_text_vocab(input_tensor, frequently will be randomly down-sampled. Reasonable starting values may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details. + corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the + total number of tokens in the corpus (e.g., sum of all the frequency + counts of `vocab_freq_file`). Used with `vocab_subsampling` for + down-sampling frequently occurring tokens. If this is specified, + `vocab_freq_file` and `vocab_subsampling` must also be specified. + If `corpus_size` is needed but not supplied, then it will be calculated + from `vocab_freq_file`. You might want to supply your own value if you + have already eliminated infrequent tokens from your vocabulary files + (where frequency < vocab_min_count) to save memory in the internal token + lookup table. Otherwise, the unused tokens' variables will waste memory. + The user-supplied `corpus_size` value must be greater than or equal to the + sum of all the frequency counts of `vocab_freq_file`. min_skips: `int` or scalar `Tensor` specifying the minimum window size to randomly use for each token. Must be >= 0 and <= `max_skips`. If `min_skips` and `max_skips` are both 0, the only label outputted will be @@ -316,7 +329,7 @@ def skip_gram_sample_with_text_vocab(input_tensor, # Iterates through the vocab file and calculates the number of vocab terms as # well as the total corpus size (by summing the frequency counts of all the # vocab terms). - corpus_size = 0.0 + calculated_corpus_size = 0.0 vocab_size = 0 with gfile.GFile(vocab_freq_file, mode="r") as f: reader = csv.reader(f, delimiter=vocab_delimiter) @@ -334,7 +347,15 @@ def skip_gram_sample_with_text_vocab(input_tensor, format(freq, row)) # Note: tokens whose frequencies are below vocab_min_count will still # contribute to the total corpus size used for vocab subsampling. - corpus_size += freq + calculated_corpus_size += freq + + if not corpus_size: + corpus_size = calculated_corpus_size + elif calculated_corpus_size - corpus_size > 1e-6: + raise ValueError( + "`corpus_size`={} must be greater than or equal to the sum of all the " + "frequency counts ({}) of `vocab_freq_file` ({}).".format( + corpus_size, calculated_corpus_size, vocab_freq_file)) vocab_freq_table = lookup.HashTable( lookup.TextFileInitializer( diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py index d989942f73..84e36146d5 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py @@ -470,7 +470,7 @@ class SkipGramOpsTest(test.TestCase): self.assertAllEqual(expected_labels, labels.eval()) def _text_vocab_subsample_vocab_helper(self, vocab_freq_file, vocab_min_count, - vocab_freq_dtype): + vocab_freq_dtype, corpus_size=None): # The outputs are non-deterministic, so set random seed to help ensure that # the outputs remain constant for testing. random_seed.set_random_seed(42) @@ -499,6 +499,7 @@ class SkipGramOpsTest(test.TestCase): vocab_freq_dtype=vocab_freq_dtype, vocab_min_count=vocab_min_count, vocab_subsampling=0.05, + corpus_size=corpus_size, min_skips=1, max_skips=1, seed=123) @@ -523,10 +524,27 @@ class SkipGramOpsTest(test.TestCase): # the: 30 # to: 20 # universe: 2 + # + # corpus_size for the above vocab is 40+8+30+20+2 = 100. + text_vocab_freq_file = self._make_text_vocab_freq_file() self._text_vocab_subsample_vocab_helper( - vocab_freq_file=self._make_text_vocab_freq_file(), + vocab_freq_file=text_vocab_freq_file, vocab_min_count=3, vocab_freq_dtype=dtypes.int64) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=100) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 100. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_freq_file, + vocab_min_count=3, + vocab_freq_dtype=dtypes.int64, + corpus_size=99) def test_skip_gram_sample_with_text_vocab_subsample_vocab_float(self): """Tests skip-gram sampling with text vocab and subsampling with floats.""" @@ -536,10 +554,27 @@ class SkipGramOpsTest(test.TestCase): # the: 0.3 # to: 0.2 # universe: 0.02 + # + # corpus_size for the above vocab is 0.4+0.08+0.3+0.2+0.02 = 1. + text_vocab_float_file = self._make_text_vocab_float_file() self._text_vocab_subsample_vocab_helper( - vocab_freq_file=self._make_text_vocab_float_file(), + vocab_freq_file=text_vocab_float_file, vocab_min_count=0.03, vocab_freq_dtype=dtypes.float32) + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=1.0) + + # The user-supplied corpus_size should not be less than the sum of all + # the frequency counts of vocab_freq_file, which is 1. + with self.assertRaises(ValueError): + self._text_vocab_subsample_vocab_helper( + vocab_freq_file=text_vocab_float_file, + vocab_min_count=0.03, + vocab_freq_dtype=dtypes.float32, + corpus_size=0.99) def test_skip_gram_sample_with_text_vocab_errors(self): """Tests various errors raised by skip_gram_sample_with_text_vocab().""" |