diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/kmeans.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans.py | 25 |
1 files changed, 7 insertions, 18 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index cb5173ce2c..99227a0442 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -55,12 +55,8 @@ class KMeansClustering(estimator.Estimator, distance_metric=clustering_ops.SQUARED_EUCLIDEAN_DISTANCE, random_seed=0, use_mini_batch=True, - batch_size=128, - steps=10, kmeans_plus_plus_num_retries=2, - continue_training=False, - config=None, - verbose=1): + config=None): """Creates a model for running KMeans training and inference. Args: @@ -73,25 +69,17 @@ class KMeansClustering(estimator.Estimator, random_seed: Python integer. Seed for PRNG used to initialize centers. use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume full batch. - batch_size: See TensorFlowEstimator - steps: See TensorFlowEstimator kmeans_plus_plus_num_retries: For each point that is sampled during kmeans++ initialization, this parameter specifies the number of additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. - continue_training: See TensorFlowEstimator - config: See TensorFlowEstimator - verbose: See TensorFlowEstimator + config: See Estimator """ super(KMeansClustering, self).__init__( model_dir=model_dir, config=config) - self.batch_size = batch_size - self.steps = steps self.kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries - self.continue_training = continue_training - self.verbose = verbose self._num_clusters = num_clusters self._training_initial_clusters = initial_clusters self._training_graph = None @@ -135,11 +123,11 @@ class KMeansClustering(estimator.Estimator, return relative_change < self._tolerance # pylint: enable=protected-access - def fit(self, x, y=None, monitors=None, logdir=None, steps=None, + def fit(self, x, y=None, monitors=None, logdir=None, steps=None, batch_size=128, relative_tolerance=None): """Trains a k-means clustering on x. - Note: See TensorFlowEstimator for logic for continuous training and graph + Note: See Estimator for logic for continuous training and graph construction across multiple calls to fit. Args: @@ -151,6 +139,7 @@ class KMeansClustering(estimator.Estimator, visualization. steps: number of training steps. If not None, overrides the value passed in constructor. + batch_size: mini-batch size to use. Requires `use_mini_batch=True`. relative_tolerance: A relative tolerance of change in the loss between iterations. Stops learning if the loss changes less than this amount. Note that this may not work correctly if use_mini_batch=True. @@ -162,7 +151,7 @@ class KMeansClustering(estimator.Estimator, if logdir is not None: self._model_dir = logdir self._data_feeder = data_feeder.setup_train_data_feeder( - x, None, self._num_clusters, self.batch_size) + x, None, self._num_clusters, batch_size if self._use_mini_batch else None) if relative_tolerance is not None: if monitors is not None: monitors += [self._StopWhenConverged(relative_tolerance)] @@ -173,7 +162,7 @@ class KMeansClustering(estimator.Estimator, or (self.steps is not None)) self._train_model(input_fn=self._data_feeder.input_builder, feed_fn=self._data_feeder.get_feed_dict_fn(), - steps=steps or self.steps, + steps=steps, monitors=monitors, init_feed_fn=self._data_feeder.get_feed_dict_fn()) return self |