aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/layers/python/layers/embedding_ops_test.py')
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py54
1 files changed, 27 insertions, 27 deletions
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 7ede193029..124515e5a6 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -109,7 +109,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -122,7 +122,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -136,7 +136,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -150,7 +150,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -164,7 +164,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -179,7 +179,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -192,7 +192,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -208,7 +208,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -224,7 +224,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -241,7 +241,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -276,7 +276,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_scattered_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@@ -288,7 +288,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@@ -304,7 +304,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
- with self.test_session():
+ with self.cached_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@@ -316,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -329,7 +329,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@@ -358,7 +358,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -370,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -408,7 +408,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_hashed_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@@ -429,7 +429,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -467,7 +467,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_shape(self):
"""Verifies the shape of the output tensor."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -481,7 +481,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values(self):
"""Verifies the values in a trivial case."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .2, .3])
@@ -495,7 +495,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sampled_candidates(self):
"""Verifies the values for given sampled_candidates."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -520,7 +520,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sign_hash(self):
"""Verifies the values in a trivial case with hash_signs=True."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .1, .1])
@@ -537,7 +537,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_distributive_property(self):
"""Verifies the distributive property of matrix multiplication."""
- with self.test_session():
+ with self.cached_session():
params = constant_op.constant([.1, .2, .3])
sp_values_a = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
@@ -710,7 +710,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
@@ -749,7 +749,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -767,7 +767,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),