aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py41
1 files changed, 28 insertions, 13 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index 8c83700d51..a2d82cf800 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-
from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework.test_util import TensorFlowTestCase
from tensorflow.python.platform import googletest
@@ -33,16 +33,20 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.test_session():
default_val = -1
empty_key = 0
- keys = tf.constant([11, 12, 13], tf.int64)
- values = tf.constant([0, 1, 2], tf.int64)
+ keys = constant_op.constant([11, 12, 13], dtypes.int64)
+ values = constant_op.constant([0, 1, 2], dtypes.int64)
table = ShardedMutableDenseHashTable(
- tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
+ dtypes.int64,
+ dtypes.int64,
+ default_val,
+ empty_key,
+ num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
- input_string = tf.constant([11, 12, 14], tf.int64)
+ input_string = constant_op.constant([11, 12, 14], dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3], output.get_shape())
self.assertAllEqual([0, 1, -1], output.eval())
@@ -52,16 +56,23 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
with self.test_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
- keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.int64)
- values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]], tf.float32)
+ keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
+ dtypes.int64)
+ values = constant_op.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]],
+ dtypes.float32)
table = ShardedMutableDenseHashTable(
- tf.int64, tf.float32, default_val, empty_key, num_shards=num_shards)
+ dtypes.int64,
+ dtypes.float32,
+ default_val,
+ empty_key,
+ num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()
self.assertAllEqual(3, table.size().eval())
- input_string = tf.constant([[11, 12], [13, 14], [11, 14]], tf.int64)
+ input_string = constant_op.constant([[11, 12], [13, 14], [11, 14]],
+ dtypes.int64)
output = table.lookup(input_string)
self.assertAllEqual([3, 2], output.get_shape())
self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],
@@ -72,10 +83,14 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
empty_key = -2
default_val = -1
num_shards = 2
- keys = tf.constant([10, 11, 12], tf.int64)
- values = tf.constant([2, 3, 4], tf.int64)
+ keys = constant_op.constant([10, 11, 12], dtypes.int64)
+ values = constant_op.constant([2, 3, 4], dtypes.int64)
table = ShardedMutableDenseHashTable(
- tf.int64, tf.int64, default_val, empty_key, num_shards=num_shards)
+ dtypes.int64,
+ dtypes.int64,
+ default_val,
+ empty_key,
+ num_shards=num_shards)
self.assertAllEqual(0, table.size().eval())
table.insert(keys, values).run()