diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/gmm_test.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/gmm_test.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 1452c90072..c951a6981f 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -109,6 +109,16 @@ class GMMTest(test.TestCase): np.linalg.inv(covs[assignments[r]])), points[r, :] - means[assignments[r]]))) return (points, assignments, scores) + + def test_weights(self): + """Tests the shape of the weights.""" + gmm = gmm_lib.GMM(self.num_centers, + initial_clusters=self.initial_means, + random_seed=4, + config=run_config.RunConfig(tf_random_seed=2)) + gmm.fit(input_fn=self.input_fn(), steps=0) + weights = gmm.weights() + self.assertAllEqual(list(weights.shape), [self.num_centers]) def test_clusters(self): """Tests the shape of the clusters.""" |