diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/lookup_table_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/lookup_table_op_test.py | 195 |
1 files changed, 195 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py new file mode 100644 index 0000000000..cd170876e6 --- /dev/null +++ b/tensorflow/python/kernel_tests/lookup_table_op_test.py @@ -0,0 +1,195 @@ +"""Tests for lookup table ops from tf.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class HashTableOpTest(tf.test.TestCase): + + def testHashTable(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2], tf.int64) + init = table.initialize_from(keys, values) + init.run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant(['brain', 'salad', 'tank']) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, 1, -1], result) + + def testHashTableInitWithPythonArrays(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + # Empty table. + self.assertAllEqual(0, table.size().eval()) + + # Initialize with keys and values tensors. + keys = ['brain', 'salad', 'surgery'] + values = [0, 1, 2] + init = table.initialize_from(keys, values) + init.run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant(['brain', 'salad', 'tank']) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, 1, -1], result) + + def testHashTableInitWithNumPyArrays(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str) + values = np.array([0, 1, 2], dtype=np.int64) + init = table.initialize_from(keys, values) + init.run() + self.assertAllEqual(3, table.size().eval()) + + input_string = tf.constant(['brain', 'salad', 'tank']) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, 1, -1], result) + + def testMultipleHashTables(self): + with self.test_session() as sess: + shared_name = '' + default_val = -1 + table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2], tf.int64) + table1.initialize_from(keys, values) + table2.initialize_from(keys, values) + table3.initialize_from(keys, values) + + tf.initialize_all_tables().run() + self.assertAllEqual(3, table1.size().eval()) + self.assertAllEqual(3, table2.size().eval()) + self.assertAllEqual(3, table3.size().eval()) + + input_string = tf.constant(['brain', 'salad', 'tank']) + output1 = table1.lookup(input_string) + output2 = table2.lookup(input_string) + output3 = table3.lookup(input_string) + + out1, out2, out3 = sess.run([output1, output2, output3]) + self.assertAllEqual([0, 1, -1], out1) + self.assertAllEqual([0, 1, -1], out2) + self.assertAllEqual([0, 1, -1], out3) + + def testHashTableWithTensorDefault(self): + with self.test_session(): + shared_name = '' + default_val = tf.constant(-1, tf.int64) + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2], tf.int64) + init = table.initialize_from(keys, values) + init.run() + + input_string = tf.constant(['brain', 'salad', 'tank']) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllEqual([0, 1, -1], result) + + def testSignatureMismatch(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2], tf.int64) + init = table.initialize_from(keys, values) + init.run() + + input_string = tf.constant([1, 2, 3], tf.int64) + with self.assertRaises(TypeError): + table.lookup(input_string) + + with self.assertRaises(TypeError): + tf.HashTable(tf.string, tf.int64, 'UNK', shared_name) + + def testDTypes(self): + with self.test_session(): + shared_name = '' + default_val = -1 + with self.assertRaises(TypeError): + tf.HashTable([tf.string], tf.string, default_val, shared_name) + + def testNotInitialized(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + input_string = tf.constant(['brain', 'salad', 'surgery']) + output = table.lookup(input_string) + + with self.assertRaisesOpError('Table not initialized'): + output.eval() + + def testInitializeTwice(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2], tf.int64) + init = table.initialize_from(keys, values) + init.run() + + with self.assertRaisesOpError('Table already initialized'): + init.run() + + def testInitializationWithInvalidDimensions(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = tf.constant(['brain', 'salad', 'surgery']) + values = tf.constant([0, 1, 2, 3, 4], tf.int64) + with self.assertRaises(ValueError): + table.initialize_from(keys, values) + + def testInitializationWithInvalidDataTypes(self): + with self.test_session(): + shared_name = '' + default_val = -1 + table = tf.HashTable(tf.string, tf.int64, default_val, shared_name) + + # Initialize with keys and values tensors. + keys = [0, 1, 2] + values = ['brain', 'salad', 'surgery'] + with self.assertRaises(TypeError): + table.initialize_from(keys, values) + + +if __name__ == '__main__': + tf.test.main() |