diff options
-rw-r--r-- | tensorflow/python/estimator/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/head.py | 21 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/head_test.py | 32 |
3 files changed, 41 insertions, 13 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index f662aca4ca..0cfd02466e 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -385,6 +385,7 @@ py_library( "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", ], ) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 8b24aefeb1..a1c1f1be0b 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -42,6 +42,9 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants + +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY class _Head(object): @@ -603,13 +606,25 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head): pred_keys.CLASSES: classes, } if mode == model_fn.ModeKeys.PREDICT: + batch_size = array_ops.shape(logistic)[0] + export_class_list = self._label_vocabulary + if not export_class_list: + export_class_list = string_ops.as_string([0, 1]) + export_output_classes = array_ops.tile( + input=array_ops.expand_dims(input=export_class_list, axis=0), + multiples=[batch_size, 1]) + classifier_output = export_output.ClassificationOutput( + scores=scores, + # `ClassificationOutput` requires string classes. + classes=export_output_classes) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ - '': - export_output.ClassificationOutput( - scores=scores, classes=classes) + '': classifier_output, # to be same as other heads. + 'classification': classifier_output, # to be called by name. + _DEFAULT_SERVING_KEY: classifier_output, # default + 'regression': export_output.RegressionOutput(value=logistic) }) # Eval. diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 5940a745f9..c6ea54f08e 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -807,7 +807,12 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertEqual(1, head.logits_dimension) # Create estimator spec. - logits = [[45.], [-41.]] + logits = [[0.3], [-0.4]] + expected_logistics = [[0.574443], [0.401312]] + expected_probabilities = [[0.425557, 0.574443], [0.598688, 0.401312]] + expected_class_ids = [[1], [0]] + expected_classes = [[b'1'], [b'0']] + expected_export_classes = [[b'0', b'1']] * 2 spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.PREDICT, @@ -817,8 +822,8 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): self.assertIsNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) self.assertIsNone(spec.train_op) - self.assertItemsEqual( - ('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys()) + self.assertItemsEqual(('', 'classification', 'regression', + _DEFAULT_SERVING_KEY), spec.export_outputs.keys()) _assert_no_hooks(self, spec) # Assert predictions. @@ -828,16 +833,23 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): predictions = sess.run(spec.predictions) self.assertAllClose(logits, predictions[prediction_keys.PredictionKeys.LOGITS]) + self.assertAllClose(expected_logistics, + predictions[prediction_keys.PredictionKeys.LOGISTIC]) self.assertAllClose( - _sigmoid(np.array(logits)), - predictions[prediction_keys.PredictionKeys.LOGISTIC]) - self.assertAllClose( - [[0., 1.], - [1., 0.]], predictions[prediction_keys.PredictionKeys.PROBABILITIES]) - self.assertAllClose([[1], [0]], + expected_probabilities, + predictions[prediction_keys.PredictionKeys.PROBABILITIES]) + self.assertAllClose(expected_class_ids, predictions[prediction_keys.PredictionKeys.CLASS_IDS]) - self.assertAllEqual([[b'1'], [b'0']], + self.assertAllEqual(expected_classes, predictions[prediction_keys.PredictionKeys.CLASSES]) + self.assertAllClose( + expected_probabilities, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllEqual( + expected_export_classes, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) + self.assertAllClose(expected_logistics, + sess.run(spec.export_outputs['regression'].value)) def test_predict_with_vocabulary_list(self): head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( |