aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python/training/evaluation_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/training/python/training/evaluation_test.py')
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index b07039916c..c36d00e842 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -27,7 +27,6 @@ import numpy as np
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.losses.python.losses import loss_ops
-from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.contrib.training.python.training import training
from tensorflow.core.protobuf import config_pb2
@@ -38,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
@@ -196,7 +196,8 @@ class EvaluateOnceTest(test.TestCase):
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
- accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
+ accuracy, update_op = metrics.accuracy(
+ predictions=predictions, labels=labels)
checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)
@@ -311,7 +312,8 @@ class EvaluateRepeatedlyTest(test.TestCase):
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
- accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
+ accuracy, update_op = metrics.accuracy(
+ predictions=predictions, labels=labels)
final_values = evaluation.evaluate_repeatedly(
checkpoint_dir=checkpoint_dir,
@@ -365,7 +367,8 @@ class EvaluateRepeatedlyTest(test.TestCase):
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
- accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
+ accuracy, update_op = metrics.accuracy(
+ predictions=predictions, labels=labels)
timeout_fn_calls = [0]
def timeout_fn():
@@ -417,9 +420,8 @@ class EvaluateRepeatedlyTest(test.TestCase):
self.assertEqual(final_values['my_var'], expected_value)
def _create_names_to_metrics(self, predictions, labels):
- accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels)
- accuracy1, update_op1 = metric_ops.streaming_accuracy(
- predictions + 1, labels)
+ accuracy0, update_op0 = metrics.accuracy(labels, predictions)
+ accuracy1, update_op1 = metrics.accuracy(labels, predictions + 1)
names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1}
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}