aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python/ops/gmm_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/gmm_test.py')
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_test.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py
index 1452c90072..c951a6981f 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py
@@ -109,6 +109,16 @@ class GMMTest(test.TestCase):
np.linalg.inv(covs[assignments[r]])), points[r, :] -
means[assignments[r]])))
return (points, assignments, scores)
+
+ def test_weights(self):
+ """Tests the shape of the weights."""
+ gmm = gmm_lib.GMM(self.num_centers,
+ initial_clusters=self.initial_means,
+ random_seed=4,
+ config=run_config.RunConfig(tf_random_seed=2))
+ gmm.fit(input_fn=self.input_fn(), steps=0)
+ weights = gmm.weights()
+ self.assertAllEqual(list(weights.shape), [self.num_centers])
def test_clusters(self):
"""Tests the shape of the clusters."""