aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-14 14:34:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 14:42:55 -0700
commitac8ce1fe760efff6585d790b784ec67255198879 (patch)
treedaef1c1f2ddb0a090c843eea47d1d51b3d801900 /tensorflow/contrib/factorization
parent61eab3f8c4ed8bdf4324e99508a104307483da2a (diff)
Rename KMeans _parse_tensor_or_dict to _parse_features_if_necessary and add a unit test.
PiperOrigin-RevId: 189087384
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py23
2 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
index 7319eaa7de..c092f85d35 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -105,7 +105,7 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook):
logging.info(e)
-def _parse_tensor_or_dict(features):
+def _parse_features_if_necessary(features):
"""Helper function to convert the input points into a usable format.
Args:
@@ -166,7 +166,7 @@ class _ModelFn(object):
# input_points is a single Tensor. Therefore, the sharding functionality
# in clustering_ops is unused, and some of the values below are lists of a
# single item.
- input_points = _parse_tensor_or_dict(features)
+ input_points = _parse_features_if_necessary(features)
# Let N = the number of input_points.
# all_distances: A list of one matrix of shape (N, num_clusters). Each value
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index f9598bfc08..06a2c52c11 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -226,6 +226,28 @@ class KMeansTest(KMeansTestBase):
self._infer_helper(kmeans, clusters, 10)
self._infer_helper(kmeans, clusters, 1)
+ def test_parse_features(self):
+ """Tests the various behaviours of kmeans._parse_features_if_necessary."""
+
+ # No-op if a tensor is passed in.
+ features = constant_op.constant(self.points)
+ parsed_features = kmeans_lib._parse_features_if_necessary(features)
+ self.assertAllEqual(features, parsed_features)
+
+ # A dict is transformed into a tensor.
+ feature_dict = {
+ 'x': [[point[0]] for point in self.points],
+ 'y': [[point[1]] for point in self.points]
+ }
+ parsed_feature_dict = kmeans_lib._parse_features_if_necessary(feature_dict)
+ # Perform a sanity check.
+ self.assertEqual(features.shape, parsed_feature_dict.shape)
+ self.assertEqual(features.dtype, parsed_feature_dict.dtype)
+ # Then check that running the tensor yields the original list of points.
+ with self.test_session() as sess:
+ parsed_points = sess.run(parsed_feature_dict)
+ self.assertAllEqual(self.points, parsed_points)
+
class KMeansTestMultiStageInit(KMeansTestBase):
@@ -394,7 +416,6 @@ class KMeansCosineDistanceTest(KMeansTestBase):
true_assignments = [0] * 2 + [1] * 2 + [2] * 8
true_score = len(points) - np.tensordot(
normalize(points), true_centers[true_assignments])
-
kmeans = kmeans_lib.KMeansClustering(
3,
initial_clusters=self.initial_clusters,