aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-08-15 16:42:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-15 16:48:43 -0700
commit058b499b07e72a3ad21bb4be1a69412a731036bc (patch)
tree1ffee9883e879abde817faa5639de6aa7a13c3f7 /tensorflow/contrib/lookup
parent1ffe362c1768cc3fb09f3792e7a44122aa852471 (diff)
[tf.contrib.lookup] More cleanups and add some simple benchmarks.
PiperOrigin-RevId: 208906068
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/BUILD1
-rw-r--r--tensorflow/contrib/lookup/lookup_ops.py5
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py60
3 files changed, 61 insertions, 5 deletions
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD
index e3928a82a2..83e80f25bc 100644
--- a/tensorflow/contrib/lookup/BUILD
+++ b/tensorflow/contrib/lookup/BUILD
@@ -34,6 +34,7 @@ tf_py_test(
":lookup_py",
"//third_party/py/numpy",
"@six_archive//:six",
+ "//tensorflow/contrib/data",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py
index 8c0bfefb30..291972cce3 100644
--- a/tensorflow/contrib/lookup/lookup_ops.py
+++ b/tensorflow/contrib/lookup/lookup_ops.py
@@ -419,11 +419,10 @@ class MutableHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
- # pylint: disable=protected-access
- lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype)
- # pylint: enable=protected-access
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
+ keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
+ values = ops.convert_to_tensor(values, self._value_dtype, name="values")
with ops.colocate_with(self._table_ref):
# pylint: disable=protected-access
op = gen_lookup_ops.lookup_table_insert_v2(
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 6fb5244fc6..81257e1de5 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -23,6 +23,7 @@ import numpy as np
import six
from tensorflow.contrib import lookup
+from tensorflow.contrib.data.python.ops import counter
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -646,11 +647,11 @@ class MutableHashTableOpTest(test.TestCase):
default_val)
# insert with keys of the wrong type
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.insert(constant_op.constant([4, 5, 6]), values).run()
# insert with values of the wrong type
- with self.assertRaises(TypeError):
+ with self.assertRaises(ValueError):
table.insert(keys, constant_op.constant(["a", "b", "c"])).run()
self.assertAllEqual(0, table.size().eval())
@@ -2397,5 +2398,60 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec([None, 2]))
+class MutableHashTableBenchmark(test.Benchmark):
+
+ def _create_table(self):
+ return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0)
+
+ def benchmark_single_repeated_scalar_insert_scalar(self):
+ table = self._create_table()
+ value = variables.Variable(1.0)
+ insert = table.insert(0, value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
+ assert sess.run(size) == 1
+
+ def benchmark_many_repeated_scalar_insert_scalar(self):
+ table = self._create_table()
+ c = counter.Counter().make_one_shot_iterator().get_next()
+ value = variables.Variable(1.0)
+ insert = table.insert(c, value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
+ assert sess.run(size) >= 10000
+
+ def benchmark_single_repeated_batch_32_insert_scalar(self):
+ table = self._create_table()
+ value = variables.Variable([1.0] * 32)
+ insert = table.insert(list(range(32)), value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
+ assert sess.run(size) == 32
+
+ def benchmark_many_repeated_batch_32_insert_scalar(self):
+ table = self._create_table()
+ c = counter.Counter().make_one_shot_iterator().get_next()
+ value = variables.Variable([1.0] * 32)
+ insert = table.insert(32 * c + list(range(32)), value)
+ size = table.size()
+ with session.Session() as sess:
+ sess.run(value.initializer)
+ self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
+ assert sess.run(size) >= 1000*32
+
+
+class MutableDenseHashTableBenchmark(MutableHashTableBenchmark):
+
+ def _create_table(self):
+ return lookup.MutableDenseHashTable(
+ dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1)
+
+
if __name__ == "__main__":
test.main()