aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-13 21:35:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-13 21:45:24 -0800
commit2a6bd09d05881d3160885a386b3ac1fb7cf6a6e1 (patch)
tree18407e88e37f2fd6dcdfaa197f0c3806556c5f55
parent6a91496ba80995c8aaeddece71738ae209ae3653 (diff)
Refactor KMeansClustering estimator from inheritance to composition and update users to use input_fn instead of x/batch_size.
Change: 141978044
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/kmeans.py77
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py19
3 files changed, 27 insertions, 71 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py
index f95e17e474..d543cf5129 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py
@@ -45,7 +45,7 @@ class GMMTest(tf.test.TestCase):
# Use initial means from kmeans (just like scikit-learn does).
clusterer = tf.contrib.learn.KMeansClustering(
num_clusters=self.num_centers)
- clusterer.fit(self.points, steps=30)
+ clusterer.fit(input_fn=lambda: (tf.constant(self.points), None), steps=30)
self.initial_means = clusterer.clusters()
@staticmethod
diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
index 7533173980..5d5c5985dc 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.control_flow_ops import with_dependencies
-from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training.session_run_hook import SessionRunArgs
@@ -83,8 +82,8 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
self._random_seed = random_seed
self._use_mini_batch = use_mini_batch
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
- self._estimator = InitializingEstimator(model_fn=self._get_model_function(),
- model_dir=model_dir)
+ self._estimator = estimator.Estimator(model_fn=self._get_model_function(),
+ model_dir=model_dir)
class LossRelativeChangeHook(session_run_hook.SessionRunHook):
"""Stops when the change in loss goes below a tolerance."""
@@ -123,19 +122,16 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
"""See Evaluable."""
return self._estimator.model_dir
- def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
- monitors=None, max_steps=None, relative_tolerance=None):
+ def fit(self, input_fn=None, steps=None, monitors=None, max_steps=None,
+ relative_tolerance=None):
"""Trains a k-means clustering on x.
Note: See Estimator for logic for continuous training and graph
construction across multiple calls to fit.
Args:
- x: see Trainable.fit.
- y: labels. Should be None.
input_fn: see Trainable.fit.
steps: see Trainable.fit.
- batch_size: see Trainable.fit.
monitors: see Trainable.fit.
max_steps: see Trainable.fit.
relative_tolerance: A relative tolerance of change in the loss between
@@ -145,7 +141,6 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
Returns:
Returns self.
"""
- assert y is None
if relative_tolerance is not None:
if monitors is None:
monitors = []
@@ -153,32 +148,24 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
# Make sure that we will eventually terminate.
assert ((monitors is not None and len(monitors)) or (steps is not None)
or (max_steps is not None))
- if not self._use_mini_batch:
- assert batch_size is None
- self._estimator.fit(input_fn=input_fn, x=x, y=y, batch_size=batch_size,
- steps=steps, max_steps=max_steps, monitors=monitors)
+ self._estimator.fit(input_fn=input_fn, steps=steps, max_steps=max_steps,
+ monitors=monitors)
return self
- def evaluate(self, x=None, y=None, input_fn=None, feed_fn=None,
- batch_size=None, steps=None, metrics=None, name=None,
- checkpoint_path=None):
+ def evaluate(self, input_fn=None, feed_fn=None, steps=None, metrics=None,
+ name=None, checkpoint_path=None):
"""See Evaluable.evaluate."""
- assert y is None
- return self._estimator.evaluate(input_fn=input_fn, x=x, y=y,
- feed_fn=feed_fn, batch_size=batch_size,
+ return self._estimator.evaluate(input_fn=input_fn, feed_fn=feed_fn,
steps=steps, metrics=metrics, name=name,
checkpoint_path=checkpoint_path)
- def predict(self, x=None, input_fn=None, batch_size=None, outputs=None,
- as_iterable=False):
+ def predict(self, input_fn=None, outputs=None, as_iterable=False):
"""See BaseEstimator.predict."""
outputs = outputs or [KMeansClustering.CLUSTER_IDX]
assert isinstance(outputs, list)
- results = self._estimator.predict(x=x,
- input_fn=input_fn,
- batch_size=batch_size,
+ results = self._estimator.predict(input_fn=input_fn,
outputs=outputs,
as_iterable=as_iterable)
if len(outputs) == 1 and not as_iterable:
@@ -186,27 +173,24 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
else:
return results
- def score(self, x=None, input_fn=None, batch_size=None, steps=None):
+ def score(self, input_fn=None, steps=None):
"""Predict total sum of distances to nearest clusters.
Note that this function is different from the corresponding one in sklearn
which returns the negative of the sum of distances.
Args:
- x: see predict.
input_fn: see predict.
- batch_size: see predict.
steps: see predict.
Returns:
Total sum of distances to nearest clusters.
"""
- return np.sum(self.evaluate(x=x, input_fn=input_fn, batch_size=batch_size,
+ return np.sum(self.evaluate(input_fn=input_fn,
steps=steps)[KMeansClustering.SCORES])
- def transform(self, x=None, input_fn=None, batch_size=None,
- as_iterable=False):
- """Transforms each element in x to distances to cluster centers.
+ def transform(self, input_fn=None, as_iterable=False):
+ """Transforms each element to distances to cluster centers.
Note that this function is different from the corresponding one in sklearn.
For SQUARED_EUCLIDEAN distance metric, sklearn transform returns the
@@ -214,16 +198,14 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
distance.
Args:
- x: see predict.
input_fn: see predict.
- batch_size: see predict.
as_iterable: see predict
Returns:
Array with same number of rows as x, and num_clusters columns, containing
distances to the cluster centers.
"""
- return self.predict(x=x, input_fn=input_fn, batch_size=batch_size,
+ return self.predict(input_fn=input_fn,
outputs=[KMeansClustering.ALL_SCORES],
as_iterable=as_iterable)
@@ -268,30 +250,3 @@ class KMeansClustering(evaluable.Evaluable, trainable.Trainable):
eval_metric_ops=eval_metric_ops,
loss=loss, train_op=training_op)
return _model_fn
-
-
-# TODO(agarwal): Push the initialization logic inside the KMeans graph itself
-# and avoid having this custom Estimator.
-class InitializingEstimator(estimator.Estimator):
- """Estimator subclass that allows looking at inputs during initialization."""
-
- def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
- monitors=None, max_steps=None):
- """See Trainable.fit."""
-
- if (steps is not None) and (max_steps is not None):
- raise ValueError('Can not provide both steps and max_steps.')
-
- input_fn, feed_fn = estimator._get_input_fn( # pylint: disable=protected-access
- x, y, input_fn, feed_fn=None,
- batch_size=batch_size, shuffle=True,
- epochs=None)
- loss = self._train_model(input_fn=input_fn,
- feed_fn=feed_fn,
- init_feed_fn=feed_fn,
- steps=steps,
- monitors=monitors,
- max_steps=max_steps)
- logging.info('Loss for final step: %s.', loss)
- return self
-
diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
index d3c9d3b110..60f2d49ceb 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
@@ -58,6 +58,7 @@ def make_random_points(centers, num_points, max_offset=20):
class KMeansTestBase(tf.test.TestCase):
def input_fn(self, batch_size=None, points=None):
+ """Returns an input_fn that randomly selects batches from given points."""
batch_size = batch_size or self.batch_size
points = points if points is not None else self.points
num_points = points.shape[0]
@@ -156,11 +157,11 @@ class KMeansTest(KMeansTestBase):
self.assertAllEqual(assignments, true_assignments)
# Test score
- score = kmeans.score(points, batch_size=128)
+ score = kmeans.score(input_fn=lambda: (tf.constant(points), None), steps=1)
self.assertNear(score, np.sum(true_offsets), 0.01 * score)
# Test transform
- transform = kmeans.transform(points, batch_size=128)
+ transform = kmeans.transform(input_fn=lambda: (tf.constant(points), None))
true_transform = np.maximum(
0,
np.sum(np.square(points), axis=1, keepdims=True) -
@@ -263,20 +264,18 @@ class KMeansTestCosineDistance(KMeansTestBase):
distance_metric=tf.contrib.factorization.COSINE_DISTANCE,
use_mini_batch=self.use_mini_batch,
config=self.config(3))
- batch_size = 12 if self.use_mini_batch else None
- kmeans.fit(input_fn=lambda: (tf.constant(points), None), steps=30,
- batch_size=batch_size)
+ kmeans.fit(input_fn=lambda: (tf.constant(points), None), steps=30)
centers = normalize(kmeans.clusters())
self.assertAllClose(sorted(centers.tolist()),
sorted(true_centers.tolist()),
atol=1e-2)
- assignments = kmeans.predict(points, batch_size=12)
+ assignments = kmeans.predict(input_fn=lambda: (tf.constant(points), None))
self.assertAllClose(centers[assignments],
true_centers[true_assignments], atol=1e-2)
- score = kmeans.score(points, batch_size=12)
+ score = kmeans.score(input_fn=lambda: (tf.constant(points), None), steps=1)
self.assertAllClose(score, true_score, atol=1e-2)
@@ -354,10 +353,12 @@ class TensorflowKMeansBenchmark(KMeansBenchmark):
kmeans_plus_plus_num_retries=int(math.log(self.num_clusters) + 2),
random_seed=i * 42,
config=tf.contrib.learn.RunConfig(tf_random_seed=3))
- tf_kmeans.fit(x=self.points, batch_size=self.num_points, steps=50,
+ tf_kmeans.fit(input_fn=lambda: (tf.constant(self.points), None),
+ steps=50,
relative_tolerance=1e-6)
_ = tf_kmeans.clusters()
- scores.append(tf_kmeans.score(self.points))
+ scores.append(tf_kmeans.score(
+ input_fn=lambda: (tf.constant(self.points), None), steps=1))
self._report(num_iters, start, time.time(), scores)