diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-08-15 16:42:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-15 16:48:43 -0700 |
commit | 058b499b07e72a3ad21bb4be1a69412a731036bc (patch) | |
tree | 1ffee9883e879abde817faa5639de6aa7a13c3f7 /tensorflow/contrib/lookup | |
parent | 1ffe362c1768cc3fb09f3792e7a44122aa852471 (diff) |
[tf.contrib.lookup] More cleanups and add some simple benchmarks.
PiperOrigin-RevId: 208906068
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 60 |
3 files changed, 61 insertions, 5 deletions
diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index e3928a82a2..83e80f25bc 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -34,6 +34,7 @@ tf_py_test( ":lookup_py", "//third_party/py/numpy", "@six_archive//:six", + "//tensorflow/contrib/data", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index 8c0bfefb30..291972cce3 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -419,11 +419,10 @@ class MutableHashTable(LookupInterface): TypeError: when `keys` or `values` doesn't match the table data types. """ - # pylint: disable=protected-access - lookup_ops._check_table_dtypes(self, keys.dtype, values.dtype) - # pylint: enable=protected-access with ops.name_scope(name, "%s_lookup_table_insert" % self._name, [self._table_ref, keys, values]) as name: + keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys") + values = ops.convert_to_tensor(values, self._value_dtype, name="values") with ops.colocate_with(self._table_ref): # pylint: disable=protected-access op = gen_lookup_ops.lookup_table_insert_v2( diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 6fb5244fc6..81257e1de5 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -23,6 +23,7 @@ import numpy as np import six from tensorflow.contrib import lookup +from tensorflow.contrib.data.python.ops import counter from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -646,11 +647,11 @@ class MutableHashTableOpTest(test.TestCase): default_val) # insert with keys of the wrong type - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): table.insert(constant_op.constant([4, 5, 6]), values).run() # insert with values of the wrong type - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): table.insert(keys, constant_op.constant(["a", "b", "c"])).run() self.assertAllEqual(0, table.size().eval()) @@ -2397,5 +2398,60 @@ class IdTableWithHashBucketsTest(test.TestCase): hasher_spec=lookup.StrongHashSpec([None, 2])) +class MutableHashTableBenchmark(test.Benchmark): + + def _create_table(self): + return lookup.MutableHashTable(dtypes.int64, dtypes.float32, 0.0) + + def benchmark_single_repeated_scalar_insert_scalar(self): + table = self._create_table() + value = variables.Variable(1.0) + insert = table.insert(0, value) + size = table.size() + with session.Session() as sess: + sess.run(value.initializer) + self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) + assert sess.run(size) == 1 + + def benchmark_many_repeated_scalar_insert_scalar(self): + table = self._create_table() + c = counter.Counter().make_one_shot_iterator().get_next() + value = variables.Variable(1.0) + insert = table.insert(c, value) + size = table.size() + with session.Session() as sess: + sess.run(value.initializer) + self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) + assert sess.run(size) >= 10000 + + def benchmark_single_repeated_batch_32_insert_scalar(self): + table = self._create_table() + value = variables.Variable([1.0] * 32) + insert = table.insert(list(range(32)), value) + size = table.size() + with session.Session() as sess: + sess.run(value.initializer) + self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) + assert sess.run(size) == 32 + + def benchmark_many_repeated_batch_32_insert_scalar(self): + table = self._create_table() + c = counter.Counter().make_one_shot_iterator().get_next() + value = variables.Variable([1.0] * 32) + insert = table.insert(32 * c + list(range(32)), value) + size = table.size() + with session.Session() as sess: + sess.run(value.initializer) + self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) + assert sess.run(size) >= 1000*32 + + +class MutableDenseHashTableBenchmark(MutableHashTableBenchmark): + + def _create_table(self): + return lookup.MutableDenseHashTable( + dtypes.int64, dtypes.float32, default_value=0.0, empty_key=-1) + + if __name__ == "__main__": test.main() |