diff options
author | 2017-10-08 15:50:43 -0700 | |
---|---|---|
committer | 2017-10-08 15:54:34 -0700 | |
commit | cab4f6f615e259546a1c0719a32d019730b2ee71 (patch) | |
tree | 6da822ea7ee2718621dfa9e5e8a3067788ef0476 | |
parent | 3431602bdf00038a87522b3afb08095d20e9a064 (diff) |
Improve invalid size vocab ValueError by appending the vocab file.
This is helpful to identify erroneous vocab file for the common case of training programs with multiple vocabs.
PiperOrigin-RevId: 171476954
-rw-r--r-- | tensorflow/python/kernel_tests/lookup_ops_test.py | 21 | ||||
-rw-r--r-- | tensorflow/python/ops/lookup_ops.py | 7 |
2 files changed, 27 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py index 1d92a08f5c..76c790a0a2 100644 --- a/tensorflow/python/kernel_tests/lookup_ops_test.py +++ b/tensorflow/python/kernel_tests/lookup_ops_test.py @@ -378,6 +378,27 @@ class IndexTableFromFile(test.TestCase): self.assertRaises( ValueError, lookup_ops.index_table_from_file, vocabulary_file=None) + def test_index_table_from_file_str_fails_with_zero_size_vocabulary(self): + vocabulary_file = self._createVocabFile("zero_vocab_str.txt") + self.assertRaisesRegexp( + ValueError, + "vocab_size must be greater than 0, got 0. " + "vocabulary_file: .*zero_vocab_str.txt", + lookup_ops.index_table_from_file, + vocabulary_file=vocabulary_file, + vocab_size=0) + + def test_index_table_from_file_tensor_fails_with_zero_size_vocabulary(self): + vocabulary_file = constant_op.constant( + self._createVocabFile("zero_vocab_tensor.txt")) + self.assertRaisesRegexp( + ValueError, + "vocab_size must be greater than 0, got 0. " + "vocabulary_file: .*zero_vocab_tensor.txt", + lookup_ops.index_table_from_file, + vocabulary_file=vocabulary_file, + vocab_size=0) + def test_index_table_from_file_with_vocab_size_too_small(self): vocabulary_file = self._createVocabFile("f2i_vocab6.txt") with self.test_session(): diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index bbfa38aa17..7f00344be2 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_lookup_ops @@ -927,7 +928,11 @@ def index_table_from_file(vocabulary_file=None, raise ValueError("num_oov_buckets must be greater or equal than 0, got %d." % num_oov_buckets) if vocab_size is not None and vocab_size < 1: - raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size) + vocab_file_value = vocabulary_file + if isinstance(vocabulary_file, ops.Tensor): + vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?" + raise ValueError("vocab_size must be greater than 0, got %d. " + "vocabulary_file: %s" % (vocab_size, vocab_file_value)) if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): raise TypeError("Only integer and string keys are supported.") |