diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-07-01 17:05:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-01 18:18:15 -0700 |
commit | 23fdab705d7caed3ff3d955b39c3cfc3f5e40678 (patch) | |
tree | 7cd484755b9bc5005ef5d9b3e80726663bdf1bdb | |
parent | 281cd0ae22f05275fcfe1fd8176ebd3769f80043 (diff) |
Add K-Means clustering and WALS matrix factorization to tensorflow.
Change: 126465430
21 files changed, 3744 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 763ae340dd..68f14676ac 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -67,6 +67,8 @@ filegroup( "//tensorflow/contrib:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/factorization:all_files", + "//tensorflow/contrib/factorization/kernels:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", "//tensorflow/contrib/framework:all_files", diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2702ddde01..8b3761ae0f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -16,6 +16,7 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/graph_editor:graph_editor_py", @@ -40,6 +41,7 @@ cc_library( name = "contrib_kernels", visibility = ["//visibility:public"], deps = [ + "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/layers:bucketization_op_kernel", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/linear_optimizer:sdca_op_kernels", @@ -51,6 +53,7 @@ cc_library( name = "contrib_ops_op_lib", visibility = ["//visibility:public"], deps = [ + "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/layers:bucketization_op_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/linear_optimizer:sdca_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index e71adbb5bc..6ec786430f 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.contrib import bayesflow from tensorflow.contrib import copy_graph from tensorflow.contrib import distributions +from tensorflow.contrib import factorization from tensorflow.contrib import framework from tensorflow.contrib import graph_editor from tensorflow.contrib import grid_rnn diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD new file mode 100644 index 0000000000..4be7a966c6 --- /dev/null +++ b/tensorflow/contrib/factorization/BUILD @@ -0,0 +1,135 @@ +# Description: +# Contains ops for factorization of data, including matrix factorization and +# clustering. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "factorization_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + data = [ + ":python/ops/_clustering_ops.so", + ":python/ops/_factorization_ops.so", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_clustering_ops", + ":gen_factorization_ops", + ], +) + +# Ops +tf_custom_op_library( + name = "python/ops/_clustering_ops.so", + srcs = [ + "ops/clustering_ops.cc", + ], + deps = [ + "//tensorflow/contrib/factorization/kernels:clustering_ops", + ], +) + +tf_custom_op_library( + name = "python/ops/_factorization_ops.so", + srcs = [ + "ops/factorization_ops.cc", + ], + deps = [ + "//tensorflow/contrib/factorization/kernels:wals_solver_ops", + ], +) + +tf_gen_op_libs([ + "clustering_ops", + "factorization_ops", +]) + +cc_library( + name = "all_ops", + deps = [ + ":clustering_ops_op_lib", + ":factorization_ops_op_lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_clustering_ops", + out = "python/ops/gen_clustering_ops.py", + deps = [ + ":clustering_ops_op_lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_factorization_ops", + out = "python/ops/gen_factorization_ops.py", + deps = [ + ":factorization_ops_op_lib", + ], +) + +# Ops tests +tf_py_test( + name = "kmeans_test", + srcs = [ + "python/ops/kmeans_test.py", + ], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +tf_py_test( + name = "factorization_ops_test", + srcs = ["python/ops/factorization_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +# Kernel tests +tf_py_test( + name = "wals_solver_ops_test", + srcs = ["python/kernel_tests/wals_solver_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +tf_py_test( + name = "clustering_ops_test", + srcs = ["python/kernel_tests/clustering_ops_test.py"], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +# All files +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py new file mode 100644 index 0000000000..101de6e7c6 --- /dev/null +++ b/tensorflow/contrib/factorization/__init__.py @@ -0,0 +1,23 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops and modules related to factorization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.factorization.python.ops import * diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD new file mode 100644 index 0000000000..bf5b829e32 --- /dev/null +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -0,0 +1,22 @@ +# Example TensorFlow models using factorization ops. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_py_test( + name = "mnist", + size = "medium", + srcs = [ + "mnist.py", + ], + additional_deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/examples/tutorials/mnist", + "//tensorflow/examples/tutorials/mnist:input_data", + ], +) diff --git a/tensorflow/contrib/factorization/examples/mnist.py b/tensorflow/contrib/factorization/examples/mnist.py new file mode 100644 index 0000000000..f5f4c23502 --- /dev/null +++ b/tensorflow/contrib/factorization/examples/mnist.py @@ -0,0 +1,292 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Example mnist model with jointly computed k-means clustering. + +This is a toy example of how clustering can be embedded into larger tensorflow +graphs. In this case, we learn a clustering on-the-fly and transform the input +into the 'distance to clusters' space. These are then fed into hidden layers to +learn the supervised objective. + +To train this model on real mnist data, run this model as follows: + mnist --nofake_data --max_steps=2000 +""" + +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import tempfile +import time + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.factorization.python.ops import clustering_ops +from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.examples.tutorials.mnist import mnist + +# Basic model parameters as external flags. +flags = tf.app.flags +FLAGS = flags.FLAGS +flags.DEFINE_float('learning_rate', 0.3, 'Initial learning rate.') +flags.DEFINE_integer('max_steps', 200, 'Number of steps to run trainer.') +flags.DEFINE_integer('num_clusters', 384, 'Number of input feature clusters') +flags.DEFINE_integer('hidden1', 256, 'Number of units in hidden layer 1.') +flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') +flags.DEFINE_integer('batch_size', 100, 'Batch size. ' + 'Must divide evenly into the dataset sizes.') +flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') +flags.DEFINE_bool('fake_data', True, 'Use fake input data.') + +# The MNIST dataset has 10 classes, representing the digits 0 through 9. +NUM_CLASSES = 10 + +# The MNIST images are always 28x28 pixels. +IMAGE_SIZE = 28 +IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE + + +def placeholder_inputs(): + """Generate placeholder variables to represent the input tensors. + + Returns: + images_placeholder: Images placeholder. + labels_placeholder: Labels placeholder. + """ + images_placeholder = tf.placeholder(tf.float32, shape=(None, + mnist.IMAGE_PIXELS)) + labels_placeholder = tf.placeholder(tf.int32, shape=(None)) + return images_placeholder, labels_placeholder + + +def fill_feed_dict(data_set, images_pl, labels_pl, batch_size): + """Fills the feed_dict for training the given step. + + Args: + data_set: The set of images and labels, from input_data.read_data_sets() + images_pl: The images placeholder, from placeholder_inputs(). + labels_pl: The labels placeholder, from placeholder_inputs(). + batch_size: Batch size of data to feed. + + Returns: + feed_dict: The feed dictionary mapping from placeholders to values. + """ + # Create the feed_dict for the placeholders filled with the next + # `batch size ` examples. + images_feed, labels_feed = data_set.next_batch(batch_size, FLAGS.fake_data) + feed_dict = { + images_pl: images_feed, + labels_pl: labels_feed, + } + return feed_dict + + +def do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_set): + """Runs one evaluation against the full epoch of data. + + Args: + sess: The session in which the model has been trained. + eval_correct: The Tensor that returns the number of correct predictions. + images_placeholder: The images placeholder. + labels_placeholder: The labels placeholder. + data_set: The set of images and labels to evaluate, from + input_data.read_data_sets(). + Returns: + Precision value on the dataset. + """ + # And run one epoch of eval. + true_count = 0 # Counts the number of correct predictions. + steps_per_epoch = data_set.num_examples // FLAGS.batch_size + num_examples = steps_per_epoch * FLAGS.batch_size + for step in xrange(steps_per_epoch): + feed_dict = fill_feed_dict(data_set, + images_placeholder, + labels_placeholder, + FLAGS.batch_size) + true_count += sess.run(eval_correct, feed_dict=feed_dict) + precision = true_count / num_examples + print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % + (num_examples, true_count, precision)) + return precision + + +def inference(inp, num_clusters, hidden1_units, hidden2_units): + """Build the MNIST model up to where it may be used for inference. + + Args: + inp: input data + num_clusters: number of clusters of input features to train. + hidden1_units: Size of the first hidden layer. + hidden2_units: Size of the second hidden layer. + + Returns: + logits: Output tensor with the computed logits. + clustering_loss: Clustering loss. + kmeans_training_op: An op to train the clustering. + """ + # Clustering + kmeans = clustering_ops.KMeans( + inp, + num_clusters, + distance_metric=clustering_ops.COSINE_DISTANCE, + # TODO(agarwal): kmeans++ is currently causing crash in dbg mode. + # Enable this after fixing. + # initial_clusters=clustering_ops.KMEANS_PLUS_PLUS_INIT, + use_mini_batch=True) + + all_scores, _, clustering_scores, kmeans_training_op = kmeans.training_graph() + # Some heuristics to approximately whiten this output. + all_scores = (all_scores[0] - 0.5) * 5 + # Here we avoid passing the gradients from the supervised objective back to + # the clusters by creating a stop_gradient node. + all_scores = tf.stop_gradient(all_scores) + clustering_loss = tf.reduce_sum(clustering_scores[0]) + # Hidden 1 + with tf.name_scope('hidden1'): + weights = tf.Variable( + tf.truncated_normal([num_clusters, hidden1_units], + stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), + name='weights') + biases = tf.Variable(tf.zeros([hidden1_units]), + name='biases') + hidden1 = tf.nn.relu(tf.matmul(all_scores, weights) + biases) + # Hidden 2 + with tf.name_scope('hidden2'): + weights = tf.Variable( + tf.truncated_normal([hidden1_units, hidden2_units], + stddev=1.0 / math.sqrt(float(hidden1_units))), + name='weights') + biases = tf.Variable(tf.zeros([hidden2_units]), + name='biases') + hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) + # Linear + with tf.name_scope('softmax_linear'): + weights = tf.Variable( + tf.truncated_normal([hidden2_units, NUM_CLASSES], + stddev=1.0 / math.sqrt(float(hidden2_units))), + name='weights') + biases = tf.Variable(tf.zeros([NUM_CLASSES]), + name='biases') + logits = tf.matmul(hidden2, weights) + biases + return logits, clustering_loss, kmeans_training_op + + +def run_training(): + """Train MNIST for a number of steps.""" + # Get the sets of images and labels for training, validation, and + # test on MNIST. + train_dir = tempfile.mkdtemp() + data_sets = input_data.read_data_sets(train_dir, FLAGS.fake_data) + + # Tell TensorFlow that the model will be built into the default Graph. + with tf.Graph().as_default(): + # Generate placeholders for the images and labels. + images_placeholder, labels_placeholder = placeholder_inputs() + + # Build a Graph that computes predictions from the inference model. + logits, clustering_loss, kmeans_training_op = inference(images_placeholder, + FLAGS.num_clusters, + FLAGS.hidden1, + FLAGS.hidden2) + + # Add to the Graph the Ops for loss calculation. + loss = mnist.loss(logits, labels_placeholder) + + # Add to the Graph the Ops that calculate and apply gradients. + train_op = tf.group(mnist.training(loss, FLAGS.learning_rate), + kmeans_training_op) + + # Add the Op to compare the logits to the labels during evaluation. + eval_correct = mnist.evaluation(logits, labels_placeholder) + + # Add the variable initializer Op. + init = tf.initialize_all_variables() + + # Create a session for running Ops on the Graph. + sess = tf.Session() + + feed_dict = fill_feed_dict(data_sets.train, + images_placeholder, + labels_placeholder, + batch_size=5000) + # Run the Op to initialize the variables. + sess.run(init, feed_dict=feed_dict) + + # Start the training loop. + max_test_prec = 0 + for step in xrange(FLAGS.max_steps): + start_time = time.time() + + # Fill a feed dictionary with the actual set of images and labels + # for this particular training step. + feed_dict = fill_feed_dict(data_sets.train, + images_placeholder, + labels_placeholder, + FLAGS.batch_size) + + # Run one step of the model. + _, loss_value, clustering_loss_value = sess.run([train_op, + loss, + clustering_loss], + feed_dict=feed_dict) + + duration = time.time() - start_time + if step % 100 == 0: + # Print status to stdout. + print('Step %d: loss = %.2f, clustering_loss = %.2f (%.3f sec)' % ( + step, loss_value, clustering_loss_value, duration)) + + # Save a checkpoint and evaluate the model periodically. + if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: + # Evaluate against the training set. + print('Training Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.train) + # Evaluate against the validation set. + print('Validation Data Eval:') + do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.validation) + # Evaluate against the test set. + print('Test Data Eval:') + test_prec = do_eval(sess, + eval_correct, + images_placeholder, + labels_placeholder, + data_sets.test) + max_test_prec = max(max_test_prec, test_prec) + return max_test_prec + + +class MnistTest(tf.test.TestCase): + + def test_train(self): + self.assertTrue(run_training() > 0.6) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/factorization/kernels/BUILD b/tensorflow/contrib/factorization/kernels/BUILD new file mode 100644 index 0000000000..301ab4c95e --- /dev/null +++ b/tensorflow/contrib/factorization/kernels/BUILD @@ -0,0 +1,67 @@ +# OpKernels for data factorization and clustering. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") + +cc_library( + name = "all_kernels", + deps = [ + ":clustering_ops", + ":wals_solver_ops", + "@protobuf//:protobuf", + ], +) + +cc_library( + name = "wals_solver_ops", + srcs = ["wals_solver_ops.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf//:protobuf", + ], + alwayslink = 1, +) + +cc_library( + name = "clustering_ops", + srcs = ["clustering_ops.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf//:protobuf", + ], + alwayslink = 1, +) + +cc_test( + name = "clustering_ops_test", + srcs = ["clustering_ops_test.cc"], + deps = [ + ":clustering_ops", + "//tensorflow/contrib/factorization:clustering_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc new file mode 100644 index 0000000000..5f680ceadb --- /dev/null +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -0,0 +1,522 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// ============================================================================== + +#define EIGEN_USE_THREADS + +#include <algorithm> +#include <memory> +#include <tuple> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/top_n.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/host_info.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace { +using errors::InvalidArgument; + +template <typename Scalar> +using RowMajorMatrix = + Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; + +using MatrixXfRowMajor = RowMajorMatrix<float>; +using MatrixXi64RowMajor = RowMajorMatrix<int64>; + +// Ideally this should be computed by dividing L3 cache size by the number of +// physical CPUs. Since there isn't a portable method to do this, we are using +// a conservative estimate here. +const int64 kDefaultL3CachePerCpu = 1 << 20; + +// These values were determined by performing a parameter sweep on the +// NearestNeighborsOp benchmark. +const int64 kNearestNeighborsCentersMaxBlockSize = 1024; +const int64 kNearestNeighborsPointsMinBlockSize = 16; + +// Returns the smallest multiple of a that is not smaller than b. +int64 NextMultiple(int64 a, int64 b) { + const int64 remainder = b % a; + return remainder == 0 ? b : (b + a - remainder); +} + +// Returns a / b rounded up to the next higher integer. +int64 CeilOfRatio(int64 a, int64 b) { return (a + b - 1) / b; } + +} // namespace + +// Implementation of K-means++ initialization. Samples points iteratively in +// proportion to the squared distances from selected points. +// TODO(ands): Add support for other distance metrics. +class KmeansPlusPlusInitializationOp : public OpKernel { + public: + explicit KmeansPlusPlusInitializationOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature( + {DT_FLOAT, DT_INT64, DT_INT64, DT_INT64}, {DT_FLOAT})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& points_tensor = context->input(0); + const Tensor& num_to_sample_tensor = context->input(1); + const Tensor& seed_tensor = context->input(2); + const Tensor& num_retries_per_sample_tensor = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()), + InvalidArgument("Input points should be a matrix.")); + OP_REQUIRES(context, + TensorShapeUtils::IsScalar(num_to_sample_tensor.shape()), + InvalidArgument("Input num_to_sample should be a scalar.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), + InvalidArgument("Input seed should be a scalar.")); + OP_REQUIRES( + context, + TensorShapeUtils::IsScalar(num_retries_per_sample_tensor.shape()), + InvalidArgument("Input num_retries_per_sample should be a scalar.")); + + const int64 num_points = points_tensor.dim_size(0); + const int64 point_dimensions = points_tensor.dim_size(1); + const int64 num_to_sample = num_to_sample_tensor.scalar<int64>()(); + const int64 seed = seed_tensor.scalar<int64>()(); + const int64 num_retries_per_sample = [&]() { + const int64 value = num_retries_per_sample_tensor.scalar<int64>()(); + return value >= 0 ? value + : 2 + static_cast<int64>(std::log(num_to_sample)); + }(); + + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected points.rows() > 0.")); + OP_REQUIRES(context, num_to_sample > 0, + InvalidArgument("Expected num_to_sample > 0.")); + OP_REQUIRES(context, num_to_sample <= num_points, + InvalidArgument("Expected num_to_sample <= points.rows(). ", + num_to_sample, " vs ", num_points, ".")); + + Tensor* output_sampled_points_tensor; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({num_to_sample, point_dimensions}), + &output_sampled_points_tensor)); + + const Eigen::Map<const MatrixXfRowMajor> points( + points_tensor.matrix<float>().data(), num_points, point_dimensions); + const Eigen::VectorXf points_half_squared_norm = + 0.5 * points.rowwise().squaredNorm(); + + Eigen::Map<MatrixXfRowMajor> sampled_points( + output_sampled_points_tensor->matrix<float>().data(), num_to_sample, + point_dimensions); + std::unordered_set<int64> sampled_indices; + + random::PhiloxRandom random(seed); + random::SimplePhilox rng(&random); + + auto add_one_point = [&](int64 from, int64 to) { + from = std::min(from, num_points - 1); + sampled_points.row(to) = points.row(from); + sampled_indices.insert(from); + }; + + // Distances from all points to nearest selected point. Initialize with + // distances to first selected point. + Eigen::VectorXf min_distances(num_points); + min_distances.fill(std::numeric_limits<float>::infinity()); + Eigen::VectorXf min_distances_cumsum(num_points); + + auto draw_one_sample = [&]() -> int64 { + if (sampled_indices.empty()) return rng.Uniform64(num_points); + int64 index = 0; + do { + // If v is drawn from Uniform[0, distances.sum()), then + // Prob[cumsum(distances)(i - 1) <= v < cumsum(distances)(i)] is + // proportional to distances(i). + index = std::upper_bound( + min_distances_cumsum.data(), + min_distances_cumsum.data() + num_points, + rng.RandFloat() * min_distances_cumsum(num_points - 1)) - + min_distances_cumsum.data(); + } while (sampled_indices.find(index) != sampled_indices.end()); + return index; + }; + + auto sample_one_point = [&]() { + const int64 sampled_index = draw_one_sample(); + min_distances = min_distances.cwiseMin(GetHalfSquaredDistancesToY( + points, points_half_squared_norm, points.row(sampled_index), + points_half_squared_norm(sampled_index))); + return sampled_index; + }; + + auto sample_one_point_with_retries = [&]() { + Eigen::VectorXf best_new_min_distances(num_points); + float best_potential = std::numeric_limits<float>::infinity(); + int64 best_sampled_index = 0; + for (int i = 1 + num_retries_per_sample; i > 0; --i) { + const int64 sampled_index = draw_one_sample(); + Eigen::VectorXf new_min_distances = + min_distances.cwiseMin(GetHalfSquaredDistancesToY( + points, points_half_squared_norm, points.row(sampled_index), + points_half_squared_norm(sampled_index))); + const float potential = new_min_distances.sum(); + if (potential < best_potential) { + best_potential = potential; + best_sampled_index = sampled_index; + best_new_min_distances.swap(new_min_distances); + } + } + min_distances.swap(best_new_min_distances); + return best_sampled_index; + }; + + for (int64 i = 0; i < num_to_sample; ++i) { + if (i > 0) { + std::partial_sum(min_distances.data(), + min_distances.data() + num_points, + min_distances_cumsum.data()); + } + int64 next = num_retries_per_sample == 0 + ? sample_one_point() + : sample_one_point_with_retries(); + add_one_point(next, i); + } + } + + private: + // Returns a column vector with the i-th element set to half the squared + // euclidean distance between the i-th row of xs, and y. Precomputed norms for + // each row of xs and y must be provided for efficiency. + // TODO(ands): Parallelize this for large xs. + static Eigen::VectorXf GetHalfSquaredDistancesToY( + const Eigen::Ref<const MatrixXfRowMajor>& xs, + const Eigen::Ref<const Eigen::VectorXf>& xs_half_squared_norm, + const Eigen::Ref<const Eigen::RowVectorXf>& y, + float y_half_squared_norm) { + // Squared distance between points xs_i and y is: + // || xs_i ||^2 - 2 <xs_i, y> + || y ||^2 + return (xs_half_squared_norm - xs * y.transpose()).array() + + y_half_squared_norm; + } +}; + +REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), + KmeansPlusPlusInitializationOp); + +// Operator for computing the nearest neighbors for a set of points. +class NearestNeighborsOp : public OpKernel { + public: + explicit NearestNeighborsOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature({DT_FLOAT, DT_FLOAT, DT_INT64}, + {DT_INT64, DT_FLOAT})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& points_tensor = context->input(0); + const Tensor& centers_tensor = context->input(1); + const Tensor& k_tensor = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(points_tensor.shape()), + InvalidArgument("Input points should be a matrix.")); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(centers_tensor.shape()), + InvalidArgument("Input centers should be a matrix.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(k_tensor.shape()), + InvalidArgument("Input k should be a scalar.")); + + const int64 num_points = points_tensor.dim_size(0); + const int64 point_dimensions = points_tensor.dim_size(1); + const int64 num_centers = centers_tensor.dim_size(0); + const int64 center_dimensions = centers_tensor.dim_size(1); + + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected points.rows() > 0.")); + OP_REQUIRES( + context, point_dimensions == center_dimensions, + InvalidArgument("Expected point_dimensions == center_dimensions: ", + point_dimensions, " vs ", center_dimensions, ".")); + + const Eigen::Map<const MatrixXfRowMajor> points( + points_tensor.matrix<float>().data(), num_points, point_dimensions); + const Eigen::Map<const MatrixXfRowMajor> centers( + centers_tensor.matrix<float>().data(), num_centers, center_dimensions); + const int64 k = std::min<int64>(num_centers, k_tensor.scalar<int64>()()); + + Tensor* output_nearest_center_indices_tensor; + Tensor* output_nearest_center_distances_tensor; + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_points, k}), + &output_nearest_center_indices_tensor)); + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({num_points, k}), + &output_nearest_center_distances_tensor)); + + if (k == 0) return; + + Eigen::Map<MatrixXi64RowMajor> nearest_center_indices( + output_nearest_center_indices_tensor->matrix<int64>().data(), + num_points, k); + Eigen::Map<MatrixXfRowMajor> nearest_center_distances( + output_nearest_center_distances_tensor->matrix<float>().data(), + num_points, k); + + const Eigen::VectorXf centers_half_squared_norm = + 0.5 * centers.rowwise().squaredNorm(); + + // The distance computation is sharded to take advantage of multiple cores + // and to allow intermediate values to reside in L3 cache. This is done by + // sharding the points and centers as follows: + // + // 1. Centers are sharded such that each block of centers has at most + // kNearestNeighborsCentersMaxBlockSize rows. + // 2. Points are sharded, and each block of points is multiplied with each + // block of centers. The block size of points is chosen such that the + // point coordinates (point_dimensions) and the matrix of distances to + // each center in one block -- the intermediate data -- fits in L3 cache. + // 3. After performing each block-block distance computation, the results + // are reduced to a set of k nearest centers as soon as possible. This + // decreases total memory I/O. + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + const int64 num_threads = worker_threads.num_threads; + // This kernel might be configured to use fewer than the total number of + // available CPUs on the host machine. To avoid descructive interference + // with other jobs running on the host machine, we must only use a fraction + // of total available L3 cache. Unfortunately, we cannot query the host + // machine to get the number of physical CPUs. So, we use a fixed per-CPU + // budget and scale it by the number of CPUs available to this operation. + const int64 total_memory_budget = + kDefaultL3CachePerCpu * port::NumSchedulableCPUs(); + // Compute the number of blocks into which rows of points must be split so + // that the distance matrix and the block of points can fit in cache. One + // row of points will yield a vector of distances to each center in a block. + const int64 bytes_per_row = + (std::min(kNearestNeighborsCentersMaxBlockSize, + num_centers) /* centers in a block */ + + point_dimensions /* coordinates of one point */) * + sizeof(float); + // The memory needed for storing the centers being processed. This is shared + // by all workers. Adding slack to the number of threads to avoid incorrect + // cache eviction when a new block of centers is loaded. + const int64 bytes_for_centers = + std::min(num_centers, + (num_threads + 2) * kNearestNeighborsCentersMaxBlockSize) * + point_dimensions * sizeof(float); + // The memory budget available for workers to store their distance matrices. + const int64 available_memory_budget = + total_memory_budget - bytes_for_centers; + // That memory budget is shared by all threads. + const int64 rows_per_block = + std::max<int64>(kNearestNeighborsPointsMinBlockSize, + available_memory_budget / num_threads / bytes_per_row); + // Divide rows into almost uniformly-sized units of work that are small + // enough for the memory budget (rows_per_block). Round up to a multiple of + // the number of threads. + const int64 num_units = + NextMultiple(num_threads, CeilOfRatio(num_points, rows_per_block)); + auto work = [&](int64 start, int64 limit) { + for (; start < limit; ++start) { + const int64 start_row = num_points * start / num_units; + const int64 limit_row = num_points * (start + 1) / num_units; + CHECK_LE(limit_row, num_points); + const int64 num_rows = limit_row - start_row; + auto points_shard = points.middleRows(start_row, num_rows); + const Eigen::VectorXf points_half_squared_norm = + 0.5 * points_shard.rowwise().squaredNorm(); + auto nearest_center_indices_shard = + nearest_center_indices.middleRows(start_row, num_rows); + auto nearest_center_distances_shard = + nearest_center_distances.middleRows(start_row, num_rows); + FindKNearestCenters(k, points_shard, points_half_squared_norm, centers, + centers_half_squared_norm, + nearest_center_indices_shard, + nearest_center_distances_shard); + } + }; + + const int64 units_per_thread = num_units / num_threads; + BlockingCounter counter(num_threads - 1); + for (int64 i = 1; i < num_threads; ++i) { + const int64 start = i * units_per_thread; + const int64 limit = start + units_per_thread; + worker_threads.workers->Schedule([work, &counter, start, limit]() { + work(start, limit); + counter.DecrementCount(); + }); + } + work(0, units_per_thread); + counter.Wait(); + } + + private: + static void FindKNearestCenters( + int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points, + const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, + const Eigen::Ref<const MatrixXfRowMajor>& centers, + const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, + Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, + Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { + CHECK_LE(k, centers.rows()); + if (centers.rows() <= kNearestNeighborsCentersMaxBlockSize) { + FindKNearestCentersOneBlock(k, points, points_half_squared_norm, centers, + centers_half_squared_norm, + nearest_center_indices, + nearest_center_distances); + } else { + FindKNearestCentersBlockwise(k, points, points_half_squared_norm, centers, + centers_half_squared_norm, + nearest_center_indices, + nearest_center_distances); + } + } + + static void FindKNearestCentersOneBlock( + int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points, + const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, + const Eigen::Ref<const MatrixXfRowMajor>& centers, + const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, + Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, + Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { + CHECK_LE(k, centers.rows()); + const int64 num_points = points.rows(); + const MatrixXfRowMajor inner_product = points * centers.transpose(); + // Find nearest neighbors. + if (k == 1) { + for (int i = 0; i < num_points; ++i) { + int64 index; + nearest_center_distances(i, 0) = + 2.0 * + (points_half_squared_norm(i) + + (centers_half_squared_norm.transpose() - inner_product.row(i)) + .minCoeff(&index)); + nearest_center_indices(i, 0) = index; + } + } else { + // Select k nearest centers for each point. + using Center = std::pair<float, int64>; + const int64 num_centers = centers.rows(); + gtl::TopN<Center, std::less<Center>> selector(k); + std::unique_ptr<std::vector<Center>> nearest_centers; + for (int i = 0; i < num_points; ++i) { + selector.reserve(num_centers); + for (int j = 0; j < num_centers; ++j) { + const float partial_distance = + centers_half_squared_norm(j) - inner_product(i, j); + selector.push(Center(partial_distance, j)); + } + nearest_centers.reset(selector.Extract()); + selector.Reset(); + const float point_half_squared_norm = points_half_squared_norm(i); + for (int j = 0; j < k; ++j) { + const Center& center = (*nearest_centers)[j]; + nearest_center_distances(i, j) = + 2.0 * (point_half_squared_norm + center.first); + nearest_center_indices(i, j) = center.second; + } + } + } + } + + static void FindKNearestCentersBlockwise( + int64 k, const Eigen::Ref<const MatrixXfRowMajor>& points, + const Eigen::Ref<const Eigen::VectorXf>& points_half_squared_norm, + const Eigen::Ref<const MatrixXfRowMajor>& centers, + const Eigen::Ref<const Eigen::VectorXf>& centers_half_squared_norm, + Eigen::Ref<MatrixXi64RowMajor> nearest_center_indices, + Eigen::Ref<MatrixXfRowMajor> nearest_center_distances) { + const int64 num_points = points.rows(); + const int64 num_centers = centers.rows(); + CHECK_LE(k, num_centers); + CHECK_GT(num_centers, kNearestNeighborsCentersMaxBlockSize); + // Store nearest neighbors with first block of centers directly into the + // output matrices. + int64 out_k = std::min(k, kNearestNeighborsCentersMaxBlockSize); + FindKNearestCentersOneBlock( + out_k, points, points_half_squared_norm, + centers.topRows(kNearestNeighborsCentersMaxBlockSize), + centers_half_squared_norm.head(kNearestNeighborsCentersMaxBlockSize), + nearest_center_indices, nearest_center_distances); + // Iteratively compute nearest neighbors with other blocks of centers, and + // update the output matrices. + MatrixXi64RowMajor block_nearest_center_indices(num_points, k); + MatrixXfRowMajor block_nearest_center_distances(num_points, k); + Eigen::Matrix<int64, 1, Eigen::Dynamic> merged_indices(k); + Eigen::Matrix<float, 1, Eigen::Dynamic> merged_distances(k); + for (int64 centers_start = kNearestNeighborsCentersMaxBlockSize; + centers_start < num_centers; + centers_start += kNearestNeighborsCentersMaxBlockSize) { + const int64 centers_block_size = std::min( + kNearestNeighborsCentersMaxBlockSize, num_centers - centers_start); + const int64 block_k = std::min(k, centers_block_size); + FindKNearestCentersOneBlock( + block_k, points, points_half_squared_norm, + centers.middleRows(centers_start, centers_block_size), + centers_half_squared_norm.segment(centers_start, centers_block_size), + block_nearest_center_indices, block_nearest_center_distances); + if (k == 1) { + for (int i = 0; i < num_points; ++i) { + if (block_nearest_center_distances(i, 0) < + nearest_center_distances(i, 0)) { + nearest_center_indices(i, 0) = + block_nearest_center_indices(i, 0) + centers_start; + nearest_center_distances(i, 0) = + block_nearest_center_distances(i, 0); + } + } + } else { + for (int i = 0; i < num_points; ++i) { + // Merge and accumulate top-k list from block_nearest_center_indices + // into nearest_center_indices. + for (int64 j_out = 0, j_block = 0, j_merged = 0; + (j_out < out_k || j_block < block_k) && j_merged < k; + ++j_merged) { + const float distance_out = + j_out < out_k ? nearest_center_distances(i, j_out) + : std::numeric_limits<float>::infinity(); + const float distance_block = + j_block < block_k ? block_nearest_center_distances(i, j_block) + : std::numeric_limits<float>::infinity(); + if (distance_out <= distance_block) { + merged_indices(j_merged) = nearest_center_indices(i, j_out); + merged_distances(j_merged) = distance_out; + ++j_out; + } else { + merged_indices(j_merged) = + block_nearest_center_indices(i, j_block) + centers_start; + merged_distances(j_merged) = distance_block; + ++j_block; + } + } + nearest_center_indices.row(i) = merged_indices; + nearest_center_distances.row(i) = merged_distances; + out_k = std::min(k, out_k + block_k); + } + } + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("NearestNeighbors").Device(DEVICE_CPU), + NearestNeighborsOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc new file mode 100644 index 0000000000..c4a96b048d --- /dev/null +++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc @@ -0,0 +1,176 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// ============================================================================== + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { +namespace { + +constexpr int k100Dim = 100; +// Number of points for tests. +constexpr int k10Points = 10; +constexpr int k100Points = 100; +constexpr int k1kPoints = 1000; +constexpr int k10kPoints = 10000; +constexpr int k1MPoints = 1000000; +// Number of centers for tests. +constexpr int k2Centers = 2; +constexpr int k5Centers = 5; +constexpr int k10Centers = 10; +constexpr int k20Centers = 20; +constexpr int k50Centers = 50; +constexpr int k100Centers = 100; +constexpr int k200Centers = 200; +constexpr int k500Centers = 500; +constexpr int k1kCenters = 1000; +constexpr int k10kCenters = 10000; +// Number of retries for tests. +constexpr int k0RetriesPerSample = 0; +constexpr int k3RetriesPerSample = 3; + +Graph* SetUpKmeansPlusPlusInitialization(int num_dims, int num_points, + int num_to_sample, + int retries_per_sample) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor points(DT_FLOAT, TensorShape({num_points, num_dims})); + Tensor sample_size(DT_INT64, TensorShape({})); + Tensor seed(DT_INT64, TensorShape({})); + Tensor num_retries_per_sample(DT_INT64, TensorShape({})); + points.flat<float>().setRandom(); + sample_size.flat<int64>().setConstant(num_to_sample); + seed.flat<int64>().setConstant(12345); + num_retries_per_sample.flat<int64>().setConstant(retries_per_sample); + + TF_CHECK_OK(NodeBuilder("kmeans_plus_plus_initialization_op", + "KmeansPlusPlusInitialization") + .Input(test::graph::Constant(g, points)) + .Input(test::graph::Constant(g, sample_size)) + .Input(test::graph::Constant(g, seed)) + .Input(test::graph::Constant(g, num_retries_per_sample)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template <int num_points, int num_to_sample, int num_dims, + int retries_per_sample> +void BM_KmeansPlusPlusInitialization(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims * + num_to_sample); + testing::UseRealTime(); + Graph* g = SetUpKmeansPlusPlusInitialization( + num_dims, num_points, num_to_sample, retries_per_sample); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +#define BENCHMARK_KMEANS_PLUS_PLUS(p, c, d, r) \ + void BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r(int iters) { \ + BM_KmeansPlusPlusInitialization<p, c, d, r>(iters); \ + } \ + BENCHMARK(BM_KmeansPlusPlusInitialization_##p##_##c##_##d##_##r); + +#define RUN_BM_KmeansPlusPlusInitialization(retries) \ + BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k2Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k5Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10Points, k10Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k10Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k20Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k50Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k100Points, k100Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k100Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k200Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k500Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1kPoints, k1kCenters, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k100Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k200Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k500Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k10kPoints, k1kCenters, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k100Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k200Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k500Centers, k100Dim, retries); \ + BENCHMARK_KMEANS_PLUS_PLUS(k1MPoints, k1kCenters, k100Dim, retries) + +RUN_BM_KmeansPlusPlusInitialization(k0RetriesPerSample); +RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample); + +#undef RUN_BM_KmeansPlusPlusInitialization +#undef BENCHMARK_KMEANS_PLUS_PLUS + +Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, + int k) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor points(DT_FLOAT, TensorShape({num_points, num_dims})); + Tensor centers(DT_FLOAT, TensorShape({num_centers, num_dims})); + Tensor top(DT_INT64, TensorShape({})); + points.flat<float>().setRandom(); + centers.flat<float>().setRandom(); + top.flat<int64>().setConstant(k); + + TF_CHECK_OK(NodeBuilder("nearest_centers_op", "NearestNeighbors") + .Input(test::graph::Constant(g, points)) + .Input(test::graph::Constant(g, centers)) + .Input(test::graph::Constant(g, top)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template <int num_dims, int num_points, int num_centers, int k> +void BM_NearestNeighbors(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims * + num_centers); + testing::UseRealTime(); + Graph* g = SetUpNearestNeighbors(num_dims, num_points, num_centers, k); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +constexpr int kTop1 = 1; +constexpr int kTop2 = 2; +constexpr int kTop5 = 5; +constexpr int kTop10 = 10; + +#define BENCHMARK_NEAREST_NEIGHBORS(d, p, c, k) \ + void BM_NearestNeighbors##d##_##p##_##c##_##k(int iters) { \ + BM_NearestNeighbors<d, p, c, k>(iters); \ + } \ + BENCHMARK(BM_NearestNeighbors##d##_##p##_##c##_##k); + +#define RUN_BM_NearestNeighbors(k) \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k100Centers, k); \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k1kCenters, k); \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1kPoints, k10kCenters, k); \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k100Centers, k); \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k1kCenters, k); \ + BENCHMARK_NEAREST_NEIGHBORS(k100Dim, k1MPoints, k10kCenters, k) + +RUN_BM_NearestNeighbors(kTop1); +// k > 1 +RUN_BM_NearestNeighbors(kTop2); +RUN_BM_NearestNeighbors(kTop5); +RUN_BM_NearestNeighbors(kTop10); + +#undef RUN_BM_NearestNeighbors +#undef BENCHMARK_NEAREST_NEIGHBORS +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc new file mode 100644 index 0000000000..6797a88fd3 --- /dev/null +++ b/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc @@ -0,0 +1,271 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// ============================================================================== + +// TensorFlow kernels and Ops for constructing WALS normal equations. +// TODO(agarwal,rmlarsen): Add security checks to the code. + +#include <algorithm> +#include <vector> + +// This is only used for std::this_thread::get_id() +#include <thread> // NOLINT + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" + +using tensorflow::DEVICE_CPU; +using tensorflow::DT_BOOL; +using tensorflow::DT_FLOAT; +using tensorflow::DT_INT64; +using tensorflow::OpKernel; +using tensorflow::OpKernelConstruction; +using tensorflow::OpKernelContext; +using tensorflow::Tensor; +using tensorflow::TensorShape; +using tensorflow::TensorShapeUtils; +using tensorflow::errors::InvalidArgument; + +namespace tensorflow { + +// TODO(ataei): Consider using RowMajor maps. +typedef Eigen::Map< + Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> + EigenMatrixFloatMap; +typedef Eigen::Map< + const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> + ConstEigenMatrixInt64Map; +typedef Eigen::Map< + const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> + ConstEigenMatrixFloatMap; + +class WALSComputePartialLhsAndRhsOp : public OpKernel { + public: + explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL}, + {DT_FLOAT, DT_FLOAT})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& factors = context->input(0); + const Tensor& factor_weights = context->input(1); + const Tensor& unobserved_weights = context->input(2); + const Tensor& input_weights = context->input(3); + const Tensor& input_indices = context->input(4); + const Tensor& input_values = context->input(5); + const Tensor& input_block_size = context->input(6); + const Tensor& input_is_transpose = context->input(7); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), + InvalidArgument("Input factors should be a matrix.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()), + InvalidArgument("Input factor_weights should be a vector.")); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(unobserved_weights.shape()), + InvalidArgument("Input unobserved_weights should be a scalar.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()), + InvalidArgument("Input input_weights should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), + InvalidArgument("Input input_indices should be a matrix.")); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), + InvalidArgument("Input input_values should be a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), + InvalidArgument("Input input_block_size should be a scalar.")); + OP_REQUIRES( + context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), + InvalidArgument("Input input_is_transpose should be a scalar.")); + + const int64 factor_dim = factors.dim_size(1); + const int64 factors_size = factors.dim_size(0); + const int64 num_nonzero_elements = input_indices.dim_size(0); + const int64 block_size = input_block_size.scalar<int64>()(); + const auto& factor_weights_vec = factor_weights.vec<float>(); + const auto& input_weights_vec = input_weights.vec<float>(); + const float w_0 = unobserved_weights.scalar<float>()(); + const auto& input_values_vec = input_values.vec<float>(); + + ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(), + factor_dim, factors_size); + ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(), + 2, num_nonzero_elements); + + Tensor* output_lhs_tensor; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({block_size, factor_dim, factor_dim}), + &output_lhs_tensor)); + auto output_lhs_t = output_lhs_tensor->tensor<float, 3>(); + output_lhs_t.setZero(); + Tensor* output_rhs_tensor; + OP_REQUIRES_OK(context, context->allocate_output( + 1, TensorShape({block_size, factor_dim}), + &output_rhs_tensor)); + EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(), + factor_dim, block_size); + rhs_mat.setZero(); + const bool is_transpose = input_is_transpose.scalar<bool>()(); + + auto get_input_index = [is_transpose, &indices_mat](int64 i) { + return is_transpose ? indices_mat(1, i) : indices_mat(0, i); + }; + auto get_factor_index = [is_transpose, &indices_mat](int64 i) { + return is_transpose ? indices_mat(0, i) : indices_mat(1, i); + }; + + // TODO(rmlarsen): In principle, we should be using the SparseTensor class + // and machinery for iterating over groups, but the fact that class + // SparseTensor makes a complete copy of the matrix makes me reluctant to + // use it. + std::vector<int64> perm(num_nonzero_elements); + std::iota(perm.begin(), perm.end(), 0); + + typedef std::pair<int64, int64> Shard; + std::vector<Shard> shards; + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + const int num_threads = worker_threads.num_threads; + int64 shard_total = 0; + if (num_threads == 1) { + shards.emplace_back(0, num_nonzero_elements); + shard_total += num_nonzero_elements; + } else { + // Compute a permutation such that get_input_index(perm[i]) is sorted, use + // stable_sort to preserve spatial locality. + std::stable_sort(perm.begin(), perm.end(), + [&get_input_index](int64 i, int64 j) { + return get_input_index(i) < get_input_index(j); + }); + + // Compute the start and end of runs with identical input_index. + // These are the shards of work that can be processed in parallel + // without locking. + int64 start = 0; + int64 end = 0; + while (end < num_nonzero_elements) { + start = end; + while (end < num_nonzero_elements && + get_input_index(perm[start]) == get_input_index(perm[end])) { + ++end; + } + shards.emplace_back(start, end); + shard_total += end - start; + } + } + CHECK_EQ(shard_total, num_nonzero_elements); + CHECK_LE(shards.size(), num_nonzero_elements); + CHECK_GT(shards.size(), 0); + + // Batch the rank-one updates into a rank-k update to lower memory traffic + const int kMaxBatchSize = 128; + + // Since we do not have an easy way of generating thread id's within the + // range [0,num_threads), we can instead call out to an std::unordered_map + // of matrices and initialize the matrix on the first call. + // However, this might have a performance penalty, as memory allocation can + // cause the OS kernel to enter a critical section and temporarily disable + // parallelism, and the unordered_map must be protected with a read/write + // mutex. + // + // TODO(jpoulson): Simplify after the thread rank can be queried + std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map; + mutex map_mutex; + + BlockingCounter counter(shards.size()); + // Lambda encapsulating the per-shard computation. + auto work = [&](const Shard& shard) { + const std::thread::id thread_id = std::this_thread::get_id(); + const size_t id_hash = std::hash<std::thread::id>()(thread_id); + // If this thread's unique factors_mat.rows() x kMaxBatchSize + // batching matrix has not yet been created, then emplace it into the + // map using the hash of the thread id as the key. + // + // TODO(jpoulson): Switch to try_emplace once C++17 is supported + map_mutex.lock(); + const auto key_count = factor_batch_map.count(id_hash); + map_mutex.unlock(); + if (!key_count) { + map_mutex.lock(); + factor_batch_map.emplace( + std::piecewise_construct, std::forward_as_tuple(id_hash), + std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize)); + map_mutex.unlock(); + } + map_mutex.lock(); + auto& factor_batch = factor_batch_map[id_hash]; + map_mutex.unlock(); + + CHECK_GE(shard.first, 0); + CHECK_LE(shard.second, perm.size()); + CHECK_LE(shard.first, shard.second); + const int64 input_index = get_input_index(perm[shard.first]); + // Acccumulate the rhs and lhs terms in the normal equations + // for the non-zero elements in the row or column of the sparse matrix + // corresponding to input_index. + int num_batched = 0; + EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() + + input_index * factor_dim * factor_dim, + factor_dim, factor_dim); + auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>(); + for (int64 p = shard.first; p < shard.second; ++p) { + const int64 i = perm[p]; + // Check that all entries in the shard have the same input index. + CHECK_EQ(input_index, get_input_index(i)); + const int64 factor_index = get_factor_index(i); + const float input_value = input_values_vec(i); + const float weight = + input_weights_vec(input_index) * factor_weights_vec(factor_index); + CHECK_GE(weight, 0); + factor_batch.col(num_batched) = + factors_mat.col(factor_index) * std::sqrt(weight); + ++num_batched; + if (num_batched == kMaxBatchSize) { + lhs_symm.rankUpdate(factor_batch); + num_batched = 0; + } + + rhs_mat.col(input_index) += + input_value * (w_0 + weight) * factors_mat.col(factor_index); + } + if (num_batched != 0) { + auto factor_block = + factor_batch.block(0, 0, factors_mat.rows(), num_batched); + lhs_symm.rankUpdate(factor_block); + } + // Copy lower triangular to upper triangular part of normal equation + // matrix. + lhs_mat = lhs_symm; + counter.DecrementCount(); + }; + for (int i = 1; i < shards.size(); ++i) { + worker_threads.workers->Schedule(std::bind(work, shards[i])); + } + // Inline execute the 1st shard. + work(shards[0]); + counter.Wait(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU), + WALSComputePartialLhsAndRhsOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc new file mode 100644 index 0000000000..5f0d05254e --- /dev/null +++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc @@ -0,0 +1,69 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// ============================================================================== + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("KmeansPlusPlusInitialization") + .Input("points: float32") + .Input("num_to_sample: int64") + .Input("seed: int64") + .Input("num_retries_per_sample: int64") + .Output("samples: float32") + .Doc(R"( +Selects num_to_sample rows of input using the KMeans++ criterion. + +Rows of points are assumed to be input points. One row is selected at random. +Subsequent rows are sampled with probability proportional to the squared L2 +distance from the nearest row selected thus far till num_to_sample rows have +been sampled. + +points: Matrix of shape (n, d). Rows are assumed to be input points. +num_to_sample: Scalar. The number of rows to sample. This value must not be + larger than n. +seed: Scalar. Seed for initializing the random number generator. +num_retries_per_sample: Scalar. For each row that is sampled, this parameter + specifies the number of additional points to draw from the current + distribution before selecting the best. If a negative value is specified, a + heuristic is used to sample O(log(num_to_sample)) additional points. +samples: Matrix of shape (num_to_sample, d). The sampled rows. +)"); + +REGISTER_OP("NearestNeighbors") + .Input("points: float32") + .Input("centers: float32") + .Input("k: int64") + .Output("nearest_center_indices: int64") + .Output("nearest_center_distances: float32") + .Doc(R"( +Selects the k nearest centers for each point. + +Rows of points are assumed to be input points. Rows of centers are assumed to be +the list of candidate centers. For each point, the k centers that have least L2 +distance to it are computed. + +points: Matrix of shape (n, d). Rows are assumed to be input points. +centers: Matrix of shape (m, d). Rows are assumed to be centers. +k: Scalar. Number of nearest centers to return for each point. If k is larger + than m, then only m centers are returned. +nearest_center_indices: Matrix of shape (n, min(m, k)). Each row contains the + indices of the centers closest to the corresponding point, ordered by + increasing distance. +nearest_center_distances: Matrix of shape (n, min(m, k)). Each row contains the + squared L2 distance to the corresponding center in nearest_center_indices. +)"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/ops/factorization_ops.cc b/tensorflow/contrib/factorization/ops/factorization_ops.cc new file mode 100644 index 0000000000..f72ea536fb --- /dev/null +++ b/tensorflow/contrib/factorization/ops/factorization_ops.cc @@ -0,0 +1,46 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy +// of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// ============================================================================== + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("WALSComputePartialLhsAndRhs") + .Input("factors: float32") + .Input("factor_weights: float32") + .Input("unobserved_weights: float32") + .Input("input_weights: float32") + .Input("input_indices: int64") + .Input("input_values: float32") + .Input("input_block_size: int64") + .Input("input_is_transpose: bool") + .Output("partial_lhs: float32") + .Output("partial_rhs: float32") + .Doc(R"( +Computes the partial left-hand side and right-hand side of WALS update. + +factors: Matrix of size m * k. +factor_weights: Vector of size m. Corresponds to column weights +unobserved_weights: Scalar. Weight for unobserved input entries. +input_weights: Vector of size n. Corresponds to row weights. +input_indices: Indices for the input SparseTensor. +input_values: Values for the input SparseTensor. +input_block_size: Scalar. Number of rows spanned by input. +input_is_transpose: If true, logically transposes the input for processing. +partial_lhs: 3-D tensor with size input_block_size x k x k. +partial_rhs: Matrix with size input_block_size x k. +)"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/factorization/python/__init__.py b/tensorflow/contrib/factorization/python/__init__.py new file mode 100644 index 0000000000..c44845bb08 --- /dev/null +++ b/tensorflow/contrib/factorization/python/__init__.py @@ -0,0 +1,20 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The python module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py new file mode 100644 index 0000000000..8d1337c153 --- /dev/null +++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py @@ -0,0 +1,157 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== + +"""Tests for clustering_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +# pylint: disable=unused-import +import tensorflow as tf +# pylint: enable=unused-import +import tensorflow.contrib.factorization.python.ops.clustering_ops as clustering_ops + + +class KmeansPlusPlusInitializationTest(tf.test.TestCase): + + # All but one input point are close to (101, 1). With uniform random sampling, + # it is highly improbable for (-1, -1) to be selected. + def setUp(self): + self._points = np.array([ + [100., 0.], + [101., 2.], + [102., 0.], + [100., 1.], + [100., 2.], + [101., 0.], + [101., 0.], + [101., 1.], + [102., 0.], + [-1., -1.] + ]).astype(np.float32) + + def runTestWithSeed(self, seed): + with self.test_session(): + sampled_points = clustering_ops.kmeans_plus_plus_initialization( + self._points, 3, seed, (seed % 5) - 1) + self.assertAllClose(sorted(sampled_points.eval().tolist()), [ + [-1., -1.], + [101., 1.], + [101., 1.] + ], atol=1.0) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + +# A simple test that can be verified by hand. +class NearestCentersTest(tf.test.TestCase): + + def setUp(self): + self._points = np.array([ + [100., 0.], + [101., 2.], + [99., 2.], + [1., 1.] + ]).astype(np.float32) + + self._centers = np.array([ + [100., 0.], + [99., 1.], + [50., 50.], + [0., 0.], + [1., 1.] + ]).astype(np.float32) + + def testNearest1(self): + with self.test_session(): + [indices, distances] = clustering_ops.nearest_neighbors(self._points, + self._centers, 1) + self.assertAllClose(indices.eval(), [[0], [0], [1], [4]]) + self.assertAllClose(distances.eval(), [[0.], [5.], [1.], [0.]]) + + def testNearest2(self): + with self.test_session(): + [indices, distances] = clustering_ops.nearest_neighbors(self._points, + self._centers, 2) + self.assertAllClose(indices.eval(), + [[0, 1], [0, 1], [1, 0], [4, 3]]) + self.assertAllClose(distances.eval(), + [[0., 2.], [5., 5.], [1., 5.], [0., 2.]]) + + +# A test with large inputs. +class NearestCentersLargeTest(tf.test.TestCase): + + def setUp(self): + num_points = 1000 + num_centers = 2000 + num_dim = 100 + max_k = 5 + # Construct a small number of random points and later tile them. + points_per_tile = 10 + assert num_points % points_per_tile == 0 + points = np.random.standard_normal([points_per_tile, num_dim]).astype( + np.float32) + # Construct random centers. + self._centers = np.random.standard_normal([num_centers, num_dim]).astype( + np.float32) + # Exhaustively compute expected nearest neighbors. + def squared_distance(x, y): + return np.linalg.norm(x - y, ord=2) ** 2 + nearest_neighbors = [sorted([(squared_distance(point, self._centers[j]), j) + for j in range(num_centers)])[:max_k] + for point in points] + expected_nearest_neighbor_indices = np.array( + [[i for _, i in nn] for nn in nearest_neighbors]) + expected_nearest_neighbor_squared_distances = np.array( + [[dist for dist, _ in nn] for nn in nearest_neighbors]) + # Tile points and expected results to reach requested size (num_points) + (self._points, + self._expected_nearest_neighbor_indices, + self._expected_nearest_neighbor_squared_distances) = ( + np.tile(x, (num_points / points_per_tile, 1)) + for x in (points, + expected_nearest_neighbor_indices, + expected_nearest_neighbor_squared_distances)) + + def testNearest1(self): + with self.test_session(): + [indices, distances] = clustering_ops.nearest_neighbors( + self._points, self._centers, 1) + self.assertAllClose(indices.eval(), + self._expected_nearest_neighbor_indices[:, [0]]) + self.assertAllClose( + distances.eval(), + self._expected_nearest_neighbor_squared_distances[:, [0]]) + + def testNearest5(self): + with self.test_session(): + [indices, distances] = clustering_ops.nearest_neighbors( + self._points, self._centers, 5) + self.assertAllClose(indices.eval(), + self._expected_nearest_neighbor_indices[:, 0:5]) + self.assertAllClose( + distances.eval(), + self._expected_nearest_neighbor_squared_distances[:, 0:5]) + + +if __name__ == '__main__': + np.random.seed(0) + tf.test.main() diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py new file mode 100644 index 0000000000..1ce1a7a2d5 --- /dev/null +++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py @@ -0,0 +1,85 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== + +"""Tests for wals_solver_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import tensorflow as tf +from tensorflow.contrib.factorization.python.ops.factorization_ops import wals_compute_partial_lhs_and_rhs + + +def SparseBlock3x3(): + ind = np.array([[0, 0], [0, 2], [1, 1], [2, 0], [2, 1], [3, 2]]).astype( + np.int64) + val = np.array([0.1, 0.2, 1.1, 2.0, 2.1, 3.2]).astype(np.float32) + shape = np.array([4, 3]).astype(np.int64) + return tf.SparseTensor(ind, val, shape) + + +class WalsSolverOpsTest(tf.test.TestCase): + + def setUp(self): + self._column_factors = np.array([ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + ]).astype(np.float32) + self._row_factors = np.array([ + [0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.1, 1.2, 1.3] + ]).astype(np.float32) + self._column_weights = np.array([0.1, 0.2, 0.3]).astype(np.float32) + self._row_weights = np.array([0.1, 0.2, 0.3, 0.4]).astype(np.float32) + self._unobserved_weights = 0.1 + + def testWalsSolverLhs(self): + sparse_block = SparseBlock3x3() + with self.test_session(): + [lhs_tensor, rhs_matrix] = wals_compute_partial_lhs_and_rhs( + self._column_factors, self._column_weights, self._unobserved_weights, + self._row_weights, sparse_block.indices, sparse_block.values, + sparse_block.shape[0], False) + self.assertAllClose(lhs_tensor.eval(), [ + [ + [0.014800, 0.017000, 0.019200], + [0.017000, 0.019600, 0.022200], + [0.019200, 0.022200, 0.025200], + ], [ + [0.0064000, 0.0080000, 0.0096000], + [0.0080000, 0.0100000, 0.0120000], + [0.0096000, 0.0120000, 0.0144000], + ], [ + [0.0099000, 0.0126000, 0.0153000], + [0.0126000, 0.0162000, 0.0198000], + [0.0153000, 0.0198000, 0.0243000], + ], [ + [0.058800, 0.067200, 0.075600], + [0.067200, 0.076800, 0.086400], + [0.075600, 0.086400, 0.097200], + ] + ]) + self.assertAllClose( + rhs_matrix.eval(), + [[0.019300, 0.023000, 0.026700], [0.061600, 0.077000, 0.092400], + [0.160400, 0.220000, 0.279600], [0.492800, 0.563200, 0.633600]]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py new file mode 100644 index 0000000000..5f1ac69017 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -0,0 +1,409 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Clustering Operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.factorization.python.ops import gen_clustering_ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.factorization.python.ops.gen_clustering_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.framework import ops +from tensorflow.python.framework.load_library import load_op_library +from tensorflow.python.ops.embedding_ops import embedding_lookup +from tensorflow.python.platform import resource_loader + +_clustering_ops = load_op_library(resource_loader.get_path_to_datafile( + '_clustering_ops.so')) +assert _clustering_ops, 'Could not load _clustering_ops.so' + +# Euclidean distance between vectors U and V is defined as ||U - V||_F which is +# the square root of the sum of the absolute squares of the elements difference. +SQUARED_EUCLIDEAN_DISTANCE = 'squared_euclidean' +# Cosine distance between vectors U and V is defined as +# 1 - (U \dot V) / (||U||_F ||V||_F) +COSINE_DISTANCE = 'cosine' + +RANDOM_INIT = 'random' +KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' + + +class KMeans(object): + """Creates the graph for k-means clustering.""" + + def __init__(self, + inputs, + num_clusters, + initial_clusters=RANDOM_INIT, + distance_metric=SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=False, + random_seed=0, + kmeans_plus_plus_num_retries=2): + """Creates an object for generating KMeans clustering graph. + + Args: + inputs: An input tensor or list of input tensors + num_clusters: number of clusters. + initial_clusters: Specifies the clusters used during initialization. Can + be a tensor or numpy array, or a function that generates the clusters. + Can also be "random" to specify that clusters should be chosen randomly + from input data. + distance_metric: distance metric used for clustering. + use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume + full batch. + random_seed: Seed for PRNG used to initialize seeds. + kmeans_plus_plus_num_retries: For each point that is sampled during + kmeans++ initialization, this parameter specifies the number of + additional points to draw from the current distribution before selecting + the best. If a negative value is specified, a heuristic is used to + sample O(log(num_to_sample)) additional points. + """ + self._inputs = inputs if isinstance(inputs, list) else [inputs] + assert num_clusters > 0, num_clusters + self._num_clusters = num_clusters + self._initial_clusters = initial_clusters + assert distance_metric in [SQUARED_EUCLIDEAN_DISTANCE, COSINE_DISTANCE] + self._distance_metric = distance_metric + self._use_mini_batch = use_mini_batch + self._random_seed = random_seed + self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + + @classmethod + def _distance_graph(cls, inputs, clusters, distance_metric): + """Computes distance between each input and each cluster center. + + Args: + inputs: list of input Tensors. + clusters: cluster Tensor. + distance_metric: distance metric used for clustering + + Returns: + list of Tensors, where each element corresponds to each element in inputs. + The value is the distance of each row to all the cluster centers. + Currently only Euclidean distance and cosine distance are supported. + """ + assert isinstance(inputs, list) + if distance_metric == SQUARED_EUCLIDEAN_DISTANCE: + return cls._compute_euclidean_distance(inputs, clusters) + elif distance_metric == COSINE_DISTANCE: + return cls._compute_cosine_distance(inputs, clusters, + inputs_normalized=True) + else: + assert False, ('Unsupported distance metric passed to Kmeans %s' + % str(distance_metric)) + + @classmethod + def _compute_euclidean_distance(cls, inputs, clusters): + """Computes Euclidean distance between each input and each cluster center. + + Args: + inputs: list of input Tensors. + clusters: cluster Tensor. + + Returns: + list of Tensors, where each element corresponds to each element in inputs. + The value is the distance of each row to all the cluster centers. + """ + output = [] + for inp in inputs: + with ops.colocate_with(inp): + # Computes Euclidean distance. Note the first and third terms are + # broadcast additions. + squared_distance = (tf.reduce_sum(tf.square(inp), 1, keep_dims=True) - + 2 * tf.matmul(inp, clusters, transpose_b=True) + + tf.transpose(tf.reduce_sum(tf.square(clusters), + 1, + keep_dims=True))) + output.append(squared_distance) + + return output + + @classmethod + def _compute_cosine_distance(cls, inputs, clusters, inputs_normalized=True): + """Computes cosine distance between each input and each cluster center. + + Args: + inputs: list of input Tensor. + clusters: cluster Tensor + inputs_normalized: if True, it assumes that inp and clusters are + normalized and computes the dot product which is equivalent to the cosine + distance. Else it L2 normalizes the inputs first. + + Returns: + list of Tensors, where each element corresponds to each element in inp. + The value is the distance of each row to all the cluster centers. + """ + output = [] + if not inputs_normalized: + with ops.colocate_with(clusters): + clusters = tf.nn.l2_normalize(clusters, dim=1) + for inp in inputs: + with ops.colocate_with(inp): + if not inputs_normalized: + inp = tf.nn.l2_normalize(inp, dim=1) + output.append(1 - tf.matmul(inp, clusters, transpose_b=True)) + return output + + def _infer_graph(self, inputs, clusters): + """Maps input to closest cluster and the score. + + Args: + inputs: list of input Tensors. + clusters: Tensor of cluster centers. + + Returns: + List of tuple, where each value in tuple corresponds to a value in inp. + The tuple has following three elements: + all_scores: distance of each input to each cluster center. + score: distance of each input to closest cluster center. + cluster_idx: index of cluster center closest to the corresponding input. + """ + assert isinstance(inputs, list) + # Pairwise distances are used only by transform(). In all other cases, this + # sub-graph is not evaluated. + scores = self._distance_graph(inputs, clusters, self._distance_metric) + output = [] + if (self._distance_metric == COSINE_DISTANCE and + not self._clusters_l2_normalized()): + # The cosine distance between normalized vectors x and y is the same as + # 2 * squared_euclidian_distance. We are using this fact and reusing the + # nearest_neighbors op. + # TODO(ands): Support COSINE distance in nearest_neighbors and remove + # this. + with ops.colocate_with(clusters): + clusters = tf.nn.l2_normalize(clusters, dim=1) + for inp, score in zip(inputs, scores): + with ops.colocate_with(inp): + (indices, + distances) = gen_clustering_ops.nearest_neighbors(inp, clusters, 1) + if self._distance_metric == COSINE_DISTANCE: + distances *= 0.5 + output.append((score, tf.squeeze(distances), tf.squeeze(indices))) + return zip(*output) + + def _init_clusters_random(self): + """Does random initialization of clusters. + + Returns: + Tensor of randomly initialized clusters. + """ + num_data = tf.add_n([tf.shape(inp)[0] for inp in self._inputs]) + # Note that for mini-batch k-means, we should ensure that the batch size of + # data used during initialization is sufficiently large to avoid duplicated + # clusters. + with tf.control_dependencies( + [tf.assert_less_equal(self._num_clusters, num_data)]): + indices = tf.random_uniform(tf.reshape(self._num_clusters, [-1]), + minval=0, + maxval=tf.cast(num_data, tf.int64), + seed=self._random_seed, + dtype=tf.int64) + clusters_init = embedding_lookup(self._inputs, indices, + partition_strategy='div') + return clusters_init + + def _clusters_l2_normalized(self): + """Returns True if clusters centers are kept normalized.""" + return self._distance_metric == COSINE_DISTANCE and not self._use_mini_batch + + def _init_clusters(self): + """Initialization of clusters. + + Returns: + Tuple with following elements: + cluster_centers: a Tensor for storing cluster centers + cluster_counts: a Tensor for storing counts of points assigned to this + cluster. This is used by mini-batch training. + """ + init = self._initial_clusters + if init == RANDOM_INIT: + clusters_init = self._init_clusters_random() + elif init == KMEANS_PLUS_PLUS_INIT: + # Points from only the first shard are used for initializing centers. + # TODO(ands): Use all points. + clusters_init = gen_clustering_ops.kmeans_plus_plus_initialization( + self._inputs[0], self._num_clusters, self._random_seed, + self._kmeans_plus_plus_num_retries) + elif callable(init): + clusters_init = init(self._inputs, self._num_clusters) + elif not isinstance(init, str): + clusters_init = init + else: + assert False, 'Unsupported init passed to Kmeans %s' % str(init) + if self._distance_metric == COSINE_DISTANCE and clusters_init is not None: + clusters_init = tf.nn.l2_normalize(clusters_init, dim=1) + clusters_init = clusters_init if clusters_init is not None else [] + cluster_centers = tf.Variable(clusters_init, + name='clusters', + validate_shape=False) + cluster_counts = (tf.Variable(tf.zeros([self._num_clusters], + dtype=tf.int64)) + if self._use_mini_batch else None) + return cluster_centers, cluster_counts + + @classmethod + def _l2_normalize_data(cls, inputs): + """Normalized the input data.""" + output = [] + for inp in inputs: + with ops.colocate_with(inp): + output.append(tf.nn.l2_normalize(inp, dim=1)) + return output + + def training_graph(self): + """Generate a training graph for kmeans algorithm. + + Returns: + A tuple consisting of: + all_scores: A matrix (or list of matrices) of dimensions (num_input, + num_clusters) where the value is the distance of an input vector and a + cluster center. + cluster_idx: A vector (or list of vectors). Each element in the vector + corresponds to an input row in 'inp' and specifies the cluster id + corresponding to the input. + scores: Similar to cluster_idx but specifies the distance to the + assigned cluster instead. + training_op: an op that runs an iteration of training. + """ + # Implementation of kmeans. + inputs = self._inputs + cluster_centers_var, total_counts = self._init_clusters() + cluster_centers = cluster_centers_var + + if self._distance_metric == COSINE_DISTANCE: + inputs = self._l2_normalize_data(inputs) + if not self._clusters_l2_normalized(): + cluster_centers = tf.nn.l2_normalize(cluster_centers, dim=1) + + all_scores, scores, cluster_idx = self._infer_graph(inputs, cluster_centers) + if self._use_mini_batch: + training_op = self._mini_batch_training_op( + inputs, cluster_idx, cluster_centers, cluster_centers_var, + total_counts) + else: + assert cluster_centers == cluster_centers_var + training_op = self._full_batch_training_op(inputs, cluster_idx, + cluster_centers_var) + return all_scores, cluster_idx, scores, training_op + + def _mini_batch_training_op(self, inputs, cluster_idx_list, + cluster_centers, cluster_centers_var, + total_counts): + """Creates an op for training for mini batch case. + + Args: + inputs: list of input Tensors. + cluster_idx_list: A vector (or list of vectors). Each element in the + vector corresponds to an input row in 'inp' and specifies the cluster id + corresponding to the input. + cluster_centers: Tensor of cluster centers, possibly normalized. + cluster_centers_var: Tensor Ref of cluster centers. + total_counts: Tensor Ref of cluster counts. + + Returns: + An op for doing an update of mini-batch k-means. + """ + update_ops = [] + for inp, cluster_idx in zip(inputs, cluster_idx_list): + with ops.colocate_with(inp): + assert total_counts is not None + cluster_idx = tf.reshape(cluster_idx, [-1]) + # Dedupe the unique ids of cluster_centers being updated so that updates + # can be locally aggregated. + unique_ids, unique_idx = tf.unique(cluster_idx) + num_unique_cluster_idx = tf.size(unique_ids) + # Fetch the old values of counts and cluster_centers. + with ops.colocate_with(total_counts): + old_counts = tf.gather(total_counts, unique_ids) + with ops.colocate_with(cluster_centers): + old_cluster_centers = tf.gather(cluster_centers, unique_ids) + # Locally aggregate the increment to counts. + count_updates = tf.unsorted_segment_sum( + tf.ones_like(unique_idx, dtype=total_counts.dtype), + unique_idx, + num_unique_cluster_idx) + # Locally compute the sum of inputs mapped to each id. + # For a cluster with old cluster value x, old count n, and with data + # d_1,...d_k newly assigned to it, we recompute the new value as + # x += (sum_i(d_i) - k * x) / (n + k). + # Compute sum_i(d_i), see comment above. + cluster_center_updates = tf.unsorted_segment_sum( + inp, + unique_idx, + num_unique_cluster_idx) + # Shape to enable broadcasting count_updates and learning_rate to inp. + # It extends the shape with 1's to match the rank of inp. + broadcast_shape = tf.concat( + 0, + [tf.reshape(num_unique_cluster_idx, [1]), + tf.ones(tf.reshape(tf.rank(inp) - 1, [1]), dtype=tf.int32)]) + # Subtract k * x, see comment above. + cluster_center_updates -= tf.cast( + tf.reshape(count_updates, broadcast_shape), + inp.dtype) * old_cluster_centers + learning_rate = tf.inv(tf.cast(old_counts + count_updates, inp.dtype)) + learning_rate = tf.reshape(learning_rate, broadcast_shape) + # scale by 1 / (n + k), see comment above. + cluster_center_updates *= learning_rate + # Apply the updates. + update_counts = tf.scatter_add( + total_counts, + unique_ids, + count_updates) + update_cluster_centers = tf.scatter_add( + cluster_centers_var, + unique_ids, + cluster_center_updates) + update_ops.extend([update_counts, update_cluster_centers]) + return tf.group(*update_ops) + + def _full_batch_training_op(self, inputs, cluster_idx_list, cluster_centers): + """Creates an op for training for full batch case. + + Args: + inputs: list of input Tensors. + cluster_idx_list: A vector (or list of vectors). Each element in the + vector corresponds to an input row in 'inp' and specifies the cluster id + corresponding to the input. + cluster_centers: Tensor Ref of cluster centers. + + Returns: + An op for doing an update of mini-batch k-means. + """ + cluster_sums = [] + cluster_counts = [] + epsilon = tf.constant(1e-6, dtype=inputs[0].dtype) + for inp, cluster_idx in zip(inputs, cluster_idx_list): + with ops.colocate_with(inp): + cluster_sums.append(tf.unsorted_segment_sum(inp, + cluster_idx, + self._num_clusters)) + cluster_counts.append(tf.unsorted_segment_sum( + tf.reshape(tf.ones(tf.reshape(tf.shape(inp)[0], [-1])), [-1, 1]), + cluster_idx, + self._num_clusters)) + with ops.colocate_with(cluster_centers): + new_clusters_centers = tf.add_n(cluster_sums) / ( + tf.cast(tf.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) + if self._clusters_l2_normalized(): + new_clusters_centers = tf.nn.l2_normalize(new_clusters_centers, dim=1) + return tf.assign(cluster_centers, new_clusters_centers) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py new file mode 100644 index 0000000000..f68869f528 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -0,0 +1,467 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Ops for matrix factorization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +# pylint: disable=wildcard-import,undefined-variable +from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import * +from tensorflow.python.framework import ops +from tensorflow.python.framework.load_library import load_op_library +from tensorflow.python.platform import resource_loader + +_factorization_ops = load_op_library(resource_loader.get_path_to_datafile( + "_factorization_ops.so")) +assert _factorization_ops, "Could not load _factorization_ops.so" + + +class WALSModel(object): + r"""A model for Weighted Alternating Least Squares matrix factorization. + + It minimizes the following loss function over U, V: + \\( ||W \odot (A - U V^T) ||_F^2 + \lambda (||U||_F^2 + ||V||_F^2) )\\ + where, + A: input matrix, + W: weight matrix, + U, V: row_factors and column_factors matrices, + \\(\lambda)\\: regularization. + Also we assume that W is of the following special form: + \\( W_{ij} = W_0 + R_i * C_j )\\ if \\(A_{ij} \ne 0)\\, + \\(W_{ij} = W_0)\\ otherwise. + where, + \\(W_0)\\: unobserved_weight, + \\(R_i)\\: row_weights, + \\(C_j)\\: col_weights. + + Note that the current implementation assumes that row_factors and col_factors + can individually fit into the memory of each worker. + """ + + def __init__(self, + input_rows, + input_cols, + n_components, + unobserved_weight=0.1, + regularization=None, + row_init="random", + col_init="random", + num_row_shards=1, + num_col_shards=1, + row_weights=None, + col_weights=None): + """Creates model for WALS matrix factorization. + + Args: + input_rows: total number of rows for input matrix. + input_cols: total number of cols for input matrix. + n_components: number of dimensions to use for the factors. + unobserved_weight: weight given to unobserved entries of matrix. + regularization: weight of L2 regularization term. If None, no + regularization is done. + row_init: initializer for row factor. Can be a tensor or numpy constant. + If set to "random", the value is initialized randomly. + col_init: initializer for column factor. See row_init for details. + num_row_shards: number of shards to use for row factors. + num_col_shards: number of shards to use for column factors. + row_weights: If not None, along with col_weights, used to compute the + weight of an observed entry. w_ij = unobserved_weight + row_weights[i] * + col_weights[j]. If None, then w_ij = unobserved_weight, which simplifies + to ALS. + col_weights: See row_weights + """ + self._input_rows = input_rows + self._input_cols = input_cols + self._num_row_shards = num_row_shards + self._num_col_shards = num_col_shards + self._n_components = n_components + self._unobserved_weight = unobserved_weight + self._regularization = (tf.diag(tf.constant(regularization, + shape=[self._n_components], + dtype=tf.float32)) + if regularization is not None else None) + assert (row_weights is None) == (col_weights is None) + self._row_weights = WALSModel._create_weights(row_weights, + self._input_rows, + self._num_row_shards, + "row_weights") + self._col_weights = WALSModel._create_weights(col_weights, + self._input_cols, + self._num_col_shards, + "col_weights") + self._row_factors = self._create_factors(self._input_rows, + self._n_components, + self._num_row_shards, + row_init, + "row_factors") + self._col_factors = self._create_factors(self._input_cols, + self._n_components, + self._num_col_shards, + col_init, + "col_factors") + self._create_transient_vars() + + @property + def row_factors(self): + """Returns a list of tensors corresponding to row factor shards.""" + return self._row_factors + + @property + def col_factors(self): + """Returns a list of tensors corresponding to column factor shards.""" + return self._col_factors + + @property + def initialize_op(self): + """Returns an op for initializing tensorflow variables.""" + all_vars = self._row_factors + self._col_factors + if self._row_weights is not None: + assert self._col_weights is not None + all_vars.extend(self._row_weights + self._col_weights) + return tf.initialize_variables(all_vars) + + @classmethod + def _shard_sizes(cls, dims, num_shards): + """Helper function to split dims values into num_shards.""" + shard_size, residual = divmod(dims, num_shards) + return [shard_size + 1] * residual + [shard_size] * (num_shards - residual) + + @classmethod + def _create_factors(cls, rows, cols, num_shards, init, name): + """Helper function to create row and column factors.""" + if callable(init): + init = init() + if isinstance(init, list): + assert len(init) == num_shards + elif isinstance(init, str) and init == "random": + pass + elif num_shards == 1: + init = [init] + sharded_matrix = [] + sizes = cls._shard_sizes(rows, num_shards) + assert len(sizes) == num_shards + + def make_initializer(i, size): + def initializer(): + if init == "random": + return tf.random_normal([size, cols]) + else: + return init[i] + return initializer + + for i, size in enumerate(sizes): + var_name = "%s_shard_%d" % (name, i) + var_init = make_initializer(i, size) + sharded_matrix.append(tf.Variable( + var_init, + dtype=tf.float32, + name=var_name)) + + return sharded_matrix + + @staticmethod + def _create_weights(wt_init, num_wts, num_shards, name): + """Helper functions to create sharded weight vector. + + Args: + wt_init: init value for the weight. If None, weights are not created. + num_wts: total size of all the weight shards + num_shards: number of shards for the weights + name: name for the new Variables. + + Returns: + A list of weight shard Tensors. + """ + if wt_init is None: + return None + if num_shards == 1 and len(wt_init) == num_wts: + wt_init = [wt_init] + assert len(wt_init) == num_shards + return [tf.Variable(wt_init[i], + dtype=tf.float32, + name="%s_shard_%d" % (name, i)) + for i in xrange(num_shards)] + + @staticmethod + def _transient_var(name): + """Helper function to create a Variable.""" + return tf.Variable(1.0, + trainable=False, + collections=[tf.GraphKeys.LOCAL_VARIABLES], + validate_shape=False, + name=name) + + def _cached_copy(self, var, name): + """Helper function to create a worker cached copy of a Variable. + + Args: + var: Variable or list of Variable to cache. If a list, the items are + concatenated along dimension 0 to get the cached entry. + name: name of cached variable. + + Returns: + Tuple consisting of following three entries: + cache: the new transient Variable. + cache_init: op to initialize the Variable + cache_reset: op to reset the Variable to some default value + """ + if var is None: + return None, None, None + else: + cache = WALSModel._transient_var(name) + with ops.colocate_with(cache): + if isinstance(var, list): + assert var + if len(var) == 1: + var = var[0] + else: + var = tf.concat(0, var) + + cache_init = tf.assign(cache, var, validate_shape=False) + cache_reset = tf.assign(cache, 1.0, validate_shape=False) + return cache, cache_init, cache_reset + + def _create_transient_vars(self): + """Creates local cache of row and column factors and weights. + + Note that currently the caching strategy is as follows: + When initiating a row update, column factors are cached while row factors + cache is reset. Similarly when initiating a column update, row factors are + cached while cached column factors are flushed. + Column and row weights are always cached. If memory becomes a bottleneck, + they could be similarly flushed. + """ + (self._row_factors_cache, + row_factors_cache_init, + row_factors_cache_reset) = self._cached_copy(self._row_factors, + "row_factors_cache") + (self._col_factors_cache, + col_factors_cache_init, + col_factors_cache_reset) = self._cached_copy(self._col_factors, + "col_factors_cache") + (self._row_wt_cache, + row_wt_cache_init, + _) = self._cached_copy(self._row_weights, "row_wt_cache") + (self._col_wt_cache, + col_wt_cache_init, + _) = self._cached_copy(self._col_weights, "col_wt_cache") + + if self._row_wt_cache is not None: + assert self._col_wt_cache is not None + self._worker_init = tf.group(row_wt_cache_init, + col_wt_cache_init, + name="worker_init") + else: + self._worker_init = tf.no_op(name="worker_init") + + self._row_updates_init = tf.group(col_factors_cache_init, + row_factors_cache_reset) + self._col_updates_init = tf.group(row_factors_cache_init, + col_factors_cache_reset) + + @property + def worker_init(self): + """Op to initialize worker state once before starting any updates.""" + return self._worker_init + + @property + def initialize_row_update_op(self): + """Op to initialize worker state before starting row updates.""" + return self._row_updates_init + + @property + def initialize_col_update_op(self): + """Op to initialize worker state before starting column updates.""" + return self._col_updates_init + + @staticmethod + def _get_sharding_func(size, num_shards): + """Create sharding function for scatter update.""" + def func(ids): + if num_shards == 1: + return None, ids + else: + ids_per_shard = size // num_shards + extras = size % num_shards + assignments = tf.maximum(ids // (ids_per_shard + 1), + (ids - extras) // ids_per_shard) + new_ids = tf.select(assignments < extras, + ids % (ids_per_shard + 1), + (ids - extras) % ids_per_shard) + return assignments, new_ids + return func + + @classmethod + def scatter_update(cls, factor, indices, values, sharding_func): + """Helper function for doing sharded scatter update.""" + assert isinstance(factor, list) + if len(factor) == 1: + with ops.colocate_with(factor[0]): + # TODO(agarwal): assign instead of scatter update for full batch update. + return tf.scatter_update(factor[0], indices, values).op + else: + num_shards = len(factor) + assignments, new_ids = sharding_func(indices) + assert assignments is not None + assignments = tf.cast(assignments, tf.int32) + sharded_ids = tf.dynamic_partition(new_ids, assignments, num_shards) + sharded_values = tf.dynamic_partition(values, assignments, num_shards) + updates = [] + for i in xrange(num_shards): + updates.append(tf.scatter_update(factor[i], + sharded_ids[i], + sharded_values[i])) + return tf.group(*updates) + + def update_row_factors(self, sp_input=None, transpose_input=False): + """Updates the row factors. + + Args: + sp_input: A SparseTensor representing a subset of rows of the full input + in any order. Please note that this SparseTensor must retain the + indexing as the original input. + transpose_input: If true, logically transposes the input. + + Returns: + A tuple consisting of the following two elements: + new_values: New values for the row factors. + update_op: An op that assigns the newly computed values to the row + factors. + """ + return self._process_input_helper(True, sp_input=sp_input, + transpose_input=transpose_input) + + def update_col_factors(self, sp_input=None, transpose_input=False): + """Updates the column factors. + + Args: + sp_input: A SparseTensor representing a subset of columns of the full + input. Please refer to comments for update_row_factors for + restrictions. + transpose_input: If true, logically transposes the input. + + Returns: + A tuple consisting of the following two elements: + new_values: New values for the column factors. + update_op: An op that assigns the newly computed values to the column + factors. + """ + return self._process_input_helper(False, sp_input=sp_input, + transpose_input=transpose_input) + + def _process_input_helper(self, update_row_factors, + sp_input=None, transpose_input=False): + """Creates the graph for processing a sparse slice of input. + + Args: + update_row_factors: if True, update the row_factors, else update the + column factors. + sp_input: Please refer to comments for update_row_factors and + update_col_factors. + transpose_input: If true, logically transpose the input. + + Returns: + A tuple consisting of the following two elements: + new_values: New values for the row/column factors. + update_op: An op that assigns the newly computed values to the row/column + factors. + """ + assert isinstance(sp_input, ops.SparseTensor) + + if update_row_factors: + left = self._row_factors + right = self._col_factors_cache + row_weights = self._row_wt_cache + col_weights = self._col_wt_cache + sharding_func = WALSModel._get_sharding_func(self._input_rows, + self._num_row_shards) + right_length = self._input_cols + else: + left = self._col_factors + right = self._row_factors_cache + row_weights = self._col_wt_cache + col_weights = self._row_wt_cache + sharding_func = WALSModel._get_sharding_func(self._input_cols, + self._num_col_shards) + right_length = self._input_rows + transpose_input = not transpose_input + + # Note that the row indices of sp_input are based on the original full input + # Here we reindex the rows and give them contiguous ids starting at 0. + # We use tf.unique to achieve this reindexing. Note that this is done so + # that the downstream kernel can assume that the input is "dense" along the + # row dimension. + row_ids, col_ids = tf.split(1, 2, sp_input.indices) + + if transpose_input: + update_indices, all_ids = tf.unique(col_ids[:, 0]) + col_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1) + else: + update_indices, all_ids = tf.unique(row_ids[:, 0]) + row_ids = tf.expand_dims(tf.cast(all_ids, tf.int64), 1) + + num_rows = tf.cast(tf.shape(update_indices)[0], tf.int64) + row_shape = tf.constant([right_length], tf.int64) + col_shape = [num_rows] + + new_sp_indices = tf.concat(1, [row_ids, col_ids]) + new_sp_shape = (tf.concat(0, [row_shape, col_shape]) if transpose_input + else tf.concat(0, [col_shape, row_shape])) + new_sp_input = tf.SparseTensor(indices=new_sp_indices, + values=sp_input.values, shape=new_sp_shape) + + # Compute lhs and rhs of the normal equations + total_lhs = (self._unobserved_weight * + tf.matmul(right, right, transpose_a=True)) + if self._regularization is not None: + total_lhs += self._regularization + if self._row_weights is None: + # Special case of ALS. Use a much simpler update rule. + total_rhs = (self._unobserved_weight * + tf.sparse_tensor_dense_matmul(new_sp_input, right, + adjoint_a=transpose_input)) + # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of + # transposing explicitly. + # TODO(rmlarsen): multi-thread tf.matrix_solve. + new_left_values = tf.transpose(tf.matrix_solve(total_lhs, + tf.transpose(total_rhs))) + else: + row_weights_slice = tf.gather(row_weights, update_indices) + partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( + right, + col_weights, + self._unobserved_weight, + row_weights_slice, + new_sp_input.indices, + new_sp_input.values, + num_rows, + transpose_input, + name="wals_compute_partial_lhs_rhs") + total_lhs = tf.expand_dims(total_lhs, 0) + partial_lhs + total_rhs = tf.expand_dims(total_rhs, -1) + new_left_values = tf.squeeze(tf.batch_matrix_solve(total_lhs, total_rhs), + [2]) + + return (new_left_values, + self.scatter_update(left, + update_indices, + new_left_values, + sharding_func)) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py new file mode 100644 index 0000000000..0ef49b7891 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -0,0 +1,456 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for factorization_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +from tensorflow.contrib.factorization.python.ops import factorization_ops + +INPUT_MATRIX = np.array( + [[0.1, 0.0, 0.2, 0.0, 0.4, 0.5, 0.0], + [0.0, 1.1, 0.0, 1.3, 1.4, 0.0, 1.6], + [2.0, 0.0, 0.0, 2.3, 0.0, 2.5, 0.0], + [3.0, 0.0, 3.2, 3.3, 0.0, 3.5, 0.0], + [0.0, 4.1, 0.0, 0.0, 4.4, 0.0, 4.6]]).astype(np.float32) + + +def np_matrix_to_tf_sparse(np_matrix, row_slices=None, + col_slices=None, transpose=False, + shuffle=False): + """Simple util to slice non-zero np matrix elements as tf.SparseTensor.""" + indices = np.nonzero(np_matrix) + + # Only allow slices of whole rows or whole columns. + assert not (row_slices is not None and col_slices is not None) + + if row_slices is not None: + selected_ind = np.concatenate( + [np.where(indices[0] == r)[0] for r in row_slices], 0) + indices = (indices[0][selected_ind], indices[1][selected_ind]) + + if col_slices is not None: + selected_ind = np.concatenate( + [np.where(indices[1] == c)[0] for c in col_slices], 0) + indices = (indices[0][selected_ind], indices[1][selected_ind]) + + if shuffle: + shuffled_ind = [x for x in range(len(indices[0]))] + random.shuffle(shuffled_ind) + indices = (indices[0][shuffled_ind], indices[1][shuffled_ind]) + + ind = (np.concatenate( + (np.expand_dims(indices[1], 1), + np.expand_dims(indices[0], 1)), 1).astype(np.int64) if transpose else + np.concatenate((np.expand_dims(indices[0], 1), + np.expand_dims(indices[1], 1)), 1).astype(np.int64)) + val = np_matrix[indices].astype(np.float32) + shape = (np.array( + [max(indices[1]) + 1, max(indices[0]) + 1]).astype(np.int64) if transpose + else np.array( + [max(indices[0]) + 1, max(indices[1]) + 1]).astype(np.int64)) + return tf.SparseTensor(ind, val, shape) + + +def sparse_input(): + return np_matrix_to_tf_sparse(INPUT_MATRIX) + + +class WalsModelTest(tf.test.TestCase): + + def setUp(self): + self.col_init = [ + # shard 0 + [[-0.36444709, -0.39077035, -0.32528427], + [1.19056475, 0.07231052, 2.11834812], + [0.93468881, -0.71099287, 1.91826844]], + # shard 1 + [[1.18160152, 1.52490723, -0.50015002], + [1.82574749, -0.57515913, -1.32810032]], + # shard 2 + [[-0.15515432, -0.84675711, 0.13097958], + [-0.9246484, 0.69117504, 1.2036494]] + ] + + self.row_wts = [[0.1, 0.2, 0.3], [0.4, 0.5]] + self.col_wts = [[0.1, 0.2, 0.3], + [0.4, 0.5], + [0.6, 0.7]] + self._wals_inputs = sparse_input() + + # Values of factor shards after running one iteration of row and column + # updates. + self._row_factors_0 = [[0.097689, -0.219293, -0.020780], + [0.50842, 0.64626, 0.22364], + [0.401159, -0.046558, -0.192854]] + self._row_factors_1 = [[1.20597, -0.48025, 0.35582], + [1.5564, 1.2528, 1.0528]] + self._col_factors_0 = [[2.4725, -1.2950, -1.9980], + [0.44625, 1.50771, 1.27118], + [1.39801, -2.10134, 0.73572]] + self._col_factors_1 = [[3.36509, -0.66595, -3.51208], + [0.57191, 1.59407, 1.33020]] + self._col_factors_2 = [[3.3459, -1.3341, -3.3008], + [0.57366, 1.83729, 1.26798]] + + def test_process_input(self): + with self.test_session(): + sp_feeder = tf.sparse_placeholder(tf.float32) + wals_model = factorization_ops.WALSModel(5, 7, 3, + num_row_shards=2, + num_col_shards=3, + regularization=0.01, + unobserved_weight=0.1, + col_init=self.col_init, + row_weights=self.row_wts, + col_weights=self.col_wts) + + wals_model.initialize_op.run() + wals_model.worker_init.run() + + # Split input into multiple sparse tensors with scattered rows. Note that + # this split can be different than the factor sharding and the inputs can + # consist of non-consecutive rows. Each row needs to include all non-zero + # elements in that row. + sp_r0 = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 2]).eval() + sp_r1 = np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=True).eval() + sp_r2 = np_matrix_to_tf_sparse(INPUT_MATRIX, [3], shuffle=True).eval() + input_scattered_rows = [sp_r0, sp_r1, sp_r2] + + # Test updating row factors. + # Here we feed in scattered rows of the input. + wals_model.initialize_row_update_op.run() + process_input_op = wals_model.update_row_factors(sp_input=sp_feeder, + transpose_input=False)[1] + for inp in input_scattered_rows: + feed_dict = {sp_feeder: inp} + process_input_op.run(feed_dict=feed_dict) + row_factors = [x.eval() for x in wals_model.row_factors] + + self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) + self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) + + # Split input into multiple sparse tensors with scattered columns. Note + # that here the elements in the sparse tensors are not ordered and also + # do not need to consist of consecutive columns. However, each column + # needs to include all non-zero elements in that column. + sp_c0 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0]).eval() + sp_c1 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5, 3, 1], + shuffle=True).eval() + sp_c2 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 6]).eval() + sp_c3 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], + shuffle=True).eval() + + input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3] + + # Test updating column factors. + # Here we feed in scattered columns of the input. + wals_model.initialize_col_update_op.run() + process_input_op = wals_model.update_col_factors(sp_input=sp_feeder, + transpose_input=False)[1] + for inp in input_scattered_cols: + feed_dict = {sp_feeder: inp} + process_input_op.run(feed_dict=feed_dict) + col_factors = [x.eval() for x in wals_model.col_factors] + + self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) + self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) + self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) + + def test_process_input_transposed(self): + with self.test_session(): + sp_feeder = tf.sparse_placeholder(tf.float32) + wals_model = factorization_ops.WALSModel(5, 7, 3, + num_row_shards=2, + num_col_shards=3, + regularization=0.01, + unobserved_weight=0.1, + col_init=self.col_init, + row_weights=self.row_wts, + col_weights=self.col_wts) + + wals_model.initialize_op.run() + wals_model.worker_init.run() + + # Split input into multiple SparseTensors with scattered rows. + # Here the inputs are transposed. But the same constraints as described in + # the previous non-transposed test case apply to these inputs (before they + # are transposed). + sp_r0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 3], + transpose=True).eval() + sp_r1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [4, 1], + shuffle=True, transpose=True).eval() + sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval() + sp_r3_t = sp_r1_t + input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t] + + # Test updating row factors. + # Here we feed in scattered rows of the input. + # Note that the needed suffix of placeholder are in the order of test + # case name lexicographical order and then in the line order of where + # they appear. + wals_model.initialize_row_update_op.run() + process_input_op = wals_model.update_row_factors(sp_input=sp_feeder, + transpose_input=True)[1] + for inp in input_scattered_rows: + feed_dict = {sp_feeder: inp} + process_input_op.run(feed_dict=feed_dict) + row_factors = [x.eval() for x in wals_model.row_factors] + + self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) + self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) + + # Split input into multiple SparseTensors with scattered columns. + # Here the inputs are transposed. But the same constraints as described in + # the previous non-transposed test case apply to these inputs (before they + # are transposed). + sp_c0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[0, 1], + transpose=True).eval() + sp_c1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 2], + transpose=True).eval() + sp_c2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5], + transpose=True, shuffle=True).eval() + sp_c3_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], + transpose=True).eval() + + sp_c4_t = sp_c2_t + input_scattered_cols = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, + sp_c4_t] + + # Test updating column factors. + # Here we feed in scattered columns of the input. + wals_model.initialize_col_update_op.run() + process_input_op = wals_model.update_col_factors(sp_input=sp_feeder, + transpose_input=True)[1] + for inp in input_scattered_cols: + feed_dict = {sp_feeder: inp} + process_input_op.run(feed_dict=feed_dict) + col_factors = [x.eval() for x in wals_model.col_factors] + + self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) + self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) + self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) + + # Note that when row_weights and col_weights are 0, WALS gives dentical + # results as ALS (Alternating Least Squares). However our implementation does + # not handle the case of zero weights differently. Instead, when row_weights + # and col_weights are set to None, we interpret that as the ALS case, and + # trigger the more efficient ALS updates. + # Here we test that those two give identical results. + def test_als(self): + with self.test_session(): + col_init = np.random.rand(7, 3) + als_model = factorization_ops.WALSModel(5, 7, 3, + col_init=col_init, + row_weights=None, + col_weights=None) + + als_model.initialize_op.run() + als_model.worker_init.run() + als_model.initialize_row_update_op.run() + process_input_op = als_model.update_row_factors(self._wals_inputs)[1] + process_input_op.run() + row_factors1 = [x.eval() for x in als_model.row_factors] + + wals_model = factorization_ops.WALSModel(5, 7, 3, + col_init=col_init, + row_weights=[0] * 5, + col_weights=[0] * 7) + wals_model.initialize_op.run() + wals_model.worker_init.run() + wals_model.initialize_row_update_op.run() + process_input_op = wals_model.update_row_factors(self._wals_inputs)[1] + process_input_op.run() + row_factors2 = [x.eval() for x in wals_model.row_factors] + + for r1, r2 in zip(row_factors1, row_factors2): + self.assertAllClose(r1, r2, atol=1e-3) + + # Here we test partial column updates. + sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0], + shuffle=True).eval() + + sp_feeder = tf.sparse_placeholder(tf.float32) + feed_dict = {sp_feeder: sp_c} + als_model.initialize_col_update_op.run() + process_input_op = als_model.update_col_factors(sp_input=sp_feeder)[1] + process_input_op.run(feed_dict=feed_dict) + col_factors1 = [x.eval() for x in als_model.col_factors] + + feed_dict = {sp_feeder: sp_c} + wals_model.initialize_col_update_op.run() + process_input_op = wals_model.update_col_factors(sp_input=sp_feeder)[1] + process_input_op.run(feed_dict=feed_dict) + col_factors2 = [x.eval() for x in wals_model.col_factors] + + for c1, c2 in zip(col_factors1, col_factors2): + self.assertAllClose(c1, c2, atol=1e-3) + + def test_als_transposed(self): + with self.test_session(): + col_init = np.random.rand(7, 3) + als_model = factorization_ops.WALSModel(5, 7, 3, + col_init=col_init, + row_weights=None, + col_weights=None) + + als_model.initialize_op.run() + als_model.worker_init.run() + + wals_model = factorization_ops.WALSModel(5, 7, 3, + col_init=col_init, + row_weights=[0] * 5, + col_weights=[0] * 7) + wals_model.initialize_op.run() + wals_model.worker_init.run() + sp_feeder = tf.sparse_placeholder(tf.float32) + # Here test partial row update with identical inputs but with transposed + # input for als. + sp_r_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1], + transpose=True).eval() + sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval() + + feed_dict = {sp_feeder: sp_r_t} + als_model.initialize_row_update_op.run() + process_input_op = als_model.update_row_factors(sp_input=sp_feeder, + transpose_input=True)[1] + process_input_op.run(feed_dict=feed_dict) + # Only updated row 1 and row 3, so only compare these rows since others + # have randomly initialized values. + row_factors1 = [als_model.row_factors[0].eval()[1], + als_model.row_factors[0].eval()[3]] + + feed_dict = {sp_feeder: sp_r} + wals_model.initialize_row_update_op.run() + process_input_op = wals_model.update_row_factors(sp_input=sp_feeder)[1] + process_input_op.run(feed_dict=feed_dict) + # Only updated row 1 and row 3, so only compare these rows since others + # have randomly initialized values. + row_factors2 = [wals_model.row_factors[0].eval()[1], + wals_model.row_factors[0].eval()[3]] + for r1, r2 in zip(row_factors1, row_factors2): + self.assertAllClose(r1, r2, atol=1e-3) + + def simple_train(self, + model, + inp, + num_iterations): + """Helper function to train model on inp for num_iterations.""" + row_update_op = model.update_row_factors(sp_input=inp)[1] + col_update_op = model.update_col_factors(sp_input=inp)[1] + + model.initialize_op.run() + model.worker_init.run() + for _ in xrange(num_iterations): + model.initialize_row_update_op.run() + row_update_op.run() + model.initialize_col_update_op.run() + col_update_op.run() + + # Trains an ALS model for a low-rank matrix and make sure the product of + # factors is close to the original input. + def test_train_full_low_rank_als(self): + rows = 15 + cols = 11 + dims = 3 + with self.test_session(): + data = np.dot(np.random.rand(rows, 3), + np.random.rand(3, cols)).astype(np.float32) / 3.0 + indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] + values = data.reshape(-1) + inp = tf.SparseTensor(indices, values, [rows, cols]) + model = factorization_ops.WALSModel(rows, cols, dims, + regularization=1e-5, + row_weights=None, + col_weights=None) + self.simple_train(model, inp, 15) + row_factor = model.row_factors[0].eval() + col_factor = model.col_factors[0].eval() + self.assertAllClose(data, + np.dot(row_factor, np.transpose(col_factor)), + rtol=0.01, atol=0.01) + + # Trains a WALS model for a low-rank matrix and make sure the product of + # factors is close to the original input. + def test_train_full_low_rank_wals(self): + rows = 15 + cols = 11 + dims = 3 + + with self.test_session(): + data = np.dot(np.random.rand(rows, 3), + np.random.rand(3, cols)).astype(np.float32) / 3.0 + indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] + values = data.reshape(-1) + inp = tf.SparseTensor(indices, values, [rows, cols]) + model = factorization_ops.WALSModel(rows, cols, dims, + regularization=1e-5, + row_weights=[0] * rows, + col_weights=[0] * cols) + self.simple_train(model, inp, 15) + row_factor = model.row_factors[0].eval() + col_factor = model.col_factors[0].eval() + self.assertAllClose(data, + np.dot(row_factor, np.transpose(col_factor)), + rtol=0.01, atol=0.01) + + # Trains a WALS model for a partially observed low-rank matrix and makes + # sure the product of factors is reasonably close to the original input. + def test_train_matrix_completion_wals(self): + rows = 11 + cols = 9 + dims = 4 + def keep_index(x): + return not (x[0] + x[1]) % 4 + + with self.test_session(): + row_wts = 0.1 + np.random.rand(rows) + col_wts = 0.1 + np.random.rand(cols) + data = np.dot(np.random.rand(rows, 3), + np.random.rand(3, cols)).astype(np.float32) / 3.0 + indices = np.array( + list(filter(keep_index, + [[i, j] for i in xrange(rows) for j in xrange(cols)]))) + values = data[indices[:, 0], indices[:, 1]] + inp = tf.SparseTensor(indices, values, [rows, cols]) + model = factorization_ops.WALSModel(rows, cols, dims, + unobserved_weight=0.01, + regularization=0.001, + row_weights=row_wts, + col_weights=col_wts) + self.simple_train(model, inp, 10) + row_factor = model.row_factors[0].eval() + col_factor = model.col_factors[0].eval() + out = np.dot(row_factor, np.transpose(col_factor)) + for i in xrange(rows): + for j in xrange(cols): + if keep_index([i, j]): + self.assertNear(data[i][j], out[i][j], + err=0.2, msg="%d, %d" % (i, j)) + else: + self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py new file mode 100644 index 0000000000..4504afe289 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -0,0 +1,238 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Implementation of k-means clustering on top of learn (aka skflow) API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +import tensorflow as tf + +from tensorflow.contrib.factorization.python.ops import clustering_ops +from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators._sklearn import TransformerMixin +from tensorflow.contrib.learn.python.learn.io import data_feeder +from tensorflow.contrib.learn.python.learn.utils import checkpoints +from tensorflow.python.ops.control_flow_ops import with_dependencies + +SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE +COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE +RANDOM_INIT = clustering_ops.RANDOM_INIT +KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT + + +# TODO(agarwal,ands): support sharded input. +# TODO(agarwal,ands): enable stopping criteria based on improvements to cost. +# TODO(agarwal,ands): support random restarts. +class KMeansClustering(estimator.Estimator, + TransformerMixin): + """K-Means clustering.""" + SCORES = 'scores' + CLUSTER_IDX = 'cluster_idx' + CLUSTERS = 'clusters' + ALL_SCORES = 'all_scores' + + def __init__(self, + num_clusters, + model_dir=None, + initial_clusters=clustering_ops.RANDOM_INIT, + distance_metric=clustering_ops.SQUARED_EUCLIDEAN_DISTANCE, + random_seed=0, + use_mini_batch=True, + batch_size=128, + steps=10, + kmeans_plus_plus_num_retries=2, + continue_training=False, + config=None, + verbose=1): + """Creates a model for running KMeans training and inference. + + Args: + num_clusters: number of clusters to train. + model_dir: the directory to save the model results and log files. + initial_clusters: specifies how to initialize the clusters for training. + See clustering_ops.kmeans for the possible values. + distance_metric: the distance metric used for clustering. + See clustering_ops.kmeans for the possible values. + random_seed: Python integer. Seed for PRNG used to initialize centers. + use_mini_batch: If true, use the mini-batch k-means algorithm. Else assume + full batch. + batch_size: See TensorFlowEstimator + steps: See TensorFlowEstimator + kmeans_plus_plus_num_retries: For each point that is sampled during + kmeans++ initialization, this parameter specifies the number of + additional points to draw from the current distribution before selecting + the best. If a negative value is specified, a heuristic is used to + sample O(log(num_to_sample)) additional points. + continue_training: See TensorFlowEstimator + config: See TensorFlowEstimator + verbose: See TensorFlowEstimator + """ + super(KMeansClustering, self).__init__( + model_dir=model_dir, + config=config) + self.batch_size = batch_size + self.steps = steps + self.kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self.continue_training = continue_training + self.verbose = verbose + self._num_clusters = num_clusters + self._training_initial_clusters = initial_clusters + self._training_graph = None + self._distance_metric = distance_metric + self._use_mini_batch = use_mini_batch + self._random_seed = random_seed + self._initialized = False + + def fit(self, x, y=None, monitors=None, logdir=None, steps=None): + """Trains a k-means clustering on x. + + Note: See TensorFlowEstimator for logic for continuous training and graph + construction across multiple calls to fit. + + Args: + x: training input matrix of shape [n_samples, n_features]. + y: labels. Should be None. + monitors: Monitor object to print training progress and invoke early + stopping + logdir: the directory to save the log file that can be used for optional + visualization. + steps: number of training steps. If not None, overrides the value passed + in constructor. + + Returns: + Returns self. + """ + assert y is None + if logdir is not None: + self._model_dir = logdir + self._data_feeder = data_feeder.setup_train_data_feeder( + x, None, self._num_clusters, self.batch_size) + self._train_model(input_fn=self._data_feeder.input_builder, + feed_fn=self._data_feeder.get_feed_dict_fn(), + steps=steps or self.steps, + monitors=monitors, + init_feed_fn=self._data_feeder.get_feed_dict_fn()) + return self + + def predict(self, x, batch_size=None): + """Predict cluster id for each element in x. + + Args: + x: 2-D matrix or iterator. + batch_size: size to use for batching up x for querying the model. + + Returns: + Array with same number of rows as x, containing cluster ids. + """ + return super(KMeansClustering, self).predict( + x=x, batch_size=batch_size)[KMeansClustering.CLUSTER_IDX] + + def score(self, x, batch_size=None): + """Predict total sum of distances to nearest clusters. + + Note that this function is different from the corresponding one in sklearn + which returns the negative of the sum of distances. + + Args: + x: 2-D matrix or iterator. + batch_size: size to use for batching up x for querying the model. + + Returns: + Total sum of distances to nearest clusters. + """ + return np.sum( + self.evaluate(x=x, batch_size=batch_size)[KMeansClustering.SCORES]) + + def transform(self, x, batch_size=None): + """Transforms each element in x to distances to cluster centers. + + Note that this function is different from the corresponding one in sklearn. + For SQUARED_EUCLIDEAN distance metric, sklearn transform returns the + EUCLIDEAN distance, while this function returns the SQUARED_EUCLIDEAN + distance. + + Args: + x: 2-D matrix or iterator. + batch_size: size to use for batching up x for querying the model. + + Returns: + Array with same number of rows as x, and num_clusters columns, containing + distances to the cluster centers. + """ + return super(KMeansClustering, self).predict( + x=x, batch_size=batch_size)[KMeansClustering.ALL_SCORES] + + def clusters(self): + """Returns cluster centers.""" + return checkpoints.load_variable(self.model_dir, self.CLUSTERS) + + def _get_train_ops(self, features, _): + (_, + _, + losses, + training_op) = clustering_ops.KMeans( + features, + self._num_clusters, + self._training_initial_clusters, + self._distance_metric, + self._use_mini_batch, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries + ).training_graph() + incr_step = tf.assign_add(tf.contrib.framework.get_global_step(), 1) + loss = tf.reduce_sum(losses) + training_op = with_dependencies([training_op, incr_step], loss) + return training_op, loss + + def _get_predict_ops(self, features): + (all_scores, + model_predictions, + _, + _) = clustering_ops.KMeans( + features, + self._num_clusters, + self._training_initial_clusters, + self._distance_metric, + self._use_mini_batch, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries + ).training_graph() + return { + KMeansClustering.ALL_SCORES: all_scores[0], + KMeansClustering.CLUSTER_IDX: model_predictions[0] + } + + def _get_eval_ops(self, features, _, unused_metrics): + (_, + _, + losses, + _) = clustering_ops.KMeans( + features, + self._num_clusters, + self._training_initial_clusters, + self._distance_metric, + self._use_mini_batch, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self.kmeans_plus_plus_num_retries + ).training_graph() + return { + KMeansClustering.SCORES: tf.reduce_sum(losses), + } + diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py new file mode 100644 index 0000000000..b4ade5ea0b --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -0,0 +1,283 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for KMeans.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_ops +from tensorflow.contrib.factorization.python.ops.kmeans import KMeansClustering as KMeans +from tensorflow.contrib.learn.python.learn.estimators import run_config + +FLAGS = tf.app.flags.FLAGS + + +def normalize(x): + return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True)) + + +def cosine_similarity(x, y): + return np.dot(normalize(x), np.transpose(normalize(y))) + + +class KMeansTest(tf.test.TestCase): + + def setUp(self): + np.random.seed(3) + self.num_centers = 5 + self.num_dims = 2 + self.num_points = 10000 + self.true_centers = self.make_random_centers(self.num_centers, + self.num_dims) + self.points, _, self.scores = self.make_random_points( + self.true_centers, + self.num_points) + self.true_score = np.add.reduce(self.scores) + + self.kmeans = KMeans(self.num_centers, + initial_clusters=kmeans_ops.RANDOM_INIT, + batch_size=self.batch_size, + use_mini_batch=self.use_mini_batch, + steps=30, + continue_training=True, + config=run_config.RunConfig(tf_random_seed=14), + random_seed=12) + + @property + def batch_size(self): + return self.num_points + + @property + def use_mini_batch(self): + return False + + @staticmethod + def make_random_centers(num_centers, num_dims): + return np.round(np.random.rand(num_centers, + num_dims).astype(np.float32) * 500) + + @staticmethod + def make_random_points(centers, num_points, max_offset=20): + num_centers, num_dims = centers.shape + assignments = np.random.choice(num_centers, num_points) + offsets = np.round(np.random.randn( + num_points, + num_dims).astype(np.float32) * max_offset) + return (centers[assignments] + offsets, + assignments, + np.add.reduce(offsets * offsets, 1)) + + def test_clusters(self): + kmeans = self.kmeans + kmeans.fit(x=self.points, steps=0) + clusters = kmeans.clusters() + self.assertAllEqual(list(clusters.shape), + [self.num_centers, self.num_dims]) + + def test_fit(self): + if self.batch_size != self.num_points: + # TODO(agarwal): Doesn't work with mini-batch. + return + kmeans = self.kmeans + kmeans.fit(x=self.points, + steps=1) + score1 = kmeans.score(x=self.points) + kmeans.fit(x=self.points, + steps=15 * self.num_points // self.batch_size) + score2 = kmeans.score(x=self.points) + self.assertTrue(score1 > score2) + self.assertNear(self.true_score, score2, self.true_score * 0.05) + + def test_infer(self): + kmeans = self.kmeans + kmeans.fit(x=self.points) + clusters = kmeans.clusters() + + # Make a small test set + points, true_assignments, true_offsets = self.make_random_points(clusters, + 10) + # Test predict + assignments = kmeans.predict(points) + self.assertAllEqual(assignments, true_assignments) + + # Test score + score = kmeans.score(points) + self.assertNear(score, np.sum(true_offsets), 0.01 * score) + + # Test transform + transform = kmeans.transform(points) + true_transform = np.maximum( + 0, + np.sum(np.square(points), axis=1, keepdims=True) - + 2 * np.dot(points, np.transpose(clusters)) + + np.transpose(np.sum(np.square(clusters), axis=1, keepdims=True))) + self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + + def test_fit_with_cosine_distance(self): + # Create points on y=x and y=1.5x lines to check the cosine similarity. + # Note that euclidean distance will give different results in this case. + points = np.array([[9, 9], [0.5, 0.5], [10, 15], [0.4, 0.6]]) + # true centers are the unit vectors on lines y=x and y=1.5x + true_centers = np.array([[0.70710678, 0.70710678], [0.5547002, 0.83205029]]) + kmeans = KMeans(2, + initial_clusters=kmeans_ops.RANDOM_INIT, + distance_metric=kmeans_ops.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + batch_size=4, + steps=30, + continue_training=True, + config=run_config.RunConfig(tf_random_seed=2), + random_seed=12) + kmeans.fit(x=points) + centers = normalize(kmeans.clusters()) + self.assertAllClose(np.sort(centers, axis=0), + np.sort(true_centers, axis=0)) + + def test_transform_with_cosine_distance(self): + points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18], + [-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]) + + true_centers = [normalize(np.mean(normalize(points)[4:, :], axis=0, + keepdims=True))[0], + normalize(np.mean(normalize(points)[0:4, :], axis=0, + keepdims=True))[0]] + + kmeans = KMeans(2, + initial_clusters=kmeans_ops.RANDOM_INIT, + distance_metric=kmeans_ops.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + batch_size=8, + continue_training=True, + config=run_config.RunConfig(tf_random_seed=3)) + kmeans.fit(x=points, steps=30) + + centers = normalize(kmeans.clusters()) + self.assertAllClose(np.sort(centers, axis=0), + np.sort(true_centers, axis=0), + atol=1e-2) + + true_transform = 1 - cosine_similarity(points, centers) + transform = kmeans.transform(points) + self.assertAllClose(transform, true_transform, atol=1e-3) + + def test_predict_with_cosine_distance(self): + points = np.array([[2.5, 3.5], [2, 8], [3, 1], [3, 18], + [-2.5, -3.5], [-2, -8], [-3, -1], [-3, -18]]).astype( + np.float32) + true_centers = np.array( + [normalize(np.mean(normalize(points)[0:4, :], + axis=0, + keepdims=True))[0], + normalize(np.mean(normalize(points)[4:, :], + axis=0, + keepdims=True))[0]]) + true_assignments = [0] * 4 + [1] * 4 + true_score = len(points) - np.tensordot(normalize(points), + true_centers[true_assignments]) + + kmeans = KMeans(2, + initial_clusters=kmeans_ops.RANDOM_INIT, + distance_metric=kmeans_ops.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + batch_size=8, + continue_training=True, + config=run_config.RunConfig(tf_random_seed=3)) + kmeans.fit(x=points, steps=30) + + centers = normalize(kmeans.clusters()) + self.assertAllClose(np.sort(centers, axis=0), + np.sort(true_centers, axis=0), rtol=1e-3) + + assignments = kmeans.predict(points) + self.assertAllClose(centers[assignments], + true_centers[true_assignments], rtol=1e-3) + + score = kmeans.score(points) + self.assertAllClose(score, true_score) + + def test_predict_with_cosine_distance_and_kmeans_plus_plus(self): + # Most points are concetrated near one center. KMeans++ is likely to find + # the less populated centers. + points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], + [-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1], + [-3., -3.1], [-3.2, -3.], [-3., -3.]]).astype(np.float32) + true_centers = np.array( + [normalize(np.mean(normalize(points)[0:2, :], axis=0, + keepdims=True))[0], + normalize(np.mean(normalize(points)[2:4, :], axis=0, + keepdims=True))[0], + normalize(np.mean(normalize(points)[4:, :], axis=0, + keepdims=True))[0]]) + true_assignments = [0] * 2 + [1] * 2 + [2] * 8 + true_score = len(points) - np.tensordot(normalize(points), + true_centers[true_assignments]) + + kmeans = KMeans(3, + initial_clusters=kmeans_ops.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_ops.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + batch_size=12, + continue_training=True, + config=run_config.RunConfig(tf_random_seed=3)) + kmeans.fit(x=points, steps=30) + + centers = normalize(kmeans.clusters()) + self.assertAllClose(sorted(centers.tolist()), + sorted(true_centers.tolist()), + rtol=1e-3) + + assignments = kmeans.predict(points) + self.assertAllClose(centers[assignments], + true_centers[true_assignments], rtol=1e-3) + + score = kmeans.score(points) + self.assertAllClose(score, true_score) + + def test_fit_raise_if_num_clusters_larger_than_num_points_random_init(self): + points = np.array([[2.0, 3.0], [1.6, 8.2]]) + + with self.assertRaisesOpError('less'): + kmeans = KMeans(num_clusters=3, initial_clusters=kmeans_ops.RANDOM_INIT) + kmeans.fit(x=points) + + def test_fit_raise_if_num_clusters_larger_than_num_points_kmeans_plus_plus( + self): + points = np.array([[2.0, 3.0], [1.6, 8.2]]) + + with self.assertRaisesOpError(AssertionError): + kmeans = KMeans(num_clusters=3, + initial_clusters=kmeans_ops.KMEANS_PLUS_PLUS_INIT) + kmeans.fit(x=points) + + +class MiniBatchKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + +if __name__ == '__main__': + tf.test.main() |