diff options
author | 2017-03-29 17:12:58 -0800 | |
---|---|---|
committer | 2017-03-29 18:33:33 -0700 | |
commit | 79d5e63091418d20b6ee01b5cc7e598a0606d691 (patch) | |
tree | b6ef49d8071110c711a7aa3cc7aca93d0fd504ee /tensorflow | |
parent | 1f6c19e3a9131db4fa0353982b7bf44d61fe2ce8 (diff) |
Add a SessionRunHook to manage synchronization of the row and column sweeps in the WALS solver.
Change: 151652662
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/factorization/BUILD | 21 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals.py | 223 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals_test.py | 129 |
4 files changed, 374 insertions, 0 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 338e67e953..aa26c6060f 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -23,6 +23,7 @@ tf_custom_op_py_library( "python/ops/factorization_ops.py", "python/ops/gmm.py", "python/ops/gmm_ops.py", + "python/ops/wals.py", ], dso = [ ":python/ops/_clustering_ops.so", @@ -176,6 +177,26 @@ tf_py_test( ], ) +# Estimators tests +tf_py_test( + name = "wals_test", + size = "medium", + srcs = ["python/ops/wals_test.py"], + additional_deps = [ + ":factorization_py", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:state_ops", + "//tensorflow/python:variables", + ], +) + # Kernel tests tf_py_test( name = "wals_solver_ops_test", diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py index 63243a7fdb..f0ca879259 100644 --- a/tensorflow/contrib/factorization/__init__.py +++ b/tensorflow/contrib/factorization/__init__.py @@ -23,4 +23,5 @@ from tensorflow.contrib.factorization.python.ops.clustering_ops import * from tensorflow.contrib.factorization.python.ops.factorization_ops import * from tensorflow.contrib.factorization.python.ops.gmm import * from tensorflow.contrib.factorization.python.ops.gmm_ops import * +from tensorflow.contrib.factorization.python.ops.wals import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py new file mode 100644 index 0000000000..3fd2cbbec2 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Weighted Alternating Least Squares (WALS) on the tf.learn API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import variables as framework_variables +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import session_run_hook + + +class _SweepHook(session_run_hook.SessionRunHook): + """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" + + def __init__(self, + is_row_sweep_var, + train_op, + num_rows, + num_cols, + processed_row_indices, + processed_col_indices, + row_prep_ops, + col_prep_ops, + cache_init_ops): + """Initializes SweepHook. + + Args: + is_row_sweep_var: A Boolean tf.Variable, determines whether we are + currently doing a row or column sweep. It is updated by the hook. + train_op: An op. All the ops created by the hook will have + control_dependencies on train_op. + num_rows: int, the total number of rows to be processed. + num_cols: int, the total number of columns to be processed. + processed_row_indices: A Tensor of type int64. The indices of the input + rows that are processed during the current sweep. All elements of + processed_row_indices must be in [0, num_rows). + processed_col_indices: A Tensor of type int64. The indices of the input + columns that are processed during the current sweep. All elements of + processed_col_indices must be in [0, num_cols). + row_prep_ops: list of ops, to be run before the beginning of each row + sweep, in the given order. + col_prep_ops: list of ops, to be run before the beginning of each column + sweep, in the given order. + cache_init_ops: list of ops, to be run once before training, in the given + order. These are typically local initialization ops (such as cache + initialization). + """ + self._num_rows = num_rows + self._num_cols = num_cols + self._row_prep_ops = row_prep_ops + self._col_prep_ops = col_prep_ops + self._cache_init_ops = cache_init_ops + self._is_row_sweep_var = is_row_sweep_var + # Boolean variable that determines whether the cache_init_ops have been run. + self._is_initialized = False + # Boolean variable that is set to True when a sweep is completed. + # Used to run the prep_ops at the beginning of a sweep, in before_run(). + self._is_sweep_done = False + # Ops to run jointly with train_op, responsible for updating + # _is_row_sweep_var and incrementing the global_step counter. They have + # control_dependencies on train_op. + self._fetches = self._create_switch_ops(processed_row_indices, + processed_col_indices, + train_op) + + def _create_switch_ops(self, + processed_row_indices, + processed_col_indices, + train_op): + """Creates ops to update is_row_sweep_var and to increment global_step. + + Creates two boolean tensors processed_rows and processed_cols, which keep + track of which rows/cols have been processed during the current sweep. + Returns ops that should be run after each row / col update. + - When is_row_sweep_var is True, it sets + processed_rows[processed_row_indices] to True. + - When is_row_sweep_var is False, it sets + processed_cols[processed_col_indices] to True . + When all rows or all cols have been processed, negates is_row_sweep_var and + resets processed_rows and processed_cols to False. + All of the ops created by this function have control_dependencies on + train_op. + + Args: + processed_row_indices: A Tensor. The indices of the input rows that are + processed during the current sweep. + processed_col_indices: A Tensor. The indices of the input columns that + are processed during the current sweep. + train_op: An op. All the ops created by this function have + control_dependencies on train_op. + Returns: + A list consisting of: + is_sweep_done: A Boolean tensor, determines whether the sweep is done, + i.e. all rows (during a row sweep) or all columns (during a column + sweep) have been processed. + switch_ops: An op that updates is_row_sweep_var when is_sweep_done is + True. Has control_dependencies on train_op. + global_step_incr_op: An op that increments the global_step counter. Has + control_dependenciens on switch_ops. + """ + processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) + with ops.colocate_with(processed_rows_init): + processed_rows = variables.Variable( + processed_rows_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="sweep_hook_processed_rows") + processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) + with ops.colocate_with(processed_cols_init): + processed_cols = variables.Variable( + processed_cols_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="sweep_hook_processed_cols") + # After running the train_op, update processed_rows or processed_cols + # tensors, depending on whether we are currently doing a row or a col sweep + with ops.control_dependencies([train_op]): + def get_row_update_op(): + with ops.colocate_with(processed_rows): + return state_ops.scatter_update( + processed_rows, processed_row_indices, + array_ops.ones_like(processed_row_indices, dtype=dtypes.bool)) + + def get_col_update_op(): + with ops.colocate_with(processed_cols): + return state_ops.scatter_update( + processed_cols, processed_col_indices, + array_ops.ones_like(processed_col_indices, dtype=dtypes.bool)) + + update_processed_op = control_flow_ops.cond( + self._is_row_sweep_var, get_row_update_op, get_col_update_op) + + # After update_processed_op, check whether we have completed a sweep. + # If this is the case, flip the is_row_sweep_var and reset processed_rows + # and processed_cols tensors. + with ops.control_dependencies([update_processed_op]): + def get_switch_op(): + return state_ops.assign( + self._is_row_sweep_var, + gen_math_ops.logical_not(self._is_row_sweep_var)).op + + def get_reset_op(): + return control_flow_ops.group( + state_ops.assign(processed_rows, processed_rows_init).op, + state_ops.assign(processed_cols, processed_cols_init).op) + + is_sweep_done = control_flow_ops.cond( + self._is_row_sweep_var, + lambda: math_ops.reduce_all(processed_rows), + lambda: math_ops.reduce_all(processed_cols), + name="sweep_hook_is_sweep_done") + switch_op = control_flow_ops.cond( + is_sweep_done, get_switch_op, control_flow_ops.no_op, + name="sweep_hook_switch_op") + reset_op = control_flow_ops.cond( + is_sweep_done, get_reset_op, control_flow_ops.no_op, + name="sweep_hook_reset_op") + switch_ops = control_flow_ops.group(switch_op, reset_op, + name="sweep_hook_switch_ops") + + # Op to increment the global step + global_step = framework_variables.get_global_step() + with ops.control_dependencies([switch_ops]): + if global_step is not None: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op( + name="global_step_incr") + + return [is_sweep_done, switch_ops, global_step_incr_op] + + def begin(self): + pass + + def before_run(self, run_context): + """Runs the appropriate prep ops, and requests running update ops.""" + # Run the appropriate cache_init and prep ops + sess = run_context.session + if not self._is_initialized: + logging.info("SweepHook running cache init ops.") + for init_op in self._cache_init_ops: + sess.run(init_op) + + if self._is_sweep_done or not self._is_initialized: + logging.info("SweepHook running sweep prep ops.") + row_sweep = sess.run(self._is_row_sweep_var) + prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + for prep_op in prep_ops: + sess.run(prep_op) + + self._is_initialized = True + + # Request running the switch_ops and the global_step_incr_op + logging.info("Partial fit starting.") + return session_run_hook.SessionRunArgs(fetches=self._fetches) + + def after_run(self, run_context, run_values): + self._is_sweep_done = run_values.results[0] + logging.info("Partial fit done.") + diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py new file mode 100644 index 0000000000..2ae2d3ab05 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -0,0 +1,129 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for WALSMatrixFactorization.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.factorization.python.ops import wals as wals_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import session_run_hook + + +class SweepHookTest(test.TestCase): + + def setUp(self): + self._num_rows = 5 + self._num_cols = 7 + self._train_op = control_flow_ops.no_op() + self._row_prep_done = variables.Variable(False) + self._col_prep_done = variables.Variable(False) + self._init_done = variables.Variable(False) + self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] + self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] + self._init_ops = [state_ops.assign(self._init_done, True)] + self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) + self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) + + def run_hook_with_indices(self, sweep_hook, row_indices, col_indices): + with self.test_session() as sess: + # Before run + run_context = session_run_hook.SessionRunContext( + original_args=None, session=sess) + sess_run_args = sweep_hook.before_run(run_context) + feed_dict = { + self._input_row_indices_ph: row_indices, + self._input_col_indices_ph: col_indices + } + # Run + run_results = sess.run(sess_run_args.fetches, feed_dict=feed_dict) + run_values = session_run_hook.SessionRunValues( + results=run_results, options=None, run_metadata=None) + # After run + sweep_hook.after_run(run_context, run_values) + + def test_row_sweep(self): + with self.test_session() as sess: + is_row_sweep_var = variables.Variable(True) + sweep_hook = wals_lib._SweepHook( + is_row_sweep_var, + self._train_op, + self._num_rows, + self._num_cols, + self._input_row_indices_ph, + self._input_col_indices_ph, + self._row_prep_ops, + self._col_prep_ops, + self._init_ops) + + # Initialize variables + sess.run([variables.global_variables_initializer()]) + # Row sweep + self.run_hook_with_indices(sweep_hook, [], []) + self.assertTrue(sess.run(self._init_done), + msg='init ops not run by the sweep_hook') + self.assertTrue(sess.run(self._row_prep_done), + msg='row_prep not run by the sweep_hook') + self.run_hook_with_indices(sweep_hook, [0, 1, 2], []) + self.assertTrue(sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep is ' + 'False.') + self.run_hook_with_indices(sweep_hook, [3, 4], [0, 1, 2, 3, 4, 5, 6]) + self.assertFalse(sess.run(is_row_sweep_var), + msg='Row sweep is complete but is_row_sweep is True.') + self.assertTrue(sweep_hook._is_sweep_done, + msg='Sweep is complete but is_sweep_done is False.') + + def test_col_sweep(self): + with self.test_session() as sess: + is_row_sweep_var = variables.Variable(False) + sweep_hook = wals_lib._SweepHook( + is_row_sweep_var, + self._train_op, + self._num_rows, + self._num_cols, + self._input_row_indices_ph, + self._input_col_indices_ph, + self._row_prep_ops, + self._col_prep_ops, + self._init_ops) + + # Initialize variables + sess.run([variables.global_variables_initializer()]) + # Col sweep + self.run_hook_with_indices(sweep_hook, [], []) + self.assertTrue(sess.run(self._col_prep_done), + msg='col_prep not run by the sweep_hook') + self.run_hook_with_indices(sweep_hook, [], [0, 1, 2, 3, 4]) + self.assertFalse(sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep is ' + 'True.') + self.assertFalse(sweep_hook._is_sweep_done, + msg='Sweep is not complete but is_sweep_done is True.') + self.run_hook_with_indices(sweep_hook, [], [4, 5, 6]) + self.assertTrue(sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep is False') + self.assertTrue(sweep_hook._is_sweep_done, + msg='Sweep is complete but is_sweep_done is False.') + + +if __name__ == '__main__': + test.main() |