diff options
author | 2017-03-15 08:24:08 -0800 | |
---|---|---|
committer | 2017-03-15 09:44:49 -0700 | |
commit | e6126230200e2ce9c96da5c9e4dc7f104c645d11 (patch) | |
tree | 293e126fad9cf1e74c2b613d23dc4a956154f483 | |
parent | 496a6ac03200bb518b4a9ea74b7b24ef58cbf918 (diff) |
Use Tensorflow log(sum(exp)) function which is numerically stable.
Add some random test data that unveiled a bug.
Change: 150200259
4 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 5f09851360..b329bb9595 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -114,6 +114,7 @@ tf_gen_op_wrapper_py( # Ops tests tf_py_test( name = "gmm_test", + size = "large", srcs = [ "python/ops/gmm_test.py", ], @@ -136,6 +137,7 @@ tf_py_test( tf_py_test( name = "gmm_ops_test", + size = "large", srcs = [ "python/ops/gmm_ops_test.py", ], diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 72d01fbb2a..396dd286b6 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -102,12 +102,12 @@ class GMM(estimator.Estimator): results = self.evaluate(input_fn=input_fn, batch_size=batch_size, steps=steps) return np.sum(results[GMM.SCORES]) - + def weights(self): """Returns the cluster weights.""" return checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT) - + def clusters(self): """Returns cluster centers.""" clusters = checkpoint_utils.load_variable( diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index fbf7afc125..8d78067b9a 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -193,8 +193,8 @@ class GmmAlgorithm(object): # selected unobservable data (in EM terms) was generated by component k. self._alpha = variables.Variable( array_ops.tile([1.0 / self._num_classes], [self._num_classes]), - name=self.CLUSTERS_WEIGHT, - validate_shape=False) + name=self.CLUSTERS_WEIGHT, + validate_shape=False) def training_ops(self): """Returns the training operation.""" @@ -315,9 +315,8 @@ class GmmAlgorithm(object): Args: shard_id: id of current shard_id. """ - self._prior_probs[shard_id] = math_ops.log( - math_ops.reduce_sum( - math_ops.exp(self._probs[shard_id]), 1, keep_dims=True)) + self._prior_probs[shard_id] = math_ops.reduce_logsumexp( + self._probs[shard_id], axis=1, keep_dims=True) def _define_expectation_operation(self, shard_id): # Shape broadcasting. diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 889d162200..758c54fbf4 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -201,6 +201,27 @@ class GMMTest(test.TestCase): def test_compare_diag(self): self._compare_with_sklearn('diag') + def test_random_input_large(self): + # sklearn version. + iterations = 5 # that should be enough to know whether this diverges + np.random.seed(5) + num_classes = 20 + x = np.array([[np.random.random() for _ in range(100)] + for _ in range(num_classes)], dtype=np.float32) + + # skflow version. + gmm = gmm_lib.GMM(num_classes, + covariance_type='full', + config=run_config.RunConfig(tf_random_seed=2)) + + def get_input_fn(x): + def input_fn(): + return constant_op.constant(x.astype(np.float32)), None + return input_fn + + gmm.fit(input_fn=get_input_fn(x), steps=iterations) + self.assertFalse(np.isnan(gmm.clusters()).any()) + if __name__ == '__main__': test.main() |