aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-01 11:41:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-01 11:49:17 -0700
commit9fc1642250713f27f520af0da080c388390912c5 (patch)
tree6487d21641e38abc6185561fbe44f57ecab94b95 /tensorflow/contrib/lookup
parent0aa3e01941d231fe313e600eaa5f7cc052c1c077 (diff)
Fix index_table_from_file to allow vocabulary_file be a Tensor
PiperOrigin-RevId: 157740677
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py7
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py20
2 files changed, 23 insertions, 4 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 65474f03fa..e49b62afa2 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -871,7 +871,7 @@ def index_table_from_file(vocabulary_file=None,
```
Args:
- vocabulary_file: The vocabulary filename.
+ vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
@@ -889,8 +889,9 @@ def index_table_from_file(vocabulary_file=None,
ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
than zero.
"""
- if not vocabulary_file:
- raise ValueError("vocabulary_file must be specified.")
+ if vocabulary_file is None or (
+ isinstance(vocabulary_file, str) and not vocabulary_file):
+ raise ValueError("vocabulary_file must be specified and must not be empty.")
if num_oov_buckets < 0:
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 5ec169b6db..180dfefe29 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1187,6 +1187,18 @@ class IndexTableFromFile(test.TestCase):
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
+ def test_string_index_table_from_file_tensor_filename(self):
+ vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
+ with self.test_session():
+ vocabulary_file = constant_op.constant(vocabulary_file)
+ table = lookup.index_table_from_file(
+ vocabulary_file=vocabulary_file, num_oov_buckets=1)
+ ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ lookup_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, 3), ids.eval())
+
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
@@ -1245,7 +1257,13 @@ class IndexTableFromFile(test.TestCase):
860), # 3 + fingerprint("toccata") mod 300.
ids.eval())
- def test_index_table_from_file_with_only_oov_buckets(self):
+ def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
+ self.assertRaises(
+ ValueError,
+ lookup.index_table_from_file,
+ vocabulary_file="")
+
+ def test_index_table_from_file_fails_with_empty_vocabulary(self):
self.assertRaises(
ValueError,
lookup.index_table_from_file,