aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Wei Ho <weiho@google.com>2016-07-19 15:32:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-19 16:46:11 -0700
commita22f30df314964263f280cf86fe453f119f5a965 (patch)
treef6ab87fcbf0ff7de8d2aa2ace0eb5fd93caa6c60
parentb0a7d7586f0a60edcb4d4431d63ecccf95705c41 (diff)
Add tf.contrib.learn.utils.export.regression_signature_fn
Change: 127888554
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py36
1 files changed, 32 insertions, 4 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 2dbbb19361..40e28f6601 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -99,7 +99,7 @@ def generic_signature_fn(examples, unused_features, predictions):
def logistic_regression_signature_fn(examples, unused_features, predictions):
- """Creates regression signature from given examples and predictions.
+ """Creates logistic regression signature from given examples and predictions.
Args:
examples: `Tensor`.
@@ -109,18 +109,46 @@ def logistic_regression_signature_fn(examples, unused_features, predictions):
Returns:
Tuple of default classification signature and named signature.
"""
- # predictions has shape [batch_size, 2] where first column is P(Y=0|x)
+ # predictions should have shape [batch_size, 2] where first column is P(Y=0|x)
# while second column is P(Y=1|x). We are only interested in the second
# column for inference.
- assert predictions.get_shape()[1] == 2
- positive_predictions = predictions[:, 1]
+ predictions_shape = predictions.get_shape()
+ predictions_rank = len(predictions_shape)
+ if predictions_rank != 2:
+ logging.fatal(
+ 'Expected predictions to have rank 2, but received predictions with '
+ 'rank: {} and shape: {}'.format(predictions_rank, predictions_shape))
+ if predictions_shape[1] != 2:
+ logging.fatal(
+ 'Expected predictions to have 2nd dimension: 2, but received '
+ 'predictions with 2nd dimension: {} and shape: {}. Did you mean to use '
+ 'regression_signature_fn instead?'.format(predictions_shape[1],
+ predictions_shape))
+ positive_predictions = predictions[:, 1]
signatures = {}
signatures['regression'] = exporter.regression_signature(examples,
positive_predictions)
return signatures['regression'], signatures
+def regression_signature_fn(examples, unused_features, predictions):
+ """Creates regression signature from given examples and predictions.
+
+ Args:
+ examples: `Tensor`.
+ unused_features: `dict` of `Tensor`s.
+ predictions: `dict` of `Tensor`s.
+
+ Returns:
+ Tuple of default regression signature and named signature.
+ """
+ signatures = {}
+ signatures['regression'] = exporter.regression_signature(
+ input_tensor=examples, output_tensor=predictions)
+ return signatures['regression'], signatures
+
+
def classification_signature_fn(examples, unused_features, predictions):
"""Creates classification signature from given examples and predictions.