aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-06-12 09:35:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 09:40:38 -0700
commit02ebc44be9a660a54793973110ae26cf948ffcea (patch)
treecfbffd272575ae581b59e0395dad5227775eb8c0 /tensorflow
parent9d6bcefde28b7f7caf3a0bdd8a6c95974df72a72 (diff)
Handled export output of binary classification which can be both a classifier or a regressor.
PiperOrigin-RevId: 158723034
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/head.py21
-rw-r--r--tensorflow/python/estimator/canned/head_test.py32
3 files changed, 41 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index f662aca4ca..0cfd02466e 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -385,6 +385,7 @@ py_library(
"//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/feature_column",
"//tensorflow/python/ops/losses",
+ "//tensorflow/python/saved_model:signature_constants",
],
)
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 8b24aefeb1..a1c1f1be0b 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -42,6 +42,9 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import signature_constants
+
+_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
class _Head(object):
@@ -603,13 +606,25 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
pred_keys.CLASSES: classes,
}
if mode == model_fn.ModeKeys.PREDICT:
+ batch_size = array_ops.shape(logistic)[0]
+ export_class_list = self._label_vocabulary
+ if not export_class_list:
+ export_class_list = string_ops.as_string([0, 1])
+ export_output_classes = array_ops.tile(
+ input=array_ops.expand_dims(input=export_class_list, axis=0),
+ multiples=[batch_size, 1])
+ classifier_output = export_output.ClassificationOutput(
+ scores=scores,
+ # `ClassificationOutput` requires string classes.
+ classes=export_output_classes)
return model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs={
- '':
- export_output.ClassificationOutput(
- scores=scores, classes=classes)
+ '': classifier_output, # to be same as other heads.
+ 'classification': classifier_output, # to be called by name.
+ _DEFAULT_SERVING_KEY: classifier_output, # default
+ 'regression': export_output.RegressionOutput(value=logistic)
})
# Eval.
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 5940a745f9..c6ea54f08e 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -807,7 +807,12 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertEqual(1, head.logits_dimension)
# Create estimator spec.
- logits = [[45.], [-41.]]
+ logits = [[0.3], [-0.4]]
+ expected_logistics = [[0.574443], [0.401312]]
+ expected_probabilities = [[0.425557, 0.574443], [0.598688, 0.401312]]
+ expected_class_ids = [[1], [0]]
+ expected_classes = [[b'1'], [b'0']]
+ expected_export_classes = [[b'0', b'1']] * 2
spec = head.create_estimator_spec(
features={'x': np.array(((42,),), dtype=np.int32)},
mode=model_fn.ModeKeys.PREDICT,
@@ -817,8 +822,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
self.assertIsNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
self.assertIsNone(spec.train_op)
- self.assertItemsEqual(
- ('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys())
+ self.assertItemsEqual(('', 'classification', 'regression',
+ _DEFAULT_SERVING_KEY), spec.export_outputs.keys())
_assert_no_hooks(self, spec)
# Assert predictions.
@@ -828,16 +833,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
predictions = sess.run(spec.predictions)
self.assertAllClose(logits,
predictions[prediction_keys.PredictionKeys.LOGITS])
+ self.assertAllClose(expected_logistics,
+ predictions[prediction_keys.PredictionKeys.LOGISTIC])
self.assertAllClose(
- _sigmoid(np.array(logits)),
- predictions[prediction_keys.PredictionKeys.LOGISTIC])
- self.assertAllClose(
- [[0., 1.],
- [1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES])
- self.assertAllClose([[1], [0]],
+ expected_probabilities,
+ predictions[prediction_keys.PredictionKeys.PROBABILITIES])
+ self.assertAllClose(expected_class_ids,
predictions[prediction_keys.PredictionKeys.CLASS_IDS])
- self.assertAllEqual([[b'1'], [b'0']],
+ self.assertAllEqual(expected_classes,
predictions[prediction_keys.PredictionKeys.CLASSES])
+ self.assertAllClose(
+ expected_probabilities,
+ sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
+ self.assertAllEqual(
+ expected_export_classes,
+ sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))
+ self.assertAllClose(expected_logistics,
+ sess.run(spec.export_outputs['regression'].value))
def test_predict_with_vocabulary_list(self):
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(