aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:24:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 19:27:54 -0700
commit496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (patch)
tree9776c9865f7b98a15817bc6be4c2b683323a67b1 /tensorflow/python/feature_column
parent361a82d73a50a800510674b3aaa20e4845e56434 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209701635
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py48
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py48
2 files changed, 48 insertions, 48 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 6be930be87..9d2babc6e0 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -262,7 +262,7 @@ class NumericColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_parse_example_with_default_value(self):
@@ -284,7 +284,7 @@ class NumericColumnTest(test.TestCase):
no_data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
def test_normalizer_fn_must_be_callable(self):
@@ -298,7 +298,7 @@ class NumericColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price])
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
def test_get_dense_tensor(self):
@@ -433,7 +433,7 @@ class BucketizedColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([bucketized_price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
@@ -700,7 +700,7 @@ class HashedCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -719,7 +719,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
- with self.test_session():
+ with self.cached_session():
self.assertEqual(dtypes.int64, output.values.dtype)
self.assertAllEqual(expected_values, output.values.eval())
self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
@@ -775,7 +775,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = builder.get(hashed_sparse)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_int32_64_is_compatible(self):
@@ -789,7 +789,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = builder.get(hashed_sparse)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_get_sparse_tensors(self):
@@ -984,7 +984,7 @@ class CrossedColumnTest(test.TestCase):
features=fc.make_parse_example_spec([price_cross_wire]))
self.assertIn('price', features)
self.assertIn('wire', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
wire_sparse = features['wire']
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
@@ -1007,7 +1007,7 @@ class CrossedColumnTest(test.TestCase):
}
outputs = _transform_features(features, [price_cross_wire])
output = outputs[price_cross_wire]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
@@ -3262,7 +3262,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_vocabulary_size(self):
@@ -3286,7 +3286,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_num_oov_buckets(self):
@@ -3350,7 +3350,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3775,7 +3775,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3797,7 +3797,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4096,7 +4096,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4365,7 +4365,7 @@ class IndicatorColumnTest(test.TestCase):
fc.categorical_column_with_hash_bucket('animal', 4))
builder = _LazyBuilder({'animal': ['fox', 'fox']})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_2D_shape_succeeds(self):
@@ -4380,7 +4380,7 @@ class IndicatorColumnTest(test.TestCase):
dense_shape=[2, 1])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_multi_hot(self):
@@ -4393,7 +4393,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
def test_multi_hot2(self):
@@ -4405,7 +4405,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
})
output = builder.get(animal)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
def test_deep_copy(self):
@@ -4430,7 +4430,7 @@ class IndicatorColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_indicator]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4641,7 +4641,7 @@ class EmbeddingColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_embedded]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5407,7 +5407,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_embedded, b_embedded]))
self.assertIn('aaa', features)
self.assertIn('bbb', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5990,7 +5990,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_weighted]))
self.assertIn('aaa', features)
self.assertIn('weights', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 80a9d5d40e..ad578d287a 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -269,7 +269,7 @@ class NumericColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_parse_example_with_default_value(self):
@@ -291,7 +291,7 @@ class NumericColumnTest(test.TestCase):
no_data.SerializeToString()],
features=fc.make_parse_example_spec([price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
def test_normalizer_fn_must_be_callable(self):
@@ -305,7 +305,7 @@ class NumericColumnTest(test.TestCase):
price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price], None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
def test_get_dense_tensor(self):
@@ -439,7 +439,7 @@ class BucketizedColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([bucketized_price]))
self.assertIn('price', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
def test_transform_feature(self):
@@ -717,7 +717,7 @@ class HashedCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -736,7 +736,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = outputs[hashed_sparse]
# Check exact hashed output. If hashing changes this test will break.
expected_values = [6, 4, 1]
- with self.test_session():
+ with self.cached_session():
self.assertEqual(dtypes.int64, output.values.dtype)
self.assertAllEqual(expected_values, output.values.eval())
self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
@@ -792,7 +792,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_int32_64_is_compatible(self):
@@ -806,7 +806,7 @@ class HashedCategoricalColumnTest(test.TestCase):
output = transformation_cache.get(hashed_sparse, None)
# Check exact hashed output. If hashing changes this test will break.
expected_values = [3, 7, 5]
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_values, output.values.eval())
def test_get_sparse_tensors(self):
@@ -1000,7 +1000,7 @@ class CrossedColumnTest(test.TestCase):
features=fc.make_parse_example_spec([price_cross_wire]))
self.assertIn('price', features)
self.assertIn('wire', features)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[20., 110.]], features['price'].eval())
wire_sparse = features['wire']
self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
@@ -1023,7 +1023,7 @@ class CrossedColumnTest(test.TestCase):
}
outputs = _transform_features(features, [price_cross_wire], None)
output = outputs[price_cross_wire]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_val = sess.run(output)
self.assertAllEqual(
[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
@@ -3427,7 +3427,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_vocabulary_size(self):
@@ -3451,7 +3451,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2))
column.get_sparse_tensors(FeatureTransformationCache({'aaa': inputs}), None)
with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
def test_invalid_num_oov_buckets(self):
@@ -3521,7 +3521,7 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3972,7 +3972,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -3994,7 +3994,7 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4311,7 +4311,7 @@ class IdentityCategoricalColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4595,7 +4595,7 @@ class IndicatorColumnTest(test.TestCase):
'animal': ['fox', 'fox']
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_2D_shape_succeeds(self):
@@ -4610,7 +4610,7 @@ class IndicatorColumnTest(test.TestCase):
dense_shape=[2, 1])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
def test_multi_hot(self):
@@ -4623,7 +4623,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
def test_multi_hot2(self):
@@ -4635,7 +4635,7 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
})
output = transformation_cache.get(animal, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
def test_deep_copy(self):
@@ -4660,7 +4660,7 @@ class IndicatorColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_indicator]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -4898,7 +4898,7 @@ class EmbeddingColumnTest(test.TestCase):
serialized=[data.SerializeToString()],
features=fc.make_parse_example_spec([a_embedded]))
self.assertIn('aaa', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -5698,7 +5698,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_embedded, b_embedded]))
self.assertIn('aaa', features)
self.assertIn('bbb', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(
@@ -6271,7 +6271,7 @@ class WeightedCategoricalColumnTest(test.TestCase):
features=fc.make_parse_example_spec([a_weighted]))
self.assertIn('aaa', features)
self.assertIn('weights', features)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(
self,
sparse_tensor.SparseTensorValue(