aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 15:03:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 15:24:24 -0800
commitdf5d3cd42335e31bccb6c796169d000d73c747d3 (patch)
treee2d694e040008cda4bf2dabe80e38b5a6b668566 /tensorflow/contrib/factorization
parent41128b0f1b3e1c7ce1f125caa53499f7a22f0c01 (diff)
Add ops for efficient WALS loss computation in factorization_ops.
Add unit tests for loss computation. Improve performance of unit tests. Update trainer to add an option for loss computation (turned off by default). Change: 147649717
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/g3doc/wals.md72
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py4
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops.py191
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py221
4 files changed, 380 insertions, 108 deletions
diff --git a/tensorflow/contrib/factorization/g3doc/wals.md b/tensorflow/contrib/factorization/g3doc/wals.md
index 461aa8a744..a428b393ba 100644
--- a/tensorflow/contrib/factorization/g3doc/wals.md
+++ b/tensorflow/contrib/factorization/g3doc/wals.md
@@ -1,21 +1,67 @@
# WALS Factorization
-WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a
-sparse matrix $$A$$ into low rank factors, $$U$$ and $$V$$, such that the
-product of these factors is a "good" approximation of the full matrix.
+$$
+% commands
+\newcommand\bracket[2]{\left\langle #1, #2 \right\rangle}
+\newcommand\trace{\text{trace}}
+\newcommand\Rbb{\mathbb{R}}
+$$
+
+### Problem formulation
+WALS (Weighed Alternating Least Squares) is an algorithm for factorizing a
+sparse matrix $$A \in \Rbb^{n \times m}$$ into low rank factors, $$U \in \Rbb^{n
+\times k}$$ and $$V \in \Rbb^{m \times k}$$, such that the product $$UV^T$$ is a
+"good" approximation of the full matrix.
![wals](wals.png)
Typically, it involves minimizing the following loss function:
-$$ min_{U,V} (||\sqrt{W} \odot (A- UV^T)||_F^2 + \lambda (||U||_F^2 + ||V||_F^2)) $$,
-where $$\lambda$$ is a regularization parameter, and $$\odot$$ represents a
-component-wise product. Assuming $$W$$ is of the form
-$$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i C_j$$,
-where $$w_0$$ is the weight of unobserved entries, and $$R$$ and $$C$$ are
-the row and column weights respectively, lends this equation to an efficient
-implementation.
-
-The algorithm proceeds in phases, or "sweeps", where each sweep involves
+
+$$ min_{U,V}
+(\|\sqrt{W} \odot (A- UV^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)),
+$$
+
+where
+
+- $$\lambda$$ is a regularization parameter,
+- $$\odot$$ denotes the component-wise product,
+- $$W$$ is a weight matrix of the form $$W_{i, j} = w_0 + 1_{A_{ij} \neq 0}R_i
+ C_j$$, where $$w_0$$ is the weight of unobserved entries, and $$R \in
+ \Rbb^n$$ and $$C \in \Rbb^m$$ are the row and column weights respectively.
+ This form of the weight matrix lends this problem to an efficient
+ implementation.
+
+### Solution method
+
+The WALS algorithm proceeds in phases, or "sweeps", where each sweep involves
fixing $$U$$ and solving for $$V$$, and then fixing $$V$$ and solving for $$U$$.
-Convergence is typically pretty fast (10-20 sweeps).
+Note that each subproblem is an unconstrained quadratic minimization problem and
+can be solved exactly. Convergence is typically pretty fast (10-20 sweeps).
+
+### Loss computation
+
+The product $$UV^T$$ is dense, and can be too large to compute. So we use the
+following reformulation of the loss to be able to compute it efficiently. First,
+we decompose the norm into two terms, corresponding to the sparsity pattern of
+$$A$$. Let $$S = \{(i, j) : A_{i, j} \neq 0\}$$.
+
+$$
+\begin{align}
+\|\sqrt W \odot (A - UV^T)\|_F^2
+&= \sum_{(i, j) \in S} (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 +
+\sum_{(i, j) \not\in S} w_0 (\bracket{u_i}{v_j})^2 \\
+&= \sum_{(i, j) \in S} \left( (w_0 + R_i C_j) (A_{ij} - \bracket{u_i}{v_j})^2 -
+w_0 \bracket{u_i}{v_j}^2 \right) + w_0\|UV^T\|_F^2
+\end{align}
+$$
+
+The last term can be computed efficiently by observing that
+
+$$
+\|UV^T\|_F^2 = \trace(UV^TVU^T) = \trace(U^TUV^TV)
+$$
+
+Each of the Gramian matrices $$U^TU$$ and $$V^TV$$ are $$k\times k$$ and are
+cheap to store. Additionally, $$\|U\|_F^2 = \trace(U^TU)$$ and similarly for
+$$V$$, so we can use the trace of the individual Gramians to compute the norms.
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
index 28bcccbde6..80aee4c904 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -27,7 +27,7 @@ if hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags"):
import numpy as np
-from tensorflow.contrib.factorization.python.ops import factorization_ops
+from tensorflow.contrib.factorization.python.ops import gen_factorization_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
@@ -59,7 +59,7 @@ class WalsSolverOpsTest(test.TestCase):
sparse_block = SparseBlock3x3()
with self.test_session():
[lhs_tensor,
- rhs_matrix] = factorization_ops.wals_compute_partial_lhs_and_rhs(
+ rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, self._column_weights, self._unobserved_weights,
self._row_weights, sparse_block.indices, sparse_block.values,
sparse_block.dense_shape[0], False)
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
index 25cc66ca81..167b442dbc 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py
@@ -23,10 +23,8 @@ import numbers
from six.moves import xrange # pylint: disable=redefined-builtin
-# pylint: disable=wildcard-import,undefined-variable
-# pylint: enable=wildcard-import
-from tensorflow.contrib.factorization.python.ops.gen_factorization_ops import *
+from tensorflow.contrib.factorization.python.ops import gen_factorization_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -53,10 +51,13 @@ class WALSModel(object):
r"""A model for Weighted Alternating Least Squares matrix factorization.
It minimizes the following loss function over U, V:
- \\( ||W \odot (A - U V^T) ||_F^2 + \lambda (||U||_F^2 + ||V||_F^2) )\\
+ \\(
+ \|\sqrt W \odot (A - U V^T) \|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)
+ )\\
where,
A: input matrix,
- W: weight matrix,
+ W: weight matrix. Note that the (element-wise) square root of the weights
+ is used in the objective function.
U, V: row_factors and column_factors matrices,
\\(\lambda)\\: regularization.
Also we assume that W is of the following special form:
@@ -75,6 +76,18 @@ class WALSModel(object):
creating the worker caches and instead the relevant weight and factor values
are looked up from parameter servers at each step.
+ Loss computation: The loss can be computed efficiently by decomposing it into
+ a sparse term and a Gramian term, see wals.md.
+ The loss is returned by the update_{col, row}_factors(sp_input), and is
+ normalized as follows:
+ _, _, minibatch_loss = update_row_factors(sp_input)
+ if sp_input contains the rows {A_i, i \in I}, and the input matrix A has n
+ total rows, then minibatch_loss is
+ \\(
+ (\|\sqrt W \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| +
+ \lambda \|V\|_F^2
+ )\\
+
A typical usage example (pseudocode):
with tf.Graph().as_default():
@@ -87,15 +100,15 @@ class WALSModel(object):
model_init_op = model.initialize_op
# To be run once per worker after session is available, prior to
- # the gramian_prep_ops for row(column) can be run.
+ # the prep_gramian_op for row(column) can be run.
worker_init_op = model.worker_init
# To be run once per interation sweep before the row(column) update
# initialize ops can be run. Note that in the distributed training
# situations, this should only be run by the chief trainer. All other
# trainers need to block until this is done.
- row_update_gramian_prep_op = model.row_update_prep_gramian_op
- col_update_gramian_prep_op = model.col_update_prep_gramian_op
+ row_update_prep_gramian_op = model.row_update_prep_gramian_op
+ col_update_prep_gramian_op = model.col_update_prep_gramian_op
# To be run once per worker per iteration sweep. Must be run before
# any actual update ops can be run.
@@ -105,11 +118,11 @@ class WALSModel(object):
# Ops to upate row(column). This can either take the entire sparse tensor
# or slices of sparse tensor. For distributed trainer, each trainer
# handles just part of the matrix.
- row_update_op = model.update_row_factors(
- sp_input=matrix_slices_from_queue_for_worker_shard)[1]
- col_update_op = model.update_col_factors(
+ _, row_update_op, row_loss = model.update_row_factors(
+ sp_input=matrix_slices_from_queue_for_worker_shard)
+ _, col_update_op, col_loss = model.update_col_factors(
sp_input=transposed_matrix_slices_from_queue_for_worker_shard,
- transpose_input=True)[1]
+ transpose_input=True)
...
@@ -134,7 +147,7 @@ class WALSModel(object):
# Row update sweep.
if is_chief:
- row_update_gramian_prep_op.run(session=sess)
+ row_update_prep_gramian_op.run(session=sess)
else:
wait_for_chief
@@ -152,7 +165,7 @@ class WALSModel(object):
# Column update sweep.
if is_chief:
- col_update_gramian_prep_op.run(session=sess)
+ col_update_prep_gramian_op.run(session=sess)
else:
wait_for_chief
@@ -218,10 +231,10 @@ class WALSModel(object):
self._num_col_shards = num_col_shards
self._n_components = n_components
self._unobserved_weight = unobserved_weight
- self._regularization = (array_ops.diag(
- constant_op.constant(
- regularization, shape=[self._n_components], dtype=dtypes.float32))
- if regularization is not None else None)
+ self._regularization = regularization
+ self._regularization_matrix = (
+ regularization * linalg_ops.eye(self._n_components)
+ if regularization is not None else None)
assert (row_weights is None) == (col_weights is None)
self._row_weights = WALSModel._create_weights(row_weights, self._input_rows,
self._num_row_shards,
@@ -584,13 +597,14 @@ class WALSModel(object):
return func
@classmethod
- def scatter_update(cls, factor, indices, values, sharding_func):
+ def scatter_update(cls, factor, indices, values, sharding_func, name=None):
"""Helper function for doing sharded scatter update."""
assert isinstance(factor, list)
if len(factor) == 1:
with ops.colocate_with(factor[0]):
# TODO(agarwal): assign instead of scatter update for full batch update.
- return state_ops.scatter_update(factor[0], indices, values).op
+ return state_ops.scatter_update(factor[0], indices, values,
+ name=name).op
else:
num_shards = len(factor)
assignments, new_ids = sharding_func(indices)
@@ -602,13 +616,12 @@ class WALSModel(object):
num_shards)
updates = []
for i in xrange(num_shards):
- updates.append(
- state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[
- i]))
- return control_flow_ops.group(*updates)
+ updates.append(state_ops.scatter_update(factor[i], sharded_ids[i],
+ sharded_values[i]))
+ return control_flow_ops.group(*updates, name=name)
def update_row_factors(self, sp_input=None, transpose_input=False):
- """Updates the row factors.
+ r"""Updates the row factors.
Args:
sp_input: A SparseTensor representing a subset of rows of the full input
@@ -618,16 +631,22 @@ class WALSModel(object):
rows corresponding to the transposed input are updated.
Returns:
- A tuple consisting of the following two elements:
+ A tuple consisting of the following elements:
new_values: New values for the row factors.
update_op: An op that assigns the newly computed values to the row
factors.
+ loss: A tensor (scalar) that contains the normalized minibatch loss,
+ corresponding to sp_input.
+ if sp_input contains the rows {A_{i, :}, i \in I}, and the input matrix
+ A has n total rows, then loss is:
+ (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) *
+ n / |I| + \lambda \|V\|_F^2.
"""
- return self._process_input_helper(
- True, sp_input=sp_input, transpose_input=transpose_input)
+ return self._process_input_helper(True, sp_input=sp_input,
+ transpose_input=transpose_input)
def update_col_factors(self, sp_input=None, transpose_input=False):
- """Updates the column factors.
+ r"""Updates the column factors.
Args:
sp_input: A SparseTensor representing a subset of columns of the full
@@ -641,13 +660,17 @@ class WALSModel(object):
new_values: New values for the column factors.
update_op: An op that assigns the newly computed values to the column
factors.
+ loss: A tensor (scalar) that contains the normalized minibatch loss,
+ corresponding to sp_input.
+ If sp_input contains the columns {A_{:, j}, j \in J}, and the input
+ matrix A has m total columns, then loss is:
+ (\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 + \lambda \|V_J\|_F^2) *
+ m / |J| + \lambda \|U\|_F^2.
"""
- return self._process_input_helper(
- False, sp_input=sp_input, transpose_input=transpose_input)
+ return self._process_input_helper(False, sp_input=sp_input,
+ transpose_input=transpose_input)
- def project_row_factors(self,
- sp_input=None,
- transpose_input=False,
+ def project_row_factors(self, sp_input=None, transpose_input=False,
projection_weights=None):
"""Projects the row factors.
@@ -672,11 +695,9 @@ class WALSModel(object):
"""
if projection_weights is None:
projection_weights = 1
- return self._process_input_helper(
- True,
- sp_input=sp_input,
- transpose_input=transpose_input,
- row_weights=projection_weights)[0]
+ return self._process_input_helper(True, sp_input=sp_input,
+ transpose_input=transpose_input,
+ row_weights=projection_weights)[0]
def project_col_factors(self,
sp_input=None,
@@ -705,16 +726,12 @@ class WALSModel(object):
"""
if projection_weights is None:
projection_weights = 1
- return self._process_input_helper(
- False,
- sp_input=sp_input,
- transpose_input=transpose_input,
- row_weights=projection_weights)[0]
-
- def _process_input_helper(self,
- update_row_factors,
- sp_input=None,
- transpose_input=False,
+ return self._process_input_helper(False, sp_input=sp_input,
+ transpose_input=transpose_input,
+ row_weights=projection_weights)[0]
+
+ def _process_input_helper(self, update_row_factors,
+ sp_input=None, transpose_input=False,
row_weights=None):
"""Creates the graph for processing a sparse slice of input.
@@ -734,10 +751,12 @@ class WALSModel(object):
of columns to be updated/projected.
Returns:
- A tuple consisting of the following two elements:
+ A tuple consisting of the following three elements:
new_values: New values for the row/column factors.
update_op: An op that assigns the newly computed values to the row/column
factors.
+ loss: A tensor (scalar) that contains the normalized minibatch loss,
+ corresponding to sp_input.
"""
assert isinstance(sp_input, sparse_tensor.SparseTensor)
@@ -746,6 +765,7 @@ class WALSModel(object):
right_factors = self._col_factors_cache
row_wt = self._row_wt_cache
col_wt = self._col_wt_cache
+ total_rows = self._input_rows
sharding_func = WALSModel._get_sharding_func(self._input_rows,
self._num_row_shards)
gramian = self._col_gramian_cache
@@ -754,6 +774,7 @@ class WALSModel(object):
right_factors = self._row_factors_cache
row_wt = self._col_wt_cache
col_wt = self._row_wt_cache
+ total_rows = self._input_cols
sharding_func = WALSModel._get_sharding_func(self._input_cols,
self._num_col_shards)
gramian = self._row_gramian_cache
@@ -799,8 +820,8 @@ class WALSModel(object):
# Compute lhs and rhs of the normal equations
total_lhs = (self._unobserved_weight * gramian)
- if self._regularization is not None:
- total_lhs += self._regularization
+ if self._regularization_matrix is not None:
+ total_lhs += self._regularization_matrix
if self._row_weights is None:
# Special case of ALS. Use a much simpler update rule.
total_rhs = (self._unobserved_weight *
@@ -819,30 +840,68 @@ class WALSModel(object):
row_weights_slice = embedding_ops.embedding_lookup(
row_wt, update_indices, partition_strategy="div")
else:
+ num_indices = array_ops.shape(update_indices)[0]
with ops.control_dependencies(
[check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
row_weights_slice = control_flow_ops.cond(
math_ops.equal(array_ops.rank(row_weights), 0),
- lambda: (array_ops.ones([array_ops.shape(update_indices)[0]]) * row_weights),
+ lambda: (array_ops.ones([num_indices]) * row_weights),
lambda: math_ops.cast(row_weights, dtypes.float32))
col_weights = embedding_ops.embedding_lookup(
col_wt, gather_indices, partition_strategy="div")
- partial_lhs, total_rhs = wals_compute_partial_lhs_and_rhs(
- right,
- col_weights,
- self._unobserved_weight,
- row_weights_slice,
- new_sp_input.indices,
- new_sp_input.values,
- num_rows,
- transpose_input,
- name="wals_compute_partial_lhs_rhs")
+ partial_lhs, total_rhs = (
+ gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
+ right,
+ col_weights,
+ self._unobserved_weight,
+ row_weights_slice,
+ new_sp_input.indices,
+ new_sp_input.values,
+ num_rows,
+ transpose_input,
+ name="wals_compute_partial_lhs_rhs"))
total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
total_rhs = array_ops.expand_dims(total_rhs, -1)
new_left_values = array_ops.squeeze(
linalg_ops.matrix_solve(total_lhs, total_rhs), [2])
- return (new_left_values, self.scatter_update(left, update_indices,
- new_left_values,
- sharding_func))
+ update_op_name = "row_update" if update_row_factors else "col_update"
+ update_op = self.scatter_update(left, update_indices, new_left_values,
+ sharding_func, name=update_op_name)
+
+ # Create the loss subgraph
+ loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
+ if transpose_input else new_sp_input)
+ # sp_approx is the low rank estimate of the input matrix, formed by
+ # computing the product <u_i, v_j> for (i, j) in loss_sp_input.indices.
+ sp_approx_vals = gen_factorization_ops.masked_matmul(
+ new_left_values, right, loss_sp_input.indices, transpose_a=False,
+ transpose_b=True)
+ sp_approx = sparse_tensor.SparseTensor(
+ loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
+ sp_approx_sq = math_ops.square(sp_approx)
+ sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
+ sp_residual_sq = math_ops.square(sp_residual)
+ row_wt_mat = (constant_op.constant(0.) if self._row_weights is None else
+ array_ops.expand_dims(row_weights_slice, 1))
+ col_wt_mat = (constant_op.constant(0.) if self._col_weights is None else
+ array_ops.expand_dims(col_weights, 0))
+ # We return the normalized loss
+ partial_row_gramian = math_ops.matmul(
+ new_left_values, new_left_values, transpose_a=True)
+ normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)
+ loss = (
+ self._unobserved_weight * (
+ sparse_ops.sparse_reduce_sum(sp_residual_sq) -
+ sparse_ops.sparse_reduce_sum(sp_approx_sq) +
+ math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))
+ ) +
+ sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
+ ) * normalization_factor
+ if self._regularization is not None:
+ loss += self._regularization * (
+ math_ops.trace(partial_row_gramian) * normalization_factor +
+ math_ops.trace(gramian)
+ )
+ return (new_left_values, update_op, loss)
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index 75ef87d15d..bbfcfabf40 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -30,9 +30,14 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.factorization.python.ops import factorization_ops
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
INPUT_MATRIX = np.array(
@@ -85,6 +90,92 @@ def sparse_input():
return np_matrix_to_tf_sparse(INPUT_MATRIX)
+def count_rows(sp_input):
+ return math_ops.cast(
+ array_ops.shape(array_ops.unique(sp_input.indices[:, 0])[0])[0],
+ dtypes.float32)
+
+
+def count_cols(sp_input):
+ return math_ops.cast(
+ array_ops.shape(array_ops.unique(sp_input.indices[:, 1])[0])[0],
+ dtypes.float32)
+
+
+def calculate_loss(input_mat, row_factors, col_factors, regularization=None,
+ w0=1., row_weights=None, col_weights=None):
+ """Calculates the loss of a given factorization.
+
+ Using a non distributed method, different than the one implemented in the
+ WALS model. The weight of an observed entry (i, j) (i.e. such that
+ input_mat[i, j] is non zero) is (w0 + row_weights[i]col_weights[j]).
+
+ Args:
+ input_mat: The input matrix, a SparseTensor of rank 2.
+ row_factors: The row factors, a dense Tensor of rank 2.
+ col_factors: The col factors, a dense Tensor of rank 2.
+ regularization: the regularization coefficient, a scalar.
+ w0: the weight of unobserved entries. A scalar.
+ row_weights: A dense tensor of rank 1.
+ col_weights: A dense tensor of rank 1.
+
+ Returns:
+ The total loss.
+ """
+ wr = (array_ops.expand_dims(row_weights, 1) if row_weights is not None
+ else constant_op.constant(1.))
+ wc = (array_ops.expand_dims(col_weights, 0) if col_weights is not None
+ else constant_op.constant(1.))
+ reg = (regularization if regularization is not None
+ else constant_op.constant(0.))
+
+ row_indices, col_indices = array_ops.split(input_mat.indices,
+ axis=1,
+ num_or_size_splits=2)
+ gathered_row_factors = array_ops.gather(row_factors, row_indices)
+ gathered_col_factors = array_ops.gather(col_factors, col_indices)
+ sp_approx_vals = array_ops.squeeze(math_ops.matmul(
+ gathered_row_factors, gathered_col_factors, adjoint_b=True))
+ sp_approx = sparse_tensor.SparseTensor(
+ indices=input_mat.indices,
+ values=sp_approx_vals,
+ dense_shape=input_mat.dense_shape)
+
+ sp_approx_sq = math_ops.square(sp_approx)
+ row_norm = math_ops.reduce_sum(math_ops.square(row_factors))
+ col_norm = math_ops.reduce_sum(math_ops.square(col_factors))
+ row_col_norm = math_ops.reduce_sum(math_ops.square(math_ops.matmul(
+ row_factors, col_factors, transpose_b=True)))
+
+ resid = sparse_ops.sparse_add(input_mat, sp_approx * (-1))
+ resid_sq = math_ops.square(resid)
+ loss = w0 * (
+ sparse_ops.sparse_reduce_sum(resid_sq) -
+ sparse_ops.sparse_reduce_sum(sp_approx_sq)
+ )
+ loss += (sparse_ops.sparse_reduce_sum(wr * (resid_sq * wc)) +
+ w0 * row_col_norm + reg * (row_norm + col_norm))
+ return loss.eval()
+
+
+def calculate_loss_from_wals_model(wals_model, sp_inputs):
+ current_rows = embedding_ops.embedding_lookup(
+ wals_model.row_factors, math_ops.range(wals_model._input_rows),
+ partition_strategy="div")
+ current_cols = embedding_ops.embedding_lookup(
+ wals_model.col_factors, math_ops.range(wals_model._input_cols),
+ partition_strategy="div")
+ row_wts = embedding_ops.embedding_lookup(
+ wals_model._row_weights, math_ops.range(wals_model._input_rows),
+ partition_strategy="div")
+ col_wts = embedding_ops.embedding_lookup(
+ wals_model._col_weights, math_ops.range(wals_model._input_cols),
+ partition_strategy="div")
+ return calculate_loss(
+ sp_inputs, current_rows, current_cols, wals_model._regularization,
+ wals_model._unobserved_weight, row_wts, col_wts)
+
+
class WalsModelTest(test.TestCase):
def setUp(self):
@@ -103,7 +194,6 @@ class WalsModelTest(test.TestCase):
self.row_wts = [[0.1, 0.2, 0.3], [0.4, 0.5]]
self.col_wts = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]
- self._wals_inputs = sparse_input()
# Values of factor shards after running one iteration of row and column
# updates.
@@ -120,13 +210,19 @@ class WalsModelTest(test.TestCase):
self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
[0.57366, 1.83729, 1.26798]]
- def _run_test_process_input(self, use_factors_weights_cache):
- with self.test_session():
+ def _run_test_process_input(self,
+ use_factors_weights_cache,
+ compute_loss=False):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ self._wals_inputs = sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
+ num_rows = 5
+ num_cols = 7
+ factor_dim = 3
wals_model = factorization_ops.WALSModel(
- 5,
- 7,
- 3,
+ num_rows,
+ num_cols,
+ factor_dim,
num_row_shards=2,
num_col_shards=3,
regularization=0.01,
@@ -152,8 +248,8 @@ class WalsModelTest(test.TestCase):
# Here we feed in scattered rows of the input.
wals_model.row_update_prep_gramian_op.run()
wals_model.initialize_row_update_op.run()
- process_input_op = wals_model.update_row_factors(
- sp_input=sp_feeder, transpose_input=False)[1]
+ _, process_input_op, factor_loss = wals_model.update_row_factors(
+ sp_input=sp_feeder, transpose_input=False)
for inp in input_scattered_rows:
feed_dict = {sp_feeder: inp}
process_input_op.run(feed_dict=feed_dict)
@@ -189,6 +285,19 @@ class WalsModelTest(test.TestCase):
[[0.569082, 0.715088, 0.31777], [1.915879, 1.992677, 1.109057]],
atol=1e-3)
+ if compute_loss:
+ # Test loss computation after the row update
+ loss = sum(
+ sess.run(factor_loss * count_rows(inp) / num_rows,
+ feed_dict={sp_feeder: inp})
+ for inp in input_scattered_rows)
+ true_loss = calculate_loss_from_wals_model(
+ wals_model, self._wals_inputs)
+ self.assertNear(
+ loss, true_loss, err=.001,
+ msg="""After row update, computed loss = {}, does not match
+ the true loss = {}.""".format(loss, true_loss))
+
# Split input into multiple sparse tensors with scattered columns. Note
# that here the elements in the sparse tensors are not ordered and also
# do not need to consist of consecutive columns. However, each column
@@ -201,13 +310,14 @@ class WalsModelTest(test.TestCase):
INPUT_MATRIX, col_slices=[3, 6], shuffle=True).eval()
input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3]
+ input_scattered_cols_non_duplicate = [sp_c0, sp_c1, sp_c2]
# Test updating column factors.
# Here we feed in scattered columns of the input.
wals_model.col_update_prep_gramian_op.run()
wals_model.initialize_col_update_op.run()
- process_input_op = wals_model.update_col_factors(
- sp_input=sp_feeder, transpose_input=False)[1]
+ _, process_input_op, factor_loss = wals_model.update_col_factors(
+ sp_input=sp_feeder, transpose_input=False)
for inp in input_scattered_cols:
feed_dict = {sp_feeder: inp}
process_input_op.run(feed_dict=feed_dict)
@@ -248,13 +358,31 @@ class WalsModelTest(test.TestCase):
[0.346433, 1.360644, 1.677121]],
atol=1e-3)
- def _run_test_process_input_transposed(self, use_factors_weights_cache):
- with self.test_session():
+ if compute_loss:
+ # Test loss computation after the column update.
+ loss = sum(
+ sess.run(factor_loss * count_cols(inp) / num_cols,
+ feed_dict={sp_feeder: inp})
+ for inp in input_scattered_cols_non_duplicate)
+ true_loss = calculate_loss_from_wals_model(
+ wals_model, self._wals_inputs)
+ self.assertNear(
+ loss, true_loss, err=.001,
+ msg="""After col update, computed loss = {}, does not match the true
+ loss = {}.""".format(loss, true_loss))
+
+ def _run_test_process_input_transposed(self, use_factors_weights_cache,
+ compute_loss=False):
+ with ops.Graph().as_default(), self.test_session() as sess:
+ self._wals_inputs = sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
+ num_rows = 5
+ num_cols = 7
+ factor_dim = 3
wals_model = factorization_ops.WALSModel(
- 5,
- 7,
- 3,
+ num_rows,
+ num_cols,
+ factor_dim,
num_row_shards=2,
num_col_shards=3,
regularization=0.01,
@@ -278,7 +406,7 @@ class WalsModelTest(test.TestCase):
sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval()
sp_r3_t = sp_r1_t
input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t]
-
+ input_scattered_rows_non_duplicate = [sp_r0_t, sp_r1_t, sp_r2_t]
# Test updating row factors.
# Here we feed in scattered rows of the input.
# Note that the needed suffix of placeholder are in the order of test
@@ -286,8 +414,8 @@ class WalsModelTest(test.TestCase):
# they appear.
wals_model.row_update_prep_gramian_op.run()
wals_model.initialize_row_update_op.run()
- process_input_op = wals_model.update_row_factors(
- sp_input=sp_feeder, transpose_input=True)[1]
+ _, process_input_op, factor_loss = wals_model.update_row_factors(
+ sp_input=sp_feeder, transpose_input=True)
for inp in input_scattered_rows:
feed_dict = {sp_feeder: inp}
process_input_op.run(feed_dict=feed_dict)
@@ -323,6 +451,19 @@ class WalsModelTest(test.TestCase):
[[1.915879, 1.992677, 1.109057], [0.569082, 0.715088, 0.31777]],
atol=1e-3)
+ if compute_loss:
+ # Test loss computation after the row update
+ loss = sum(
+ sess.run(factor_loss * count_cols(inp) / num_rows,
+ feed_dict={sp_feeder: inp})
+ for inp in input_scattered_rows_non_duplicate)
+ true_loss = calculate_loss_from_wals_model(
+ wals_model, self._wals_inputs)
+ self.assertNear(
+ loss, true_loss, err=.001,
+ msg="""After row update, computed loss = {}, does not match the true
+ loss = {}.""".format(loss, true_loss))
+
# Split input into multiple SparseTensors with scattered columns.
# Here the inputs are transposed. But the same constraints as described in
# the previous non-transposed test case apply to these inputs (before they
@@ -338,13 +479,14 @@ class WalsModelTest(test.TestCase):
sp_c4_t = sp_c2_t
input_scattered_cols = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, sp_c4_t]
+ input_scattered_cols_non_duplicate = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t]
# Test updating column factors.
# Here we feed in scattered columns of the input.
wals_model.col_update_prep_gramian_op.run()
wals_model.initialize_col_update_op.run()
- process_input_op = wals_model.update_col_factors(
- sp_input=sp_feeder, transpose_input=True)[1]
+ _, process_input_op, factor_loss = wals_model.update_col_factors(
+ sp_input=sp_feeder, transpose_input=True)
for inp in input_scattered_cols:
feed_dict = {sp_feeder: inp}
process_input_op.run(feed_dict=feed_dict)
@@ -377,15 +519,28 @@ class WalsModelTest(test.TestCase):
[[3.585139, -0.487476, -3.852232],
[0.557937, 1.813907, 1.331171]],
atol=1e-3)
-
- # Note that when row_weights and col_weights are 0, WALS gives dentical
+ if compute_loss:
+ # Test loss computation after the col update
+ loss = sum(
+ sess.run(factor_loss * count_rows(inp) / num_cols,
+ feed_dict={sp_feeder: inp})
+ for inp in input_scattered_cols_non_duplicate)
+ true_loss = calculate_loss_from_wals_model(
+ wals_model, self._wals_inputs)
+ self.assertNear(
+ loss, true_loss, err=.001,
+ msg="""After col update, computed loss = {}, does not match the true
+ loss = {}.""".format(loss, true_loss))
+
+ # Note that when row_weights and col_weights are 0, WALS gives identical
# results as ALS (Alternating Least Squares). However our implementation does
# not handle the case of zero weights differently. Instead, when row_weights
# and col_weights are set to None, we interpret that as the ALS case, and
# trigger the more efficient ALS updates.
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
- with self.test_session():
+ with ops.Graph().as_default(), self.test_session():
+ self._wals_inputs = sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
5,
@@ -463,7 +618,8 @@ class WalsModelTest(test.TestCase):
atol=1e-2)
def _run_test_als_transposed(self, use_factors_weights_cache):
- with self.test_session():
+ with ops.Graph().as_default(), self.test_session():
+ self._wals_inputs = sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
5,
@@ -552,7 +708,7 @@ class WalsModelTest(test.TestCase):
rows = 15
cols = 11
dims = 3
- with self.test_session():
+ with ops.Graph().as_default(), self.test_session():
data = np.dot(np.random.rand(rows, 3),
np.random.rand(3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -582,7 +738,7 @@ class WalsModelTest(test.TestCase):
cols = 11
dims = 3
- with self.test_session():
+ with ops.Graph().as_default(), self.test_session():
data = np.dot(np.random.rand(rows, 3),
np.random.rand(3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -615,7 +771,7 @@ class WalsModelTest(test.TestCase):
def keep_index(x):
return not (x[0] + x[1]) % 4
- with self.test_session():
+ with ops.Graph().as_default(), self.test_session():
row_wts = 0.1 + np.random.rand(rows)
col_wts = 0.1 + np.random.rand(cols)
data = np.dot(np.random.rand(rows, 3),
@@ -683,6 +839,17 @@ class WalsModelTest(test.TestCase):
def test_train_matrix_completion_wals_without_cache(self):
self._run_test_train_matrix_completion_wals(False)
+ def test_loss_transposed_with_cache(self):
+ self._run_test_process_input_transposed(True, compute_loss=True)
+
+ def test_loss_transposed_without_cache(self):
+ self._run_test_process_input_transposed(False, compute_loss=True)
+
+ def test_loss_with_cache(self):
+ self._run_test_process_input(True, compute_loss=True)
+
+ def test_loss_without_cache(self):
+ self._run_test_process_input(False, compute_loss=True)
if __name__ == "__main__":
test.main()