diff options
author | 2017-12-21 15:20:42 -0800 | |
---|---|---|
committer | 2017-12-21 15:25:20 -0800 | |
commit | e528973d975ea5b0cb872397d276883d4f9bdd52 (patch) | |
tree | 263a3e82a141773561d77225a317237af4ddd14e /tensorflow/contrib/factorization | |
parent | 83b4195ec1ee8647f61b8b3b42d6fccd39e36dde (diff) |
Only squeeze the last dimension of outputs and indices in infer_graph of
kmeans.
PiperOrigin-RevId: 179865588
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/clustering_ops.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans_test.py | 21 |
2 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 96cc80ce24..6d3acb2750 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -261,8 +261,8 @@ class KMeans(object): inp, clusters, 1) if self._distance_metric == COSINE_DISTANCE: distances *= 0.5 - output.append((score, array_ops.squeeze(distances), - array_ops.squeeze(indices))) + output.append((score, array_ops.squeeze(distances, [-1]), + array_ops.squeeze(indices, [-1]))) return zip(*output) def _clusters_l2_normalized(self): diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 4709d79425..f9598bfc08 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -194,15 +194,7 @@ class KMeansTest(KMeansTestBase): score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) self.assertNear(self.true_score, score, self.true_score * 0.01) - def test_infer(self): - kmeans = self._kmeans() - # Make a call to fit to initialize the cluster centers. - max_steps = 1 - kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) - clusters = kmeans.cluster_centers() - - # Make a small test set - num_points = 10 + def _infer_helper(self, kmeans, clusters, num_points): points, true_assignments, true_offsets = make_random_points( clusters, num_points) input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) @@ -223,6 +215,17 @@ class KMeansTest(KMeansTestBase): np.sum(np.square(clusters), axis=1, keepdims=True))) self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Run inference on small datasets. + self._infer_helper(kmeans, clusters, 10) + self._infer_helper(kmeans, clusters, 1) + class KMeansTestMultiStageInit(KMeansTestBase): |