diff options
Diffstat (limited to 'tensorflow/contrib')
12 files changed, 86 insertions, 85 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index 7846814546..01ee8703a9 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -43,7 +43,7 @@ class BatchOpsTest(test.TestCase): def testBasicBatch(self): """Tests that a single batched tensor executes together and only once.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -83,7 +83,7 @@ class BatchOpsTest(test.TestCase): def testBatchWithPadding(self): """Test that batching with padding up to an allowed batch size works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -113,7 +113,7 @@ class BatchOpsTest(test.TestCase): def testMultipleBatch(self): """Tests that multiple batched tensors execute together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, _, _ = batch_ops.batch( @@ -152,7 +152,7 @@ class BatchOpsTest(test.TestCase): def testIllegalBatchDifferentDim0Sizes(self): """Tests illegally feeding tensors with different dim0 sizes.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) batched, index, _ = batch_ops.batch( @@ -166,7 +166,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatch(self): """Tests that batch and unbatch work together.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=10, @@ -190,7 +190,8 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchV1Decorated(self): """Tests that the batch_function_v1 decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) def computation(in_t): return in_t + 1 @@ -211,7 +212,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: # TODO(apassos): Removing this line causes test flakiness! Ideally should # be investigated. default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable @@ -236,7 +237,7 @@ class BatchOpsTest(test.TestCase): def testBatchDecoratedWithCapturedInput(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) @@ -260,7 +261,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOp(self): """Tests that the batch_function op works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @function.Defun(dtypes.int32) def computation(in_t): @@ -289,7 +290,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithCapturedInput(self): """Tests that batch_function op works with captured input.""" - with self.test_session() as sess: + with self.cached_session() as sess: captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @@ -323,7 +324,7 @@ class BatchOpsTest(test.TestCase): def testBatchFunctionOpWithInputError(self): """Tests that batch_function op works with error in the inputs.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) @function.Defun(dtypes.int32, dtypes.int32) @@ -346,7 +347,7 @@ class BatchOpsTest(test.TestCase): def testBasicUnbatchDecoratedWithReshape(self): """Tests that the batch_function decorator works.""" - with self.test_session() as sess: + with self.cached_session() as sess: @batch_ops.batch_function(1, 10, 100000) def computation(in_t): @@ -368,7 +369,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, @@ -410,7 +411,7 @@ class BatchOpsTest(test.TestCase): def testUnbatchGrad(self): """Tests that batch and unbatch are differentiable.""" - with self.test_session() as sess: + with self.cached_session() as sess: inp = array_ops.placeholder(dtype=dtypes.float32, shape=[1]) batched, index, id_t = batch_ops.batch( [inp], num_batch_threads=1, max_batch_size=2, diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 42b3b9f026..3e631b5909 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,7 +173,7 @@ class JITTest(test.TestCase): class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([[3.]]) y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py index 8dad80aa64..c32ea9ade7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py @@ -93,12 +93,12 @@ class SoftsignBijectorTest(test.TestCase): bijector.inverse_log_det_jacobian(y, event_ndims=1))) def testScalarCongruency(self): - with self.test_session(): + with self.cached_session(): bijector = Softsign(validate_args=True) assert_scalar_congruency(bijector, lower_x=-20., upper_x=20.) def testBijectiveAndFinite(self): - with self.test_session(): + with self.cached_session(): bijector = Softsign(validate_args=True) x = np.linspace(-20., 20., 100).astype(np.float32) y = np.linspace(-0.99, 0.99, 100).astype(np.float32) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 58f348034f..64d6706199 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -399,7 +399,7 @@ class StarGANModelTest(test.TestCase): target_tensor = train._generate_stargan_random_domain_target( batch_size, domain_numbers) - with self.test_session() as sess: + with self.cached_session() as sess: targets = sess.run(target_tensor) self.assertTupleEqual((batch_size, domain_numbers), targets.shape) for target in targets: @@ -676,7 +676,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): self.assertIsInstance(model_loss, namedtuples.GANLoss) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py index fed8a771cc..27aed091c2 100644 --- a/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py +++ b/tensorflow/contrib/grid_rnn/python/kernel_tests/grid_rnn_test.py @@ -233,7 +233,7 @@ class GridRNNCellTest(test.TestCase): ([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]]))) def testGrid2LSTMCellWithRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -261,7 +261,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid2BasicRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) @@ -292,7 +292,7 @@ class GridRNNCellTest(test.TestCase): [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellTied(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([2, 2]) @@ -323,7 +323,7 @@ class GridRNNCellTest(test.TestCase): [[0.80049908, 0.80049908], [0.97574311, 0.97574311]])) def testGrid2BasicRNNCellWithRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -348,7 +348,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid1LSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)) as root_scope: x = array_ops.zeros([1, 3]) @@ -410,7 +410,7 @@ class GridRNNCellTest(test.TestCase): """ def testGrid3LSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -455,7 +455,7 @@ class GridRNNCellTest(test.TestCase): """ def testGridRNNEdgeCasesLikeRelu(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([3, 2]) @@ -481,7 +481,7 @@ class GridRNNCellTest(test.TestCase): self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],)) def testGridRNNEdgeCasesNoOutput(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -541,7 +541,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -581,7 +581,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -623,7 +623,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -663,7 +663,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape(), (3, num_units)) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((batch_size, input_size)) @@ -700,7 +700,7 @@ class GridRNNCellTest(test.TestCase): self.assertEqual(out[0].get_shape()[1], num_units) self.assertEqual(out[0].dtype, inp.dtype) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) input_value = np.ones((3, input_size)) @@ -715,7 +715,7 @@ class GridRNNCellTest(test.TestCase): def testGrid2LSTMCellLegacy(self): """Test for legacy case (when state_is_tuple=False).""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( 'root', initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py index 9ed017592a..f44edaa14c 100644 --- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class InputPipelineOpsTest(test.TestCase): def testObtainNext(self): - with self.test_session(): + with self.cached_session(): var = state_ops.variable_op([], dtypes.int64) state_ops.assign(var, -1).op.run() c = constant_op.constant(["a", "b"]) @@ -45,7 +45,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNext(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list) session.run([variables.global_variables_initializer()]) self.assertEqual(b"a", session.run(elem)) @@ -65,7 +65,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNextLimitEpochs(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) session.run([ variables.local_variables_initializer(), @@ -75,7 +75,7 @@ class InputPipelineOpsTest(test.TestCase): def testSeekNextLimitEpochsThree(self): string_list = ["a", "b", "c"] - with self.test_session() as session: + with self.cached_session() as session: elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) session.run([ variables.local_variables_initializer(), diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index 72507539f8..4d5cc24ce0 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -32,7 +32,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLogitsShape(self): """An error is raised when logits have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2,)) labels = constant_op.constant([0, 1]) with self.assertRaises(ValueError): @@ -40,7 +40,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLabelsShape(self): """An error is raised when labels have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(1, 1, 2)) with self.assertRaises(ValueError): @@ -48,7 +48,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidWeightsShape(self): """An error is raised when weights have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(2,)) weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1)) @@ -57,7 +57,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInvalidLabelsDtype(self): """An error is raised when labels have invalid shape.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], dtype=dtypes.float32) with self.assertRaises(ValueError): @@ -65,7 +65,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNoneWeightRaisesValueError(self): """An error is raised when weights are None.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0]) with self.assertRaises(ValueError): @@ -73,7 +73,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInconsistentLabelsAndWeightsShapesSameRank(self): """Error raised when weights and labels have same ranks, different sizes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1)) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) weights = constant_op.constant([1.1, 2.0], shape=(2, 1)) @@ -82,7 +82,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testInconsistentLabelsAndWeightsShapesDifferentRank(self): """Error raised when weights and labels have different ranks and sizes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) labels = constant_op.constant([1, 0], shape=(2, 1)) weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,)) @@ -91,7 +91,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testOutOfRangeLabels(self): """An error is raised when labels are not in [0, num_classes).""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) labels = constant_op.constant([1, 0, 4]) @@ -101,7 +101,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testZeroLossInt32Labels(self): """Loss is 0 if true class logits sufficiently higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32) @@ -110,7 +110,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testZeroLossInt64Labels(self): """Loss is 0 if true class logits sufficiently higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0], [-0.5, 0.8, -1.0]]) labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64) @@ -130,7 +130,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): ] for batch_size, num_classes in logits_shapes: - with self.test_session(): + with self.cached_session(): logits = array_ops.placeholder( dtypes.float32, shape=(batch_size, num_classes)) labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) @@ -140,7 +140,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testCorrectPredictionsSomeClassesInsideMargin(self): """Loss is > 0 even if true class logits are higher than other classes.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0], [1.5, 1.8, -1.0]]) labels = constant_op.constant([0, 2, 1]) @@ -150,7 +150,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictions(self): """Loss is >0 when an incorrect class has higher logits than true class.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0], [0.5, -1.8, 2.0]]) labels = constant_op.constant([1, 0, 2]) @@ -162,7 +162,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictionsColumnLabels(self): """Same as above but labels is a rank-2 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -174,7 +174,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testIncorrectPredictionsZeroWeights(self): """Loss is 0 when all weights are missing even if predictions are wrong.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -185,7 +185,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWithPythonScalarWeights(self): """Weighted loss is correctly computed when weights is a python scalar.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -195,7 +195,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWithScalarTensorWeights(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -205,7 +205,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWith1DTensorWeightsColumnLabels(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0]]) labels = constant_op.constant([1, 0, 2], shape=(3, 1)) @@ -216,7 +216,7 @@ class SparseMulticlassHingeLossTest(test.TestCase): def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self): """Weighted loss is correctly computed when weights is a rank-0 tensor.""" - with self.test_session(): + with self.cached_session(): logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]]) labels = constant_op.constant([1, 0, 2, 1]) diff --git a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py index 2ff4d41d75..bad0a596a7 100644 --- a/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py +++ b/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features_test.py @@ -58,7 +58,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): def testInvalidInputShape(self): x = constant_op.constant([[2.0, 1.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 10) with self.assertRaisesWithPredicateMatch( dense_kernel_mapper.InvalidShapeError, @@ -70,7 +70,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): x2 = constant_op.constant([[1.0, -1.0, 2.0], [-1.0, 10.0, 1.0], [4.0, -2.0, -1.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 10, 1.0) mapped_x1 = rffm.map(x1) mapped_x2 = rffm.map(x2) @@ -80,7 +80,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): def testSameOmegaReused(self): x = constant_op.constant([[2.0, 1.0, 0.0]]) - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(3, 100) mapped_x = rffm.map(x) mapped_x_copy = rffm.map(x) @@ -93,7 +93,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): y = constant_op.constant([[1.0, -1.0, 2.0]]) stddev = 3.0 - with self.test_session(): + with self.cached_session(): # The mapped dimension is fairly small, so the kernel approximation is # very rough. rffm1 = RandomFourierFeatureMapper(3, 100, stddev) @@ -113,7 +113,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): y = constant_op.constant([[1.0, -1.0, 2.0]]) stddev = 3.0 - with self.test_session(): + with self.cached_session(): # The mapped dimension is fairly small, so the kernel approximation is # very rough. rffm = RandomFourierFeatureMapper(3, 100, stddev, seed=0) @@ -139,7 +139,7 @@ class RandomFourierFeatureMapperTest(TensorFlowTestCase): normalized_points = [nn.l2_normalize(point, dim=1) for point in points] total_absolute_error = 0.0 - with self.test_session(): + with self.cached_session(): rffm = RandomFourierFeatureMapper(input_dim, mapped_dim, stddev, seed=0) # Cache mappings so that they are not computed multiple times. cached_mappings = dict((point, rffm.map(point)) diff --git a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py index e4db5f2e3c..e6a0b30567 100644 --- a/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py +++ b/tensorflow/contrib/stat_summarizer/python/stat_summarizer_test.py @@ -38,7 +38,7 @@ class StatSummarizerTest(test.TestCase): graph_def = graph.as_graph_def() ss = pywrap_tensorflow.NewStatSummarizer(graph_def.SerializeToString()) - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) for _ in range(20): diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py index e429d12e96..1c4e18dbda 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py @@ -32,7 +32,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase): indices = [[1], [10]] updates = [100., 200.] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run() self.assertAllEqual( @@ -45,7 +45,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase): indices = [[0, 0, 1], [1, 1, 2]] updates = [100., 200.] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run() self.assertAllEqual([[[1., 102., 3.], [4., 5., 6.]], @@ -57,7 +57,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase): indices = [] updates = [] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run() self.assertAllEqual(init_val, input_data.eval()) @@ -67,7 +67,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase): input_data = variables.Variable(init_val) indices = [[0, 0, 1], [1, 1, 2]] updates = [100.] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() with self.assertRaisesOpError( 'Number of updates should be same as number of indices.'): @@ -80,7 +80,7 @@ class ScatterAddNdimTest(test_util.TensorFlowTestCase): indices = [[0, 0], [1, 1]] updates = [[100., 200., 300.], [400., 500., 600.]] - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() tensor_forest_ops.scatter_add_ndim(input_data, indices, updates).run() self.assertAllEqual([[[101., 202., 303.], [4., 5., 6.]], diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index 1c9c81827e..e0f0c0d4ff 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -149,7 +149,7 @@ class TensorForestTest(test_util.TensorFlowTestCase): self.assertTrue(isinstance(probs, ops.Tensor)) self.assertTrue(isinstance(paths, ops.Tensor)) self.assertTrue(isinstance(var, ops.Tensor)) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() resources.initialize_resources(resources.shared_resources()).run() self.assertEquals(probs.eval().shape, (4, 2)) diff --git a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py index 84e36146d5..832d34d60d 100644 --- a/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py +++ b/tensorflow/contrib/text/python/ops/skip_gram_ops_test.py @@ -63,7 +63,7 @@ class SkipGramOpsTest(test.TestCase): (b"jumps", b"brown"), (b"jumps", b"fox"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -94,7 +94,7 @@ class SkipGramOpsTest(test.TestCase): (b"jumps", b"fox"), (b"jumps", b"jumps"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -105,7 +105,7 @@ class SkipGramOpsTest(test.TestCase): # If emit_self_as_target is False (default), output will be empty. tokens, labels = text.skip_gram_sample( input_tensor, min_skips=0, max_skips=0, emit_self_as_target=False) - with self.test_session(): + with self.cached_session(): self.assertEqual(0, tokens.eval().size) self.assertEqual(0, labels.eval().size) @@ -117,7 +117,7 @@ class SkipGramOpsTest(test.TestCase): (b"quick", b"quick"), (b"brown", b"brown"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -134,7 +134,7 @@ class SkipGramOpsTest(test.TestCase): (b"brown", b"the"), (b"brown", b"quick"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -150,7 +150,7 @@ class SkipGramOpsTest(test.TestCase): (b"quick", b"brown"), (b"brown", b"quick"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -165,7 +165,7 @@ class SkipGramOpsTest(test.TestCase): (b"quick", b"brown"), (b"brown", b"quick"), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -196,7 +196,7 @@ class SkipGramOpsTest(test.TestCase): (b"over", b"fox"), (b"over", b"jumps"), ]) - with self.test_session() as sess: + with self.cached_session() as sess: tokens_eval, labels_eval = sess.run([tokens, labels]) self.assertAllEqual(expected_tokens, tokens_eval) self.assertAllEqual(expected_labels, labels_eval) @@ -222,7 +222,7 @@ class SkipGramOpsTest(test.TestCase): tokens_2, labels_2 = text.skip_gram_sample( input_tensor, min_skips=1, max_skips=5) - with self.test_session() as sess: + with self.cached_session() as sess: tokens_1_eval, labels_1_eval, tokens_2_eval, labels_2_eval = sess.run( [tokens_1, labels_1, tokens_2, labels_2]) @@ -244,7 +244,7 @@ class SkipGramOpsTest(test.TestCase): (b"brown", b"fox"), (b"fox", b"brown"), ]) - with self.test_session() as sess: + with self.cached_session() as sess: coord = coordinator.Coordinator() threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) @@ -269,7 +269,7 @@ class SkipGramOpsTest(test.TestCase): (2, 3), (3, 2), ]) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -286,7 +286,7 @@ class SkipGramOpsTest(test.TestCase): for min_skips, max_skips in invalid_skips: tokens, labels = text.skip_gram_sample( input_tensor, min_skips=min_skips, max_skips=max_skips) - with self.test_session() as sess, self.assertRaises( + with self.cached_session() as sess, self.assertRaises( errors.InvalidArgumentError): sess.run([tokens, labels]) @@ -338,7 +338,7 @@ class SkipGramOpsTest(test.TestCase): vocab_freq_table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), -1) - with self.test_session(): + with self.cached_session(): vocab_freq_table.init.run() # No vocab_freq_table specified - output should be the same as input. @@ -395,7 +395,7 @@ class SkipGramOpsTest(test.TestCase): vocab_freq_table = lookup.HashTable( lookup.KeyValueTensorInitializer(keys, values), -1) - with self.test_session(): + with self.cached_session(): vocab_freq_table.init.run() output = skip_gram_ops._filter_input( input_tensor=input_tensor, @@ -464,7 +464,7 @@ class SkipGramOpsTest(test.TestCase): (b"life", b"and"), (b"and", b"life"), ]) - with self.test_session(): + with self.cached_session(): lookup_ops.tables_initializer().run() self.assertAllEqual(expected_tokens, tokens.eval()) self.assertAllEqual(expected_labels, labels.eval()) @@ -510,7 +510,7 @@ class SkipGramOpsTest(test.TestCase): (b"to", b"life"), (b"life", b"to"), ]) - with self.test_session() as sess: + with self.cached_session() as sess: lookup_ops.tables_initializer().run() tokens_eval, labels_eval = sess.run([tokens, labels]) self.assertAllEqual(expected_tokens, tokens_eval) |