aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar Yutaka Leon <yleon@google.com>2018-01-24 15:45:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 16:18:49 -0800
commit7bf8ccdb4ef5b0b28c1cf0d5084e07ffbf0e2703 (patch)
tree60924e727d52b71196e8eb0abeaf050fa05e01dc /tensorflow/contrib/factorization
parent1df1544aeb8a6311c98a0d9ee9b6946e035fdbeb (diff)
Set export_outs in KMeans' EstimatorSpec.
PiperOrigin-RevId: 183154542
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 9a5413fc3f..4d0f9b2424 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -25,6 +25,7 @@ import time
from tensorflow.contrib.factorization.python.ops import clustering_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@@ -207,6 +209,15 @@ class _ModelFn(object):
training_hooks.append(
_LossRelativeChangeHook(loss, self._relative_tolerance))
+ export_outputs = {
+ KMeansClustering.ALL_DISTANCES:
+ export_output.PredictOutput(all_distances[0]),
+ KMeansClustering.CLUSTER_INDEX:
+ export_output.PredictOutput(model_predictions[0]),
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ export_output.PredictOutput(model_predictions[0])
+ }
+
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions={
@@ -216,7 +227,8 @@ class _ModelFn(object):
loss=loss,
train_op=training_op,
eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)},
- training_hooks=training_hooks)
+ training_hooks=training_hooks,
+ export_outputs=export_outputs)
# TODO(agarwal,ands): support sharded input.