diff options
author | 2017-02-15 15:03:07 -0800 | |
---|---|---|
committer | 2017-02-15 15:24:24 -0800 | |
commit | df5d3cd42335e31bccb6c796169d000d73c747d3 (patch) | |
tree | e2d694e040008cda4bf2dabe80e38b5a6b668566 /tensorflow/contrib/factorization | |
parent | 41128b0f1b3e1c7ce1f125caa53499f7a22f0c01 (diff) |
Add ops for efficient WALS loss computation in factorization_ops.
Add unit tests for loss computation.
Improve performance of unit tests.
Update trainer to add an option for loss computation (turned off by default).
Change: 147649717
Diffstat (limited to 'tensorflow/contrib/factorization')
4 files changed, 380 insertions, 108 deletions
diff --git a/tensorflow/contrib/factorization/g3doc/wals.md b/tensorflow/contrib/factorization/g3doc/wals.md index 461aa8a744..a428b393ba 100644 --- a/tensorflow/contrib/factorization/g3doc/wals.md +++ b/tensorflow/contrib/factorization/g3doc/wals.md @@ -1,21 +1,67 @@ # WALS Factorization -WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a -sparse matrix $$A$$ into low rank factors, $$U$$ and $$V$$, such that the -product of these factors is a "good" approximation of the full matrix. +$$ +% commands +\newcommand\bracket[2]{\left\langle #1, #2 \right\rangle} +\newcommand\trace{\text{trace}} +\newcommand\Rbb{\mathbb{R}} +$$ + +### Problem formulation +WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a +sparse matrix $$A \in \Rbb^{n \times m}$$ into low rank factors, $$U \in \Rbb^{n +\times k}$$ and $$V \in \Rbb^{m \times k}$$, such that the product $$UV^T$$ is a +"good" approximation of the full matrix. ![wals](wals.png) Typically, it involves minimizing the following loss function: -$$ min_{U,V} (||\sqrt{W} \odot (A- UV^T)||_F^2 + \lambda (||U||_F^2 + ||V||_F^2)) $$, -where $$\lambda$$ is a regularization parameter, and $$\odot$$ represents a -component-wise product. Assuming $$W$$ is of the form -$$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i C_j$$, -where $$w_0$$ is the weight of unobserved entries, and $$R$$ and $$C$$ are -the row and column weights respectively, lends this equation to an efficient -implementation. - -The algorithm proceeds in phases, or "sweeps", where each sweep involves + +$$ min_{U,V} +(\|\sqrt{W} \odot (A- UV^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)), +$$ + +where + +- $$\lambda$$ is a regularization parameter, +- $$\odot$$ denotes the component-wise product, +- $$W$$ is a weight matrix of the form $$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i + C_j$$, where $$w_0$$ is the weight of unobserved entries, and $$R \in + \Rbb^n$$ and $$C \in \Rbb^m$$ are the row and column weights respectively. + This form of the weight matrix lends this problem to an efficient + implementation. + +### Solution method + +The WALS algorithm proceeds in phases, or "sweeps", where each sweep involves fixing $$U$$ and solving for $$V$$, and then fixing $$V$$ and solving for $$U$$. -Convergence is typically pretty fast (10-20 sweeps). +Note that each subproblem is an unconstrained quadratic minimization problem and +can be solved exactly. Convergence is typically pretty fast (10-20 sweeps). + +### Loss computation + +The product $$UV^T$$ is dense, and can be too large to compute. So we use the +following reformulation of the loss to be able to compute it efficiently. First, +we decompose the norm into two terms, corresponding to the sparsity pattern of +$$A$$. Let $$S = \{(i, j) : A_{i, j} \neq 0\}$$. + +$$ +\begin{align} +\|\sqrt W \odot (A - UV^T)\|_F^2 +&= \sum_{(i, j) \in S} (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 + +\sum_{(i, j) \not\in S} w_0 (\bracket{u_i}{v_j})^2 \\ +&= \sum_{(i, j) \in S} \left( (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 - +w_0 \bracket{u_i}{v_j}^2 \right) + w_0\|UV^T\|_F^2 +\end{align} +$$ + +The last term can be computed efficiently by observing that + +$$ +\|UV^T\|_F^2 = \trace(UV^TVU^T) = \trace(U^TUV^TV) +$$ + +Each of the Gramian matrices $$U^TU$$ and $$V^TV$$ are $$k\times k$$ and are +cheap to store. Additionally, $$\|U\|_F^2 = \trace(U^TU)$$ and similarly for +$$V$$, so we can use the trace of the individual Gramians to compute the norms. diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py index 28bcccbde6..80aee4c904 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py @@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"): import numpy as np -from tensorflow.contrib.factorization.python.ops import factorization_ops +from tensorflow.contrib.factorization.python.ops import gen_factorization_ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.platform import test @@ -59,7 +59,7 @@ class WalsSolverOpsTest(test.TestCase): sparse_block = SparseBlock3x3() with self.test_session(): [lhs_tensor, - rhs_matrix] = factorization_ops.wals_compute_partial_lhs_and_rhs( + rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs( self._column_factors, self._column_weights, self._unobserved_weights, self._row_weights, sparse_block.indices, sparse_block.values, sparse_block.dense_shape[0], False) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 25cc66ca81..167b442dbc 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -23,10 +23,8 @@ import numbers from six.moves import xrange # pylint: disable=redefined-builtin -# pylint: disable=wildcard-import,undefined-variable -# pylint: enable=wildcard-import -from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import * +from tensorflow.contrib.factorization.python.ops import gen_factorization_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -53,10 +51,13 @@ class WALSModel(object): r"""A model for Weighted Alternating Least Squares matrix factorization. It minimizes the following loss function over U, V: - \\( ||W \odot (A - U V^T) ||_F^2 + \lambda (||U||_F^2 + ||V||_F^2) )\\ + \\( + \|\sqrt W \odot (A - U V^T) \|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) + )\\ where, A: input matrix, - W: weight matrix, + W: weight matrix. Note that the (element-wise) square root of the weights + is used in the objective function. U, V: row_factors and column_factors matrices, \\(\lambda)\\: regularization. Also we assume that W is of the following special form: @@ -75,6 +76,18 @@ class WALSModel(object): creating the worker caches and instead the relevant weight and factor values are looked up from parameter servers at each step. + Loss computation: The loss can be computed efficiently by decomposing it into + a sparse term and a Gramian term, see wals.md. + The loss is returned by the update_{col, row}_factors(sp_input), and is + normalized as follows: + _, _, minibatch_loss = update_row_factors(sp_input) + if sp_input contains the rows {A_i, i \in I}, and the input matrix A has n + total rows, then minibatch_loss is + \\( + (\|\sqrt W \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| + + \lambda \|V\|_F^2 + )\\ + A typical usage example (pseudocode): with tf.Graph().as_default(): @@ -87,15 +100,15 @@ class WALSModel(object): model_init_op = model.initialize_op # To be run once per worker after session is available, prior to - # the gramian_prep_ops for row(column) can be run. + # the prep_gramian_op for row(column) can be run. worker_init_op = model.worker_init # To be run once per interation sweep before the row(column) update # initialize ops can be run. Note that in the distributed training # situations, this should only be run by the chief trainer. All other # trainers need to block until this is done. - row_update_gramian_prep_op = model.row_update_prep_gramian_op - col_update_gramian_prep_op = model.col_update_prep_gramian_op + row_update_prep_gramian_op = model.row_update_prep_gramian_op + col_update_prep_gramian_op = model.col_update_prep_gramian_op # To be run once per worker per iteration sweep. Must be run before # any actual update ops can be run. @@ -105,11 +118,11 @@ class WALSModel(object): # Ops to upate row(column). This can either take the entire sparse tensor # or slices of sparse tensor. For distributed trainer, each trainer # handles just part of the matrix. - row_update_op = model.update_row_factors( - sp_input=matrix_slices_from_queue_for_worker_shard)[1] - col_update_op = model.update_col_factors( + _, row_update_op, row_loss = model.update_row_factors( + sp_input=matrix_slices_from_queue_for_worker_shard) + _, col_update_op, col_loss = model.update_col_factors( sp_input=transposed_matrix_slices_from_queue_for_worker_shard, - transpose_input=True)[1] + transpose_input=True) ... @@ -134,7 +147,7 @@ class WALSModel(object): # Row update sweep. if is_chief: - row_update_gramian_prep_op.run(session=sess) + row_update_prep_gramian_op.run(session=sess) else: wait_for_chief @@ -152,7 +165,7 @@ class WALSModel(object): # Column update sweep. if is_chief: - col_update_gramian_prep_op.run(session=sess) + col_update_prep_gramian_op.run(session=sess) else: wait_for_chief @@ -218,10 +231,10 @@ class WALSModel(object): self._num_col_shards = num_col_shards self._n_components = n_components self._unobserved_weight = unobserved_weight - self._regularization = (array_ops.diag( - constant_op.constant( - regularization, shape=[self._n_components], dtype=dtypes.float32)) - if regularization is not None else None) + self._regularization = regularization + self._regularization_matrix = ( + regularization * linalg_ops.eye(self._n_components) + if regularization is not None else None) assert (row_weights is None) == (col_weights is None) self._row_weights = WALSModel._create_weights(row_weights, self._input_rows, self._num_row_shards, @@ -584,13 +597,14 @@ class WALSModel(object): return func @classmethod - def scatter_update(cls, factor, indices, values, sharding_func): + def scatter_update(cls, factor, indices, values, sharding_func, name=None): """Helper function for doing sharded scatter update.""" assert isinstance(factor, list) if len(factor) == 1: with ops.colocate_with(factor[0]): # TODO(agarwal): assign instead of scatter update for full batch update. - return state_ops.scatter_update(factor[0], indices, values).op + return state_ops.scatter_update(factor[0], indices, values, + name=name).op else: num_shards = len(factor) assignments, new_ids = sharding_func(indices) @@ -602,13 +616,12 @@ class WALSModel(object): num_shards) updates = [] for i in xrange(num_shards): - updates.append( - state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[ - i])) - return control_flow_ops.group(*updates) + updates.append(state_ops.scatter_update(factor[i], sharded_ids[i], + sharded_values[i])) + return control_flow_ops.group(*updates, name=name) def update_row_factors(self, sp_input=None, transpose_input=False): - """Updates the row factors. + r"""Updates the row factors. Args: sp_input: A SparseTensor representing a subset of rows of the full input @@ -618,16 +631,22 @@ class WALSModel(object): rows corresponding to the transposed input are updated. Returns: - A tuple consisting of the following two elements: + A tuple consisting of the following elements: new_values: New values for the row factors. update_op: An op that assigns the newly computed values to the row factors. + loss: A tensor (scalar) that contains the normalized minibatch loss, + corresponding to sp_input. + if sp_input contains the rows {A_{i, :}, i \in I}, and the input matrix + A has n total rows, then loss is: + (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * + n / |I| + \lambda \|V\|_F^2. """ - return self._process_input_helper( - True, sp_input=sp_input, transpose_input=transpose_input) + return self._process_input_helper(True, sp_input=sp_input, + transpose_input=transpose_input) def update_col_factors(self, sp_input=None, transpose_input=False): - """Updates the column factors. + r"""Updates the column factors. Args: sp_input: A SparseTensor representing a subset of columns of the full @@ -641,13 +660,17 @@ class WALSModel(object): new_values: New values for the column factors. update_op: An op that assigns the newly computed values to the column factors. + loss: A tensor (scalar) that contains the normalized minibatch loss, + corresponding to sp_input. + If sp_input contains the columns {A_{:, j}, j \in J}, and the input + matrix A has m total columns, then loss is: + (\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 + \lambda \|V_J\|_F^2) * + m / |J| + \lambda \|U\|_F^2. """ - return self._process_input_helper( - False, sp_input=sp_input, transpose_input=transpose_input) + return self._process_input_helper(False, sp_input=sp_input, + transpose_input=transpose_input) - def project_row_factors(self, - sp_input=None, - transpose_input=False, + def project_row_factors(self, sp_input=None, transpose_input=False, projection_weights=None): """Projects the row factors. @@ -672,11 +695,9 @@ class WALSModel(object): """ if projection_weights is None: projection_weights = 1 - return self._process_input_helper( - True, - sp_input=sp_input, - transpose_input=transpose_input, - row_weights=projection_weights)[0] + return self._process_input_helper(True, sp_input=sp_input, + transpose_input=transpose_input, + row_weights=projection_weights)[0] def project_col_factors(self, sp_input=None, @@ -705,16 +726,12 @@ class WALSModel(object): """ if projection_weights is None: projection_weights = 1 - return self._process_input_helper( - False, - sp_input=sp_input, - transpose_input=transpose_input, - row_weights=projection_weights)[0] - - def _process_input_helper(self, - update_row_factors, - sp_input=None, - transpose_input=False, + return self._process_input_helper(False, sp_input=sp_input, + transpose_input=transpose_input, + row_weights=projection_weights)[0] + + def _process_input_helper(self, update_row_factors, + sp_input=None, transpose_input=False, row_weights=None): """Creates the graph for processing a sparse slice of input. @@ -734,10 +751,12 @@ class WALSModel(object): of columns to be updated/projected. Returns: - A tuple consisting of the following two elements: + A tuple consisting of the following three elements: new_values: New values for the row/column factors. update_op: An op that assigns the newly computed values to the row/column factors. + loss: A tensor (scalar) that contains the normalized minibatch loss, + corresponding to sp_input. """ assert isinstance(sp_input, sparse_tensor.SparseTensor) @@ -746,6 +765,7 @@ class WALSModel(object): right_factors = self._col_factors_cache row_wt = self._row_wt_cache col_wt = self._col_wt_cache + total_rows = self._input_rows sharding_func = WALSModel._get_sharding_func(self._input_rows, self._num_row_shards) gramian = self._col_gramian_cache @@ -754,6 +774,7 @@ class WALSModel(object): right_factors = self._row_factors_cache row_wt = self._col_wt_cache col_wt = self._row_wt_cache + total_rows = self._input_cols sharding_func = WALSModel._get_sharding_func(self._input_cols, self._num_col_shards) gramian = self._row_gramian_cache @@ -799,8 +820,8 @@ class WALSModel(object): # Compute lhs and rhs of the normal equations total_lhs = (self._unobserved_weight * gramian) - if self._regularization is not None: - total_lhs += self._regularization + if self._regularization_matrix is not None: + total_lhs += self._regularization_matrix if self._row_weights is None: # Special case of ALS. Use a much simpler update rule. total_rhs = (self._unobserved_weight * @@ -819,30 +840,68 @@ class WALSModel(object): row_weights_slice = embedding_ops.embedding_lookup( row_wt, update_indices, partition_strategy="div") else: + num_indices = array_ops.shape(update_indices)[0] with ops.control_dependencies( [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): row_weights_slice = control_flow_ops.cond( math_ops.equal(array_ops.rank(row_weights), 0), - lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights), + lambda: (array_ops.ones([num_indices]) * row_weights), lambda: math_ops.cast(row_weights, dtypes.float32)) col_weights = embedding_ops.embedding_lookup( col_wt, gather_indices, partition_strategy="div") - partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs( - right, - col_weights, - self._unobserved_weight, - row_weights_slice, - new_sp_input.indices, - new_sp_input.values, - num_rows, - transpose_input, - name="wals_compute_partial_lhs_rhs") + partial_lhs, total_rhs = ( + gen_factorization_ops.wals_compute_partial_lhs_and_rhs( + right, + col_weights, + self._unobserved_weight, + row_weights_slice, + new_sp_input.indices, + new_sp_input.values, + num_rows, + transpose_input, + name="wals_compute_partial_lhs_rhs")) total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs total_rhs = array_ops.expand_dims(total_rhs, -1) new_left_values = array_ops.squeeze( linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) - return (new_left_values, self.scatter_update(left, update_indices, - new_left_values, - sharding_func)) + update_op_name = "row_update" if update_row_factors else "col_update" + update_op = self.scatter_update(left, update_indices, new_left_values, + sharding_func, name=update_op_name) + + # Create the loss subgraph + loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) + if transpose_input else new_sp_input) + # sp_approx is the low rank estimate of the input matrix, formed by + # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices. + sp_approx_vals = gen_factorization_ops.masked_matmul( + new_left_values, right, loss_sp_input.indices, transpose_a=False, + transpose_b=True) + sp_approx = sparse_tensor.SparseTensor( + loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) + sp_approx_sq = math_ops.square(sp_approx) + sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) + sp_residual_sq = math_ops.square(sp_residual) + row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else + array_ops.expand_dims(row_weights_slice, 1)) + col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else + array_ops.expand_dims(col_weights, 0)) + # We return the normalized loss + partial_row_gramian = math_ops.matmul( + new_left_values, new_left_values, transpose_a=True) + normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) + loss = ( + self._unobserved_weight * ( + sparse_ops.sparse_reduce_sum(sp_residual_sq) - + sparse_ops.sparse_reduce_sum(sp_approx_sq) + + math_ops.trace(math_ops.matmul(partial_row_gramian, gramian)) + ) + + sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) + ) * normalization_factor + if self._regularization is not None: + loss += self._regularization * ( + math_ops.trace(partial_row_gramian) * normalization_factor + + math_ops.trace(gramian) + ) + return (new_left_values, update_op, loss) diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index 75ef87d15d..bbfcfabf40 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -30,9 +30,14 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.factorization.python.ops import factorization_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test INPUT_MATRIX = np.array( @@ -85,6 +90,92 @@ def sparse_input(): return np_matrix_to_tf_sparse(INPUT_MATRIX) +def count_rows(sp_input): + return math_ops.cast( + array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0], + dtypes.float32) + + +def count_cols(sp_input): + return math_ops.cast( + array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0], + dtypes.float32) + + +def calculate_loss(input_mat, row_factors, col_factors, regularization=None, + w0=1., row_weights=None, col_weights=None): + """Calculates the loss of a given factorization. + + Using a non distributed method, different than the one implemented in the + WALS model. The weight of an observed entry (i, j) (i.e. such that + input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]). + + Args: + input_mat: The input matrix, a SparseTensor of rank 2. + row_factors: The row factors, a dense Tensor of rank 2. + col_factors: The col factors, a dense Tensor of rank 2. + regularization: the regularization coefficient, a scalar. + w0: the weight of unobserved entries. A scalar. + row_weights: A dense tensor of rank 1. + col_weights: A dense tensor of rank 1. + + Returns: + The total loss. + """ + wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None + else constant_op.constant(1.)) + wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None + else constant_op.constant(1.)) + reg = (regularization if regularization is not None + else constant_op.constant(0.)) + + row_indices, col_indices = array_ops.split(input_mat.indices, + axis=1, + num_or_size_splits=2) + gathered_row_factors = array_ops.gather(row_factors, row_indices) + gathered_col_factors = array_ops.gather(col_factors, col_indices) + sp_approx_vals = array_ops.squeeze(math_ops.matmul( + gathered_row_factors, gathered_col_factors, adjoint_b=True)) + sp_approx = sparse_tensor.SparseTensor( + indices=input_mat.indices, + values=sp_approx_vals, + dense_shape=input_mat.dense_shape) + + sp_approx_sq = math_ops.square(sp_approx) + row_norm = math_ops.reduce_sum(math_ops.square(row_factors)) + col_norm = math_ops.reduce_sum(math_ops.square(col_factors)) + row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul( + row_factors, col_factors, transpose_b=True))) + + resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1)) + resid_sq = math_ops.square(resid) + loss = w0 * ( + sparse_ops.sparse_reduce_sum(resid_sq) - + sparse_ops.sparse_reduce_sum(sp_approx_sq) + ) + loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) + + w0 * row_col_norm + reg * (row_norm + col_norm)) + return loss.eval() + + +def calculate_loss_from_wals_model(wals_model, sp_inputs): + current_rows = embedding_ops.embedding_lookup( + wals_model.row_factors, math_ops.range(wals_model._input_rows), + partition_strategy="div") + current_cols = embedding_ops.embedding_lookup( + wals_model.col_factors, math_ops.range(wals_model._input_cols), + partition_strategy="div") + row_wts = embedding_ops.embedding_lookup( + wals_model._row_weights, math_ops.range(wals_model._input_rows), + partition_strategy="div") + col_wts = embedding_ops.embedding_lookup( + wals_model._col_weights, math_ops.range(wals_model._input_cols), + partition_strategy="div") + return calculate_loss( + sp_inputs, current_rows, current_cols, wals_model._regularization, + wals_model._unobserved_weight, row_wts, col_wts) + + class WalsModelTest(test.TestCase): def setUp(self): @@ -103,7 +194,6 @@ class WalsModelTest(test.TestCase): self.row_wts = [[0.1, 0.2, 0.3], [0.4, 0.5]] self.col_wts = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]] - self._wals_inputs = sparse_input() # Values of factor shards after running one iteration of row and column # updates. @@ -120,13 +210,19 @@ class WalsModelTest(test.TestCase): self._col_factors_2 = [[3.3459, -1.3341, -3.3008], [0.57366, 1.83729, 1.26798]] - def _run_test_process_input(self, use_factors_weights_cache): - with self.test_session(): + def _run_test_process_input(self, + use_factors_weights_cache, + compute_loss=False): + with ops.Graph().as_default(), self.test_session() as sess: + self._wals_inputs = sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) + num_rows = 5 + num_cols = 7 + factor_dim = 3 wals_model = factorization_ops.WALSModel( - 5, - 7, - 3, + num_rows, + num_cols, + factor_dim, num_row_shards=2, num_col_shards=3, regularization=0.01, @@ -152,8 +248,8 @@ class WalsModelTest(test.TestCase): # Here we feed in scattered rows of the input. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() - process_input_op = wals_model.update_row_factors( - sp_input=sp_feeder, transpose_input=False)[1] + _, process_input_op, factor_loss = wals_model.update_row_factors( + sp_input=sp_feeder, transpose_input=False) for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) @@ -189,6 +285,19 @@ class WalsModelTest(test.TestCase): [[0.569082, 0.715088, 0.31777], [1.915879, 1.992677, 1.109057]], atol=1e-3) + if compute_loss: + # Test loss computation after the row update + loss = sum( + sess.run(factor_loss * count_rows(inp) / num_rows, + feed_dict={sp_feeder: inp}) + for inp in input_scattered_rows) + true_loss = calculate_loss_from_wals_model( + wals_model, self._wals_inputs) + self.assertNear( + loss, true_loss, err=.001, + msg="""After row update, computed loss = {}, does not match + the true loss = {}.""".format(loss, true_loss)) + # Split input into multiple sparse tensors with scattered columns. Note # that here the elements in the sparse tensors are not ordered and also # do not need to consist of consecutive columns. However, each column @@ -201,13 +310,14 @@ class WalsModelTest(test.TestCase): INPUT_MATRIX, col_slices=[3, 6], shuffle=True).eval() input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3] + input_scattered_cols_non_duplicate = [sp_c0, sp_c1, sp_c2] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() - process_input_op = wals_model.update_col_factors( - sp_input=sp_feeder, transpose_input=False)[1] + _, process_input_op, factor_loss = wals_model.update_col_factors( + sp_input=sp_feeder, transpose_input=False) for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) @@ -248,13 +358,31 @@ class WalsModelTest(test.TestCase): [0.346433, 1.360644, 1.677121]], atol=1e-3) - def _run_test_process_input_transposed(self, use_factors_weights_cache): - with self.test_session(): + if compute_loss: + # Test loss computation after the column update. + loss = sum( + sess.run(factor_loss * count_cols(inp) / num_cols, + feed_dict={sp_feeder: inp}) + for inp in input_scattered_cols_non_duplicate) + true_loss = calculate_loss_from_wals_model( + wals_model, self._wals_inputs) + self.assertNear( + loss, true_loss, err=.001, + msg="""After col update, computed loss = {}, does not match the true + loss = {}.""".format(loss, true_loss)) + + def _run_test_process_input_transposed(self, use_factors_weights_cache, + compute_loss=False): + with ops.Graph().as_default(), self.test_session() as sess: + self._wals_inputs = sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) + num_rows = 5 + num_cols = 7 + factor_dim = 3 wals_model = factorization_ops.WALSModel( - 5, - 7, - 3, + num_rows, + num_cols, + factor_dim, num_row_shards=2, num_col_shards=3, regularization=0.01, @@ -278,7 +406,7 @@ class WalsModelTest(test.TestCase): sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval() sp_r3_t = sp_r1_t input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t] - + input_scattered_rows_non_duplicate = [sp_r0_t, sp_r1_t, sp_r2_t] # Test updating row factors. # Here we feed in scattered rows of the input. # Note that the needed suffix of placeholder are in the order of test @@ -286,8 +414,8 @@ class WalsModelTest(test.TestCase): # they appear. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() - process_input_op = wals_model.update_row_factors( - sp_input=sp_feeder, transpose_input=True)[1] + _, process_input_op, factor_loss = wals_model.update_row_factors( + sp_input=sp_feeder, transpose_input=True) for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) @@ -323,6 +451,19 @@ class WalsModelTest(test.TestCase): [[1.915879, 1.992677, 1.109057], [0.569082, 0.715088, 0.31777]], atol=1e-3) + if compute_loss: + # Test loss computation after the row update + loss = sum( + sess.run(factor_loss * count_cols(inp) / num_rows, + feed_dict={sp_feeder: inp}) + for inp in input_scattered_rows_non_duplicate) + true_loss = calculate_loss_from_wals_model( + wals_model, self._wals_inputs) + self.assertNear( + loss, true_loss, err=.001, + msg="""After row update, computed loss = {}, does not match the true + loss = {}.""".format(loss, true_loss)) + # Split input into multiple SparseTensors with scattered columns. # Here the inputs are transposed. But the same constraints as described in # the previous non-transposed test case apply to these inputs (before they @@ -338,13 +479,14 @@ class WalsModelTest(test.TestCase): sp_c4_t = sp_c2_t input_scattered_cols = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, sp_c4_t] + input_scattered_cols_non_duplicate = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() - process_input_op = wals_model.update_col_factors( - sp_input=sp_feeder, transpose_input=True)[1] + _, process_input_op, factor_loss = wals_model.update_col_factors( + sp_input=sp_feeder, transpose_input=True) for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) @@ -377,15 +519,28 @@ class WalsModelTest(test.TestCase): [[3.585139, -0.487476, -3.852232], [0.557937, 1.813907, 1.331171]], atol=1e-3) - - # Note that when row_weights and col_weights are 0, WALS gives dentical + if compute_loss: + # Test loss computation after the col update + loss = sum( + sess.run(factor_loss * count_rows(inp) / num_cols, + feed_dict={sp_feeder: inp}) + for inp in input_scattered_cols_non_duplicate) + true_loss = calculate_loss_from_wals_model( + wals_model, self._wals_inputs) + self.assertNear( + loss, true_loss, err=.001, + msg="""After col update, computed loss = {}, does not match the true + loss = {}.""".format(loss, true_loss)) + + # Note that when row_weights and col_weights are 0, WALS gives identical # results as ALS (Alternating Least Squares). However our implementation does # not handle the case of zero weights differently. Instead, when row_weights # and col_weights are set to None, we interpret that as the ALS case, and # trigger the more efficient ALS updates. # Here we test that those two give identical results. def _run_test_als(self, use_factors_weights_cache): - with self.test_session(): + with ops.Graph().as_default(), self.test_session(): + self._wals_inputs = sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, @@ -463,7 +618,8 @@ class WalsModelTest(test.TestCase): atol=1e-2) def _run_test_als_transposed(self, use_factors_weights_cache): - with self.test_session(): + with ops.Graph().as_default(), self.test_session(): + self._wals_inputs = sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, @@ -552,7 +708,7 @@ class WalsModelTest(test.TestCase): rows = 15 cols = 11 dims = 3 - with self.test_session(): + with ops.Graph().as_default(), self.test_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand(3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -582,7 +738,7 @@ class WalsModelTest(test.TestCase): cols = 11 dims = 3 - with self.test_session(): + with ops.Graph().as_default(), self.test_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand(3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -615,7 +771,7 @@ class WalsModelTest(test.TestCase): def keep_index(x): return not (x[0] + x[1]) % 4 - with self.test_session(): + with ops.Graph().as_default(), self.test_session(): row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) data = np.dot(np.random.rand(rows, 3), @@ -683,6 +839,17 @@ class WalsModelTest(test.TestCase): def test_train_matrix_completion_wals_without_cache(self): self._run_test_train_matrix_completion_wals(False) + def test_loss_transposed_with_cache(self): + self._run_test_process_input_transposed(True, compute_loss=True) + + def test_loss_transposed_without_cache(self): + self._run_test_process_input_transposed(False, compute_loss=True) + + def test_loss_with_cache(self): + self._run_test_process_input(True, compute_loss=True) + + def test_loss_without_cache(self): + self._run_test_process_input(False, compute_loss=True) if __name__ == "__main__": test.main() |