aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-06 11:37:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-06 11:53:53 -0700
commit3110185270e93e0b6a3e82be9199febed1239602 (patch)
treeb9f06efb708c1dbe672bd8ac1e36a5492875f812 /tensorflow/contrib/factorization/python
parent7fceb8d879dd23a2fd15403d216367e5e8f52b56 (diff)
Use the new Estimator.get_variable_value() method to get the kmeans cluster centers.
PiperOrigin-RevId: 171320755
Diffstat (limited to 'tensorflow/contrib/factorization/python')
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py8
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py28
2 files changed, 9 insertions, 27 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
index e5c9180662..d7320aeb3d 100644
--- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -51,6 +51,9 @@ COSINE_DISTANCE = 'cosine'
RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
+# The name of the variable holding the cluster centers. Used by the Estimator.
+CLUSTERS_VAR_NAME = 'clusters'
+
class KMeans(object):
"""Creates the graph for k-means clustering."""
@@ -279,7 +282,7 @@ class KMeans(object):
"""
init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(
- init_value, name='clusters', validate_shape=False)
+ init_value, name=CLUSTERS_VAR_NAME, validate_shape=False)
cluster_centers_initialized = variable_scope.variable(
False, dtype=dtypes.bool, name='initialized')
@@ -337,7 +340,6 @@ class KMeans(object):
assigned cluster instead.
cluster_centers_initialized: scalar indicating whether clusters have been
initialized.
- cluster_centers_var: a Variable holding the cluster centers.
init_op: an op to initialize the clusters.
training_op: an op that runs an iteration of training.
"""
@@ -381,7 +383,7 @@ class KMeans(object):
inputs, num_clusters, cluster_idx, cluster_centers_var)
return (all_scores, cluster_idx, scores, cluster_centers_initialized,
- cluster_centers_var, init_op, training_op)
+ init_op, training_op)
def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_updated, total_counts):
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 6284768bdd..9a5413fc3f 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -21,12 +21,10 @@ from __future__ import division
from __future__ import print_function
import time
-import numpy as np
from tensorflow.contrib.factorization.python.ops import clustering_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -161,8 +159,7 @@ class _ModelFn(object):
* `eval_metric_ops`: Maps `SCORE` to `loss`.
* `predictions`: Maps `ALL_DISTANCES` to the distance from each input
point to each cluster center; maps `CLUSTER_INDEX` to the index of
- the closest cluster center for each input point; maps `CLUSTERS` to
- the cluster centers (which ignores the input points).
+ the closest cluster center for each input point.
"""
# input_points is a single Tensor. Therefore, the sharding functionality
# in clustering_ops is unused, and some of the values below are lists of a
@@ -184,8 +181,8 @@ class _ModelFn(object):
# training_op: an op that runs an iteration of training, either an entire
# Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers
# may execute this op, but only after is_initialized becomes True.
- (all_distances, model_predictions, losses, is_initialized,
- cluster_centers_var, init_op, training_op) = clustering_ops.KMeans(
+ (all_distances, model_predictions, losses, is_initialized, init_op,
+ training_op) = clustering_ops.KMeans(
inputs=input_points,
num_clusters=self._num_clusters,
initial_clusters=self._initial_clusters,
@@ -215,7 +212,6 @@ class _ModelFn(object):
predictions={
KMeansClustering.ALL_DISTANCES: all_distances[0],
KMeansClustering.CLUSTER_INDEX: model_predictions[0],
- KMeansClustering.CLUSTERS: cluster_centers_var.value(),
},
loss=loss,
train_op=training_op,
@@ -242,9 +238,7 @@ class KMeansClustering(estimator.Estimator):
# Keys returned by predict().
# ALL_DISTANCES: The distance from each input point to each cluster center.
# CLUSTER_INDEX: The index of the closest cluster center for each input point.
- # CLUSTERS: The cluster centers (which ignores the input points).
CLUSTER_INDEX = 'cluster_index'
- CLUSTERS = 'clusters'
ALL_DISTANCES = 'all_distances'
def __init__(self,
@@ -400,18 +394,4 @@ class KMeansClustering(estimator.Estimator):
def cluster_centers(self):
"""Returns the cluster centers."""
-
- # TODO(ccolby): Fix this clunky code once cl/168262087 is submitted.
- # Discussion: go/estimator-get-variable-value
- class RunOnceHook(session_run_hook.SessionRunHook):
- """Stops after a single run."""
-
- def after_run(self, run_context, run_values):
- del run_values # unused
- run_context.request_stop()
-
- result = self.predict(
- input_fn=lambda: (constant_op.constant([], shape=[0, 1]), None),
- predict_keys=[KMeansClustering.CLUSTERS],
- hooks=[RunOnceHook()])
- return np.array([r[KMeansClustering.CLUSTERS] for r in result])
+ return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME)