aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-13 15:23:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-13 16:49:59 -0700
commit866de15d53cdd7f249043197cd9caae83fc4f8c6 (patch)
tree07b7e8da4b684a0bd140e9627cf0ffbc660b2b4b /tensorflow/contrib/lookup
parent9ee1e726f5eed318cea70faf267750e5f1dd1933 (diff)
Change public fns to support non-tensor inputs.
Fix `index_table_from_tensor` to handle `int32` keys with no oov buckets. Add better tests for sparse integer inputs. Change: 150006782
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py13
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py11
2 files changed, 17 insertions, 7 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 6a20ee4440..363bfa6ba2 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -1029,13 +1029,12 @@ def index_table_from_tensor(mapping,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
- if num_oov_buckets:
- table = IdTableWithHashBuckets(
- table,
- num_oov_buckets=num_oov_buckets,
- hasher_spec=hasher_spec,
- name=feat_to_id_scope,
- key_dtype=dtype)
+ table = IdTableWithHashBuckets(
+ table,
+ num_oov_buckets=num_oov_buckets,
+ hasher_spec=hasher_spec,
+ name=feat_to_id_scope,
+ key_dtype=dtype)
return table
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index fe8fa71981..f0a04bd3ea 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -1341,6 +1341,17 @@ class IndexTableFromTensor(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
+ def test_int32_index_table_from_tensor_with_no_buckets(self):
+ with self.test_session():
+ table = lookup.index_table_from_tensor(
+ mapping=(42, 1, -1000), dtype=dtypes.int32)
+ ids = table.lookup(
+ constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
+
+ self.assertRaises(errors_impl.OpError, ids.eval)
+ data_flow_ops.tables_initializer().run()
+ self.assertAllEqual((1, 2, -1), ids.eval())
+
def test_int64_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
table = lookup.index_table_from_tensor(