diff options
author | 2017-10-06 11:37:42 -0700 | |
---|---|---|
committer | 2017-10-06 11:53:53 -0700 | |
commit | 3110185270e93e0b6a3e82be9199febed1239602 (patch) | |
tree | b9f06efb708c1dbe672bd8ac1e36a5492875f812 /tensorflow/contrib/factorization/python | |
parent | 7fceb8d879dd23a2fd15403d216367e5e8f52b56 (diff) |
Use the new Estimator.get_variable_value() method to get the kmeans cluster centers.
PiperOrigin-RevId: 171320755
Diffstat (limited to 'tensorflow/contrib/factorization/python')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/clustering_ops.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans.py | 28 |
2 files changed, 9 insertions, 27 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index e5c9180662..d7320aeb3d 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -51,6 +51,9 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +# The name of the variable holding the cluster centers. Used by the Estimator. +CLUSTERS_VAR_NAME = 'clusters' + class KMeans(object): """Creates the graph for k-means clustering.""" @@ -279,7 +282,7 @@ class KMeans(object): """ init_value = array_ops.constant([], dtype=dtypes.float32) cluster_centers = variable_scope.variable( - init_value, name='clusters', validate_shape=False) + init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) cluster_centers_initialized = variable_scope.variable( False, dtype=dtypes.bool, name='initialized') @@ -337,7 +340,6 @@ class KMeans(object): assigned cluster instead. cluster_centers_initialized: scalar indicating whether clusters have been initialized. - cluster_centers_var: a Variable holding the cluster centers. init_op: an op to initialize the clusters. training_op: an op that runs an iteration of training. """ @@ -381,7 +383,7 @@ class KMeans(object): inputs, num_clusters, cluster_idx, cluster_centers_var) return (all_scores, cluster_idx, scores, cluster_centers_initialized, - cluster_centers_var, init_op, training_op) + init_op, training_op) def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var, cluster_centers_updated, total_counts): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 6284768bdd..9a5413fc3f 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -21,12 +21,10 @@ from __future__ import division from __future__ import print_function import time -import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -161,8 +159,7 @@ class _ModelFn(object): * `eval_metric_ops`: Maps `SCORE` to `loss`. * `predictions`: Maps `ALL_DISTANCES` to the distance from each input point to each cluster center; maps `CLUSTER_INDEX` to the index of - the closest cluster center for each input point; maps `CLUSTERS` to - the cluster centers (which ignores the input points). + the closest cluster center for each input point. """ # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a @@ -184,8 +181,8 @@ class _ModelFn(object): # training_op: an op that runs an iteration of training, either an entire # Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers # may execute this op, but only after is_initialized becomes True. - (all_distances, model_predictions, losses, is_initialized, - cluster_centers_var, init_op, training_op) = clustering_ops.KMeans( + (all_distances, model_predictions, losses, is_initialized, init_op, + training_op) = clustering_ops.KMeans( inputs=input_points, num_clusters=self._num_clusters, initial_clusters=self._initial_clusters, @@ -215,7 +212,6 @@ class _ModelFn(object): predictions={ KMeansClustering.ALL_DISTANCES: all_distances[0], KMeansClustering.CLUSTER_INDEX: model_predictions[0], - KMeansClustering.CLUSTERS: cluster_centers_var.value(), }, loss=loss, train_op=training_op, @@ -242,9 +238,7 @@ class KMeansClustering(estimator.Estimator): # Keys returned by predict(). # ALL_DISTANCES: The distance from each input point to each cluster center. # CLUSTER_INDEX: The index of the closest cluster center for each input point. - # CLUSTERS: The cluster centers (which ignores the input points). CLUSTER_INDEX = 'cluster_index' - CLUSTERS = 'clusters' ALL_DISTANCES = 'all_distances' def __init__(self, @@ -400,18 +394,4 @@ class KMeansClustering(estimator.Estimator): def cluster_centers(self): """Returns the cluster centers.""" - - # TODO(ccolby): Fix this clunky code once cl/168262087 is submitted. - # Discussion: go/estimator-get-variable-value - class RunOnceHook(session_run_hook.SessionRunHook): - """Stops after a single run.""" - - def after_run(self, run_context, run_values): - del run_values # unused - run_context.request_stop() - - result = self.predict( - input_fn=lambda: (constant_op.constant([], shape=[0, 1]), None), - predict_keys=[KMeansClustering.CLUSTERS], - hooks=[RunOnceHook()]) - return np.array([r[KMeansClustering.CLUSTERS] for r in result]) + return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) |