diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/gmm_ops.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/gmm_ops.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index e795c0aac7..fbf7afc125 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -92,6 +92,7 @@ def _init_clusters_random(data, num_clusters, random_seed): class GmmAlgorithm(object): """Tensorflow Gaussian mixture model clustering class.""" + CLUSTERS_WEIGHT = 'alphas' CLUSTERS_VARIABLE = 'clusters' CLUSTERS_COVS_VARIABLE = 'clusters_covs' @@ -187,11 +188,13 @@ class GmmAlgorithm(object): array_ops.expand_dims(array_ops.diag_part(cov), 0), [self._num_classes, 1]) self._covs = variables.Variable( - covs, name='clusters_covs', validate_shape=False) + covs, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False) # Mixture weights, representing the probability that a randomly # selected unobservable data (in EM terms) was generated by component k. self._alpha = variables.Variable( - array_ops.tile([1.0 / self._num_classes], [self._num_classes])) + array_ops.tile([1.0 / self._num_classes], [self._num_classes]), + name=self.CLUSTERS_WEIGHT, + validate_shape=False) def training_ops(self): """Returns the training operation.""" |