aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-01 17:05:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-01 18:18:15 -0700
commit23fdab705d7caed3ff3d955b39c3cfc3f5e40678 (patch)
tree7cd484755b9bc5005ef5d9b3e80726663bdf1bdb
parent281cd0ae22f05275fcfe1fd8176ebd3769f80043 (diff)
Add K-Means clustering and WALS matrix factorization to tensorflow.
Change: 126465430
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/contrib/BUILD3
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/factorization/BUILD135
-rw-r--r--tensorflow/contrib/factorization/__init__.py23
-rw-r--r--tensorflow/contrib/factorization/examples/BUILD22
-rw-r--r--tensorflow/contrib/factorization/examples/mnist.py292
-rw-r--r--tensorflow/contrib/factorization/kernels/BUILD67
-rw-r--r--tensorflow/contrib/factorization/kernels/clustering_ops.cc522
-rw-r--r--tensorflow/contrib/factorization/kernels/clustering_ops_test.cc176
-rw-r--r--tensorflow/contrib/factorization/kernels/wals_solver_ops.cc271
-rw-r--r--tensorflow/contrib/factorization/ops/clustering_ops.cc69
-rw-r--r--tensorflow/contrib/factorization/ops/factorization_ops.cc46
-rw-r--r--tensorflow/contrib/factorization/python/__init__.py20
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py157
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py85
-rw-r--r--tensorflow/contrib/factorization/python/ops/clustering_ops.py409
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py467
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py456
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans.py238
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py283
21 files changed, 3744 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 763ae340dd..68f14676ac 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -67,6 +67,8 @@ filegroup(
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/distributions:all_files",
+ "//tensorflow/contrib/factorization:all_files",
+ "//tensorflow/contrib/factorization/kernels:all_files",
"//tensorflow/contrib/ffmpeg:all_files",
"//tensorflow/contrib/ffmpeg/default:all_files",
"//tensorflow/contrib/framework:all_files",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 2702ddde01..8b3761ae0f 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -16,6 +16,7 @@ py_library(
"//tensorflow/contrib/bayesflow:bayesflow_py",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/distributions:distributions_py",
+ "//tensorflow/contrib/factorization:factorization_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/graph_editor:graph_editor_py",
@@ -40,6 +41,7 @@ cc_library(
name = "contrib_kernels",
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/layers:bucketization_op_kernel",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/linear_optimizer:sdca_op_kernels",
@@ -51,6 +53,7 @@ cc_library(
name = "contrib_ops_op_lib",
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/layers:bucketization_op_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/linear_optimizer:sdca_ops_op_lib",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index e71adbb5bc..6ec786430f 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.contrib import bayesflow
from tensorflow.contrib import copy_graph
from tensorflow.contrib import distributions
+from tensorflow.contrib import factorization
from tensorflow.contrib import framework
from tensorflow.contrib import graph_editor
from tensorflow.contrib import grid_rnn
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
new file mode 100644
index 0000000000..4be7a966c6
--- /dev/null
+++ b/tensorflow/contrib/factorization/BUILD
@@ -0,0 +1,135 @@
+# Description:
+# Contains ops for factorization of data, including matrix factorization and
+# clustering.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+py_library(
+ name = "factorization_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ data = [
+ ":python/ops/_clustering_ops.so",
+ ":python/ops/_factorization_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_clustering_ops",
+ ":gen_factorization_ops",
+ ],
+)
+
+# Ops
+tf_custom_op_library(
+ name = "python/ops/_clustering_ops.so",
+ srcs = [
+ "ops/clustering_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/factorization/kernels:clustering_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "python/ops/_factorization_ops.so",
+ srcs = [
+ "ops/factorization_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/factorization/kernels:wals_solver_ops",
+ ],
+)
+
+tf_gen_op_libs([
+ "clustering_ops",
+ "factorization_ops",
+])
+
+cc_library(
+ name = "all_ops",
+ deps = [
+ ":clustering_ops_op_lib",
+ ":factorization_ops_op_lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_clustering_ops",
+ out = "python/ops/gen_clustering_ops.py",
+ deps = [
+ ":clustering_ops_op_lib",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_factorization_ops",
+ out = "python/ops/gen_factorization_ops.py",
+ deps = [
+ ":factorization_ops_op_lib",
+ ],
+)
+
+# Ops tests
+tf_py_test(
+ name = "kmeans_test",
+ srcs = [
+ "python/ops/kmeans_test.py",
+ ],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_py_test(
+ name = "factorization_ops_test",
+ srcs = ["python/ops/factorization_ops_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+# Kernel tests
+tf_py_test(
+ name = "wals_solver_ops_test",
+ srcs = ["python/kernel_tests/wals_solver_ops_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_py_test(
+ name = "clustering_ops_test",
+ srcs = ["python/kernel_tests/clustering_ops_test.py"],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+# All files
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py
new file mode 100644
index 0000000000..101de6e7c6
--- /dev/null
+++ b/tensorflow/contrib/factorization/__init__.py
@@ -0,0 +1,23 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to factorization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.factorization.python.ops import *
diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD
new file mode 100644
index 0000000000..bf5b829e32
--- /dev/null
+++ b/tensorflow/contrib/factorization/examples/BUILD
@@ -0,0 +1,22 @@
+# Example TensorFlow models using factorization ops.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
+
+tf_py_test(
+ name = "mnist",
+ size = "medium",
+ srcs = [
+ "mnist.py",
+ ],
+ additional_deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
diff --git a/tensorflow/contrib/factorization/examples/mnist.py b/tensorflow/contrib/factorization/examples/mnist.py
new file mode 100644
index 0000000000..f5f4c23502
--- /dev/null
+++ b/tensorflow/contrib/factorization/examples/mnist.py
@@ -0,0 +1,292 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Example mnist model with jointly computed k-means clustering.
+
+This is a toy example of how clustering can be embedded into larger tensorflow
+graphs. In this case, we learn a clustering on-the-fly and transform the input
+into the 'distance to clusters' space. These are then fed into hidden layers to
+learn the supervised objective.
+
+To train this model on real mnist data, run this model as follows:
+ mnist --nofake_data --max_steps=2000
+"""
+
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import tempfile
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.contrib.factorization.python.ops import clustering_ops
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
+
+# Basic model parameters as external flags.
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_float('learning_rate', 0.3, 'Initial learning rate.')
+flags.DEFINE_integer('max_steps', 200, 'Number of steps to run trainer.')
+flags.DEFINE_integer('num_clusters', 384, 'Number of input feature clusters')
+flags.DEFINE_integer('hidden1', 256, 'Number of units in hidden layer 1.')
+flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
+flags.DEFINE_integer('batch_size', 100, 'Batch size. '
+ 'Must divide evenly into the dataset sizes.')
+flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
+flags.DEFINE_bool('fake_data', True, 'Use fake input data.')
+
+# The MNIST dataset has 10 classes, representing the digits 0 through 9.
+NUM_CLASSES = 10
+
+# The MNIST images are always 28x28 pixels.
+IMAGE_SIZE = 28
+IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
+
+
+def placeholder_inputs():
+ """Generate placeholder variables to represent the input tensors.
+
+ Returns:
+ images_placeholder: Images placeholder.
+ labels_placeholder: Labels placeholder.
+ """
+ images_placeholder = tf.placeholder(tf.float32, shape=(None,
+ mnist.IMAGE_PIXELS))
+ labels_placeholder = tf.placeholder(tf.int32, shape=(None))
+ return images_placeholder, labels_placeholder
+
+
+def fill_feed_dict(data_set, images_pl, labels_pl, batch_size):
+ """Fills the feed_dict for training the given step.
+
+ Args:
+ data_set: The set of images and labels, from input_data.read_data_sets()
+ images_pl: The images placeholder, from placeholder_inputs().
+ labels_pl: The labels placeholder, from placeholder_inputs().
+ batch_size: Batch size of data to feed.
+
+ Returns:
+ feed_dict: The feed dictionary mapping from placeholders to values.
+ """
+ # Create the feed_dict for the placeholders filled with the next
+ # `batch size ` examples.
+ images_feed, labels_feed = data_set.next_batch(batch_size, FLAGS.fake_data)
+ feed_dict = {
+ images_pl: images_feed,
+ labels_pl: labels_feed,
+ }
+ return feed_dict
+
+
+def do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_set):
+ """Runs one evaluation against the full epoch of data.
+
+ Args:
+ sess: The session in which the model has been trained.
+ eval_correct: The Tensor that returns the number of correct predictions.
+ images_placeholder: The images placeholder.
+ labels_placeholder: The labels placeholder.
+ data_set: The set of images and labels to evaluate, from
+ input_data.read_data_sets().
+ Returns:
+ Precision value on the dataset.
+ """
+ # And run one epoch of eval.
+ true_count = 0 # Counts the number of correct predictions.
+ steps_per_epoch = data_set.num_examples // FLAGS.batch_size
+ num_examples = steps_per_epoch * FLAGS.batch_size
+ for step in xrange(steps_per_epoch):
+ feed_dict = fill_feed_dict(data_set,
+ images_placeholder,
+ labels_placeholder,
+ FLAGS.batch_size)
+ true_count += sess.run(eval_correct, feed_dict=feed_dict)
+ precision = true_count / num_examples
+ print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
+ (num_examples, true_count, precision))
+ return precision
+
+
+def inference(inp, num_clusters, hidden1_units, hidden2_units):
+ """Build the MNIST model up to where it may be used for inference.
+
+ Args:
+ inp: input data
+ num_clusters: number of clusters of input features to train.
+ hidden1_units: Size of the first hidden layer.
+ hidden2_units: Size of the second hidden layer.
+
+ Returns:
+ logits: Output tensor with the computed logits.
+ clustering_loss: Clustering loss.
+ kmeans_training_op: An op to train the clustering.
+ """
+ # Clustering
+ kmeans = clustering_ops.KMeans(
+ inp,
+ num_clusters,
+ distance_metric=clustering_ops.COSINE_DISTANCE,
+ # TODO(agarwal): kmeans++ is currently causing crash in dbg mode.
+ # Enable this after fixing.
+ # initial_clusters=clustering_ops.KMEANS_PLUS_PLUS_INIT,
+ use_mini_batch=True)
+
+ all_scores, _, clustering_scores, kmeans_training_op = kmeans.training_graph()
+ # Some heuristics to approximately whiten this output.
+ all_scores = (all_scores[0] - 0.5) * 5
+ # Here we avoid passing the gradients from the supervised objective back to
+ # the clusters by creating a stop_gradient node.
+ all_scores = tf.stop_gradient(all_scores)
+ clustering_loss = tf.reduce_sum(clustering_scores[0])
+ # Hidden 1
+ with tf.name_scope('hidden1'):
+ weights = tf.Variable(
+ tf.truncated_normal([num_clusters, hidden1_units],
+ stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([hidden1_units]),
+ name='biases')
+ hidden1 = tf.nn.relu(tf.matmul(all_scores, weights) + biases)
+ # Hidden 2
+ with tf.name_scope('hidden2'):
+ weights = tf.Variable(
+ tf.truncated_normal([hidden1_units, hidden2_units],
+ stddev=1.0 / math.sqrt(float(hidden1_units))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([hidden2_units]),
+ name='biases')
+ hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
+ # Linear
+ with tf.name_scope('softmax_linear'):
+ weights = tf.Variable(
+ tf.truncated_normal([hidden2_units, NUM_CLASSES],
+ stddev=1.0 / math.sqrt(float(hidden2_units))),
+ name='weights')
+ biases = tf.Variable(tf.zeros([NUM_CLASSES]),
+ name='biases')
+ logits = tf.matmul(hidden2, weights) + biases
+ return logits, clustering_loss, kmeans_training_op
+
+
+def run_training():
+ """Train MNIST for a number of steps."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ train_dir = tempfile.mkdtemp()
+ data_sets = input_data.read_data_sets(train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ # Generate placeholders for the images and labels.
+ images_placeholder, labels_placeholder = placeholder_inputs()
+
+ # Build a Graph that computes predictions from the inference model.
+ logits, clustering_loss, kmeans_training_op = inference(images_placeholder,
+ FLAGS.num_clusters,
+ FLAGS.hidden1,
+ FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels_placeholder)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = tf.group(mnist.training(loss, FLAGS.learning_rate),
+ kmeans_training_op)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels_placeholder)
+
+ # Add the variable initializer Op.
+ init = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ feed_dict = fill_feed_dict(data_sets.train,
+ images_placeholder,
+ labels_placeholder,
+ batch_size=5000)
+ # Run the Op to initialize the variables.
+ sess.run(init, feed_dict=feed_dict)
+
+ # Start the training loop.
+ max_test_prec = 0
+ for step in xrange(FLAGS.max_steps):
+ start_time = time.time()
+
+ # Fill a feed dictionary with the actual set of images and labels
+ # for this particular training step.
+ feed_dict = fill_feed_dict(data_sets.train,
+ images_placeholder,
+ labels_placeholder,
+ FLAGS.batch_size)
+
+ # Run one step of the model.
+ _, loss_value, clustering_loss_value = sess.run([train_op,
+ loss,
+ clustering_loss],
+ feed_dict=feed_dict)
+
+ duration = time.time() - start_time
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f, clustering_loss = %.2f (%.3f sec)' % (
+ step, loss_value, clustering_loss_value, duration))
+
+ # Save a checkpoint and evaluate the model periodically.
+ if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+ # Evaluate against the training set.
+ print('Training Data Eval:')
+ do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.train)
+ # Evaluate against the validation set.
+ print('Validation Data Eval:')
+ do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.validation)
+ # Evaluate against the test set.
+ print('Test Data Eval:')
+ test_prec = do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.test)
+ max_test_prec = max(max_test_prec, test_prec)
+ return max_test_prec
+
+
+class MnistTest(tf.test.TestCase):
+
+ def test_train(self):
+ self.assertTrue(run_training() > 0.6)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD
new file mode 100644
index 0000000000..301ab4c95e
--- /dev/null
+++ b/tensorflow/contrib/factorization/kernels/BUILD
@@ -0,0 +1,67 @@
+# OpKernels for data factorization and clustering.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+
+cc_library(
+ name = "all_kernels",
+ deps = [
+ ":clustering_ops",
+ ":wals_solver_ops",
+ "@protobuf//:protobuf",
+ ],
+)
+
+cc_library(
+ name = "wals_solver_ops",
+ srcs = ["wals_solver_ops.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf//:protobuf",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "clustering_ops",
+ srcs = ["clustering_ops.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf//:protobuf",
+ ],
+ alwayslink = 1,
+)
+
+cc_test(
+ name = "clustering_ops_test",
+ srcs = ["clustering_ops_test.cc"],
+ deps = [
+ ":clustering_ops",
+ "//tensorflow/contrib/factorization:clustering_ops_op_lib",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+)
diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc
new file mode 100644
index 0000000000..5f680ceadb
--- /dev/null
+++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc
@@ -0,0 +1,522 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy
+// of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+// ==============================================================================
+
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <memory>
+#include <tuple>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/top_n.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace {
+using errors::InvalidArgument;
+
+template <typename Scalar>
+using RowMajorMatrix =
+ Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
+
+using MatrixXfRowMajor = RowMajorMatrix<float>;
+using MatrixXi64RowMajor = RowMajorMatrix<int64>;
+
+// Ideally this should be computed by dividing L3 cache size by the number of
+// physical CPUs. Since there isn't a portable method to do this, we are using
+// a conservative estimate here.
+const int64 kDefaultL3CachePerCpu = 1 << 20;
+
+// These values were determined by performing a parameter sweep on the
+// NearestNeighborsOp benchmark.
+const int64 kNearestNeighborsCentersMaxBlockSize = 1024;
+const int64 kNearestNeighborsPointsMinBlockSize = 16;
+
+// Returns the smallest multiple of a that is not smaller than b.
+int64 NextMultiple(int64 a, int64 b) {
+ const int64 remainder = b % a;
+ return remainder == 0 ? b : (b + a - remainder);
+}
+
+// Returns a / b rounded up to the next higher integer.
+int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; }
+
+} // namespace
+
+// Implementation of K-means++ initialization. Samples points iteratively in
+// proportion to the squared distances from selected points.
+// TODO(ands): Add support for other distance metrics.
+class KmeansPlusPlusInitializationOp : public OpKernel {
+ public:
+ explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->MatchSignature(
+ {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& points_tensor = context->input(0);
+ const Tensor& num_to_sample_tensor = context->input(1);
+ const Tensor& seed_tensor = context->input(2);
+ const Tensor& num_retries_per_sample_tensor = context->input(3);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
+ InvalidArgument("Input points should be a matrix."));
+ OP_REQUIRES(context,
+ TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()),
+ InvalidArgument("Input num_to_sample should be a scalar."));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
+ InvalidArgument("Input seed should be a scalar."));
+ OP_REQUIRES(
+ context,
+ TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()),
+ InvalidArgument("Input num_retries_per_sample should be a scalar."));
+
+ const int64 num_points = points_tensor.dim_size(0);
+ const int64 point_dimensions = points_tensor.dim_size(1);
+ const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()();
+ const int64 seed = seed_tensor.scalar<int64>()();
+ const int64 num_retries_per_sample = [&]() {
+ const int64 value = num_retries_per_sample_tensor.scalar<int64>()();
+ return value >= 0 ? value
+ : 2 + static_cast<int64>(std::log(num_to_sample));
+ }();
+
+ OP_REQUIRES(context, num_points > 0,
+ InvalidArgument("Expected points.rows() > 0."));
+ OP_REQUIRES(context, num_to_sample > 0,
+ InvalidArgument("Expected num_to_sample > 0."));
+ OP_REQUIRES(context, num_to_sample <= num_points,
+ InvalidArgument("Expected num_to_sample <= points.rows(). ",
+ num_to_sample, " vs ", num_points, "."));
+
+ Tensor* output_sampled_points_tensor;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({num_to_sample, point_dimensions}),
+ &output_sampled_points_tensor));
+
+ const Eigen::Map<const MatrixXfRowMajor> points(
+ points_tensor.matrix<float>().data(), num_points, point_dimensions);
+ const Eigen::VectorXf points_half_squared_norm =
+ 0.5 * points.rowwise().squaredNorm();
+
+ Eigen::Map<MatrixXfRowMajor> sampled_points(
+ output_sampled_points_tensor->matrix<float>().data(), num_to_sample,
+ point_dimensions);
+ std::unordered_set<int64> sampled_indices;
+
+ random::PhiloxRandom random(seed);
+ random::SimplePhilox rng(&random);
+
+ auto add_one_point = [&](int64 from, int64 to) {
+ from = std::min(from, num_points - 1);
+ sampled_points.row(to) = points.row(from);
+ sampled_indices.insert(from);
+ };
+
+ // Distances from all points to nearest selected point. Initialize with
+ // distances to first selected point.
+ Eigen::VectorXf min_distances(num_points);
+ min_distances.fill(std::numeric_limits<float>::infinity());
+ Eigen::VectorXf min_distances_cumsum(num_points);
+
+ auto draw_one_sample = [&]() -> int64 {
+ if (sampled_indices.empty()) return rng.Uniform64(num_points);
+ int64 index = 0;
+ do {
+ // If v is drawn from Uniform[0, distances.sum()), then
+ // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is
+ // proportional to distances(i).
+ index = std::upper_bound(
+ min_distances_cumsum.data(),
+ min_distances_cumsum.data() + num_points,
+ rng.RandFloat() * min_distances_cumsum(num_points - 1)) -
+ min_distances_cumsum.data();
+ } while (sampled_indices.find(index) != sampled_indices.end());
+ return index;
+ };
+
+ auto sample_one_point = [&]() {
+ const int64 sampled_index = draw_one_sample();
+ min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY(
+ points, points_half_squared_norm, points.row(sampled_index),
+ points_half_squared_norm(sampled_index)));
+ return sampled_index;
+ };
+
+ auto sample_one_point_with_retries = [&]() {
+ Eigen::VectorXf best_new_min_distances(num_points);
+ float best_potential = std::numeric_limits<float>::infinity();
+ int64 best_sampled_index = 0;
+ for (int i = 1 + num_retries_per_sample; i > 0; --i) {
+ const int64 sampled_index = draw_one_sample();
+ Eigen::VectorXf new_min_distances =
+ min_distances.cwiseMin(GetHalfSquaredDistancesToY(
+ points, points_half_squared_norm, points.row(sampled_index),
+ points_half_squared_norm(sampled_index)));
+ const float potential = new_min_distances.sum();
+ if (potential < best_potential) {
+ best_potential = potential;
+ best_sampled_index = sampled_index;
+ best_new_min_distances.swap(new_min_distances);
+ }
+ }
+ min_distances.swap(best_new_min_distances);
+ return best_sampled_index;
+ };
+
+ for (int64 i = 0; i < num_to_sample; ++i) {
+ if (i > 0) {
+ std::partial_sum(min_distances.data(),
+ min_distances.data() + num_points,
+ min_distances_cumsum.data());
+ }
+ int64 next = num_retries_per_sample == 0
+ ? sample_one_point()
+ : sample_one_point_with_retries();
+ add_one_point(next, i);
+ }
+ }
+
+ private:
+ // Returns a column vector with the i-th element set to half the squared
+ // euclidean distance between the i-th row of xs, and y. Precomputed norms for
+ // each row of xs and y must be provided for efficiency.
+ // TODO(ands): Parallelize this for large xs.
+ static Eigen::VectorXf GetHalfSquaredDistancesToY(
+ const Eigen::Ref<const MatrixXfRowMajor>& xs,
+ const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm,
+ const Eigen::Ref<const Eigen::RowVectorXf>& y,
+ float y_half_squared_norm) {
+ // Squared distance between points xs_i and y is:
+ // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2
+ return (xs_half_squared_norm - xs * y.transpose()).array() +
+ y_half_squared_norm;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
+ KmeansPlusPlusInitializationOp);
+
+// Operator for computing the nearest neighbors for a set of points.
+class NearestNeighborsOp : public OpKernel {
+ public:
+ explicit NearestNeighborsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64},
+ {DT_INT64, DT_FLOAT}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& points_tensor = context->input(0);
+ const Tensor& centers_tensor = context->input(1);
+ const Tensor& k_tensor = context->input(2);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()),
+ InvalidArgument("Input points should be a matrix."));
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()),
+ InvalidArgument("Input centers should be a matrix."));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()),
+ InvalidArgument("Input k should be a scalar."));
+
+ const int64 num_points = points_tensor.dim_size(0);
+ const int64 point_dimensions = points_tensor.dim_size(1);
+ const int64 num_centers = centers_tensor.dim_size(0);
+ const int64 center_dimensions = centers_tensor.dim_size(1);
+
+ OP_REQUIRES(context, num_points > 0,
+ InvalidArgument("Expected points.rows() > 0."));
+ OP_REQUIRES(
+ context, point_dimensions == center_dimensions,
+ InvalidArgument("Expected point_dimensions == center_dimensions: ",
+ point_dimensions, " vs ", center_dimensions, "."));
+
+ const Eigen::Map<const MatrixXfRowMajor> points(
+ points_tensor.matrix<float>().data(), num_points, point_dimensions);
+ const Eigen::Map<const MatrixXfRowMajor> centers(
+ centers_tensor.matrix<float>().data(), num_centers, center_dimensions);
+ const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()());
+
+ Tensor* output_nearest_center_indices_tensor;
+ Tensor* output_nearest_center_distances_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 0, TensorShape({num_points, k}),
+ &output_nearest_center_indices_tensor));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({num_points, k}),
+ &output_nearest_center_distances_tensor));
+
+ if (k == 0) return;
+
+ Eigen::Map<MatrixXi64RowMajor> nearest_center_indices(
+ output_nearest_center_indices_tensor->matrix<int64>().data(),
+ num_points, k);
+ Eigen::Map<MatrixXfRowMajor> nearest_center_distances(
+ output_nearest_center_distances_tensor->matrix<float>().data(),
+ num_points, k);
+
+ const Eigen::VectorXf centers_half_squared_norm =
+ 0.5 * centers.rowwise().squaredNorm();
+
+ // The distance computation is sharded to take advantage of multiple cores
+ // and to allow intermediate values to reside in L3 cache. This is done by
+ // sharding the points and centers as follows:
+ //
+ // 1. Centers are sharded such that each block of centers has at most
+ // kNearestNeighborsCentersMaxBlockSize rows.
+ // 2. Points are sharded, and each block of points is multiplied with each
+ // block of centers. The block size of points is chosen such that the
+ // point coordinates (point_dimensions) and the matrix of distances to
+ // each center in one block -- the intermediate data -- fits in L3 cache.
+ // 3. After performing each block-block distance computation, the results
+ // are reduced to a set of k nearest centers as soon as possible. This
+ // decreases total memory I/O.
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ const int64 num_threads = worker_threads.num_threads;
+ // This kernel might be configured to use fewer than the total number of
+ // available CPUs on the host machine. To avoid descructive interference
+ // with other jobs running on the host machine, we must only use a fraction
+ // of total available L3 cache. Unfortunately, we cannot query the host
+ // machine to get the number of physical CPUs. So, we use a fixed per-CPU
+ // budget and scale it by the number of CPUs available to this operation.
+ const int64 total_memory_budget =
+ kDefaultL3CachePerCpu * port::NumSchedulableCPUs();
+ // Compute the number of blocks into which rows of points must be split so
+ // that the distance matrix and the block of points can fit in cache. One
+ // row of points will yield a vector of distances to each center in a block.
+ const int64 bytes_per_row =
+ (std::min(kNearestNeighborsCentersMaxBlockSize,
+ num_centers) /* centers in a block */
+ + point_dimensions /* coordinates of one point */) *
+ sizeof(float);
+ // The memory needed for storing the centers being processed. This is shared
+ // by all workers. Adding slack to the number of threads to avoid incorrect
+ // cache eviction when a new block of centers is loaded.
+ const int64 bytes_for_centers =
+ std::min(num_centers,
+ (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) *
+ point_dimensions * sizeof(float);
+ // The memory budget available for workers to store their distance matrices.
+ const int64 available_memory_budget =
+ total_memory_budget - bytes_for_centers;
+ // That memory budget is shared by all threads.
+ const int64 rows_per_block =
+ std::max<int64>(kNearestNeighborsPointsMinBlockSize,
+ available_memory_budget / num_threads / bytes_per_row);
+ // Divide rows into almost uniformly-sized units of work that are small
+ // enough for the memory budget (rows_per_block). Round up to a multiple of
+ // the number of threads.
+ const int64 num_units =
+ NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block));
+ auto work = [&](int64 start, int64 limit) {
+ for (; start < limit; ++start) {
+ const int64 start_row = num_points * start / num_units;
+ const int64 limit_row = num_points * (start + 1) / num_units;
+ CHECK_LE(limit_row, num_points);
+ const int64 num_rows = limit_row - start_row;
+ auto points_shard = points.middleRows(start_row, num_rows);
+ const Eigen::VectorXf points_half_squared_norm =
+ 0.5 * points_shard.rowwise().squaredNorm();
+ auto nearest_center_indices_shard =
+ nearest_center_indices.middleRows(start_row, num_rows);
+ auto nearest_center_distances_shard =
+ nearest_center_distances.middleRows(start_row, num_rows);
+ FindKNearestCenters(k, points_shard, points_half_squared_norm, centers,
+ centers_half_squared_norm,
+ nearest_center_indices_shard,
+ nearest_center_distances_shard);
+ }
+ };
+
+ const int64 units_per_thread = num_units / num_threads;
+ BlockingCounter counter(num_threads - 1);
+ for (int64 i = 1; i < num_threads; ++i) {
+ const int64 start = i * units_per_thread;
+ const int64 limit = start + units_per_thread;
+ worker_threads.workers->Schedule([work, &counter, start, limit]() {
+ work(start, limit);
+ counter.DecrementCount();
+ });
+ }
+ work(0, units_per_thread);
+ counter.Wait();
+ }
+
+ private:
+ static void FindKNearestCenters(
+ int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
+ const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
+ const Eigen::Ref<const MatrixXfRowMajor>& centers,
+ const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
+ Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
+ Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
+ CHECK_LE(k, centers.rows());
+ if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) {
+ FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers,
+ centers_half_squared_norm,
+ nearest_center_indices,
+ nearest_center_distances);
+ } else {
+ FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers,
+ centers_half_squared_norm,
+ nearest_center_indices,
+ nearest_center_distances);
+ }
+ }
+
+ static void FindKNearestCentersOneBlock(
+ int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
+ const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
+ const Eigen::Ref<const MatrixXfRowMajor>& centers,
+ const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
+ Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
+ Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
+ CHECK_LE(k, centers.rows());
+ const int64 num_points = points.rows();
+ const MatrixXfRowMajor inner_product = points * centers.transpose();
+ // Find nearest neighbors.
+ if (k == 1) {
+ for (int i = 0; i < num_points; ++i) {
+ int64 index;
+ nearest_center_distances(i, 0) =
+ 2.0 *
+ (points_half_squared_norm(i) +
+ (centers_half_squared_norm.transpose() - inner_product.row(i))
+ .minCoeff(&index));
+ nearest_center_indices(i, 0) = index;
+ }
+ } else {
+ // Select k nearest centers for each point.
+ using Center = std::pair<float, int64>;
+ const int64 num_centers = centers.rows();
+ gtl::TopN<Center, std::less<Center>> selector(k);
+ std::unique_ptr<std::vector<Center>> nearest_centers;
+ for (int i = 0; i < num_points; ++i) {
+ selector.reserve(num_centers);
+ for (int j = 0; j < num_centers; ++j) {
+ const float partial_distance =
+ centers_half_squared_norm(j) - inner_product(i, j);
+ selector.push(Center(partial_distance, j));
+ }
+ nearest_centers.reset(selector.Extract());
+ selector.Reset();
+ const float point_half_squared_norm = points_half_squared_norm(i);
+ for (int j = 0; j < k; ++j) {
+ const Center& center = (*nearest_centers)[j];
+ nearest_center_distances(i, j) =
+ 2.0 * (point_half_squared_norm + center.first);
+ nearest_center_indices(i, j) = center.second;
+ }
+ }
+ }
+ }
+
+ static void FindKNearestCentersBlockwise(
+ int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points,
+ const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm,
+ const Eigen::Ref<const MatrixXfRowMajor>& centers,
+ const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm,
+ Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices,
+ Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) {
+ const int64 num_points = points.rows();
+ const int64 num_centers = centers.rows();
+ CHECK_LE(k, num_centers);
+ CHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize);
+ // Store nearest neighbors with first block of centers directly into the
+ // output matrices.
+ int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize);
+ FindKNearestCentersOneBlock(
+ out_k, points, points_half_squared_norm,
+ centers.topRows(kNearestNeighborsCentersMaxBlockSize),
+ centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize),
+ nearest_center_indices, nearest_center_distances);
+ // Iteratively compute nearest neighbors with other blocks of centers, and
+ // update the output matrices.
+ MatrixXi64RowMajor block_nearest_center_indices(num_points, k);
+ MatrixXfRowMajor block_nearest_center_distances(num_points, k);
+ Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k);
+ Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k);
+ for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize;
+ centers_start < num_centers;
+ centers_start += kNearestNeighborsCentersMaxBlockSize) {
+ const int64 centers_block_size = std::min(
+ kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start);
+ const int64 block_k = std::min(k, centers_block_size);
+ FindKNearestCentersOneBlock(
+ block_k, points, points_half_squared_norm,
+ centers.middleRows(centers_start, centers_block_size),
+ centers_half_squared_norm.segment(centers_start, centers_block_size),
+ block_nearest_center_indices, block_nearest_center_distances);
+ if (k == 1) {
+ for (int i = 0; i < num_points; ++i) {
+ if (block_nearest_center_distances(i, 0) <
+ nearest_center_distances(i, 0)) {
+ nearest_center_indices(i, 0) =
+ block_nearest_center_indices(i, 0) + centers_start;
+ nearest_center_distances(i, 0) =
+ block_nearest_center_distances(i, 0);
+ }
+ }
+ } else {
+ for (int i = 0; i < num_points; ++i) {
+ // Merge and accumulate top-k list from block_nearest_center_indices
+ // into nearest_center_indices.
+ for (int64 j_out = 0, j_block = 0, j_merged = 0;
+ (j_out < out_k || j_block < block_k) && j_merged < k;
+ ++j_merged) {
+ const float distance_out =
+ j_out < out_k ? nearest_center_distances(i, j_out)
+ : std::numeric_limits<float>::infinity();
+ const float distance_block =
+ j_block < block_k ? block_nearest_center_distances(i, j_block)
+ : std::numeric_limits<float>::infinity();
+ if (distance_out <= distance_block) {
+ merged_indices(j_merged) = nearest_center_indices(i, j_out);
+ merged_distances(j_merged) = distance_out;
+ ++j_out;
+ } else {
+ merged_indices(j_merged) =
+ block_nearest_center_indices(i, j_block) + centers_start;
+ merged_distances(j_merged) = distance_block;
+ ++j_block;
+ }
+ }
+ nearest_center_indices.row(i) = merged_indices;
+ nearest_center_distances.row(i) = merged_distances;
+ out_k = std::min(k, out_k + block_k);
+ }
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU),
+ NearestNeighborsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc
new file mode 100644
index 0000000000..c4a96b048d
--- /dev/null
+++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc
@@ -0,0 +1,176 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy
+// of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+// ==============================================================================
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+constexpr int k100Dim = 100;
+// Number of points for tests.
+constexpr int k10Points = 10;
+constexpr int k100Points = 100;
+constexpr int k1kPoints = 1000;
+constexpr int k10kPoints = 10000;
+constexpr int k1MPoints = 1000000;
+// Number of centers for tests.
+constexpr int k2Centers = 2;
+constexpr int k5Centers = 5;
+constexpr int k10Centers = 10;
+constexpr int k20Centers = 20;
+constexpr int k50Centers = 50;
+constexpr int k100Centers = 100;
+constexpr int k200Centers = 200;
+constexpr int k500Centers = 500;
+constexpr int k1kCenters = 1000;
+constexpr int k10kCenters = 10000;
+// Number of retries for tests.
+constexpr int k0RetriesPerSample = 0;
+constexpr int k3RetriesPerSample = 3;
+
+Graph* SetUpKmeansPlusPlusInitialization(int num_dims, int num_points,
+ int num_to_sample,
+ int retries_per_sample) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor points(DT_FLOAT, TensorShape({num_points, num_dims}));
+ Tensor sample_size(DT_INT64, TensorShape({}));
+ Tensor seed(DT_INT64, TensorShape({}));
+ Tensor num_retries_per_sample(DT_INT64, TensorShape({}));
+ points.flat<float>().setRandom();
+ sample_size.flat<int64>().setConstant(num_to_sample);
+ seed.flat<int64>().setConstant(12345);
+ num_retries_per_sample.flat<int64>().setConstant(retries_per_sample);
+
+ TF_CHECK_OK(NodeBuilder("kmeans_plus_plus_initialization_op",
+ "KmeansPlusPlusInitialization")
+ .Input(test::graph::Constant(g, points))
+ .Input(test::graph::Constant(g, sample_size))
+ .Input(test::graph::Constant(g, seed))
+ .Input(test::graph::Constant(g, num_retries_per_sample))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+template <int num_points, int num_to_sample, int num_dims,
+ int retries_per_sample>
+void BM_KmeansPlusPlusInitialization(int iters) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
+ num_to_sample);
+ testing::UseRealTime();
+ Graph* g = SetUpKmeansPlusPlusInitialization(
+ num_dims, num_points, num_to_sample, retries_per_sample);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+#define BENCHMARK_KMEANS_PLUS_PLUS(p, c, d, r) \
+ void BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r(int iters) { \
+ BM_KmeansPlusPlusInitialization<p, c, d, r>(iters); \
+ } \
+ BENCHMARK(BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r);
+
+#define RUN_BM_KmeansPlusPlusInitialization(retries) \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k2Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k5Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k10Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k10Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k20Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k50Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k100Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k100Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k200Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k500Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k1kCenters, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k100Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k200Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k500Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k1kCenters, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k100Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k200Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k500Centers, k100Dim, retries); \
+ BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k1kCenters, k100Dim, retries)
+
+RUN_BM_KmeansPlusPlusInitialization(k0RetriesPerSample);
+RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
+
+#undef RUN_BM_KmeansPlusPlusInitialization
+#undef BENCHMARK_KMEANS_PLUS_PLUS
+
+Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
+ int k) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor points(DT_FLOAT, TensorShape({num_points, num_dims}));
+ Tensor centers(DT_FLOAT, TensorShape({num_centers, num_dims}));
+ Tensor top(DT_INT64, TensorShape({}));
+ points.flat<float>().setRandom();
+ centers.flat<float>().setRandom();
+ top.flat<int64>().setConstant(k);
+
+ TF_CHECK_OK(NodeBuilder("nearest_centers_op", "NearestNeighbors")
+ .Input(test::graph::Constant(g, points))
+ .Input(test::graph::Constant(g, centers))
+ .Input(test::graph::Constant(g, top))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+template <int num_dims, int num_points, int num_centers, int k>
+void BM_NearestNeighbors(int iters) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
+ num_centers);
+ testing::UseRealTime();
+ Graph* g = SetUpNearestNeighbors(num_dims, num_points, num_centers, k);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+constexpr int kTop1 = 1;
+constexpr int kTop2 = 2;
+constexpr int kTop5 = 5;
+constexpr int kTop10 = 10;
+
+#define BENCHMARK_NEAREST_NEIGHBORS(d, p, c, k) \
+ void BM_NearestNeighbors##d##_##p##_##c##_##k(int iters) { \
+ BM_NearestNeighbors<d, p, c, k>(iters); \
+ } \
+ BENCHMARK(BM_NearestNeighbors##d##_##p##_##c##_##k);
+
+#define RUN_BM_NearestNeighbors(k) \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k100Centers, k); \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k1kCenters, k); \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k10kCenters, k); \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k100Centers, k); \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k1kCenters, k); \
+ BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k10kCenters, k)
+
+RUN_BM_NearestNeighbors(kTop1);
+// k > 1
+RUN_BM_NearestNeighbors(kTop2);
+RUN_BM_NearestNeighbors(kTop5);
+RUN_BM_NearestNeighbors(kTop10);
+
+#undef RUN_BM_NearestNeighbors
+#undef BENCHMARK_NEAREST_NEIGHBORS
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
new file mode 100644
index 0000000000..6797a88fd3
--- /dev/null
+++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc
@@ -0,0 +1,271 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy
+// of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+// ==============================================================================
+
+// TensorFlow kernels and Ops for constructing WALS normal equations.
+// TODO(agarwal,rmlarsen): Add security checks to the code.
+
+#include <algorithm>
+#include <vector>
+
+// This is only used for std::this_thread::get_id()
+#include <thread> // NOLINT
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT64;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::TensorShapeUtils;
+using tensorflow::errors::InvalidArgument;
+
+namespace tensorflow {
+
+// TODO(ataei): Consider using RowMajor maps.
+typedef Eigen::Map<
+ Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
+ EigenMatrixFloatMap;
+typedef Eigen::Map<
+ const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
+ ConstEigenMatrixInt64Map;
+typedef Eigen::Map<
+ const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
+ ConstEigenMatrixFloatMap;
+
+class WALSComputePartialLhsAndRhsOp : public OpKernel {
+ public:
+ explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->MatchSignature(
+ {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL},
+ {DT_FLOAT, DT_FLOAT}));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& factors = context->input(0);
+ const Tensor& factor_weights = context->input(1);
+ const Tensor& unobserved_weights = context->input(2);
+ const Tensor& input_weights = context->input(3);
+ const Tensor& input_indices = context->input(4);
+ const Tensor& input_values = context->input(5);
+ const Tensor& input_block_size = context->input(6);
+ const Tensor& input_is_transpose = context->input(7);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()),
+ InvalidArgument("Input factors should be a matrix."));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()),
+ InvalidArgument("Input factor_weights should be a vector."));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(unobserved_weights.shape()),
+ InvalidArgument("Input unobserved_weights should be a scalar."));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()),
+ InvalidArgument("Input input_weights should be a vector."));
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
+ InvalidArgument("Input input_indices should be a matrix."));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
+ InvalidArgument("Input input_values should be a vector"));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()),
+ InvalidArgument("Input input_block_size should be a scalar."));
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsScalar(input_is_transpose.shape()),
+ InvalidArgument("Input input_is_transpose should be a scalar."));
+
+ const int64 factor_dim = factors.dim_size(1);
+ const int64 factors_size = factors.dim_size(0);
+ const int64 num_nonzero_elements = input_indices.dim_size(0);
+ const int64 block_size = input_block_size.scalar<int64>()();
+ const auto& factor_weights_vec = factor_weights.vec<float>();
+ const auto& input_weights_vec = input_weights.vec<float>();
+ const float w_0 = unobserved_weights.scalar<float>()();
+ const auto& input_values_vec = input_values.vec<float>();
+
+ ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(),
+ factor_dim, factors_size);
+ ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(),
+ 2, num_nonzero_elements);
+
+ Tensor* output_lhs_tensor;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ 0, TensorShape({block_size, factor_dim, factor_dim}),
+ &output_lhs_tensor));
+ auto output_lhs_t = output_lhs_tensor->tensor<float, 3>();
+ output_lhs_t.setZero();
+ Tensor* output_rhs_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, TensorShape({block_size, factor_dim}),
+ &output_rhs_tensor));
+ EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(),
+ factor_dim, block_size);
+ rhs_mat.setZero();
+ const bool is_transpose = input_is_transpose.scalar<bool>()();
+
+ auto get_input_index = [is_transpose, &indices_mat](int64 i) {
+ return is_transpose ? indices_mat(1, i) : indices_mat(0, i);
+ };
+ auto get_factor_index = [is_transpose, &indices_mat](int64 i) {
+ return is_transpose ? indices_mat(0, i) : indices_mat(1, i);
+ };
+
+ // TODO(rmlarsen): In principle, we should be using the SparseTensor class
+ // and machinery for iterating over groups, but the fact that class
+ // SparseTensor makes a complete copy of the matrix makes me reluctant to
+ // use it.
+ std::vector<int64> perm(num_nonzero_elements);
+ std::iota(perm.begin(), perm.end(), 0);
+
+ typedef std::pair<int64, int64> Shard;
+ std::vector<Shard> shards;
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+ const int num_threads = worker_threads.num_threads;
+ int64 shard_total = 0;
+ if (num_threads == 1) {
+ shards.emplace_back(0, num_nonzero_elements);
+ shard_total += num_nonzero_elements;
+ } else {
+ // Compute a permutation such that get_input_index(perm[i]) is sorted, use
+ // stable_sort to preserve spatial locality.
+ std::stable_sort(perm.begin(), perm.end(),
+ [&get_input_index](int64 i, int64 j) {
+ return get_input_index(i) < get_input_index(j);
+ });
+
+ // Compute the start and end of runs with identical input_index.
+ // These are the shards of work that can be processed in parallel
+ // without locking.
+ int64 start = 0;
+ int64 end = 0;
+ while (end < num_nonzero_elements) {
+ start = end;
+ while (end < num_nonzero_elements &&
+ get_input_index(perm[start]) == get_input_index(perm[end])) {
+ ++end;
+ }
+ shards.emplace_back(start, end);
+ shard_total += end - start;
+ }
+ }
+ CHECK_EQ(shard_total, num_nonzero_elements);
+ CHECK_LE(shards.size(), num_nonzero_elements);
+ CHECK_GT(shards.size(), 0);
+
+ // Batch the rank-one updates into a rank-k update to lower memory traffic
+ const int kMaxBatchSize = 128;
+
+ // Since we do not have an easy way of generating thread id's within the
+ // range [0,num_threads), we can instead call out to an std::unordered_map
+ // of matrices and initialize the matrix on the first call.
+ // However, this might have a performance penalty, as memory allocation can
+ // cause the OS kernel to enter a critical section and temporarily disable
+ // parallelism, and the unordered_map must be protected with a read/write
+ // mutex.
+ //
+ // TODO(jpoulson): Simplify after the thread rank can be queried
+ std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map;
+ mutex map_mutex;
+
+ BlockingCounter counter(shards.size());
+ // Lambda encapsulating the per-shard computation.
+ auto work = [&](const Shard& shard) {
+ const std::thread::id thread_id = std::this_thread::get_id();
+ const size_t id_hash = std::hash<std::thread::id>()(thread_id);
+ // If this thread's unique factors_mat.rows() x kMaxBatchSize
+ // batching matrix has not yet been created, then emplace it into the
+ // map using the hash of the thread id as the key.
+ //
+ // TODO(jpoulson): Switch to try_emplace once C++17 is supported
+ map_mutex.lock();
+ const auto key_count = factor_batch_map.count(id_hash);
+ map_mutex.unlock();
+ if (!key_count) {
+ map_mutex.lock();
+ factor_batch_map.emplace(
+ std::piecewise_construct, std::forward_as_tuple(id_hash),
+ std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize));
+ map_mutex.unlock();
+ }
+ map_mutex.lock();
+ auto& factor_batch = factor_batch_map[id_hash];
+ map_mutex.unlock();
+
+ CHECK_GE(shard.first, 0);
+ CHECK_LE(shard.second, perm.size());
+ CHECK_LE(shard.first, shard.second);
+ const int64 input_index = get_input_index(perm[shard.first]);
+ // Acccumulate the rhs and lhs terms in the normal equations
+ // for the non-zero elements in the row or column of the sparse matrix
+ // corresponding to input_index.
+ int num_batched = 0;
+ EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() +
+ input_index * factor_dim * factor_dim,
+ factor_dim, factor_dim);
+ auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>();
+ for (int64 p = shard.first; p < shard.second; ++p) {
+ const int64 i = perm[p];
+ // Check that all entries in the shard have the same input index.
+ CHECK_EQ(input_index, get_input_index(i));
+ const int64 factor_index = get_factor_index(i);
+ const float input_value = input_values_vec(i);
+ const float weight =
+ input_weights_vec(input_index) * factor_weights_vec(factor_index);
+ CHECK_GE(weight, 0);
+ factor_batch.col(num_batched) =
+ factors_mat.col(factor_index) * std::sqrt(weight);
+ ++num_batched;
+ if (num_batched == kMaxBatchSize) {
+ lhs_symm.rankUpdate(factor_batch);
+ num_batched = 0;
+ }
+
+ rhs_mat.col(input_index) +=
+ input_value * (w_0 + weight) * factors_mat.col(factor_index);
+ }
+ if (num_batched != 0) {
+ auto factor_block =
+ factor_batch.block(0, 0, factors_mat.rows(), num_batched);
+ lhs_symm.rankUpdate(factor_block);
+ }
+ // Copy lower triangular to upper triangular part of normal equation
+ // matrix.
+ lhs_mat = lhs_symm;
+ counter.DecrementCount();
+ };
+ for (int i = 1; i < shards.size(); ++i) {
+ worker_threads.workers->Schedule(std::bind(work, shards[i]));
+ }
+ // Inline execute the 1st shard.
+ work(shards[0]);
+ counter.Wait();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU),
+ WALSComputePartialLhsAndRhsOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc
new file mode 100644
index 0000000000..5f0d05254e
--- /dev/null
+++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc
@@ -0,0 +1,69 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy
+// of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+// ==============================================================================
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("KmeansPlusPlusInitialization")
+ .Input("points: float32")
+ .Input("num_to_sample: int64")
+ .Input("seed: int64")
+ .Input("num_retries_per_sample: int64")
+ .Output("samples: float32")
+ .Doc(R"(
+Selects num_to_sample rows of input using the KMeans++ criterion.
+
+Rows of points are assumed to be input points. One row is selected at random.
+Subsequent rows are sampled with probability proportional to the squared L2
+distance from the nearest row selected thus far till num_to_sample rows have
+been sampled.
+
+points: Matrix of shape (n, d). Rows are assumed to be input points.
+num_to_sample: Scalar. The number of rows to sample. This value must not be
+ larger than n.
+seed: Scalar. Seed for initializing the random number generator.
+num_retries_per_sample: Scalar. For each row that is sampled, this parameter
+ specifies the number of additional points to draw from the current
+ distribution before selecting the best. If a negative value is specified, a
+ heuristic is used to sample O(log(num_to_sample)) additional points.
+samples: Matrix of shape (num_to_sample, d). The sampled rows.
+)");
+
+REGISTER_OP("NearestNeighbors")
+ .Input("points: float32")
+ .Input("centers: float32")
+ .Input("k: int64")
+ .Output("nearest_center_indices: int64")
+ .Output("nearest_center_distances: float32")
+ .Doc(R"(
+Selects the k nearest centers for each point.
+
+Rows of points are assumed to be input points. Rows of centers are assumed to be
+the list of candidate centers. For each point, the k centers that have least L2
+distance to it are computed.
+
+points: Matrix of shape (n, d). Rows are assumed to be input points.
+centers: Matrix of shape (m, d). Rows are assumed to be centers.
+k: Scalar. Number of nearest centers to return for each point. If k is larger
+ than m, then only m centers are returned.
+nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the
+ indices of the centers closest to the corresponding point, ordered by
+ increasing distance.
+nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the
+ squared L2 distance to the corresponding center in nearest_center_indices.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc
new file mode 100644
index 0000000000..f72ea536fb
--- /dev/null
+++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc
@@ -0,0 +1,46 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy
+// of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+// ==============================================================================
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("WALSComputePartialLhsAndRhs")
+ .Input("factors: float32")
+ .Input("factor_weights: float32")
+ .Input("unobserved_weights: float32")
+ .Input("input_weights: float32")
+ .Input("input_indices: int64")
+ .Input("input_values: float32")
+ .Input("input_block_size: int64")
+ .Input("input_is_transpose: bool")
+ .Output("partial_lhs: float32")
+ .Output("partial_rhs: float32")
+ .Doc(R"(
+Computes the partial left-hand side and right-hand side of WALS update.
+
+factors: Matrix of size m * k.
+factor_weights: Vector of size m. Corresponds to column weights
+unobserved_weights: Scalar. Weight for unobserved input entries.
+input_weights: Vector of size n. Corresponds to row weights.
+input_indices: Indices for the input SparseTensor.
+input_values: Values for the input SparseTensor.
+input_block_size: Scalar. Number of rows spanned by input.
+input_is_transpose: If true, logically transposes the input for processing.
+partial_lhs: 3-D tensor with size input_block_size x k x k.
+partial_rhs: Matrix with size input_block_size x k.
+)");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/factorization/python/__init__.py b/tensorflow/contrib/factorization/python/__init__.py
new file mode 100644
index 0000000000..c44845bb08
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/__init__.py
@@ -0,0 +1,20 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The python module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
new file mode 100644
index 0000000000..8d1337c153
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
@@ -0,0 +1,157 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+
+"""Tests for clustering_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+# pylint: disable=unused-import
+import tensorflow as tf
+# pylint: enable=unused-import
+import tensorflow.contrib.factorization.python.ops.clustering_ops as clustering_ops
+
+
+class KmeansPlusPlusInitializationTest(tf.test.TestCase):
+
+ # All but one input point are close to (101, 1). With uniform random sampling,
+ # it is highly improbable for (-1, -1) to be selected.
+ def setUp(self):
+ self._points = np.array([
+ [100., 0.],
+ [101., 2.],
+ [102., 0.],
+ [100., 1.],
+ [100., 2.],
+ [101., 0.],
+ [101., 0.],
+ [101., 1.],
+ [102., 0.],
+ [-1., -1.]
+ ]).astype(np.float32)
+
+ def runTestWithSeed(self, seed):
+ with self.test_session():
+ sampled_points = clustering_ops.kmeans_plus_plus_initialization(
+ self._points, 3, seed, (seed % 5) - 1)
+ self.assertAllClose(sorted(sampled_points.eval().tolist()), [
+ [-1., -1.],
+ [101., 1.],
+ [101., 1.]
+ ], atol=1.0)
+
+ def testBasic(self):
+ for seed in range(100):
+ self.runTestWithSeed(seed)
+
+
+# A simple test that can be verified by hand.
+class NearestCentersTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._points = np.array([
+ [100., 0.],
+ [101., 2.],
+ [99., 2.],
+ [1., 1.]
+ ]).astype(np.float32)
+
+ self._centers = np.array([
+ [100., 0.],
+ [99., 1.],
+ [50., 50.],
+ [0., 0.],
+ [1., 1.]
+ ]).astype(np.float32)
+
+ def testNearest1(self):
+ with self.test_session():
+ [indices, distances] = clustering_ops.nearest_neighbors(self._points,
+ self._centers, 1)
+ self.assertAllClose(indices.eval(), [[0], [0], [1], [4]])
+ self.assertAllClose(distances.eval(), [[0.], [5.], [1.], [0.]])
+
+ def testNearest2(self):
+ with self.test_session():
+ [indices, distances] = clustering_ops.nearest_neighbors(self._points,
+ self._centers, 2)
+ self.assertAllClose(indices.eval(),
+ [[0, 1], [0, 1], [1, 0], [4, 3]])
+ self.assertAllClose(distances.eval(),
+ [[0., 2.], [5., 5.], [1., 5.], [0., 2.]])
+
+
+# A test with large inputs.
+class NearestCentersLargeTest(tf.test.TestCase):
+
+ def setUp(self):
+ num_points = 1000
+ num_centers = 2000
+ num_dim = 100
+ max_k = 5
+ # Construct a small number of random points and later tile them.
+ points_per_tile = 10
+ assert num_points % points_per_tile == 0
+ points = np.random.standard_normal([points_per_tile, num_dim]).astype(
+ np.float32)
+ # Construct random centers.
+ self._centers = np.random.standard_normal([num_centers, num_dim]).astype(
+ np.float32)
+ # Exhaustively compute expected nearest neighbors.
+ def squared_distance(x, y):
+ return np.linalg.norm(x - y, ord=2) ** 2
+ nearest_neighbors = [sorted([(squared_distance(point, self._centers[j]), j)
+ for j in range(num_centers)])[:max_k]
+ for point in points]
+ expected_nearest_neighbor_indices = np.array(
+ [[i for _, i in nn] for nn in nearest_neighbors])
+ expected_nearest_neighbor_squared_distances = np.array(
+ [[dist for dist, _ in nn] for nn in nearest_neighbors])
+ # Tile points and expected results to reach requested size (num_points)
+ (self._points,
+ self._expected_nearest_neighbor_indices,
+ self._expected_nearest_neighbor_squared_distances) = (
+ np.tile(x, (num_points / points_per_tile, 1))
+ for x in (points,
+ expected_nearest_neighbor_indices,
+ expected_nearest_neighbor_squared_distances))
+
+ def testNearest1(self):
+ with self.test_session():
+ [indices, distances] = clustering_ops.nearest_neighbors(
+ self._points, self._centers, 1)
+ self.assertAllClose(indices.eval(),
+ self._expected_nearest_neighbor_indices[:, [0]])
+ self.assertAllClose(
+ distances.eval(),
+ self._expected_nearest_neighbor_squared_distances[:, [0]])
+
+ def testNearest5(self):
+ with self.test_session():
+ [indices, distances] = clustering_ops.nearest_neighbors(
+ self._points, self._centers, 5)
+ self.assertAllClose(indices.eval(),
+ self._expected_nearest_neighbor_indices[:, 0:5])
+ self.assertAllClose(
+ distances.eval(),
+ self._expected_nearest_neighbor_squared_distances[:, 0:5])
+
+
+if __name__ == '__main__':
+ np.random.seed(0)
+ tf.test.main()
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
new file mode 100644
index 0000000000..1ce1a7a2d5
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -0,0 +1,85 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+
+"""Tests for wals_solver_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow as tf
+from tensorflow.contrib.factorization.python.ops.factorization_ops import wals_compute_partial_lhs_and_rhs
+
+
+def SparseBlock3x3():
+ ind = np.array([[0, 0], [0, 2], [1, 1], [2, 0], [2, 1], [3, 2]]).astype(
+ np.int64)
+ val = np.array([0.1, 0.2, 1.1, 2.0, 2.1, 3.2]).astype(np.float32)
+ shape = np.array([4, 3]).astype(np.int64)
+ return tf.SparseTensor(ind, val, shape)
+
+
+class WalsSolverOpsTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._column_factors = np.array([
+ [0.1, 0.2, 0.3],
+ [0.4, 0.5, 0.6],
+ [0.7, 0.8, 0.9],
+ ]).astype(np.float32)
+ self._row_factors = np.array([
+ [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.1, 1.2, 1.3]
+ ]).astype(np.float32)
+ self._column_weights = np.array([0.1, 0.2, 0.3]).astype(np.float32)
+ self._row_weights = np.array([0.1, 0.2, 0.3, 0.4]).astype(np.float32)
+ self._unobserved_weights = 0.1
+
+ def testWalsSolverLhs(self):
+ sparse_block = SparseBlock3x3()
+ with self.test_session():
+ [lhs_tensor, rhs_matrix] = wals_compute_partial_lhs_and_rhs(
+ self._column_factors, self._column_weights, self._unobserved_weights,
+ self._row_weights, sparse_block.indices, sparse_block.values,
+ sparse_block.shape[0], False)
+ self.assertAllClose(lhs_tensor.eval(), [
+ [
+ [0.014800, 0.017000, 0.019200],
+ [0.017000, 0.019600, 0.022200],
+ [0.019200, 0.022200, 0.025200],
+ ], [
+ [0.0064000, 0.0080000, 0.0096000],
+ [0.0080000, 0.0100000, 0.0120000],
+ [0.0096000, 0.0120000, 0.0144000],
+ ], [
+ [0.0099000, 0.0126000, 0.0153000],
+ [0.0126000, 0.0162000, 0.0198000],
+ [0.0153000, 0.0198000, 0.0243000],
+ ], [
+ [0.058800, 0.067200, 0.075600],
+ [0.067200, 0.076800, 0.086400],
+ [0.075600, 0.086400, 0.097200],
+ ]
+ ])
+ self.assertAllClose(
+ rhs_matrix.eval(),
+ [[0.019300, 0.023000, 0.026700], [0.061600, 0.077000, 0.092400],
+ [0.160400, 0.220000, 0.279600], [0.492800, 0.563200, 0.633600]])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
new file mode 100644
index 0000000000..5f1ac69017
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py
@@ -0,0 +1,409 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Clustering Operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.factorization.python.ops import gen_clustering_ops
+# go/tf-wildcard-import
+# pylint: disable=wildcard-import
+from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import *
+# pylint: enable=wildcard-import
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.load_library import load_op_library
+from tensorflow.python.ops.embedding_ops import embedding_lookup
+from tensorflow.python.platform import resource_loader
+
+_clustering_ops = load_op_library(resource_loader.get_path_to_datafile(
+ '_clustering_ops.so'))
+assert _clustering_ops, 'Could not load _clustering_ops.so'
+
+# Euclidean distance between vectors U and V is defined as ||U - V||_F which is
+# the square root of the sum of the absolute squares of the elements difference.
+SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean'
+# Cosine distance between vectors U and V is defined as
+# 1 - (U \dot V) / (||U||_F ||V||_F)
+COSINE_DISTANCE = 'cosine'
+
+RANDOM_INIT = 'random'
+KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
+
+
+class KMeans(object):
+ """Creates the graph for k-means clustering."""
+
+ def __init__(self,
+ inputs,
+ num_clusters,
+ initial_clusters=RANDOM_INIT,
+ distance_metric=SQUARED_EUCLIDEAN_DISTANCE,
+ use_mini_batch=False,
+ random_seed=0,
+ kmeans_plus_plus_num_retries=2):
+ """Creates an object for generating KMeans clustering graph.
+
+ Args:
+ inputs: An input tensor or list of input tensors
+ num_clusters: number of clusters.
+ initial_clusters: Specifies the clusters used during initialization. Can
+ be a tensor or numpy array, or a function that generates the clusters.
+ Can also be "random" to specify that clusters should be chosen randomly
+ from input data.
+ distance_metric: distance metric used for clustering.
+ use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
+ full batch.
+ random_seed: Seed for PRNG used to initialize seeds.
+ kmeans_plus_plus_num_retries: For each point that is sampled during
+ kmeans++ initialization, this parameter specifies the number of
+ additional points to draw from the current distribution before selecting
+ the best. If a negative value is specified, a heuristic is used to
+ sample O(log(num_to_sample)) additional points.
+ """
+ self._inputs = inputs if isinstance(inputs, list) else [inputs]
+ assert num_clusters > 0, num_clusters
+ self._num_clusters = num_clusters
+ self._initial_clusters = initial_clusters
+ assert distance_metric in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE]
+ self._distance_metric = distance_metric
+ self._use_mini_batch = use_mini_batch
+ self._random_seed = random_seed
+ self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
+
+ @classmethod
+ def _distance_graph(cls, inputs, clusters, distance_metric):
+ """Computes distance between each input and each cluster center.
+
+ Args:
+ inputs: list of input Tensors.
+ clusters: cluster Tensor.
+ distance_metric: distance metric used for clustering
+
+ Returns:
+ list of Tensors, where each element corresponds to each element in inputs.
+ The value is the distance of each row to all the cluster centers.
+ Currently only Euclidean distance and cosine distance are supported.
+ """
+ assert isinstance(inputs, list)
+ if distance_metric == SQUARED_EUCLIDEAN_DISTANCE:
+ return cls._compute_euclidean_distance(inputs, clusters)
+ elif distance_metric == COSINE_DISTANCE:
+ return cls._compute_cosine_distance(inputs, clusters,
+ inputs_normalized=True)
+ else:
+ assert False, ('Unsupported distance metric passed to Kmeans %s'
+ % str(distance_metric))
+
+ @classmethod
+ def _compute_euclidean_distance(cls, inputs, clusters):
+ """Computes Euclidean distance between each input and each cluster center.
+
+ Args:
+ inputs: list of input Tensors.
+ clusters: cluster Tensor.
+
+ Returns:
+ list of Tensors, where each element corresponds to each element in inputs.
+ The value is the distance of each row to all the cluster centers.
+ """
+ output = []
+ for inp in inputs:
+ with ops.colocate_with(inp):
+ # Computes Euclidean distance. Note the first and third terms are
+ # broadcast additions.
+ squared_distance = (tf.reduce_sum(tf.square(inp), 1, keep_dims=True) -
+ 2 * tf.matmul(inp, clusters, transpose_b=True) +
+ tf.transpose(tf.reduce_sum(tf.square(clusters),
+ 1,
+ keep_dims=True)))
+ output.append(squared_distance)
+
+ return output
+
+ @classmethod
+ def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True):
+ """Computes cosine distance between each input and each cluster center.
+
+ Args:
+ inputs: list of input Tensor.
+ clusters: cluster Tensor
+ inputs_normalized: if True, it assumes that inp and clusters are
+ normalized and computes the dot product which is equivalent to the cosine
+ distance. Else it L2 normalizes the inputs first.
+
+ Returns:
+ list of Tensors, where each element corresponds to each element in inp.
+ The value is the distance of each row to all the cluster centers.
+ """
+ output = []
+ if not inputs_normalized:
+ with ops.colocate_with(clusters):
+ clusters = tf.nn.l2_normalize(clusters, dim=1)
+ for inp in inputs:
+ with ops.colocate_with(inp):
+ if not inputs_normalized:
+ inp = tf.nn.l2_normalize(inp, dim=1)
+ output.append(1 - tf.matmul(inp, clusters, transpose_b=True))
+ return output
+
+ def _infer_graph(self, inputs, clusters):
+ """Maps input to closest cluster and the score.
+
+ Args:
+ inputs: list of input Tensors.
+ clusters: Tensor of cluster centers.
+
+ Returns:
+ List of tuple, where each value in tuple corresponds to a value in inp.
+ The tuple has following three elements:
+ all_scores: distance of each input to each cluster center.
+ score: distance of each input to closest cluster center.
+ cluster_idx: index of cluster center closest to the corresponding input.
+ """
+ assert isinstance(inputs, list)
+ # Pairwise distances are used only by transform(). In all other cases, this
+ # sub-graph is not evaluated.
+ scores = self._distance_graph(inputs, clusters, self._distance_metric)
+ output = []
+ if (self._distance_metric == COSINE_DISTANCE and
+ not self._clusters_l2_normalized()):
+ # The cosine distance between normalized vectors x and y is the same as
+ # 2 * squared_euclidian_distance. We are using this fact and reusing the
+ # nearest_neighbors op.
+ # TODO(ands): Support COSINE distance in nearest_neighbors and remove
+ # this.
+ with ops.colocate_with(clusters):
+ clusters = tf.nn.l2_normalize(clusters, dim=1)
+ for inp, score in zip(inputs, scores):
+ with ops.colocate_with(inp):
+ (indices,
+ distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1)
+ if self._distance_metric == COSINE_DISTANCE:
+ distances *= 0.5
+ output.append((score, tf.squeeze(distances), tf.squeeze(indices)))
+ return zip(*output)
+
+ def _init_clusters_random(self):
+ """Does random initialization of clusters.
+
+ Returns:
+ Tensor of randomly initialized clusters.
+ """
+ num_data = tf.add_n([tf.shape(inp)[0] for inp in self._inputs])
+ # Note that for mini-batch k-means, we should ensure that the batch size of
+ # data used during initialization is sufficiently large to avoid duplicated
+ # clusters.
+ with tf.control_dependencies(
+ [tf.assert_less_equal(self._num_clusters, num_data)]):
+ indices = tf.random_uniform(tf.reshape(self._num_clusters, [-1]),
+ minval=0,
+ maxval=tf.cast(num_data, tf.int64),
+ seed=self._random_seed,
+ dtype=tf.int64)
+ clusters_init = embedding_lookup(self._inputs, indices,
+ partition_strategy='div')
+ return clusters_init
+
+ def _clusters_l2_normalized(self):
+ """Returns True if clusters centers are kept normalized."""
+ return self._distance_metric == COSINE_DISTANCE and not self._use_mini_batch
+
+ def _init_clusters(self):
+ """Initialization of clusters.
+
+ Returns:
+ Tuple with following elements:
+ cluster_centers: a Tensor for storing cluster centers
+ cluster_counts: a Tensor for storing counts of points assigned to this
+ cluster. This is used by mini-batch training.
+ """
+ init = self._initial_clusters
+ if init == RANDOM_INIT:
+ clusters_init = self._init_clusters_random()
+ elif init == KMEANS_PLUS_PLUS_INIT:
+ # Points from only the first shard are used for initializing centers.
+ # TODO(ands): Use all points.
+ clusters_init = gen_clustering_ops.kmeans_plus_plus_initialization(
+ self._inputs[0], self._num_clusters, self._random_seed,
+ self._kmeans_plus_plus_num_retries)
+ elif callable(init):
+ clusters_init = init(self._inputs, self._num_clusters)
+ elif not isinstance(init, str):
+ clusters_init = init
+ else:
+ assert False, 'Unsupported init passed to Kmeans %s' % str(init)
+ if self._distance_metric == COSINE_DISTANCE and clusters_init is not None:
+ clusters_init = tf.nn.l2_normalize(clusters_init, dim=1)
+ clusters_init = clusters_init if clusters_init is not None else []
+ cluster_centers = tf.Variable(clusters_init,
+ name='clusters',
+ validate_shape=False)
+ cluster_counts = (tf.Variable(tf.zeros([self._num_clusters],
+ dtype=tf.int64))
+ if self._use_mini_batch else None)
+ return cluster_centers, cluster_counts
+
+ @classmethod
+ def _l2_normalize_data(cls, inputs):
+ """Normalized the input data."""
+ output = []
+ for inp in inputs:
+ with ops.colocate_with(inp):
+ output.append(tf.nn.l2_normalize(inp, dim=1))
+ return output
+
+ def training_graph(self):
+ """Generate a training graph for kmeans algorithm.
+
+ Returns:
+ A tuple consisting of:
+ all_scores: A matrix (or list of matrices) of dimensions (num_input,
+ num_clusters) where the value is the distance of an input vector and a
+ cluster center.
+ cluster_idx: A vector (or list of vectors). Each element in the vector
+ corresponds to an input row in 'inp' and specifies the cluster id
+ corresponding to the input.
+ scores: Similar to cluster_idx but specifies the distance to the
+ assigned cluster instead.
+ training_op: an op that runs an iteration of training.
+ """
+ # Implementation of kmeans.
+ inputs = self._inputs
+ cluster_centers_var, total_counts = self._init_clusters()
+ cluster_centers = cluster_centers_var
+
+ if self._distance_metric == COSINE_DISTANCE:
+ inputs = self._l2_normalize_data(inputs)
+ if not self._clusters_l2_normalized():
+ cluster_centers = tf.nn.l2_normalize(cluster_centers, dim=1)
+
+ all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers)
+ if self._use_mini_batch:
+ training_op = self._mini_batch_training_op(
+ inputs, cluster_idx, cluster_centers, cluster_centers_var,
+ total_counts)
+ else:
+ assert cluster_centers == cluster_centers_var
+ training_op = self._full_batch_training_op(inputs, cluster_idx,
+ cluster_centers_var)
+ return all_scores, cluster_idx, scores, training_op
+
+ def _mini_batch_training_op(self, inputs, cluster_idx_list,
+ cluster_centers, cluster_centers_var,
+ total_counts):
+ """Creates an op for training for mini batch case.
+
+ Args:
+ inputs: list of input Tensors.
+ cluster_idx_list: A vector (or list of vectors). Each element in the
+ vector corresponds to an input row in 'inp' and specifies the cluster id
+ corresponding to the input.
+ cluster_centers: Tensor of cluster centers, possibly normalized.
+ cluster_centers_var: Tensor Ref of cluster centers.
+ total_counts: Tensor Ref of cluster counts.
+
+ Returns:
+ An op for doing an update of mini-batch k-means.
+ """
+ update_ops = []
+ for inp, cluster_idx in zip(inputs, cluster_idx_list):
+ with ops.colocate_with(inp):
+ assert total_counts is not None
+ cluster_idx = tf.reshape(cluster_idx, [-1])
+ # Dedupe the unique ids of cluster_centers being updated so that updates
+ # can be locally aggregated.
+ unique_ids, unique_idx = tf.unique(cluster_idx)
+ num_unique_cluster_idx = tf.size(unique_ids)
+ # Fetch the old values of counts and cluster_centers.
+ with ops.colocate_with(total_counts):
+ old_counts = tf.gather(total_counts, unique_ids)
+ with ops.colocate_with(cluster_centers):
+ old_cluster_centers = tf.gather(cluster_centers, unique_ids)
+ # Locally aggregate the increment to counts.
+ count_updates = tf.unsorted_segment_sum(
+ tf.ones_like(unique_idx, dtype=total_counts.dtype),
+ unique_idx,
+ num_unique_cluster_idx)
+ # Locally compute the sum of inputs mapped to each id.
+ # For a cluster with old cluster value x, old count n, and with data
+ # d_1,...d_k newly assigned to it, we recompute the new value as
+ # x += (sum_i(d_i) - k * x) / (n + k).
+ # Compute sum_i(d_i), see comment above.
+ cluster_center_updates = tf.unsorted_segment_sum(
+ inp,
+ unique_idx,
+ num_unique_cluster_idx)
+ # Shape to enable broadcasting count_updates and learning_rate to inp.
+ # It extends the shape with 1's to match the rank of inp.
+ broadcast_shape = tf.concat(
+ 0,
+ [tf.reshape(num_unique_cluster_idx, [1]),
+ tf.ones(tf.reshape(tf.rank(inp) - 1, [1]), dtype=tf.int32)])
+ # Subtract k * x, see comment above.
+ cluster_center_updates -= tf.cast(
+ tf.reshape(count_updates, broadcast_shape),
+ inp.dtype) * old_cluster_centers
+ learning_rate = tf.inv(tf.cast(old_counts + count_updates, inp.dtype))
+ learning_rate = tf.reshape(learning_rate, broadcast_shape)
+ # scale by 1 / (n + k), see comment above.
+ cluster_center_updates *= learning_rate
+ # Apply the updates.
+ update_counts = tf.scatter_add(
+ total_counts,
+ unique_ids,
+ count_updates)
+ update_cluster_centers = tf.scatter_add(
+ cluster_centers_var,
+ unique_ids,
+ cluster_center_updates)
+ update_ops.extend([update_counts, update_cluster_centers])
+ return tf.group(*update_ops)
+
+ def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers):
+ """Creates an op for training for full batch case.
+
+ Args:
+ inputs: list of input Tensors.
+ cluster_idx_list: A vector (or list of vectors). Each element in the
+ vector corresponds to an input row in 'inp' and specifies the cluster id
+ corresponding to the input.
+ cluster_centers: Tensor Ref of cluster centers.
+
+ Returns:
+ An op for doing an update of mini-batch k-means.
+ """
+ cluster_sums = []
+ cluster_counts = []
+ epsilon = tf.constant(1e-6, dtype=inputs[0].dtype)
+ for inp, cluster_idx in zip(inputs, cluster_idx_list):
+ with ops.colocate_with(inp):
+ cluster_sums.append(tf.unsorted_segment_sum(inp,
+ cluster_idx,
+ self._num_clusters))
+ cluster_counts.append(tf.unsorted_segment_sum(
+ tf.reshape(tf.ones(tf.reshape(tf.shape(inp)[0], [-1])), [-1, 1]),
+ cluster_idx,
+ self._num_clusters))
+ with ops.colocate_with(cluster_centers):
+ new_clusters_centers = tf.add_n(cluster_sums) / (
+ tf.cast(tf.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
+ if self._clusters_l2_normalized():
+ new_clusters_centers = tf.nn.l2_normalize(new_clusters_centers, dim=1)
+ return tf.assign(cluster_centers, new_clusters_centers)
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
new file mode 100644
index 0000000000..f68869f528
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -0,0 +1,467 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Ops for matrix factorization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+# pylint: disable=wildcard-import,undefined-variable
+from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import *
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.load_library import load_op_library
+from tensorflow.python.platform import resource_loader
+
+_factorization_ops = load_op_library(resource_loader.get_path_to_datafile(
+ "_factorization_ops.so"))
+assert _factorization_ops, "Could not load _factorization_ops.so"
+
+
+class WALSModel(object):
+ r"""A model for Weighted Alternating Least Squares matrix factorization.
+
+ It minimizes the following loss function over U, V:
+ \\( ||W \odot (A - U V^T) ||_F^2 + \lambda (||U||_F^2 + ||V||_F^2) )\\
+ where,
+ A: input matrix,
+ W: weight matrix,
+ U, V: row_factors and column_factors matrices,
+ \\(\lambda)\\: regularization.
+ Also we assume that W is of the following special form:
+ \\( W_{ij} = W_0 + R_i * C_j )\\ if \\(A_{ij} \ne 0)\\,
+ \\(W_{ij} = W_0)\\ otherwise.
+ where,
+ \\(W_0)\\: unobserved_weight,
+ \\(R_i)\\: row_weights,
+ \\(C_j)\\: col_weights.
+
+ Note that the current implementation assumes that row_factors and col_factors
+ can individually fit into the memory of each worker.
+ """
+
+ def __init__(self,
+ input_rows,
+ input_cols,
+ n_components,
+ unobserved_weight=0.1,
+ regularization=None,
+ row_init="random",
+ col_init="random",
+ num_row_shards=1,
+ num_col_shards=1,
+ row_weights=None,
+ col_weights=None):
+ """Creates model for WALS matrix factorization.
+
+ Args:
+ input_rows: total number of rows for input matrix.
+ input_cols: total number of cols for input matrix.
+ n_components: number of dimensions to use for the factors.
+ unobserved_weight: weight given to unobserved entries of matrix.
+ regularization: weight of L2 regularization term. If None, no
+ regularization is done.
+ row_init: initializer for row factor. Can be a tensor or numpy constant.
+ If set to "random", the value is initialized randomly.
+ col_init: initializer for column factor. See row_init for details.
+ num_row_shards: number of shards to use for row factors.
+ num_col_shards: number of shards to use for column factors.
+ row_weights: If not None, along with col_weights, used to compute the
+ weight of an observed entry. w_ij = unobserved_weight + row_weights[i] *
+ col_weights[j]. If None, then w_ij = unobserved_weight, which simplifies
+ to ALS.
+ col_weights: See row_weights
+ """
+ self._input_rows = input_rows
+ self._input_cols = input_cols
+ self._num_row_shards = num_row_shards
+ self._num_col_shards = num_col_shards
+ self._n_components = n_components
+ self._unobserved_weight = unobserved_weight
+ self._regularization = (tf.diag(tf.constant(regularization,
+ shape=[self._n_components],
+ dtype=tf.float32))
+ if regularization is not None else None)
+ assert (row_weights is None) == (col_weights is None)
+ self._row_weights = WALSModel._create_weights(row_weights,
+ self._input_rows,
+ self._num_row_shards,
+ "row_weights")
+ self._col_weights = WALSModel._create_weights(col_weights,
+ self._input_cols,
+ self._num_col_shards,
+ "col_weights")
+ self._row_factors = self._create_factors(self._input_rows,
+ self._n_components,
+ self._num_row_shards,
+ row_init,
+ "row_factors")
+ self._col_factors = self._create_factors(self._input_cols,
+ self._n_components,
+ self._num_col_shards,
+ col_init,
+ "col_factors")
+ self._create_transient_vars()
+
+ @property
+ def row_factors(self):
+ """Returns a list of tensors corresponding to row factor shards."""
+ return self._row_factors
+
+ @property
+ def col_factors(self):
+ """Returns a list of tensors corresponding to column factor shards."""
+ return self._col_factors
+
+ @property
+ def initialize_op(self):
+ """Returns an op for initializing tensorflow variables."""
+ all_vars = self._row_factors + self._col_factors
+ if self._row_weights is not None:
+ assert self._col_weights is not None
+ all_vars.extend(self._row_weights + self._col_weights)
+ return tf.initialize_variables(all_vars)
+
+ @classmethod
+ def _shard_sizes(cls, dims, num_shards):
+ """Helper function to split dims values into num_shards."""
+ shard_size, residual = divmod(dims, num_shards)
+ return [shard_size + 1] * residual + [shard_size] * (num_shards - residual)
+
+ @classmethod
+ def _create_factors(cls, rows, cols, num_shards, init, name):
+ """Helper function to create row and column factors."""
+ if callable(init):
+ init = init()
+ if isinstance(init, list):
+ assert len(init) == num_shards
+ elif isinstance(init, str) and init == "random":
+ pass
+ elif num_shards == 1:
+ init = [init]
+ sharded_matrix = []
+ sizes = cls._shard_sizes(rows, num_shards)
+ assert len(sizes) == num_shards
+
+ def make_initializer(i, size):
+ def initializer():
+ if init == "random":
+ return tf.random_normal([size, cols])
+ else:
+ return init[i]
+ return initializer
+
+ for i, size in enumerate(sizes):
+ var_name = "%s_shard_%d" % (name, i)
+ var_init = make_initializer(i, size)
+ sharded_matrix.append(tf.Variable(
+ var_init,
+ dtype=tf.float32,
+ name=var_name))
+
+ return sharded_matrix
+
+ @staticmethod
+ def _create_weights(wt_init, num_wts, num_shards, name):
+ """Helper functions to create sharded weight vector.
+
+ Args:
+ wt_init: init value for the weight. If None, weights are not created.
+ num_wts: total size of all the weight shards
+ num_shards: number of shards for the weights
+ name: name for the new Variables.
+
+ Returns:
+ A list of weight shard Tensors.
+ """
+ if wt_init is None:
+ return None
+ if num_shards == 1 and len(wt_init) == num_wts:
+ wt_init = [wt_init]
+ assert len(wt_init) == num_shards
+ return [tf.Variable(wt_init[i],
+ dtype=tf.float32,
+ name="%s_shard_%d" % (name, i))
+ for i in xrange(num_shards)]
+
+ @staticmethod
+ def _transient_var(name):
+ """Helper function to create a Variable."""
+ return tf.Variable(1.0,
+ trainable=False,
+ collections=[tf.GraphKeys.LOCAL_VARIABLES],
+ validate_shape=False,
+ name=name)
+
+ def _cached_copy(self, var, name):
+ """Helper function to create a worker cached copy of a Variable.
+
+ Args:
+ var: Variable or list of Variable to cache. If a list, the items are
+ concatenated along dimension 0 to get the cached entry.
+ name: name of cached variable.
+
+ Returns:
+ Tuple consisting of following three entries:
+ cache: the new transient Variable.
+ cache_init: op to initialize the Variable
+ cache_reset: op to reset the Variable to some default value
+ """
+ if var is None:
+ return None, None, None
+ else:
+ cache = WALSModel._transient_var(name)
+ with ops.colocate_with(cache):
+ if isinstance(var, list):
+ assert var
+ if len(var) == 1:
+ var = var[0]
+ else:
+ var = tf.concat(0, var)
+
+ cache_init = tf.assign(cache, var, validate_shape=False)
+ cache_reset = tf.assign(cache, 1.0, validate_shape=False)
+ return cache, cache_init, cache_reset
+
+ def _create_transient_vars(self):
+ """Creates local cache of row and column factors and weights.
+
+ Note that currently the caching strategy is as follows:
+ When initiating a row update, column factors are cached while row factors
+ cache is reset. Similarly when initiating a column update, row factors are
+ cached while cached column factors are flushed.
+ Column and row weights are always cached. If memory becomes a bottleneck,
+ they could be similarly flushed.
+ """
+ (self._row_factors_cache,
+ row_factors_cache_init,
+ row_factors_cache_reset) = self._cached_copy(self._row_factors,
+ "row_factors_cache")
+ (self._col_factors_cache,
+ col_factors_cache_init,
+ col_factors_cache_reset) = self._cached_copy(self._col_factors,
+ "col_factors_cache")
+ (self._row_wt_cache,
+ row_wt_cache_init,
+ _) = self._cached_copy(self._row_weights, "row_wt_cache")
+ (self._col_wt_cache,
+ col_wt_cache_init,
+ _) = self._cached_copy(self._col_weights, "col_wt_cache")
+
+ if self._row_wt_cache is not None:
+ assert self._col_wt_cache is not None
+ self._worker_init = tf.group(row_wt_cache_init,
+ col_wt_cache_init,
+ name="worker_init")
+ else:
+ self._worker_init = tf.no_op(name="worker_init")
+
+ self._row_updates_init = tf.group(col_factors_cache_init,
+ row_factors_cache_reset)
+ self._col_updates_init = tf.group(row_factors_cache_init,
+ col_factors_cache_reset)
+
+ @property
+ def worker_init(self):
+ """Op to initialize worker state once before starting any updates."""
+ return self._worker_init
+
+ @property
+ def initialize_row_update_op(self):
+ """Op to initialize worker state before starting row updates."""
+ return self._row_updates_init
+
+ @property
+ def initialize_col_update_op(self):
+ """Op to initialize worker state before starting column updates."""
+ return self._col_updates_init
+
+ @staticmethod
+ def _get_sharding_func(size, num_shards):
+ """Create sharding function for scatter update."""
+ def func(ids):
+ if num_shards == 1:
+ return None, ids
+ else:
+ ids_per_shard = size // num_shards
+ extras = size % num_shards
+ assignments = tf.maximum(ids // (ids_per_shard + 1),
+ (ids - extras) // ids_per_shard)
+ new_ids = tf.select(assignments < extras,
+ ids % (ids_per_shard + 1),
+ (ids - extras) % ids_per_shard)
+ return assignments, new_ids
+ return func
+
+ @classmethod
+ def scatter_update(cls, factor, indices, values, sharding_func):
+ """Helper function for doing sharded scatter update."""
+ assert isinstance(factor, list)
+ if len(factor) == 1:
+ with ops.colocate_with(factor[0]):
+ # TODO(agarwal): assign instead of scatter update for full batch update.
+ return tf.scatter_update(factor[0], indices, values).op
+ else:
+ num_shards = len(factor)
+ assignments, new_ids = sharding_func(indices)
+ assert assignments is not None
+ assignments = tf.cast(assignments, tf.int32)
+ sharded_ids = tf.dynamic_partition(new_ids, assignments, num_shards)
+ sharded_values = tf.dynamic_partition(values, assignments, num_shards)
+ updates = []
+ for i in xrange(num_shards):
+ updates.append(tf.scatter_update(factor[i],
+ sharded_ids[i],
+ sharded_values[i]))
+ return tf.group(*updates)
+
+ def update_row_factors(self, sp_input=None, transpose_input=False):
+ """Updates the row factors.
+
+ Args:
+ sp_input: A SparseTensor representing a subset of rows of the full input
+ in any order. Please note that this SparseTensor must retain the
+ indexing as the original input.
+ transpose_input: If true, logically transposes the input.
+
+ Returns:
+ A tuple consisting of the following two elements:
+ new_values: New values for the row factors.
+ update_op: An op that assigns the newly computed values to the row
+ factors.
+ """
+ return self._process_input_helper(True, sp_input=sp_input,
+ transpose_input=transpose_input)
+
+ def update_col_factors(self, sp_input=None, transpose_input=False):
+ """Updates the column factors.
+
+ Args:
+ sp_input: A SparseTensor representing a subset of columns of the full
+ input. Please refer to comments for update_row_factors for
+ restrictions.
+ transpose_input: If true, logically transposes the input.
+
+ Returns:
+ A tuple consisting of the following two elements:
+ new_values: New values for the column factors.
+ update_op: An op that assigns the newly computed values to the column
+ factors.
+ """
+ return self._process_input_helper(False, sp_input=sp_input,
+ transpose_input=transpose_input)
+
+ def _process_input_helper(self, update_row_factors,
+ sp_input=None, transpose_input=False):
+ """Creates the graph for processing a sparse slice of input.
+
+ Args:
+ update_row_factors: if True, update the row_factors, else update the
+ column factors.
+ sp_input: Please refer to comments for update_row_factors and
+ update_col_factors.
+ transpose_input: If true, logically transpose the input.
+
+ Returns:
+ A tuple consisting of the following two elements:
+ new_values: New values for the row/column factors.
+ update_op: An op that assigns the newly computed values to the row/column
+ factors.
+ """
+ assert isinstance(sp_input, ops.SparseTensor)
+
+ if update_row_factors:
+ left = self._row_factors
+ right = self._col_factors_cache
+ row_weights = self._row_wt_cache
+ col_weights = self._col_wt_cache
+ sharding_func = WALSModel._get_sharding_func(self._input_rows,
+ self._num_row_shards)
+ right_length = self._input_cols
+ else:
+ left = self._col_factors
+ right = self._row_factors_cache
+ row_weights = self._col_wt_cache
+ col_weights = self._row_wt_cache
+ sharding_func = WALSModel._get_sharding_func(self._input_cols,
+ self._num_col_shards)
+ right_length = self._input_rows
+ transpose_input = not transpose_input
+
+ # Note that the row indices of sp_input are based on the original full input
+ # Here we reindex the rows and give them contiguous ids starting at 0.
+ # We use tf.unique to achieve this reindexing. Note that this is done so
+ # that the downstream kernel can assume that the input is "dense" along the
+ # row dimension.
+ row_ids, col_ids = tf.split(1, 2, sp_input.indices)
+
+ if transpose_input:
+ update_indices, all_ids = tf.unique(col_ids[:, 0])
+ col_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1)
+ else:
+ update_indices, all_ids = tf.unique(row_ids[:, 0])
+ row_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1)
+
+ num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64)
+ row_shape = tf.constant([right_length], tf.int64)
+ col_shape = [num_rows]
+
+ new_sp_indices = tf.concat(1, [row_ids, col_ids])
+ new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input
+ else tf.concat(0, [col_shape, row_shape]))
+ new_sp_input = tf.SparseTensor(indices=new_sp_indices,
+ values=sp_input.values, shape=new_sp_shape)
+
+ # Compute lhs and rhs of the normal equations
+ total_lhs = (self._unobserved_weight *
+ tf.matmul(right, right, transpose_a=True))
+ if self._regularization is not None:
+ total_lhs += self._regularization
+ if self._row_weights is None:
+ # Special case of ALS. Use a much simpler update rule.
+ total_rhs = (self._unobserved_weight *
+ tf.sparse_tensor_dense_matmul(new_sp_input, right,
+ adjoint_a=transpose_input))
+ # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
+ # transposing explicitly.
+ # TODO(rmlarsen): multi-thread tf.matrix_solve.
+ new_left_values = tf.transpose(tf.matrix_solve(total_lhs,
+ tf.transpose(total_rhs)))
+ else:
+ row_weights_slice = tf.gather(row_weights, update_indices)
+ partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
+ right,
+ col_weights,
+ self._unobserved_weight,
+ row_weights_slice,
+ new_sp_input.indices,
+ new_sp_input.values,
+ num_rows,
+ transpose_input,
+ name="wals_compute_partial_lhs_rhs")
+ total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs
+ total_rhs = tf.expand_dims(total_rhs, -1)
+ new_left_values = tf.squeeze(tf.batch_matrix_solve(total_lhs, total_rhs),
+ [2])
+
+ return (new_left_values,
+ self.scatter_update(left,
+ update_indices,
+ new_left_values,
+ sharding_func))
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
new file mode 100644
index 0000000000..0ef49b7891
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -0,0 +1,456 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for factorization_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf
+
+from tensorflow.contrib.factorization.python.ops import factorization_ops
+
+INPUT_MATRIX = np.array(
+ [[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0],
+ [0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6],
+ [2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0],
+ [3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0],
+ [0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32)
+
+
+def np_matrix_to_tf_sparse(np_matrix, row_slices=None,
+ col_slices=None, transpose=False,
+ shuffle=False):
+ """Simple util to slice non-zero np matrix elements as tf.SparseTensor."""
+ indices = np.nonzero(np_matrix)
+
+ # Only allow slices of whole rows or whole columns.
+ assert not (row_slices is not None and col_slices is not None)
+
+ if row_slices is not None:
+ selected_ind = np.concatenate(
+ [np.where(indices[0] == r)[0] for r in row_slices], 0)
+ indices = (indices[0][selected_ind], indices[1][selected_ind])
+
+ if col_slices is not None:
+ selected_ind = np.concatenate(
+ [np.where(indices[1] == c)[0] for c in col_slices], 0)
+ indices = (indices[0][selected_ind], indices[1][selected_ind])
+
+ if shuffle:
+ shuffled_ind = [x for x in range(len(indices[0]))]
+ random.shuffle(shuffled_ind)
+ indices = (indices[0][shuffled_ind], indices[1][shuffled_ind])
+
+ ind = (np.concatenate(
+ (np.expand_dims(indices[1], 1),
+ np.expand_dims(indices[0], 1)), 1).astype(np.int64) if transpose else
+ np.concatenate((np.expand_dims(indices[0], 1),
+ np.expand_dims(indices[1], 1)), 1).astype(np.int64))
+ val = np_matrix[indices].astype(np.float32)
+ shape = (np.array(
+ [max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64) if transpose
+ else np.array(
+ [max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64))
+ return tf.SparseTensor(ind, val, shape)
+
+
+def sparse_input():
+ return np_matrix_to_tf_sparse(INPUT_MATRIX)
+
+
+class WalsModelTest(tf.test.TestCase):
+
+ def setUp(self):
+ self.col_init = [
+ # shard 0
+ [[-0.36444709, -0.39077035, -0.32528427],
+ [1.19056475, 0.07231052, 2.11834812],
+ [0.93468881, -0.71099287, 1.91826844]],
+ # shard 1
+ [[1.18160152, 1.52490723, -0.50015002],
+ [1.82574749, -0.57515913, -1.32810032]],
+ # shard 2
+ [[-0.15515432, -0.84675711, 0.13097958],
+ [-0.9246484, 0.69117504, 1.2036494]]
+ ]
+
+ self.row_wts = [[0.1, 0.2, 0.3], [0.4, 0.5]]
+ self.col_wts = [[0.1, 0.2, 0.3],
+ [0.4, 0.5],
+ [0.6, 0.7]]
+ self._wals_inputs = sparse_input()
+
+ # Values of factor shards after running one iteration of row and column
+ # updates.
+ self._row_factors_0 = [[0.097689, -0.219293, -0.020780],
+ [0.50842, 0.64626, 0.22364],
+ [0.401159, -0.046558, -0.192854]]
+ self._row_factors_1 = [[1.20597, -0.48025, 0.35582],
+ [1.5564, 1.2528, 1.0528]]
+ self._col_factors_0 = [[2.4725, -1.2950, -1.9980],
+ [0.44625, 1.50771, 1.27118],
+ [1.39801, -2.10134, 0.73572]]
+ self._col_factors_1 = [[3.36509, -0.66595, -3.51208],
+ [0.57191, 1.59407, 1.33020]]
+ self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
+ [0.57366, 1.83729, 1.26798]]
+
+ def test_process_input(self):
+ with self.test_session():
+ sp_feeder = tf.sparse_placeholder(tf.float32)
+ wals_model = factorization_ops.WALSModel(5, 7, 3,
+ num_row_shards=2,
+ num_col_shards=3,
+ regularization=0.01,
+ unobserved_weight=0.1,
+ col_init=self.col_init,
+ row_weights=self.row_wts,
+ col_weights=self.col_wts)
+
+ wals_model.initialize_op.run()
+ wals_model.worker_init.run()
+
+ # Split input into multiple sparse tensors with scattered rows. Note that
+ # this split can be different than the factor sharding and the inputs can
+ # consist of non-consecutive rows. Each row needs to include all non-zero
+ # elements in that row.
+ sp_r0 = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 2]).eval()
+ sp_r1 = np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=True).eval()
+ sp_r2 = np_matrix_to_tf_sparse(INPUT_MATRIX, [3], shuffle=True).eval()
+ input_scattered_rows = [sp_r0, sp_r1, sp_r2]
+
+ # Test updating row factors.
+ # Here we feed in scattered rows of the input.
+ wals_model.initialize_row_update_op.run()
+ process_input_op = wals_model.update_row_factors(sp_input=sp_feeder,
+ transpose_input=False)[1]
+ for inp in input_scattered_rows:
+ feed_dict = {sp_feeder: inp}
+ process_input_op.run(feed_dict=feed_dict)
+ row_factors = [x.eval() for x in wals_model.row_factors]
+
+ self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
+ self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)
+
+ # Split input into multiple sparse tensors with scattered columns. Note
+ # that here the elements in the sparse tensors are not ordered and also
+ # do not need to consist of consecutive columns. However, each column
+ # needs to include all non-zero elements in that column.
+ sp_c0 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0]).eval()
+ sp_c1 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5, 3, 1],
+ shuffle=True).eval()
+ sp_c2 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 6]).eval()
+ sp_c3 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6],
+ shuffle=True).eval()
+
+ input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3]
+
+ # Test updating column factors.
+ # Here we feed in scattered columns of the input.
+ wals_model.initialize_col_update_op.run()
+ process_input_op = wals_model.update_col_factors(sp_input=sp_feeder,
+ transpose_input=False)[1]
+ for inp in input_scattered_cols:
+ feed_dict = {sp_feeder: inp}
+ process_input_op.run(feed_dict=feed_dict)
+ col_factors = [x.eval() for x in wals_model.col_factors]
+
+ self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
+ self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
+ self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
+
+ def test_process_input_transposed(self):
+ with self.test_session():
+ sp_feeder = tf.sparse_placeholder(tf.float32)
+ wals_model = factorization_ops.WALSModel(5, 7, 3,
+ num_row_shards=2,
+ num_col_shards=3,
+ regularization=0.01,
+ unobserved_weight=0.1,
+ col_init=self.col_init,
+ row_weights=self.row_wts,
+ col_weights=self.col_wts)
+
+ wals_model.initialize_op.run()
+ wals_model.worker_init.run()
+
+ # Split input into multiple SparseTensors with scattered rows.
+ # Here the inputs are transposed. But the same constraints as described in
+ # the previous non-transposed test case apply to these inputs (before they
+ # are transposed).
+ sp_r0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 3],
+ transpose=True).eval()
+ sp_r1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [4, 1],
+ shuffle=True, transpose=True).eval()
+ sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval()
+ sp_r3_t = sp_r1_t
+ input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t]
+
+ # Test updating row factors.
+ # Here we feed in scattered rows of the input.
+ # Note that the needed suffix of placeholder are in the order of test
+ # case name lexicographical order and then in the line order of where
+ # they appear.
+ wals_model.initialize_row_update_op.run()
+ process_input_op = wals_model.update_row_factors(sp_input=sp_feeder,
+ transpose_input=True)[1]
+ for inp in input_scattered_rows:
+ feed_dict = {sp_feeder: inp}
+ process_input_op.run(feed_dict=feed_dict)
+ row_factors = [x.eval() for x in wals_model.row_factors]
+
+ self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
+ self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)
+
+ # Split input into multiple SparseTensors with scattered columns.
+ # Here the inputs are transposed. But the same constraints as described in
+ # the previous non-transposed test case apply to these inputs (before they
+ # are transposed).
+ sp_c0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[0, 1],
+ transpose=True).eval()
+ sp_c1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 2],
+ transpose=True).eval()
+ sp_c2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5],
+ transpose=True, shuffle=True).eval()
+ sp_c3_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6],
+ transpose=True).eval()
+
+ sp_c4_t = sp_c2_t
+ input_scattered_cols = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t,
+ sp_c4_t]
+
+ # Test updating column factors.
+ # Here we feed in scattered columns of the input.
+ wals_model.initialize_col_update_op.run()
+ process_input_op = wals_model.update_col_factors(sp_input=sp_feeder,
+ transpose_input=True)[1]
+ for inp in input_scattered_cols:
+ feed_dict = {sp_feeder: inp}
+ process_input_op.run(feed_dict=feed_dict)
+ col_factors = [x.eval() for x in wals_model.col_factors]
+
+ self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
+ self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
+ self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
+
+ # Note that when row_weights and col_weights are 0, WALS gives dentical
+ # results as ALS (Alternating Least Squares). However our implementation does
+ # not handle the case of zero weights differently. Instead, when row_weights
+ # and col_weights are set to None, we interpret that as the ALS case, and
+ # trigger the more efficient ALS updates.
+ # Here we test that those two give identical results.
+ def test_als(self):
+ with self.test_session():
+ col_init = np.random.rand(7, 3)
+ als_model = factorization_ops.WALSModel(5, 7, 3,
+ col_init=col_init,
+ row_weights=None,
+ col_weights=None)
+
+ als_model.initialize_op.run()
+ als_model.worker_init.run()
+ als_model.initialize_row_update_op.run()
+ process_input_op = als_model.update_row_factors(self._wals_inputs)[1]
+ process_input_op.run()
+ row_factors1 = [x.eval() for x in als_model.row_factors]
+
+ wals_model = factorization_ops.WALSModel(5, 7, 3,
+ col_init=col_init,
+ row_weights=[0] * 5,
+ col_weights=[0] * 7)
+ wals_model.initialize_op.run()
+ wals_model.worker_init.run()
+ wals_model.initialize_row_update_op.run()
+ process_input_op = wals_model.update_row_factors(self._wals_inputs)[1]
+ process_input_op.run()
+ row_factors2 = [x.eval() for x in wals_model.row_factors]
+
+ for r1, r2 in zip(row_factors1, row_factors2):
+ self.assertAllClose(r1, r2, atol=1e-3)
+
+ # Here we test partial column updates.
+ sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0],
+ shuffle=True).eval()
+
+ sp_feeder = tf.sparse_placeholder(tf.float32)
+ feed_dict = {sp_feeder: sp_c}
+ als_model.initialize_col_update_op.run()
+ process_input_op = als_model.update_col_factors(sp_input=sp_feeder)[1]
+ process_input_op.run(feed_dict=feed_dict)
+ col_factors1 = [x.eval() for x in als_model.col_factors]
+
+ feed_dict = {sp_feeder: sp_c}
+ wals_model.initialize_col_update_op.run()
+ process_input_op = wals_model.update_col_factors(sp_input=sp_feeder)[1]
+ process_input_op.run(feed_dict=feed_dict)
+ col_factors2 = [x.eval() for x in wals_model.col_factors]
+
+ for c1, c2 in zip(col_factors1, col_factors2):
+ self.assertAllClose(c1, c2, atol=1e-3)
+
+ def test_als_transposed(self):
+ with self.test_session():
+ col_init = np.random.rand(7, 3)
+ als_model = factorization_ops.WALSModel(5, 7, 3,
+ col_init=col_init,
+ row_weights=None,
+ col_weights=None)
+
+ als_model.initialize_op.run()
+ als_model.worker_init.run()
+
+ wals_model = factorization_ops.WALSModel(5, 7, 3,
+ col_init=col_init,
+ row_weights=[0] * 5,
+ col_weights=[0] * 7)
+ wals_model.initialize_op.run()
+ wals_model.worker_init.run()
+ sp_feeder = tf.sparse_placeholder(tf.float32)
+ # Here test partial row update with identical inputs but with transposed
+ # input for als.
+ sp_r_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1],
+ transpose=True).eval()
+ sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval()
+
+ feed_dict = {sp_feeder: sp_r_t}
+ als_model.initialize_row_update_op.run()
+ process_input_op = als_model.update_row_factors(sp_input=sp_feeder,
+ transpose_input=True)[1]
+ process_input_op.run(feed_dict=feed_dict)
+ # Only updated row 1 and row 3, so only compare these rows since others
+ # have randomly initialized values.
+ row_factors1 = [als_model.row_factors[0].eval()[1],
+ als_model.row_factors[0].eval()[3]]
+
+ feed_dict = {sp_feeder: sp_r}
+ wals_model.initialize_row_update_op.run()
+ process_input_op = wals_model.update_row_factors(sp_input=sp_feeder)[1]
+ process_input_op.run(feed_dict=feed_dict)
+ # Only updated row 1 and row 3, so only compare these rows since others
+ # have randomly initialized values.
+ row_factors2 = [wals_model.row_factors[0].eval()[1],
+ wals_model.row_factors[0].eval()[3]]
+ for r1, r2 in zip(row_factors1, row_factors2):
+ self.assertAllClose(r1, r2, atol=1e-3)
+
+ def simple_train(self,
+ model,
+ inp,
+ num_iterations):
+ """Helper function to train model on inp for num_iterations."""
+ row_update_op = model.update_row_factors(sp_input=inp)[1]
+ col_update_op = model.update_col_factors(sp_input=inp)[1]
+
+ model.initialize_op.run()
+ model.worker_init.run()
+ for _ in xrange(num_iterations):
+ model.initialize_row_update_op.run()
+ row_update_op.run()
+ model.initialize_col_update_op.run()
+ col_update_op.run()
+
+ # Trains an ALS model for a low-rank matrix and make sure the product of
+ # factors is close to the original input.
+ def test_train_full_low_rank_als(self):
+ rows = 15
+ cols = 11
+ dims = 3
+ with self.test_session():
+ data = np.dot(np.random.rand(rows, 3),
+ np.random.rand(3, cols)).astype(np.float32) / 3.0
+ indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
+ values = data.reshape(-1)
+ inp = tf.SparseTensor(indices, values, [rows, cols])
+ model = factorization_ops.WALSModel(rows, cols, dims,
+ regularization=1e-5,
+ row_weights=None,
+ col_weights=None)
+ self.simple_train(model, inp, 15)
+ row_factor = model.row_factors[0].eval()
+ col_factor = model.col_factors[0].eval()
+ self.assertAllClose(data,
+ np.dot(row_factor, np.transpose(col_factor)),
+ rtol=0.01, atol=0.01)
+
+ # Trains a WALS model for a low-rank matrix and make sure the product of
+ # factors is close to the original input.
+ def test_train_full_low_rank_wals(self):
+ rows = 15
+ cols = 11
+ dims = 3
+
+ with self.test_session():
+ data = np.dot(np.random.rand(rows, 3),
+ np.random.rand(3, cols)).astype(np.float32) / 3.0
+ indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
+ values = data.reshape(-1)
+ inp = tf.SparseTensor(indices, values, [rows, cols])
+ model = factorization_ops.WALSModel(rows, cols, dims,
+ regularization=1e-5,
+ row_weights=[0] * rows,
+ col_weights=[0] * cols)
+ self.simple_train(model, inp, 15)
+ row_factor = model.row_factors[0].eval()
+ col_factor = model.col_factors[0].eval()
+ self.assertAllClose(data,
+ np.dot(row_factor, np.transpose(col_factor)),
+ rtol=0.01, atol=0.01)
+
+ # Trains a WALS model for a partially observed low-rank matrix and makes
+ # sure the product of factors is reasonably close to the original input.
+ def test_train_matrix_completion_wals(self):
+ rows = 11
+ cols = 9
+ dims = 4
+ def keep_index(x):
+ return not (x[0] + x[1]) % 4
+
+ with self.test_session():
+ row_wts = 0.1 + np.random.rand(rows)
+ col_wts = 0.1 + np.random.rand(cols)
+ data = np.dot(np.random.rand(rows, 3),
+ np.random.rand(3, cols)).astype(np.float32) / 3.0
+ indices = np.array(
+ list(filter(keep_index,
+ [[i, j] for i in xrange(rows) for j in xrange(cols)])))
+ values = data[indices[:, 0], indices[:, 1]]
+ inp = tf.SparseTensor(indices, values, [rows, cols])
+ model = factorization_ops.WALSModel(rows, cols, dims,
+ unobserved_weight=0.01,
+ regularization=0.001,
+ row_weights=row_wts,
+ col_weights=col_wts)
+ self.simple_train(model, inp, 10)
+ row_factor = model.row_factors[0].eval()
+ col_factor = model.col_factors[0].eval()
+ out = np.dot(row_factor, np.transpose(col_factor))
+ for i in xrange(rows):
+ for j in xrange(cols):
+ if keep_index([i, j]):
+ self.assertNear(data[i][j], out[i][j],
+ err=0.2, msg="%d, %d" % (i, j))
+ else:
+ self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py
new file mode 100644
index 0000000000..4504afe289
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/kmeans.py
@@ -0,0 +1,238 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Implementation of k-means clustering on top of learn (aka skflow) API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.contrib.factorization.python.ops import clustering_ops
+from tensorflow.contrib.learn.python.learn.estimators import estimator
+from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin
+from tensorflow.contrib.learn.python.learn.io import data_feeder
+from tensorflow.contrib.learn.python.learn.utils import checkpoints
+from tensorflow.python.ops.control_flow_ops import with_dependencies
+
+SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE
+COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE
+RANDOM_INIT = clustering_ops.RANDOM_INIT
+KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT
+
+
+# TODO(agarwal,ands): support sharded input.
+# TODO(agarwal,ands): enable stopping criteria based on improvements to cost.
+# TODO(agarwal,ands): support random restarts.
+class KMeansClustering(estimator.Estimator,
+ TransformerMixin):
+ """K-Means clustering."""
+ SCORES = 'scores'
+ CLUSTER_IDX = 'cluster_idx'
+ CLUSTERS = 'clusters'
+ ALL_SCORES = 'all_scores'
+
+ def __init__(self,
+ num_clusters,
+ model_dir=None,
+ initial_clusters=clustering_ops.RANDOM_INIT,
+ distance_metric=clustering_ops.SQUARED_EUCLIDEAN_DISTANCE,
+ random_seed=0,
+ use_mini_batch=True,
+ batch_size=128,
+ steps=10,
+ kmeans_plus_plus_num_retries=2,
+ continue_training=False,
+ config=None,
+ verbose=1):
+ """Creates a model for running KMeans training and inference.
+
+ Args:
+ num_clusters: number of clusters to train.
+ model_dir: the directory to save the model results and log files.
+ initial_clusters: specifies how to initialize the clusters for training.
+ See clustering_ops.kmeans for the possible values.
+ distance_metric: the distance metric used for clustering.
+ See clustering_ops.kmeans for the possible values.
+ random_seed: Python integer. Seed for PRNG used to initialize centers.
+ use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume
+ full batch.
+ batch_size: See TensorFlowEstimator
+ steps: See TensorFlowEstimator
+ kmeans_plus_plus_num_retries: For each point that is sampled during
+ kmeans++ initialization, this parameter specifies the number of
+ additional points to draw from the current distribution before selecting
+ the best. If a negative value is specified, a heuristic is used to
+ sample O(log(num_to_sample)) additional points.
+ continue_training: See TensorFlowEstimator
+ config: See TensorFlowEstimator
+ verbose: See TensorFlowEstimator
+ """
+ super(KMeansClustering, self).__init__(
+ model_dir=model_dir,
+ config=config)
+ self.batch_size = batch_size
+ self.steps = steps
+ self.kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
+ self.continue_training = continue_training
+ self.verbose = verbose
+ self._num_clusters = num_clusters
+ self._training_initial_clusters = initial_clusters
+ self._training_graph = None
+ self._distance_metric = distance_metric
+ self._use_mini_batch = use_mini_batch
+ self._random_seed = random_seed
+ self._initialized = False
+
+ def fit(self, x, y=None, monitors=None, logdir=None, steps=None):
+ """Trains a k-means clustering on x.
+
+ Note: See TensorFlowEstimator for logic for continuous training and graph
+ construction across multiple calls to fit.
+
+ Args:
+ x: training input matrix of shape [n_samples, n_features].
+ y: labels. Should be None.
+ monitors: Monitor object to print training progress and invoke early
+ stopping
+ logdir: the directory to save the log file that can be used for optional
+ visualization.
+ steps: number of training steps. If not None, overrides the value passed
+ in constructor.
+
+ Returns:
+ Returns self.
+ """
+ assert y is None
+ if logdir is not None:
+ self._model_dir = logdir
+ self._data_feeder = data_feeder.setup_train_data_feeder(
+ x, None, self._num_clusters, self.batch_size)
+ self._train_model(input_fn=self._data_feeder.input_builder,
+ feed_fn=self._data_feeder.get_feed_dict_fn(),
+ steps=steps or self.steps,
+ monitors=monitors,
+ init_feed_fn=self._data_feeder.get_feed_dict_fn())
+ return self
+
+ def predict(self, x, batch_size=None):
+ """Predict cluster id for each element in x.
+
+ Args:
+ x: 2-D matrix or iterator.
+ batch_size: size to use for batching up x for querying the model.
+
+ Returns:
+ Array with same number of rows as x, containing cluster ids.
+ """
+ return super(KMeansClustering, self).predict(
+ x=x, batch_size=batch_size)[KMeansClustering.CLUSTER_IDX]
+
+ def score(self, x, batch_size=None):
+ """Predict total sum of distances to nearest clusters.
+
+ Note that this function is different from the corresponding one in sklearn
+ which returns the negative of the sum of distances.
+
+ Args:
+ x: 2-D matrix or iterator.
+ batch_size: size to use for batching up x for querying the model.
+
+ Returns:
+ Total sum of distances to nearest clusters.
+ """
+ return np.sum(
+ self.evaluate(x=x, batch_size=batch_size)[KMeansClustering.SCORES])
+
+ def transform(self, x, batch_size=None):
+ """Transforms each element in x to distances to cluster centers.
+
+ Note that this function is different from the corresponding one in sklearn.
+ For SQUARED_EUCLIDEAN distance metric, sklearn transform returns the
+ EUCLIDEAN distance, while this function returns the SQUARED_EUCLIDEAN
+ distance.
+
+ Args:
+ x: 2-D matrix or iterator.
+ batch_size: size to use for batching up x for querying the model.
+
+ Returns:
+ Array with same number of rows as x, and num_clusters columns, containing
+ distances to the cluster centers.
+ """
+ return super(KMeansClustering, self).predict(
+ x=x, batch_size=batch_size)[KMeansClustering.ALL_SCORES]
+
+ def clusters(self):
+ """Returns cluster centers."""
+ return checkpoints.load_variable(self.model_dir, self.CLUSTERS)
+
+ def _get_train_ops(self, features, _):
+ (_,
+ _,
+ losses,
+ training_op) = clustering_ops.KMeans(
+ features,
+ self._num_clusters,
+ self._training_initial_clusters,
+ self._distance_metric,
+ self._use_mini_batch,
+ random_seed=self._random_seed,
+ kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries
+ ).training_graph()
+ incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1)
+ loss = tf.reduce_sum(losses)
+ training_op = with_dependencies([training_op, incr_step], loss)
+ return training_op, loss
+
+ def _get_predict_ops(self, features):
+ (all_scores,
+ model_predictions,
+ _,
+ _) = clustering_ops.KMeans(
+ features,
+ self._num_clusters,
+ self._training_initial_clusters,
+ self._distance_metric,
+ self._use_mini_batch,
+ random_seed=self._random_seed,
+ kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries
+ ).training_graph()
+ return {
+ KMeansClustering.ALL_SCORES: all_scores[0],
+ KMeansClustering.CLUSTER_IDX: model_predictions[0]
+ }
+
+ def _get_eval_ops(self, features, _, unused_metrics):
+ (_,
+ _,
+ losses,
+ _) = clustering_ops.KMeans(
+ features,
+ self._num_clusters,
+ self._training_initial_clusters,
+ self._distance_metric,
+ self._use_mini_batch,
+ random_seed=self._random_seed,
+ kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries
+ ).training_graph()
+ return {
+ KMeansClustering.SCORES: tf.reduce_sum(losses),
+ }
+
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
new file mode 100644
index 0000000000..b4ade5ea0b
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -0,0 +1,283 @@
+# pylint: disable=g-bad-file-header
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for KMeans."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_ops
+from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans
+from tensorflow.contrib.learn.python.learn.estimators import run_config
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def normalize(x):
+ return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True))
+
+
+def cosine_similarity(x, y):
+ return np.dot(normalize(x), np.transpose(normalize(y)))
+
+
+class KMeansTest(tf.test.TestCase):
+
+ def setUp(self):
+ np.random.seed(3)
+ self.num_centers = 5
+ self.num_dims = 2
+ self.num_points = 10000
+ self.true_centers = self.make_random_centers(self.num_centers,
+ self.num_dims)
+ self.points, _, self.scores = self.make_random_points(
+ self.true_centers,
+ self.num_points)
+ self.true_score = np.add.reduce(self.scores)
+
+ self.kmeans = KMeans(self.num_centers,
+ initial_clusters=kmeans_ops.RANDOM_INIT,
+ batch_size=self.batch_size,
+ use_mini_batch=self.use_mini_batch,
+ steps=30,
+ continue_training=True,
+ config=run_config.RunConfig(tf_random_seed=14),
+ random_seed=12)
+
+ @property
+ def batch_size(self):
+ return self.num_points
+
+ @property
+ def use_mini_batch(self):
+ return False
+
+ @staticmethod
+ def make_random_centers(num_centers, num_dims):
+ return np.round(np.random.rand(num_centers,
+ num_dims).astype(np.float32) * 500)
+
+ @staticmethod
+ def make_random_points(centers, num_points, max_offset=20):
+ num_centers, num_dims = centers.shape
+ assignments = np.random.choice(num_centers, num_points)
+ offsets = np.round(np.random.randn(
+ num_points,
+ num_dims).astype(np.float32) * max_offset)
+ return (centers[assignments] + offsets,
+ assignments,
+ np.add.reduce(offsets * offsets, 1))
+
+ def test_clusters(self):
+ kmeans = self.kmeans
+ kmeans.fit(x=self.points, steps=0)
+ clusters = kmeans.clusters()
+ self.assertAllEqual(list(clusters.shape),
+ [self.num_centers, self.num_dims])
+
+ def test_fit(self):
+ if self.batch_size != self.num_points:
+ # TODO(agarwal): Doesn't work with mini-batch.
+ return
+ kmeans = self.kmeans
+ kmeans.fit(x=self.points,
+ steps=1)
+ score1 = kmeans.score(x=self.points)
+ kmeans.fit(x=self.points,
+ steps=15 * self.num_points // self.batch_size)
+ score2 = kmeans.score(x=self.points)
+ self.assertTrue(score1 > score2)
+ self.assertNear(self.true_score, score2, self.true_score * 0.05)
+
+ def test_infer(self):
+ kmeans = self.kmeans
+ kmeans.fit(x=self.points)
+ clusters = kmeans.clusters()
+
+ # Make a small test set
+ points, true_assignments, true_offsets = self.make_random_points(clusters,
+ 10)
+ # Test predict
+ assignments = kmeans.predict(points)
+ self.assertAllEqual(assignments, true_assignments)
+
+ # Test score
+ score = kmeans.score(points)
+ self.assertNear(score, np.sum(true_offsets), 0.01 * score)
+
+ # Test transform
+ transform = kmeans.transform(points)
+ true_transform = np.maximum(
+ 0,
+ np.sum(np.square(points), axis=1, keepdims=True) -
+ 2 * np.dot(points, np.transpose(clusters)) +
+ np.transpose(np.sum(np.square(clusters), axis=1, keepdims=True)))
+ self.assertAllClose(transform, true_transform, rtol=0.05, atol=10)
+
+ def test_fit_with_cosine_distance(self):
+ # Create points on y=x and y=1.5x lines to check the cosine similarity.
+ # Note that euclidean distance will give different results in this case.
+ points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]])
+ # true centers are the unit vectors on lines y=x and y=1.5x
+ true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]])
+ kmeans = KMeans(2,
+ initial_clusters=kmeans_ops.RANDOM_INIT,
+ distance_metric=kmeans_ops.COSINE_DISTANCE,
+ use_mini_batch=self.use_mini_batch,
+ batch_size=4,
+ steps=30,
+ continue_training=True,
+ config=run_config.RunConfig(tf_random_seed=2),
+ random_seed=12)
+ kmeans.fit(x=points)
+ centers = normalize(kmeans.clusters())
+ self.assertAllClose(np.sort(centers, axis=0),
+ np.sort(true_centers, axis=0))
+
+ def test_transform_with_cosine_distance(self):
+ points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
+ [-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]])
+
+ true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0,
+ keepdims=True))[0],
+ normalize(np.mean(normalize(points)[0:4, :], axis=0,
+ keepdims=True))[0]]
+
+ kmeans = KMeans(2,
+ initial_clusters=kmeans_ops.RANDOM_INIT,
+ distance_metric=kmeans_ops.COSINE_DISTANCE,
+ use_mini_batch=self.use_mini_batch,
+ batch_size=8,
+ continue_training=True,
+ config=run_config.RunConfig(tf_random_seed=3))
+ kmeans.fit(x=points, steps=30)
+
+ centers = normalize(kmeans.clusters())
+ self.assertAllClose(np.sort(centers, axis=0),
+ np.sort(true_centers, axis=0),
+ atol=1e-2)
+
+ true_transform = 1 - cosine_similarity(points, centers)
+ transform = kmeans.transform(points)
+ self.assertAllClose(transform, true_transform, atol=1e-3)
+
+ def test_predict_with_cosine_distance(self):
+ points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18],
+ [-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype(
+ np.float32)
+ true_centers = np.array(
+ [normalize(np.mean(normalize(points)[0:4, :],
+ axis=0,
+ keepdims=True))[0],
+ normalize(np.mean(normalize(points)[4:, :],
+ axis=0,
+ keepdims=True))[0]])
+ true_assignments = [0] * 4 + [1] * 4
+ true_score = len(points) - np.tensordot(normalize(points),
+ true_centers[true_assignments])
+
+ kmeans = KMeans(2,
+ initial_clusters=kmeans_ops.RANDOM_INIT,
+ distance_metric=kmeans_ops.COSINE_DISTANCE,
+ use_mini_batch=self.use_mini_batch,
+ batch_size=8,
+ continue_training=True,
+ config=run_config.RunConfig(tf_random_seed=3))
+ kmeans.fit(x=points, steps=30)
+
+ centers = normalize(kmeans.clusters())
+ self.assertAllClose(np.sort(centers, axis=0),
+ np.sort(true_centers, axis=0), rtol=1e-3)
+
+ assignments = kmeans.predict(points)
+ self.assertAllClose(centers[assignments],
+ true_centers[true_assignments], rtol=1e-3)
+
+ score = kmeans.score(points)
+ self.assertAllClose(score, true_score)
+
+ def test_predict_with_cosine_distance_and_kmeans_plus_plus(self):
+ # Most points are concetrated near one center. KMeans++ is likely to find
+ # the less populated centers.
+ points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],
+ [-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],
+ [-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32)
+ true_centers = np.array(
+ [normalize(np.mean(normalize(points)[0:2, :], axis=0,
+ keepdims=True))[0],
+ normalize(np.mean(normalize(points)[2:4, :], axis=0,
+ keepdims=True))[0],
+ normalize(np.mean(normalize(points)[4:, :], axis=0,
+ keepdims=True))[0]])
+ true_assignments = [0] * 2 + [1] * 2 + [2] * 8
+ true_score = len(points) - np.tensordot(normalize(points),
+ true_centers[true_assignments])
+
+ kmeans = KMeans(3,
+ initial_clusters=kmeans_ops.KMEANS_PLUS_PLUS_INIT,
+ distance_metric=kmeans_ops.COSINE_DISTANCE,
+ use_mini_batch=self.use_mini_batch,
+ batch_size=12,
+ continue_training=True,
+ config=run_config.RunConfig(tf_random_seed=3))
+ kmeans.fit(x=points, steps=30)
+
+ centers = normalize(kmeans.clusters())
+ self.assertAllClose(sorted(centers.tolist()),
+ sorted(true_centers.tolist()),
+ rtol=1e-3)
+
+ assignments = kmeans.predict(points)
+ self.assertAllClose(centers[assignments],
+ true_centers[true_assignments], rtol=1e-3)
+
+ score = kmeans.score(points)
+ self.assertAllClose(score, true_score)
+
+ def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self):
+ points = np.array([[2.0, 3.0], [1.6, 8.2]])
+
+ with self.assertRaisesOpError('less'):
+ kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT)
+ kmeans.fit(x=points)
+
+ def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus(
+ self):
+ points = np.array([[2.0, 3.0], [1.6, 8.2]])
+
+ with self.assertRaisesOpError(AssertionError):
+ kmeans = KMeans(num_clusters=3,
+ initial_clusters=kmeans_ops.KMEANS_PLUS_PLUS_INIT)
+ kmeans.fit(x=points)
+
+
+class MiniBatchKMeansTest(KMeansTest):
+
+ @property
+ def batch_size(self):
+ return 50
+
+ @property
+ def use_mini_batch(self):
+ return True
+
+
+if __name__ == '__main__':
+ tf.test.main()