diff options
author | 2016-11-16 10:12:45 -0800 | |
---|---|---|
committer | 2016-11-16 10:26:41 -0800 | |
commit | 634790044c81521c438799b558d33b8440fa9e23 (patch) | |
tree | a0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/contrib/lookup | |
parent | eabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff) |
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import.
Change: 139346138
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r-- | tensorflow/contrib/lookup/lookup_ops_test.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5e49a3818f..7f69dcd480 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -984,6 +984,65 @@ class MutableDenseHashTableOpTest(tf.test.TestCase): self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], output.eval()) + def testVectorScalarSaveRestore(self): + save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") + save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") + + with self.test_session(graph=tf.Graph()) as sess: + empty_key = tf.constant([11, 13], tf.int64) + default_value = tf.constant(-1, tf.int64) + keys = tf.constant([[11, 12], [11, 14], [13, 14]], tf.int64) + values = tf.constant([0, 1, 2], tf.int64) + table = tf.contrib.lookup.MutableDenseHashTable( + tf.int64, + tf.int64, + default_value=default_value, + empty_key=empty_key, + name="t2", + checkpoint=True, + initial_num_buckets=32) + + save = tf.train.Saver() + + self.assertAllEqual(0, table.size().eval()) + table.insert(keys, values).run() + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + val = save.save(sess, save_path) + self.assertTrue(isinstance(val, six.string_types)) + self.assertEqual(save_path, val) + + with self.test_session(graph=tf.Graph()) as sess: + empty_key = tf.constant([11, 13], tf.int64) + default_value = tf.constant(-1, tf.int64) + table = tf.contrib.lookup.MutableDenseHashTable( + tf.int64, + tf.int64, + default_value=default_value, + empty_key=empty_key, + name="t2", + checkpoint=True, + initial_num_buckets=64) + table.insert( + tf.constant([[11, 12], [13, 15]], tf.int64), + tf.constant([3, 4], tf.int64)).run() + self.assertAllEqual(2, table.size().eval()) + self.assertAllEqual(64, len(table.export()[0].eval())) + + save = tf.train.Saver() + + # Restore the saved values in the parameter nodes. + save.restore(sess, save_path) + + self.assertAllEqual(3, table.size().eval()) + self.assertAllEqual(32, len(table.export()[0].eval())) + + input_string = tf.constant( + [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], tf.int64) + output = table.lookup(input_string) + self.assertAllEqual([0, 1, -1, 2, -1], output.eval()) + def testReprobe(self): with self.test_session(): # Insert 6 keys into a table with 8 buckets. |