aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-04 07:29:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 07:34:57 -0700
commit956159119b6c81ec500dc541f6f5ea3f776f2d0a (patch)
tree83012fe884f948e7ec3f3c9a89ea3588a2b62d08 /tensorflow/contrib/factorization
parent1678f76bfb7c1b2ba46fa50af7cb548859179d8f (diff)
Minor fix in the WALS estimator test, to pass the correct shape of the transposed matrix.
PiperOrigin-RevId: 211453816
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py70
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py10
2 files changed, 43 insertions, 37 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index ca46c39baa..b82bf1188f 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator):
WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
factorization. It computes a low-rank approximation of a given sparse (n x m)
- matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix
- and V is a (m x k) matrix. Here k is the rank of the approximation, also
- called the embedding dimension. We refer to U as the row factors, and V as the
- column factors.
+ matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k)
+ matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation,
+ also called the embedding dimension. We refer to `U` as the row factors, and
+ `V` as the column factors.
See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
formulation.
- The training proceeds in sweeps: during a row_sweep, we fix V and solve for U.
- During a column sweep, we fix U and solve for V. Each one of these problems is
- an unconstrained quadratic minimization problem and can be solved exactly (it
- can also be solved in mini-batches, since the solution decouples nicely).
+ The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for
+ `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these
+ problems is an unconstrained quadratic minimization problem and can be solved
+ exactly (it can also be solved in mini-batches, since the solution decouples
+ across rows of each matrix).
The alternating between sweeps is achieved by using a hook during training,
which is responsible for keeping track of the sweeps and running preparation
ops at the beginning of each sweep. It also updates the global_step variable,
which keeps track of the number of batches processed since the beginning of
training.
The current implementation assumes that the training is run on a single
- machine, and will fail if config.num_worker_replicas is not equal to one.
- Training is done by calling self.fit(input_fn=input_fn), where input_fn
+ machine, and will fail if `config.num_worker_replicas` is not equal to one.
+ Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn`
provides two tensors: one for rows of the input matrix, and one for rows of
the transposed input matrix (i.e. columns of the original matrix). Note that
during a row sweep, only row batches are processed (ignoring column batches)
and vice-versa.
Also note that every row (respectively every column) of the input matrix
must be processed at least once for the sweep to be considered complete. In
- particular, training will not make progress if input_fn does not generate some
- rows.
-
- For prediction, given a new set of input rows A' (e.g. new rows of the A
- matrix), we compute a corresponding set of row factors U', such that U' * V^T
- is a good approximation of A'. We call this operation a row projection. A
- similar operation is defined for columns.
- Projection is done by calling self.get_projections(input_fn=input_fn), where
- input_fn satisfies the constraints given below.
-
- The input functions must satisfy the following constraints: Calling input_fn
- must return a tuple (features, labels) where labels is None, and features is
- a dict containing the following keys:
+ particular, training will not make progress if some rows are not generated by
+ the `input_fn`.
+
+ For prediction, given a new set of input rows `A'`, we compute a corresponding
+ set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`.
+ We call this operation a row projection. A similar operation is defined for
+ columns. Projection is done by calling
+ `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the
+ constraints given below.
+
+ The input functions must satisfy the following constraints: Calling `input_fn`
+ must return a tuple `(features, labels)` where `labels` is None, and
+ `features` is a dict containing the following keys:
+
TRAIN:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows of the input matrix to process (or to project).
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns of the input matrix to process (or to project), transposed.
+
INFER:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
- - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor
+ * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor
(vector). The weights to use in the projection.
+
EVAL:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
"""
# Keys to be used in model_fn
@@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator):
max_sweeps=None,
model_dir=None,
config=None):
- """Creates a model for matrix factorization using the WALS method.
+ r"""Creates a model for matrix factorization using the WALS method.
Args:
num_rows: Total number of rows for input matrix.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 36b483c6d7..31820a18b4 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase):
nz_row_ids = np.arange(np.shape(np_matrix)[0])
nz_col_ids = np.arange(np.shape(np_matrix)[1])
- def extract_features(row_batch, col_batch, shape):
+ def extract_features(row_batch, col_batch, num_rows, num_cols):
row_ids = row_batch[0]
col_ids = col_batch[0]
- rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
- cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
+ rows = self.remap_sparse_tensor_rows(
+ row_batch[1], row_ids, shape=[num_rows, num_cols])
+ cols = self.remap_sparse_tensor_rows(
+ col_batch[1], col_ids, shape=[num_cols, num_rows])
features = {
wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
@@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
capacity=10,
enqueue_many=True)
- features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
+ features = extract_features(row_batch, col_batch, num_rows, num_cols)
if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
self.assertTrue(