aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-15 08:24:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 09:44:49 -0700
commite6126230200e2ce9c96da5c9e4dc7f104c645d11 (patch)
tree293e126fad9cf1e74c2b613d23dc4a956154f483
parent496a6ac03200bb518b4a9ea74b7b24ef58cbf918 (diff)
Use Tensorflow log(sum(exp)) function which is numerically stable.
Add some random test data that unveiled a bug. Change: 150200259
-rw-r--r--tensorflow/contrib/factorization/BUILD2
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops.py9
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_test.py21
4 files changed, 29 insertions, 7 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 5f09851360..b329bb9595 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -114,6 +114,7 @@ tf_gen_op_wrapper_py(
# Ops tests
tf_py_test(
name = "gmm_test",
+ size = "large",
srcs = [
"python/ops/gmm_test.py",
],
@@ -136,6 +137,7 @@ tf_py_test(
tf_py_test(
name = "gmm_ops_test",
+ size = "large",
srcs = [
"python/ops/gmm_ops_test.py",
],
diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py
index 72d01fbb2a..396dd286b6 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm.py
@@ -102,12 +102,12 @@ class GMM(estimator.Estimator):
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
steps=steps)
return np.sum(results[GMM.SCORES])
-
+
def weights(self):
"""Returns the cluster weights."""
return checkpoint_utils.load_variable(
self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
-
+
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoint_utils.load_variable(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
index fbf7afc125..8d78067b9a 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
@@ -193,8 +193,8 @@ class GmmAlgorithm(object):
# selected unobservable data (in EM terms) was generated by component k.
self._alpha = variables.Variable(
array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
- name=self.CLUSTERS_WEIGHT,
- validate_shape=False)
+ name=self.CLUSTERS_WEIGHT,
+ validate_shape=False)
def training_ops(self):
"""Returns the training operation."""
@@ -315,9 +315,8 @@ class GmmAlgorithm(object):
Args:
shard_id: id of current shard_id.
"""
- self._prior_probs[shard_id] = math_ops.log(
- math_ops.reduce_sum(
- math_ops.exp(self._probs[shard_id]), 1, keep_dims=True))
+ self._prior_probs[shard_id] = math_ops.reduce_logsumexp(
+ self._probs[shard_id], axis=1, keep_dims=True)
def _define_expectation_operation(self, shard_id):
# Shape broadcasting.
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py
index 889d162200..758c54fbf4 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py
@@ -201,6 +201,27 @@ class GMMTest(test.TestCase):
def test_compare_diag(self):
self._compare_with_sklearn('diag')
+ def test_random_input_large(self):
+ # sklearn version.
+ iterations = 5 # that should be enough to know whether this diverges
+ np.random.seed(5)
+ num_classes = 20
+ x = np.array([[np.random.random() for _ in range(100)]
+ for _ in range(num_classes)], dtype=np.float32)
+
+ # skflow version.
+ gmm = gmm_lib.GMM(num_classes,
+ covariance_type='full',
+ config=run_config.RunConfig(tf_random_seed=2))
+
+ def get_input_fn(x):
+ def input_fn():
+ return constant_op.constant(x.astype(np.float32)), None
+ return input_fn
+
+ gmm.fit(input_fn=get_input_fn(x), steps=iterations)
+ self.assertFalse(np.isnan(gmm.clusters()).any())
+
if __name__ == '__main__':
test.main()