aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-23 10:32:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 11:45:26 -0700
commitcbc9c5c7af0c63492aa65b418dc84e381f9443d3 (patch)
treef5e6ae117fd2aab047a458537c8e536ea5fd14f8 /tensorflow/contrib/lookup
parent3768c139003bdb4a7700e96b89938c4458f3c439 (diff)
Support int64 to float mapping in MutableHashTable
Change: 151031147
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py20
1 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index fe8fa71981..0ec40a63f2 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -654,7 +654,25 @@ class MutableHashTableOpTest(test.TestCase):
output = table.lookup(input_string)
result = output.eval()
- self.assertAllClose([0, 1.1, -1.5], result)
+ self.assertAllClose([0, 1.1, default_val], result)
+
+ def testMutableHashTableIntFloat(self):
+ with self.test_session():
+ default_val = -1.0
+ keys = constant_op.constant([3, 7, 0], dtypes.int64)
+ values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
+ table = lookup.MutableHashTable(dtypes.int64, dtypes.float32,
+ default_val)
+ self.assertAllEqual(0, table.size().eval())
+
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+
+ input_string = constant_op.constant([7, 0, 11], dtypes.int64)
+ output = table.lookup(input_string)
+
+ result = output.eval()
+ self.assertAllClose([-1.2, 9.9, default_val], result)
def testMutableHashTableInt64String(self):
with self.test_session():