aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lookup
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2016-11-16 10:12:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-16 10:26:41 -0800
commit634790044c81521c438799b558d33b8440fa9e23 (patch)
treea0319de4c600738e4e917a3f9f2a194b08214635 /tensorflow/contrib/lookup
parenteabf41b7cc8ab515b556bc91b4f282d1d671c1a7 (diff)
Fixing a bug in the MutableDenseHashTable implementation where the difference in shapes between the Insert and Import functions was causing issues with a vector key and scalar value input.
Fixed by splitting the LookupInterface CheckKeysAndValueTensors method into one for Insert and the other for Import. Change: 139346138
Diffstat (limited to 'tensorflow/contrib/lookup')
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 5e49a3818f..7f69dcd480 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -984,6 +984,65 @@ class MutableDenseHashTableOpTest(tf.test.TestCase):
self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
output.eval())
+ def testVectorScalarSaveRestore(self):
+ save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
+ save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
+
+ with self.test_session(graph=tf.Graph()) as sess:
+ empty_key = tf.constant([11, 13], tf.int64)
+ default_value = tf.constant(-1, tf.int64)
+ keys = tf.constant([[11, 12], [11, 14], [13, 14]], tf.int64)
+ values = tf.constant([0, 1, 2], tf.int64)
+ table = tf.contrib.lookup.MutableDenseHashTable(
+ tf.int64,
+ tf.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t2",
+ checkpoint=True,
+ initial_num_buckets=32)
+
+ save = tf.train.Saver()
+
+ self.assertAllEqual(0, table.size().eval())
+ table.insert(keys, values).run()
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ val = save.save(sess, save_path)
+ self.assertTrue(isinstance(val, six.string_types))
+ self.assertEqual(save_path, val)
+
+ with self.test_session(graph=tf.Graph()) as sess:
+ empty_key = tf.constant([11, 13], tf.int64)
+ default_value = tf.constant(-1, tf.int64)
+ table = tf.contrib.lookup.MutableDenseHashTable(
+ tf.int64,
+ tf.int64,
+ default_value=default_value,
+ empty_key=empty_key,
+ name="t2",
+ checkpoint=True,
+ initial_num_buckets=64)
+ table.insert(
+ tf.constant([[11, 12], [13, 15]], tf.int64),
+ tf.constant([3, 4], tf.int64)).run()
+ self.assertAllEqual(2, table.size().eval())
+ self.assertAllEqual(64, len(table.export()[0].eval()))
+
+ save = tf.train.Saver()
+
+ # Restore the saved values in the parameter nodes.
+ save.restore(sess, save_path)
+
+ self.assertAllEqual(3, table.size().eval())
+ self.assertAllEqual(32, len(table.export()[0].eval()))
+
+ input_string = tf.constant(
+ [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], tf.int64)
+ output = table.lookup(input_string)
+ self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
+
def testReprobe(self):
with self.test_session():
# Insert 6 keys into a table with 8 buckets.