diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-23 10:32:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 11:45:26 -0700 |
commit | cbc9c5c7af0c63492aa65b418dc84e381f9443d3 (patch) | |
tree | f5e6ae117fd2aab047a458537c8e536ea5fd14f8 /tensorflow/contrib/lookup | |
parent | 3768c139003bdb4a7700e96b89938c4458f3c439 (diff) |
Support int64 to float mapping in MutableHashTable
Change: 151031147
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 20 |
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index fe8fa71981..0ec40a63f2 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -654,7 +654,25 @@ class MutableHashTableOpTest(test.TestCase): output = table.lookup(input_string) result = output.eval() - self.assertAllClose([0, 1.1, -1.5], result) + self.assertAllClose([0, 1.1, default_val], result) + + def testMutableHashTableIntFloat(self): + with self.test_session(): + default_val = -1.0 + keys = constant_op.constant([3, 7, 0], dtypes.int64) + values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) + table = lookup.MutableHashTable(dtypes.int64, dtypes.float32, + default_val) + self.assertAllEqual(0, table.size().eval()) + + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + + input_string = constant_op.constant([7, 0, 11], dtypes.int64) + output = table.lookup(input_string) + + result = output.eval() + self.assertAllClose([-1.2, 9.9, default_val], result) def testMutableHashTableInt64String(self): with self.test_session(): |