aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python/ops/gmm.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/gmm.py')
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py
index eddce45c88..72d01fbb2a 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm.py
@@ -102,7 +102,12 @@ class GMM(estimator.Estimator):
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
steps=steps)
return np.sum(results[GMM.SCORES])
-
+
+ def weights(self):
+ """Returns the cluster weights."""
+ return checkpoint_utils.load_variable(
+ self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
+
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoint_utils.load_variable(