diff options
author | 2016-07-19 15:32:31 -0800 | |
---|---|---|
committer | 2016-07-19 16:46:11 -0700 | |
commit | a22f30df314964263f280cf86fe453f119f5a965 (patch) | |
tree | f6ab87fcbf0ff7de8d2aa2ace0eb5fd93caa6c60 | |
parent | b0a7d7586f0a60edcb4d4431d63ecccf95705c41 (diff) |
Add tf.contrib.learn.utils.export.regression_signature_fn
Change: 127888554
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/export.py | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 2dbbb19361..40e28f6601 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -99,7 +99,7 @@ def generic_signature_fn(examples, unused_features, predictions): def logistic_regression_signature_fn(examples, unused_features, predictions): - """Creates regression signature from given examples and predictions. + """Creates logistic regression signature from given examples and predictions. Args: examples: `Tensor`. @@ -109,18 +109,46 @@ def logistic_regression_signature_fn(examples, unused_features, predictions): Returns: Tuple of default classification signature and named signature. """ - # predictions has shape [batch_size, 2] where first column is P(Y=0|x) + # predictions should have shape [batch_size, 2] where first column is P(Y=0|x) # while second column is P(Y=1|x). We are only interested in the second # column for inference. - assert predictions.get_shape()[1] == 2 - positive_predictions = predictions[:, 1] + predictions_shape = predictions.get_shape() + predictions_rank = len(predictions_shape) + if predictions_rank != 2: + logging.fatal( + 'Expected predictions to have rank 2, but received predictions with ' + 'rank: {} and shape: {}'.format(predictions_rank, predictions_shape)) + if predictions_shape[1] != 2: + logging.fatal( + 'Expected predictions to have 2nd dimension: 2, but received ' + 'predictions with 2nd dimension: {} and shape: {}. Did you mean to use ' + 'regression_signature_fn instead?'.format(predictions_shape[1], + predictions_shape)) + positive_predictions = predictions[:, 1] signatures = {} signatures['regression'] = exporter.regression_signature(examples, positive_predictions) return signatures['regression'], signatures +def regression_signature_fn(examples, unused_features, predictions): + """Creates regression signature from given examples and predictions. + + Args: + examples: `Tensor`. + unused_features: `dict` of `Tensor`s. + predictions: `dict` of `Tensor`s. + + Returns: + Tuple of default regression signature and named signature. + """ + signatures = {} + signatures['regression'] = exporter.regression_signature( + input_tensor=examples, output_tensor=predictions) + return signatures['regression'], signatures + + def classification_signature_fn(examples, unused_features, predictions): """Creates classification signature from given examples and predictions. |