diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/gmm.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/gmm.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index eddce45c88..72d01fbb2a 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -102,7 +102,12 @@ class GMM(estimator.Estimator): results = self.evaluate(input_fn=input_fn, batch_size=batch_size, steps=steps) return np.sum(results[GMM.SCORES]) - + + def weights(self): + """Returns the cluster weights.""" + return checkpoint_utils.load_variable( + self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT) + def clusters(self): """Returns cluster centers.""" clusters = checkpoint_utils.load_variable( |