diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-01 11:41:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-01 11:49:17 -0700 |
commit | 9fc1642250713f27f520af0da080c388390912c5 (patch) | |
tree | 6487d21641e38abc6185561fbe44f57ecab94b95 /tensorflow/contrib/lookup | |
parent | 0aa3e01941d231fe313e600eaa5f7cc052c1c077 (diff) |
Fix index_table_from_file to allow vocabulary_file be a Tensor
PiperOrigin-RevId: 157740677
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 20 |
2 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 65474f03fa..e49b62afa2 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None, ``` Args: - vocabulary_file: The vocabulary filename. + vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`. num_oov_buckets: The number of out-of-vocabulary buckets. vocab_size: Number of the elements in the vocabulary, if known. default_value: The value to use for out-of-vocabulary feature values. @@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None, ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater than zero. """ - if not vocabulary_file: - raise ValueError("vocabulary_file must be specified.") + if vocabulary_file is None or ( + isinstance(vocabulary_file, str) and not vocabulary_file): + raise ValueError("vocabulary_file must be specified and must not be empty.") if num_oov_buckets < 0: raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5ec169b6db..180dfefe29 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase): lookup_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_string_index_table_from_file_tensor_filename(self): + vocabulary_file = self._createVocabFile("f2i_vocab1.txt") + with self.test_session(): + vocabulary_file = constant_op.constant(vocabulary_file) + table = lookup.index_table_from_file( + vocabulary_file=vocabulary_file, num_oov_buckets=1) + ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) + + self.assertRaises(errors_impl.OpError, ids.eval) + lookup_ops.tables_initializer().run() + self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_file(self): vocabulary_file = self._createVocabFile( "f2i_vocab2.txt", values=("42", "1", "-1000")) @@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase): 860), # 3 + fingerprint("toccata") mod 300. ids.eval()) - def test_index_table_from_file_with_only_oov_buckets(self): + def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): + self.assertRaises( + ValueError, + lookup.index_table_from_file, + vocabulary_file="") + + def test_index_table_from_file_fails_with_empty_vocabulary(self): self.assertRaises( ValueError, lookup.index_table_from_file, |