diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/confusion_matrix_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/confusion_matrix_test.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/confusion_matrix_test.py b/tensorflow/python/kernel_tests/confusion_matrix_test.py index 93f5323c41..bc24345261 100644 --- a/tensorflow/python/kernel_tests/confusion_matrix_test.py +++ b/tensorflow/python/kernel_tests/confusion_matrix_test.py @@ -37,7 +37,7 @@ class ConfusionMatrixTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testExample(self): """This is a test of the example provided in pydoc.""" - with self.test_session(): + with self.cached_session(): self.assertAllEqual([ [0, 0, 0, 0, 0], [0, 0, 1, 0, 0], @@ -49,7 +49,7 @@ class ConfusionMatrixTest(test.TestCase): def _testConfMatrix(self, labels, predictions, truth, weights=None, num_classes=None): - with self.test_session(): + with self.cached_session(): dtype = predictions.dtype ans = confusion_matrix.confusion_matrix( labels, predictions, dtype=dtype, weights=weights, @@ -78,7 +78,7 @@ class ConfusionMatrixTest(test.TestCase): self._testBasic(dtype=np.int64) def _testConfMatrixOnTensors(self, tf_dtype, np_dtype): - with self.test_session() as sess: + with self.cached_session() as sess: m_neg = array_ops.placeholder(dtype=dtypes.float32) m_pos = array_ops.placeholder(dtype=dtypes.float32) s = array_ops.placeholder(dtype=dtypes.float32) @@ -229,7 +229,7 @@ class ConfusionMatrixTest(test.TestCase): def testOutputIsInt32(self): labels = np.arange(2) predictions = np.arange(2) - with self.test_session(): + with self.cached_session(): cm = confusion_matrix.confusion_matrix( labels, predictions, dtype=dtypes.int32) tf_cm = cm.eval() @@ -238,7 +238,7 @@ class ConfusionMatrixTest(test.TestCase): def testOutputIsInt64(self): labels = np.arange(2) predictions = np.arange(2) - with self.test_session(): + with self.cached_session(): cm = confusion_matrix.confusion_matrix( labels, predictions, dtype=dtypes.int64) tf_cm = cm.eval() @@ -260,7 +260,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): confusion_matrix.remove_squeezable_dimensions( labels_placeholder, predictions_placeholder)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(label_values, static_labels.eval()) self.assertAllEqual(prediction_values, static_predictions.eval()) feed_dict = { @@ -285,7 +285,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): confusion_matrix.remove_squeezable_dimensions( labels_placeholder, predictions_placeholder)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(label_values, static_labels.eval()) self.assertAllEqual(prediction_values, static_predictions.eval()) feed_dict = { @@ -310,7 +310,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): confusion_matrix.remove_squeezable_dimensions( labels_placeholder, predictions_placeholder, expected_rank_diff=0)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(label_values, static_labels.eval()) self.assertAllEqual(prediction_values, static_predictions.eval()) feed_dict = { @@ -336,7 +336,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): labels_placeholder, predictions_placeholder)) expected_label_values = np.reshape(label_values, newshape=(2, 3)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_label_values, static_labels.eval()) self.assertAllEqual(prediction_values, static_predictions.eval()) feed_dict = { @@ -362,7 +362,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): labels_placeholder, predictions_placeholder, expected_rank_diff=1)) expected_label_values = np.reshape(label_values, newshape=(2, 3)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(expected_label_values, static_labels.eval()) self.assertAllEqual(prediction_values, static_predictions.eval()) feed_dict = { @@ -388,7 +388,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): labels_placeholder, predictions_placeholder)) expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(label_values, static_labels.eval()) self.assertAllEqual(expected_prediction_values, static_predictions.eval()) feed_dict = { @@ -415,7 +415,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): labels_placeholder, predictions_placeholder, expected_rank_diff=-1)) expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3)) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(label_values, static_labels.eval()) self.assertAllEqual(expected_prediction_values, static_predictions.eval()) feed_dict = { @@ -441,7 +441,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): confusion_matrix.remove_squeezable_dimensions( labels_placeholder, predictions_placeholder)) - with self.test_session(): + with self.cached_session(): feed_dict = { labels_placeholder: label_values, predictions_placeholder: prediction_values @@ -466,7 +466,7 @@ class RemoveSqueezableDimensionsTest(test.TestCase): confusion_matrix.remove_squeezable_dimensions( labels_placeholder, predictions_placeholder)) - with self.test_session(): + with self.cached_session(): feed_dict = { labels_placeholder: label_values, predictions_placeholder: prediction_values |