diff options
author | Yutaka Leon <yleon@google.com> | 2018-01-24 15:45:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-24 16:18:49 -0800 |
commit | 7bf8ccdb4ef5b0b28c1cf0d5084e07ffbf0e2703 (patch) | |
tree | 60924e727d52b71196e8eb0abeaf050fa05e01dc /tensorflow/contrib/factorization | |
parent | 1df1544aeb8a6311c98a0d9ee9b6946e035fdbeb (diff) |
Set export_outs in KMeans' EstimatorSpec.
PiperOrigin-RevId: 183154542
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9a5413fc3f..4d0f9b2424 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -25,6 +25,7 @@ import time from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -207,6 +209,15 @@ class _ModelFn(object): training_hooks.append( _LossRelativeChangeHook(loss, self._relative_tolerance)) + export_outputs = { + KMeansClustering.ALL_DISTANCES: + export_output.PredictOutput(all_distances[0]), + KMeansClustering.CLUSTER_INDEX: + export_output.PredictOutput(model_predictions[0]), + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(model_predictions[0]) + } + return model_fn_lib.EstimatorSpec( mode=mode, predictions={ @@ -216,7 +227,8 @@ class _ModelFn(object): loss=loss, train_op=training_op, eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, - training_hooks=training_hooks) + training_hooks=training_hooks, + export_outputs=export_outputs) # TODO(agarwal,ands): support sharded input. |