aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-21 15:20:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-21 15:25:20 -0800
commite528973d975ea5b0cb872397d276883d4f9bdd52 (patch)
tree263a3e82a141773561d77225a317237af4ddd14e /tensorflow/contrib/factorization
parent83b4195ec1ee8647f61b8b3b42d6fccd39e36dde (diff)
Only squeeze the last dimension of outputs and indices in infer_graph of
kmeans. PiperOrigin-RevId: 179865588
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py21
2 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
index 96cc80ce24..6d3acb2750 100644
--- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -261,8 +261,8 @@ class KMeans(object):
inp, clusters, 1)
if self._distance_metric == COSINE_DISTANCE:
distances *= 0.5
- output.append((score, array_ops.squeeze(distances),
- array_ops.squeeze(indices)))
+ output.append((score, array_ops.squeeze(distances, [-1]),
+ array_ops.squeeze(indices, [-1])))
return zip(*output)
def _clusters_l2_normalized(self):
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index 4709d79425..f9598bfc08 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -194,15 +194,7 @@ class KMeansTest(KMeansTestBase):
score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points))
self.assertNear(self.true_score, score, self.true_score * 0.01)
- def test_infer(self):
- kmeans = self._kmeans()
- # Make a call to fit to initialize the cluster centers.
- max_steps = 1
- kmeans.train(input_fn=self.input_fn(), max_steps=max_steps)
- clusters = kmeans.cluster_centers()
-
- # Make a small test set
- num_points = 10
+ def _infer_helper(self, kmeans, clusters, num_points):
points, true_assignments, true_offsets = make_random_points(
clusters, num_points)
input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1)
@@ -223,6 +215,17 @@ class KMeansTest(KMeansTestBase):
np.sum(np.square(clusters), axis=1, keepdims=True)))
self.assertAllClose(transform, true_transform, rtol=0.05, atol=10)
+ def test_infer(self):
+ kmeans = self._kmeans()
+ # Make a call to fit to initialize the cluster centers.
+ max_steps = 1
+ kmeans.train(input_fn=self.input_fn(), max_steps=max_steps)
+ clusters = kmeans.cluster_centers()
+
+ # Run inference on small datasets.
+ self._infer_helper(kmeans, clusters, 10)
+ self._infer_helper(kmeans, clusters, 1)
+
class KMeansTestMultiStageInit(KMeansTestBase):