aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python/ops/kmeans.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/kmeans.py')
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py25
1 files changed, 7 insertions, 18 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index cb5173ce2c..99227a0442 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -55,12 +55,8 @@ class KMeansClustering(estimator.Estimator,
distance_metric=clustering_ops.SQUARED_EUCLIDEAN_DISTANCE,
random_seed=0,
use_mini_batch=True,
- batch_size=128,
- steps=10,
kmeans_plus_plus_num_retries=2,
- continue_training=False,
- config=None,
- verbose=1):
+ config=None):
"""Creates a model for running KMeans training and inference.
Args:
@@ -73,25 +69,17 @@ class KMeansClustering(estimator.Estimator,
random_seed: Python integer. Seed for PRNG used to initialize centers.
use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
full batch.
- batch_size: See TensorFlowEstimator
- steps: See TensorFlowEstimator
kmeans_plus_plus_num_retries: For each point that is sampled during
kmeans++ initialization, this parameter specifies the number of
additional points to draw from the current distribution before selecting
the best. If a negative value is specified, a heuristic is used to
sample O(log(num_to_sample)) additional points.
- continue_training: See TensorFlowEstimator
- config: See TensorFlowEstimator
- verbose: See TensorFlowEstimator
+ config: See Estimator
"""
super(KMeansClustering, self).__init__(
model_dir=model_dir,
config=config)
- self.batch_size = batch_size
- self.steps = steps
self.kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
- self.continue_training = continue_training
- self.verbose = verbose
self._num_clusters = num_clusters
self._training_initial_clusters = initial_clusters
self._training_graph = None
@@ -135,11 +123,11 @@ class KMeansClustering(estimator.Estimator,
return relative_change < self._tolerance
# pylint: enable=protected-access
- def fit(self, x, y=None, monitors=None, logdir=None, steps=None,
+ def fit(self, x, y=None, monitors=None, logdir=None, steps=None, batch_size=128,
relative_tolerance=None):
"""Trains a k-means clustering on x.
- Note: See TensorFlowEstimator for logic for continuous training and graph
+ Note: See Estimator for logic for continuous training and graph
construction across multiple calls to fit.
Args:
@@ -151,6 +139,7 @@ class KMeansClustering(estimator.Estimator,
visualization.
steps: number of training steps. If not None, overrides the value passed
in constructor.
+ batch_size: mini-batch size to use. Requires `use_mini_batch=True`.
relative_tolerance: A relative tolerance of change in the loss between
iterations. Stops learning if the loss changes less than this amount.
Note that this may not work correctly if use_mini_batch=True.
@@ -162,7 +151,7 @@ class KMeansClustering(estimator.Estimator,
if logdir is not None:
self._model_dir = logdir
self._data_feeder = data_feeder.setup_train_data_feeder(
- x, None, self._num_clusters, self.batch_size)
+ x, None, self._num_clusters, batch_size if self._use_mini_batch else None)
if relative_tolerance is not None:
if monitors is not None:
monitors += [self._StopWhenConverged(relative_tolerance)]
@@ -173,7 +162,7 @@ class KMeansClustering(estimator.Estimator,
or (self.steps is not None))
self._train_model(input_fn=self._data_feeder.input_builder,
feed_fn=self._data_feeder.get_feed_dict_fn(),
- steps=steps or self.steps,
+ steps=steps,
monitors=monitors,
init_feed_fn=self._data_feeder.get_feed_dict_fn())
return self