aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-08 15:50:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-08 15:54:34 -0700
commitcab4f6f615e259546a1c0719a32d019730b2ee71 (patch)
tree6da822ea7ee2718621dfa9e5e8a3067788ef0476
parent3431602bdf00038a87522b3afb08095d20e9a064 (diff)
Improve invalid size vocab ValueError by appending the vocab file.
This is helpful to identify erroneous vocab file for the common case of training programs with multiple vocabs. PiperOrigin-RevId: 171476954
-rw-r--r--tensorflow/python/kernel_tests/lookup_ops_test.py21
-rw-r--r--tensorflow/python/ops/lookup_ops.py7
2 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index 1d92a08f5c..76c790a0a2 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -378,6 +378,27 @@ class IndexTableFromFile(test.TestCase):
self.assertRaises(
ValueError, lookup_ops.index_table_from_file, vocabulary_file=None)
+ def test_index_table_from_file_str_fails_with_zero_size_vocabulary(self):
+ vocabulary_file = self._createVocabFile("zero_vocab_str.txt")
+ self.assertRaisesRegexp(
+ ValueError,
+ "vocab_size must be greater than 0, got 0. "
+ "vocabulary_file: .*zero_vocab_str.txt",
+ lookup_ops.index_table_from_file,
+ vocabulary_file=vocabulary_file,
+ vocab_size=0)
+
+ def test_index_table_from_file_tensor_fails_with_zero_size_vocabulary(self):
+ vocabulary_file = constant_op.constant(
+ self._createVocabFile("zero_vocab_tensor.txt"))
+ self.assertRaisesRegexp(
+ ValueError,
+ "vocab_size must be greater than 0, got 0. "
+ "vocabulary_file: .*zero_vocab_tensor.txt",
+ lookup_ops.index_table_from_file,
+ vocabulary_file=vocabulary_file,
+ vocab_size=0)
+
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
with self.test_session():
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index bbfa38aa17..7f00344be2 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_lookup_ops
@@ -927,7 +928,11 @@ def index_table_from_file(vocabulary_file=None,
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)
if vocab_size is not None and vocab_size < 1:
- raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
+ vocab_file_value = vocabulary_file
+ if isinstance(vocabulary_file, ops.Tensor):
+ vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
+ raise ValueError("vocab_size must be greater than 0, got %d. "
+ "vocabulary_file: %s" % (vocab_size, vocab_file_value))
if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
raise TypeError("Only integer and string keys are supported.")