aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-14 19:57:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-14 21:03:56 -0700
commit333b69580537bf14a3072f1388de64eb3fb5ebc2 (patch)
treed128511a6db7b2105c59198496792f91b7d4b40a /tensorflow
parent0f9c0e3f5f47ae1d349d657111d39a89d7fcaa34 (diff)
Moves eval op from dnn_linear_combined to TargetColumn.
Change: 127504952
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column.py149
-rw-r--r--tensorflow/contrib/layers/python/layers/target_column_test.py75
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py84
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py9
-rw-r--r--tensorflow/contrib/metrics/python/ops/set_ops.py9
6 files changed, 227 insertions, 101 deletions
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 34d419b91b..521e00707b 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -18,6 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import inspect
+
+import six
+
+from tensorflow.contrib import metrics as metrics_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -116,9 +121,9 @@ class _TargetColumn(object):
# Abstrat, Subclasses must implement.
raise NotImplementedError()
- def eval_metrics(self, logits):
- # Do nothing by defalut, subclasses can override.
- pass
+ def get_eval_ops(self, features, logits, targets, metrics=None):
+ """Returns eval op."""
+ raise NotImplementedError
@property
def label_name(self):
@@ -187,6 +192,15 @@ class _RegressionTargetColumn(_TargetColumn):
return array_ops.squeeze(logits, squeeze_dims=[1])
return logits
+ def get_eval_ops(self, features, logits, targets, metrics=None):
+ loss = self.loss(logits, targets, features)
+ result = {"loss": metrics_lib.streaming_mean(loss)}
+ if metrics:
+ predictions = self.logits_to_predictions(logits, proba=False)
+ result.update(_run_metrics(predictions, targets, metrics,
+ self.get_weight_tensor(features)))
+ return result
+
class _MultiClassTargetColumn(_TargetColumn):
"""_TargetColumn for classification."""
@@ -214,9 +228,59 @@ class _MultiClassTargetColumn(_TargetColumn):
else:
return math_ops.argmax(logits, 1)
- def eval_metrics(self, logits, targets):
- # TODO(zakaria): Handle eval metric in target column.
- raise NotImplementedError
+ def _default_eval_metrics(self):
+ if self._num_label_columns == 1:
+ return _get_default_binary_metrics_for_eval(thresholds=[.5])
+ return {}
+
+ def get_eval_ops(self, features, logits, targets, metrics=None):
+ loss = self.loss(logits, targets, features)
+ result = {"loss": metrics_lib.streaming_mean(loss)}
+
+ # Adds default metrics.
+ if metrics is None:
+ # TODO(b/29366811): This currently results in both an "accuracy" and an
+ # "accuracy/threshold_0.500000_mean" metric for binary classification.
+ metrics = {("accuracy", "classes"): metrics_lib.streaming_accuracy}
+
+ predictions = math_ops.sigmoid(logits)
+ targets_float = math_ops.to_float(targets)
+
+ default_metrics = self._default_eval_metrics()
+ for metric_name, metric_op in default_metrics.items():
+ result[metric_name] = metric_op(predictions, targets_float)
+
+ class_metrics = {}
+ proba_metrics = {}
+ for name, metric_op in six.iteritems(metrics):
+ if isinstance(name, tuple):
+ if len(name) != 2:
+ raise ValueError("Ignoring metric {}. It returned a tuple with "
+ "len {}, expected 2.".format(name, len(name)))
+ else:
+ if name[1] not in ["classes", "probabilities"]:
+ raise ValueError("Ignoring metric {}. The 2nd element of its "
+ "name should be either 'classes' or "
+ "'probabilities'.".format(name))
+ elif name[1] == "classes":
+ class_metrics[name[0]] = metric_op
+ else:
+ proba_metrics[name[0]] = metric_op
+ elif isinstance(name, str):
+ class_metrics[name] = metric_op
+ else:
+ raise ValueError("Ignoring metric {}. Its name is not in the correct "
+ "form.".format(name))
+ if class_metrics:
+ class_predictions = self.logits_to_predictions(logits, proba=False)
+ result.update(_run_metrics(class_predictions, targets,
+ class_metrics,
+ self.get_weight_tensor(features)))
+ if proba_metrics:
+ predictions = self.logits_to_predictions(logits, proba=True)
+ result.update(_run_metrics(predictions, targets, proba_metrics,
+ self.get_weight_tensor(features)))
+ return result
# TODO(zakaria): use contrib losses.
@@ -250,3 +314,76 @@ def _softmax_cross_entropy_loss(logits, target):
target = array_ops.squeeze(target, squeeze_dims=[1])
loss_vec = nn.sparse_softmax_cross_entropy_with_logits(logits, target)
return loss_vec
+
+
+def _run_metrics(predictions, targets, metrics, weights):
+ result = {}
+ targets = math_ops.cast(targets, predictions.dtype)
+ for name, metric in six.iteritems(metrics or {}):
+ if "weights" in inspect.getargspec(metric)[0]:
+ result[name] = metric(predictions, targets, weights=weights)
+ else:
+ result[name] = metric(predictions, targets)
+
+ return result
+
+
+def _get_default_binary_metrics_for_eval(thresholds):
+ """Returns a dictionary of basic metrics for logistic regression.
+
+ Args:
+ thresholds: List of floating point thresholds to use for accuracy,
+ precision, and recall metrics. If None, defaults to [0.5].
+
+ Returns:
+ Dictionary mapping metrics string names to metrics functions.
+ """
+ metrics = {}
+ metrics[_MetricKeys.PREDICTION_MEAN] = _predictions_streaming_mean
+ metrics[_MetricKeys.TARGET_MEAN] = _targets_streaming_mean
+ # Also include the streaming mean of the label as an accuracy baseline, as
+ # a reminder to users.
+ metrics[_MetricKeys.ACCURACY_BASELINE] = _targets_streaming_mean
+
+ metrics[_MetricKeys.AUC] = metrics_lib.streaming_auc
+
+ for threshold in thresholds:
+ metrics[_MetricKeys.ACCURACY_MEAN % threshold] = _streaming_with_threshold(
+ metrics_lib.streaming_accuracy, threshold)
+ # Precision for positive examples.
+ metrics[_MetricKeys.PRECISION_MEAN % threshold] = _streaming_with_threshold(
+ metrics_lib.streaming_precision, threshold)
+ # Recall for positive examples.
+ metrics[_MetricKeys.RECALL_MEAN % threshold] = _streaming_with_threshold(
+ metrics_lib.streaming_recall, threshold)
+
+ return metrics
+
+
+# TODO(zakaria): support weights.
+def _targets_streaming_mean(unused_predictions, targets):
+ return metrics_lib.streaming_mean(targets)
+
+
+def _predictions_streaming_mean(predictions, unused_targets):
+ return metrics_lib.streaming_mean(predictions)
+
+
+def _streaming_with_threshold(streaming_metrics_fn, threshold):
+
+ def _streaming_metrics(predictions, targets):
+ return streaming_metrics_fn(predictions=math_ops.to_float(
+ math_ops.greater_equal(predictions, threshold)),
+ labels=targets)
+
+ return _streaming_metrics
+
+
+class _MetricKeys(object):
+ AUC = "auc"
+ PREDICTION_MEAN = "labels/prediction_mean"
+ TARGET_MEAN = "labels/actual_target_mean"
+ ACCURACY_BASELINE = "accuracy/baseline_target_mean"
+ ACCURACY_MEAN = "accuracy/threshold_%f_mean"
+ PRECISION_MEAN = "precision/positive_threshold_%f_mean"
+ RECALL_MEAN = "recall/positive_threshold_%f_mean"
diff --git a/tensorflow/contrib/layers/python/layers/target_column_test.py b/tensorflow/contrib/layers/python/layers/target_column_test.py
index 27a23a0a1e..3e86e0cce2 100644
--- a/tensorflow/contrib/layers/python/layers/target_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/target_column_test.py
@@ -21,8 +21,9 @@ from __future__ import print_function
import tensorflow as tf
-class TargetColumnTest(tf.test.TestCase):
+class RegressionTargetColumnTest(tf.test.TestCase):
+ # TODO(zakaria): test multilabel regresssion.
def testRegression(self):
target_column = tf.contrib.layers.regression_target()
with tf.Graph().as_default(), tf.Session() as sess:
@@ -42,11 +43,47 @@ class TargetColumnTest(tf.test.TestCase):
1.,
sess.run(target_column.loss(logits, targets, features)))
- # TODO(zakaria): test multlabel regresssion.
- def testSoftmax(self):
+class MulltiClassTargetColumnTest(tf.test.TestCase):
+
+ def testBinaryClassification(self):
+ target_column = tf.contrib.layers.multi_class_target(n_classes=2)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ logits = tf.constant([[1.], [1.]])
+ targets = tf.constant([[1.], [0.]])
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ self.assertAlmostEqual(.81326163,
+ sess.run(target_column.loss(logits, targets, {})))
+
+ def testBinaryClassificationWithWeights(self):
target_column = tf.contrib.layers.multi_class_target(
- n_classes=3)
+ n_classes=2,
+ weight_column_name="label_weight")
+ with tf.Graph().as_default(), tf.Session() as sess:
+ features = {"label_weight": tf.constant([[1.], [0.]])}
+ logits = tf.constant([[1.], [1.]])
+ targets = tf.constant([[1.], [0.]])
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ self.assertAlmostEqual(
+ .31326166, sess.run(target_column.loss(logits, targets, features)))
+
+ def testBinaryEvalMetrics(self):
+ target_column = tf.contrib.layers.multi_class_target(n_classes=2)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ logits = tf.constant([[1.], [1.], [-1.]])
+ targets = tf.constant([[1.], [0.], [1.]])
+ eval_dict = target_column.get_eval_ops({}, logits, targets)
+ # TODO(zakaria): test all metrics
+ accuracy_op, update_op = eval_dict["accuracy/threshold_0.500000_mean"]
+ sess.run(tf.initialize_all_variables())
+ sess.run(tf.initialize_local_variables())
+ sess.run(update_op)
+ self.assertAlmostEqual(1.0/3, sess.run(accuracy_op))
+
+ def testMultiClass(self):
+ target_column = tf.contrib.layers.multi_class_target(n_classes=3)
with tf.Graph().as_default(), tf.Session() as sess:
logits = tf.constant([[1., 0., 0.]])
targets = tf.constant([2])
@@ -55,7 +92,21 @@ class TargetColumnTest(tf.test.TestCase):
self.assertAlmostEqual(1.5514446,
sess.run(target_column.loss(logits, targets, {})))
- def testSoftmaxWithInvalidNClass(self):
+ def testMultiClassWithWeight(self):
+ target_column = tf.contrib.layers.multi_class_target(
+ n_classes=3,
+ weight_column_name="label_weight")
+ with tf.Graph().as_default(), tf.Session() as sess:
+ features = {"label_weight": tf.constant([0.1])}
+ logits = tf.constant([[1., 0., 0.]])
+ targets = tf.constant([2])
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ self.assertAlmostEqual(1.5514446,
+ sess.run(target_column.loss(
+ logits, targets, features)))
+
+ def testMultiClassWithInvalidNClass(self):
try:
tf.contrib.layers.multi_class_target(n_classes=1)
self.fail("Softmax with no n_classes did not raise error.")
@@ -63,6 +114,20 @@ class TargetColumnTest(tf.test.TestCase):
# Expected
pass
+ def testMultiClassEvalMetrics(self):
+ target_column = tf.contrib.layers.multi_class_target(n_classes=3)
+ with tf.Graph().as_default(), tf.Session() as sess:
+ logits = tf.constant([[1., 0., 0.]])
+ targets = tf.constant([2])
+ eval_dict = target_column.get_eval_ops({}, logits, targets)
+ loss_op, update_op = eval_dict["loss"]
+ sess.run(tf.initialize_all_variables())
+ sess.run(tf.initialize_local_variables())
+ sess.run(update_op)
+ # logloss: z:label, x:logit
+ # z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+ self.assertAlmostEqual(1.5514446, sess.run(loss_op))
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 749d643b89..e5d840c895 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -19,22 +19,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import inspect
-
import numpy as np
import six
from tensorflow.contrib import layers
-from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.contrib.learn.python.learn.estimators import composable_model
from tensorflow.contrib.learn.python.learn.estimators import estimator
-from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
@@ -200,17 +195,6 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
with ops.get_default_graph().colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op, loss
- def _run_metrics(self, predictions, targets, metrics, weights):
- result = {}
- targets = math_ops.cast(targets, predictions.dtype)
- for name, metric in six.iteritems(metrics or {}):
- if "weights" in inspect.getargspec(metric)[0]:
- result[name] = metric(predictions, targets, weights=weights)
- else:
- result[name] = metric(predictions, targets)
-
- return result
-
def _get_eval_ops(self, features, targets, metrics=None):
raise NotImplementedError
@@ -459,63 +443,7 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
"""See base class."""
features = self._get_feature_dict(features)
logits = self._logits(features)
- loss = self._target_column.loss(logits, targets, features)
- result = {"loss": metrics_lib.streaming_mean(loss)}
-
- # Adds default metrics.
- if metrics is None:
- # TODO(b/29366811): This currently results in both an "accuracy" and an
- # "accuracy/threshold_0.500000_mean" metric for binary classification.
- metrics = {("accuracy", "classes"): metrics_lib.streaming_accuracy}
-
- # Adds additional useful metrics for the special case of binary
- # classification.
- # TODO(zakaria): Move LogisticRegressor.get_default_metrics to metrics
- # and handle eval metric from targetcolumn.
- if self._target_column.num_label_columns == 1:
- predictions = math_ops.sigmoid(logits)
- targets_float = math_ops.to_float(targets)
- default_metrics = (
- logistic_regressor.LogisticRegressor.get_default_metrics())
- for metric_name, metric_op in default_metrics.items():
- result[metric_name] = metric_op(predictions, targets_float)
-
- if metrics:
- class_metrics = {}
- proba_metrics = {}
- for name, metric_op in six.iteritems(metrics):
- if isinstance(name, tuple):
- if len(name) != 2:
- raise ValueError("Ignoring metric {}. It returned a tuple with "
- "len {}, expected 2.".format(name, len(name)))
- else:
- if name[1] not in ["classes", "probabilities"]:
- raise ValueError("Ignoring metric {}. The 2nd element of its "
- "name should be either 'classes' or "
- "'probabilities'.".format(name))
- elif name[1] == "classes":
- class_metrics[name[0]] = metric_op
- else:
- proba_metrics[name[0]] = metric_op
- elif isinstance(name, str):
- class_metrics[name] = metric_op
- else:
- raise ValueError("Ignoring metric {}. Its name is not in the correct "
- "form.".format(name))
- if class_metrics:
- predictions = self._target_column.logits_to_predictions(logits,
- proba=False)
- result.update(self._run_metrics(predictions, targets, class_metrics,
- self._target_column.get_weight_tensor(
- features)))
- if proba_metrics:
- predictions = self._target_column.logits_to_predictions(logits,
- proba=True)
- result.update(self._run_metrics(predictions, targets, proba_metrics,
- self._target_column.get_weight_tensor(
- features)))
-
- return result
+ return self._target_column.get_eval_ops(features, logits, targets, metrics)
class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
@@ -650,14 +578,6 @@ class DNNLinearCombinedRegressor(_DNNLinearCombinedBaseEstimator):
"""See base class."""
features = self._get_feature_dict(features)
logits = self._logits(features)
- loss = self._target_column.loss(logits, targets, features)
- result = {"loss": metrics_lib.streaming_mean(loss)}
+ return self._target_column.get_eval_ops(features, logits, targets, metrics)
- if metrics:
- predictions = self._target_column.logits_to_predictions(logits,
- proba=False)
- result.update(self._run_metrics(predictions, targets, metrics,
- self._target_column.get_weight_tensor(
- features)))
- return result
diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py
index 596697a841..d4e0bb6283 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py
@@ -67,6 +67,8 @@ class LogisticRegressor(estimator.Estimator):
model_dir=model_dir,
config=config)
+ # TODO(zakaria): use target column.
+
# Metrics string keys.
AUC = "auc"
PREDICTION_MEAN = "labels/prediction_mean"
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 87f15e13d5..d579eb43e6 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -22,7 +22,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import framework
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+
from tensorflow.contrib.metrics.python.ops import confusion_matrix_ops
from tensorflow.contrib.metrics.python.ops import metric_ops_util
from tensorflow.contrib.metrics.python.ops import set_ops
@@ -1217,7 +1218,7 @@ def _streaming_sparse_true_positive_at_k(predictions_idx,
batch_total_tp = math_ops.cast(
math_ops.reduce_sum(tp), dtype=dtypes.float64)
- var = framework.local_variable(
+ var = contrib_variables.local_variable(
array_ops.zeros([], dtype=dtypes.float64), name=scope)
return var, state_ops.assign_add(var, batch_total_tp, name='update')
@@ -1268,7 +1269,7 @@ def _streaming_sparse_false_positive_at_k(predictions_idx,
batch_total_fp = math_ops.cast(
math_ops.reduce_sum(fp), dtype=dtypes.float64)
- var = framework.local_variable(
+ var = contrib_variables.local_variable(
array_ops.zeros([], dtype=dtypes.float64), name=scope)
return var, state_ops.assign_add(var, batch_total_fp, name='update')
@@ -1319,7 +1320,7 @@ def _streaming_sparse_false_negative_at_k(predictions_idx,
batch_total_fn = math_ops.cast(
math_ops.reduce_sum(fn), dtype=dtypes.float64)
- var = framework.local_variable(
+ var = contrib_variables.local_variable(
array_ops.zeros([], dtype=dtypes.float64), name=scope)
return var, state_ops.assign_add(var, batch_total_fn, name='update')
diff --git a/tensorflow/contrib/metrics/python/ops/set_ops.py b/tensorflow/contrib/metrics/python/ops/set_ops.py
index 396d4fb57a..4ed4370d92 100644
--- a/tensorflow/contrib/metrics/python/ops/set_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/set_ops.py
@@ -17,7 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import framework
+from tensorflow.contrib.framework.python.framework import tensor_util
+
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
@@ -56,7 +57,7 @@ def set_size(a, validate_indices=True):
Raises:
TypeError: If `a` is an invalid types.
"""
- a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
+ a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
if not isinstance(a, ops.SparseTensor):
raise TypeError("Expected `SparseTensor`, got %s." % a)
if a.values.dtype.base_dtype not in _VALID_DTYPES:
@@ -106,10 +107,10 @@ def _set_operation(a, b, set_operation, validate_indices=True):
TypeError: If inputs are invalid types.
ValueError: If `a` is sparse and `b` is dense.
"""
- a = framework.convert_to_tensor_or_sparse_tensor(a, name="a")
+ a = tensor_util.convert_to_tensor_or_sparse_tensor(a, name="a")
if a.dtype.base_dtype not in _VALID_DTYPES:
raise TypeError("'a' invalid dtype %s." % a.dtype)
- b = framework.convert_to_tensor_or_sparse_tensor(b, name="b")
+ b = tensor_util.convert_to_tensor_or_sparse_tensor(b, name="b")
if b.dtype.base_dtype != a.dtype.base_dtype:
raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
# pylint: disable=protected-access