diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-13 15:23:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-13 16:49:59 -0700 |
commit | 866de15d53cdd7f249043197cd9caae83fc4f8c6 (patch) | |
tree | 07b7e8da4b684a0bd140e9627cf0ffbc660b2b4b /tensorflow/contrib/lookup | |
parent | 9ee1e726f5eed318cea70faf267750e5f1dd1933 (diff) |
Change public fns to support non-tensor inputs.
Fix `index_table_from_tensor` to handle `int32` keys with no oov buckets.
Add better tests for sparse integer inputs.
Change: 150006782
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 11 |
2 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 6a20ee4440..363bfa6ba2 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -1029,13 +1029,12 @@ def index_table_from_tensor(mapping, name="table_init") table = HashTable( init, default_value, shared_name=shared_name, name=hash_table_scope) - if num_oov_buckets: - table = IdTableWithHashBuckets( - table, - num_oov_buckets=num_oov_buckets, - hasher_spec=hasher_spec, - name=feat_to_id_scope, - key_dtype=dtype) + table = IdTableWithHashBuckets( + table, + num_oov_buckets=num_oov_buckets, + hasher_spec=hasher_spec, + name=feat_to_id_scope, + key_dtype=dtype) return table diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index fe8fa71981..f0a04bd3ea 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -1341,6 +1341,17 @@ class IndexTableFromTensor(test.TestCase): data_flow_ops.tables_initializer().run() self.assertAllEqual((1, 2, 3), ids.eval()) + def test_int32_index_table_from_tensor_with_no_buckets(self): + with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=(42, 1, -1000), dtype=dtypes.int32) + ids = table.lookup( + constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) + + self.assertRaises(errors_impl.OpError, ids.eval) + data_flow_ops.tables_initializer().run() + self.assertAllEqual((1, 2, -1), ids.eval()) + def test_int64_index_table_from_tensor_with_tensor_init(self): with self.test_session(): table = lookup.index_table_from_tensor( |