aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Zakaria Haque <zakaria@google.com>2017-03-22 15:23:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-22 16:51:32 -0700
commit7d0360a3dc78ff454e3c2783831b970b1b6306f0 (patch)
treef9a89b407b9cbfc8c75302c946193b17d11b9033
parent42c204df8f3e40dffad8ddd2770c0ab881b5a4d8 (diff)
Fixes a bug where heads/pre-canned estimators were not exporting proper classes tensor.
Servo expects classes to be a string tensor of the same shape as scores and containing the labels for corresponding scores. While creating output alternatives, if classes tensor does not match these conditions, we create a new tensor with these properties. Change: 150943474
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py95
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py60
2 files changed, 125 insertions, 30 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 028a13ca20..65f5b49b0e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -42,6 +42,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.summary import summary
@@ -816,7 +817,8 @@ class _BinaryLogisticHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
- create_output_alternatives_fn=self._create_output_alternatives,
+ create_output_alternatives_fn=_classification_output_alternatives(
+ self.head_name, self._problem_type),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@@ -1009,7 +1011,8 @@ class _MultiClassHead(_SingleHead):
loss_fn=self._wrapped_loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
- create_output_alternatives_fn=self._create_output_alternatives,
+ create_output_alternatives_fn=_classification_output_alternatives(
+ self.head_name, self._problem_type, self._label_keys),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@@ -1113,25 +1116,6 @@ class _MultiClassHead(_SingleHead):
return metrics
- def _create_output_alternatives(self, predictions):
- """See superclass."""
- probabilities = predictions[prediction_key.PredictionKey.PROBABILITIES]
- batch_size = array_ops.shape(probabilities)[0]
- if self._label_keys:
- classes = array_ops.tile(
- input=array_ops.expand_dims(input=self._label_keys, axis=0),
- multiples=[batch_size, 1])
- else:
- classes = array_ops.tile(
- input=array_ops.expand_dims(
- input=math_ops.range(self.logits_dimension), axis=0),
- multiples=[batch_size, 1])
- predictions_for_serving = {
- prediction_key.PredictionKey.CLASSES: classes,
- prediction_key.PredictionKey.PROBABILITIES: probabilities,
- }
- return {self._head_name: (self._problem_type, predictions_for_serving)}
-
def _to_labels_tensor(labels, label_name):
"""Returns label as a tensor.
@@ -1226,6 +1210,7 @@ class _BinarySvmHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
+ # TODO(zakaria): Handle labels for export.
create_output_alternatives_fn=self._create_output_alternatives,
labels=labels,
train_op_fn=train_op_fn,
@@ -1325,7 +1310,8 @@ class _MultiLabelHead(_SingleHead):
loss_fn=self._loss_fn,
logits_to_predictions_fn=self._logits_to_predictions,
metrics_fn=self._metrics,
- create_output_alternatives_fn=self._create_output_alternatives,
+ create_output_alternatives_fn=_classification_output_alternatives(
+ self.head_name, self._problem_type),
labels=labels,
train_op_fn=train_op_fn,
logits=logits,
@@ -1901,6 +1887,71 @@ def _streaming_recall_at_threshold(predictions, labels, weights, threshold):
return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
+def _classification_output_alternatives(head_name, problem_type,
+ label_keys=None):
+ """Creates a func to generate output alternatives for classification.
+
+ Servo expects classes to be a string tensor, and have the same dimensions
+ as the probabilities tensor. It should contain the labels of the corresponding
+ entries in probabilities. This function creates a new classes tensor that
+ satisfies these conditions and can be exported.
+
+ Args:
+ head_name: Name of the head.
+ problem_type: `ProblemType`
+ label_keys: Optional label keys
+
+ Returns:
+ A function to generate output alternatives.
+ """
+ def _create_output_alternatives(predictions):
+ """Creates output alternative for the Head.
+
+ Args:
+ predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
+ symbolic name for an output Tensor possibly but not necessarily taken
+ from `PredictionKey`, and 'Tensor' is the corresponding output Tensor
+ itself.
+
+ Returns:
+ `dict` of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
+ 'submodel_name' is a submodel identifier that should be consistent across
+ the pipeline (here likely taken from the head_name),
+ 'problem_type' is a `ProblemType`,
+ 'tensor_name' is a symbolic name for an output Tensor possibly but not
+ necessarily taken from `PredictionKey`, and
+ 'Tensor' is the corresponding output Tensor itself.
+
+ Raises:
+ ValueError: if predictions does not have PredictionKey.PROBABILITIES key.
+ """
+ probabilities = predictions.get(prediction_key.PredictionKey.PROBABILITIES)
+ if probabilities is None:
+ raise ValueError("%s missing in predictions" %
+ prediction_key.PredictionKey.PROBABILITIES)
+
+ with ops.name_scope(None, "_classification_output_alternatives",
+ (probabilities,)):
+ batch_size = array_ops.shape(probabilities)[0]
+ if label_keys:
+ classes = array_ops.tile(
+ input=array_ops.expand_dims(input=label_keys, axis=0),
+ multiples=[batch_size, 1],
+ name="classes_tensor")
+ else:
+ n = array_ops.shape(probabilities)[1]
+ classes = array_ops.tile(
+ input=array_ops.expand_dims(input=math_ops.range(n), axis=0),
+ multiples=[batch_size, 1])
+ classes = string_ops.as_string(classes, name="classes_tensor")
+
+ exported_predictions = {
+ prediction_key.PredictionKey.PROBABILITIES: probabilities,
+ prediction_key.PredictionKey.CLASSES: classes}
+ return {head_name: (problem_type, exported_predictions)}
+
+ return _create_output_alternatives
+
# Aliases
# TODO(zakaria): Remove these aliases, See b/34751732
_regression_head = regression_head
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index ecc1d9ff9e..749508147b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -417,7 +417,7 @@ class MultiLabelHeadTest(test.TestCase):
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
logits_input=((0., 0.),), logits=self._logits)
- def testMultiLabelEvalMode(self):
+ def testMultiLabelEval(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@@ -433,7 +433,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
- def testMultiClassEvalModeWithLargeLogits(self):
+ def testMultiClassEvalWithLargeLogits(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@@ -472,6 +472,36 @@ class MultiLabelHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
expected_eval_metrics, model_fn_ops)
+ def testMultiLabelInfer(self):
+ n_classes = 3
+ head = head_lib.multi_label_head(n_classes=n_classes, head_name="head_name")
+ with ops.Graph().as_default(), session.Session():
+ model_fn_ops = head.create_model_fn_ops(
+ {}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
+ logits=((1., 0., 0.), (0., 0., 1)))
+ self.assertIsNone(model_fn_ops.train_op)
+ _assert_no_variables(self)
+ with session.Session():
+ self.assertListEqual(
+ [1, 0, 0], model_fn_ops.predictions["classes"].eval().tolist()[0])
+ self.assertItemsEqual(
+ ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
+ self.assertEqual(
+ constants.ProblemType.CLASSIFICATION,
+ model_fn_ops.output_alternatives["head_name"][0])
+
+ predictions_for_serving = (
+ model_fn_ops.output_alternatives["head_name"][1])
+ self.assertIn("classes", six.iterkeys(predictions_for_serving))
+ self.assertAllEqual(
+ [["0", "1", "2"], ["0", "1", "2"]],
+ predictions_for_serving["classes"].eval())
+ self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
+ self.assertAllClose(
+ [[0.731059, 0.5, 0.5],
+ [0.5, 0.5, 0.731059,]],
+ predictions_for_serving["probabilities"].eval())
+
def testMultiLabelWithLabelName(self):
n_classes = 3
label_name = "my_label"
@@ -691,7 +721,7 @@ class BinaryClassificationHeadTest(test.TestCase):
{}, model_fn.ModeKeys.TRAIN, self._labels, head_lib.no_op_train_fn,
logits_input=((0., 0.), (0., 0.)), logits=self._logits)
- def testBinaryClassificationEvalMode(self):
+ def testBinaryClassificationEval(self):
n_classes = 2
head = head_lib.multi_class_head(n_classes=n_classes)
with ops.Graph().as_default(), session.Session():
@@ -708,18 +738,32 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
- def testBinaryClassificationInferMode(self):
+ def testBinaryClassificationInfer(self):
n_classes = 2
- head = head_lib.multi_class_head(n_classes=n_classes)
+ head = head_lib.multi_class_head(n_classes=n_classes, head_name="head_name")
with ops.Graph().as_default(), session.Session():
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
model_fn_ops = head.create_model_fn_ops(
{}, model_fn.ModeKeys.INFER, self._labels, head_lib.no_op_train_fn,
logits=self._logits)
- self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
+ with session.Session():
+ self.assertListEqual(
+ [1, 1], list(model_fn_ops.predictions["classes"].eval()))
+ self.assertItemsEqual(
+ ["head_name"], six.iterkeys(model_fn_ops.output_alternatives))
+ self.assertEqual(
+ constants.ProblemType.LOGISTIC_REGRESSION,
+ model_fn_ops.output_alternatives["head_name"][0])
+ predictions_for_serving = (
+ model_fn_ops.output_alternatives["head_name"][1])
+ self.assertIn("classes", six.iterkeys(predictions_for_serving))
+ predicted_classes = predictions_for_serving["classes"].eval().tolist()
+ self.assertListEqual(
+ ["0", "1"], predicted_classes[0])
+ self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
def testBinaryClassificationInferMode_withWightColumn(self):
n_classes = 2
@@ -1006,7 +1050,7 @@ class MultiClassHeadTest(test.TestCase):
"multi_class_head/centered_bias/bias_1",
"multi_class_head/centered_bias/bias_2"])
- def testMultiClassEvalMode(self):
+ def testMultiClassEval(self):
n_classes = 3
head = head_lib.multi_class_head(
n_classes=n_classes, metric_class_ids=range(n_classes))
@@ -1131,7 +1175,7 @@ class MultiClassHeadTest(test.TestCase):
model_fn_ops.output_alternatives["head_name"][1])
self.assertIn("classes", six.iterkeys(predictions_for_serving))
self.assertAllEqual(
- [[0, 1, 2], [0, 1, 2]],
+ [["0", "1", "2"], ["0", "1", "2"]],
predictions_for_serving["classes"].eval())
self.assertIn("probabilities", six.iterkeys(predictions_for_serving))
self.assertAllClose(