diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-04 07:29:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 07:34:57 -0700 |
commit | 956159119b6c81ec500dc541f6f5ea3f776f2d0a (patch) | |
tree | 83012fe884f948e7ec3f3c9a89ea3588a2b62d08 /tensorflow/contrib/factorization | |
parent | 1678f76bfb7c1b2ba46fa50af7cb548859179d8f (diff) |
Minor fix in the WALS estimator test, to pass the correct shape of the transposed matrix.
PiperOrigin-RevId: 211453816
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals.py | 70 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals_test.py | 10 |
2 files changed, 43 insertions, 37 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index ca46c39baa..b82bf1188f 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator): WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix factorization. It computes a low-rank approximation of a given sparse (n x m) - matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix - and V is a (m x k) matrix. Here k is the rank of the approximation, also - called the embedding dimension. We refer to U as the row factors, and V as the - column factors. + matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k) + matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation, + also called the embedding dimension. We refer to `U` as the row factors, and + `V` as the column factors. See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem formulation. - The training proceeds in sweeps: during a row_sweep, we fix V and solve for U. - During a column sweep, we fix U and solve for V. Each one of these problems is - an unconstrained quadratic minimization problem and can be solved exactly (it - can also be solved in mini-batches, since the solution decouples nicely). + The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for + `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these + problems is an unconstrained quadratic minimization problem and can be solved + exactly (it can also be solved in mini-batches, since the solution decouples + across rows of each matrix). The alternating between sweeps is achieved by using a hook during training, which is responsible for keeping track of the sweeps and running preparation ops at the beginning of each sweep. It also updates the global_step variable, which keeps track of the number of batches processed since the beginning of training. The current implementation assumes that the training is run on a single - machine, and will fail if config.num_worker_replicas is not equal to one. - Training is done by calling self.fit(input_fn=input_fn), where input_fn + machine, and will fail if `config.num_worker_replicas` is not equal to one. + Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn` provides two tensors: one for rows of the input matrix, and one for rows of the transposed input matrix (i.e. columns of the original matrix). Note that during a row sweep, only row batches are processed (ignoring column batches) and vice-versa. Also note that every row (respectively every column) of the input matrix must be processed at least once for the sweep to be considered complete. In - particular, training will not make progress if input_fn does not generate some - rows. - - For prediction, given a new set of input rows A' (e.g. new rows of the A - matrix), we compute a corresponding set of row factors U', such that U' * V^T - is a good approximation of A'. We call this operation a row projection. A - similar operation is defined for columns. - Projection is done by calling self.get_projections(input_fn=input_fn), where - input_fn satisfies the constraints given below. - - The input functions must satisfy the following constraints: Calling input_fn - must return a tuple (features, labels) where labels is None, and features is - a dict containing the following keys: + particular, training will not make progress if some rows are not generated by + the `input_fn`. + + For prediction, given a new set of input rows `A'`, we compute a corresponding + set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`. + We call this operation a row projection. A similar operation is defined for + columns. Projection is done by calling + `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the + constraints given below. + + The input functions must satisfy the following constraints: Calling `input_fn` + must return a tuple `(features, labels)` where `labels` is None, and + `features` is a dict containing the following keys: + TRAIN: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows of the input matrix to process (or to project). - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns of the input matrix to process (or to project), transposed. + INFER: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows to project. - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns to project. - - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project + * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project the rows or columns. - - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor + * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor (vector). The weights to use in the projection. + EVAL: - - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). Rows to project. - - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix). + * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). Columns to project. - - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project + * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project the rows or columns. """ # Keys to be used in model_fn @@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator): max_sweeps=None, model_dir=None, config=None): - """Creates a model for matrix factorization using the WALS method. + r"""Creates a model for matrix factorization using the WALS method. Args: num_rows: Total number of rows for input matrix. diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 36b483c6d7..31820a18b4 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase): nz_row_ids = np.arange(np.shape(np_matrix)[0]) nz_col_ids = np.arange(np.shape(np_matrix)[1]) - def extract_features(row_batch, col_batch, shape): + def extract_features(row_batch, col_batch, num_rows, num_cols): row_ids = row_batch[0] col_ids = col_batch[0] - rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape) - cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape) + rows = self.remap_sparse_tensor_rows( + row_batch[1], row_ids, shape=[num_rows, num_cols]) + cols = self.remap_sparse_tensor_rows( + col_batch[1], col_ids, shape=[num_cols, num_rows]) features = { wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows, wals_lib.WALSMatrixFactorization.INPUT_COLS: cols, @@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase): capacity=10, enqueue_many=True) - features = extract_features(row_batch, col_batch, sp_mat.dense_shape) + features = extract_features(row_batch, col_batch, num_rows, num_cols) if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL: self.assertTrue( |