diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/wals_test.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals_test.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 36b483c6d7..31820a18b4 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase): nz_row_ids = np.arange(np.shape(np_matrix)[0]) nz_col_ids = np.arange(np.shape(np_matrix)[1]) - def extract_features(row_batch, col_batch, shape): + def extract_features(row_batch, col_batch, num_rows, num_cols): row_ids = row_batch[0] col_ids = col_batch[0] - rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape) - cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape) + rows = self.remap_sparse_tensor_rows( + row_batch[1], row_ids, shape=[num_rows, num_cols]) + cols = self.remap_sparse_tensor_rows( + col_batch[1], col_ids, shape=[num_cols, num_rows]) features = { wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows, wals_lib.WALSMatrixFactorization.INPUT_COLS: cols, @@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase): capacity=10, enqueue_many=True) - features = extract_features(row_batch, col_batch, sp_mat.dense_shape) + features = extract_features(row_batch, col_batch, num_rows, num_cols) if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL: self.assertTrue( |