diff options
author | 2018-03-14 14:34:15 -0700 | |
---|---|---|
committer | 2018-03-14 14:42:55 -0700 | |
commit | ac8ce1fe760efff6585d790b784ec67255198879 (patch) | |
tree | daef1c1f2ddb0a090c843eea47d1d51b3d801900 /tensorflow/contrib/factorization | |
parent | 61eab3f8c4ed8bdf4324e99508a104307483da2a (diff) |
Rename KMeans _parse_tensor_or_dict to _parse_features_if_necessary and add a unit test.
PiperOrigin-RevId: 189087384
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/kmeans_test.py | 23 |
2 files changed, 24 insertions, 3 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 7319eaa7de..c092f85d35 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -105,7 +105,7 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): logging.info(e) -def _parse_tensor_or_dict(features): +def _parse_features_if_necessary(features): """Helper function to convert the input points into a usable format. Args: @@ -166,7 +166,7 @@ class _ModelFn(object): # input_points is a single Tensor. Therefore, the sharding functionality # in clustering_ops is unused, and some of the values below are lists of a # single item. - input_points = _parse_tensor_or_dict(features) + input_points = _parse_features_if_necessary(features) # Let N = the number of input_points. # all_distances: A list of one matrix of shape (N, num_clusters). Each value diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index f9598bfc08..06a2c52c11 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -226,6 +226,28 @@ class KMeansTest(KMeansTestBase): self._infer_helper(kmeans, clusters, 10) self._infer_helper(kmeans, clusters, 1) + def test_parse_features(self): + """Tests the various behaviours of kmeans._parse_features_if_necessary.""" + + # No-op if a tensor is passed in. + features = constant_op.constant(self.points) + parsed_features = kmeans_lib._parse_features_if_necessary(features) + self.assertAllEqual(features, parsed_features) + + # A dict is transformed into a tensor. + feature_dict = { + 'x': [[point[0]] for point in self.points], + 'y': [[point[1]] for point in self.points] + } + parsed_feature_dict = kmeans_lib._parse_features_if_necessary(feature_dict) + # Perform a sanity check. + self.assertEqual(features.shape, parsed_feature_dict.shape) + self.assertEqual(features.dtype, parsed_feature_dict.dtype) + # Then check that running the tensor yields the original list of points. + with self.test_session() as sess: + parsed_points = sess.run(parsed_feature_dict) + self.assertAllEqual(self.points, parsed_points) + class KMeansTestMultiStageInit(KMeansTestBase): @@ -394,7 +416,6 @@ class KMeansCosineDistanceTest(KMeansTestBase): true_assignments = [0] * 2 + [1] * 2 + [2] * 8 true_score = len(points) - np.tensordot( normalize(points), true_centers[true_assignments]) - kmeans = kmeans_lib.KMeansClustering( 3, initial_clusters=self.initial_clusters, |