From dc81a24204fd93f2dcc90e11c28566b9567767f2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 6 Jun 2017 10:19:56 -0700 Subject: Updates to the WALSMatrixFactorization estimator: - Add a completed_sweeps variable to keep track of sweeps that have been completed during training. - Add a StopAtSweepHook, which can request a stop after completing a specified number of sweeps. PiperOrigin-RevId: 158156347 --- .../contrib/factorization/python/ops/wals.py | 103 ++++++++++++++++++--- .../contrib/factorization/python/ops/wals_test.py | 46 ++++++++- 2 files changed, 133 insertions(+), 16 deletions(-) (limited to 'tensorflow/contrib/factorization') diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 2f7bf48041..0bc0ef39ec 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook @@ -46,7 +47,8 @@ class _SweepHook(session_run_hook.SessionRunHook): processed_col_indices, row_prep_ops, col_prep_ops, - cache_init_ops): + cache_init_ops, + completed_sweeps_var): """Initializes SweepHook. Args: @@ -69,22 +71,24 @@ class _SweepHook(session_run_hook.SessionRunHook): cache_init_ops: list of ops, to be run once before training, in the given order. These are typically local initialization ops (such as cache initialization). + completed_sweeps_var: An integer tf.Variable, indicates the number of + completed sweeps. It is updated by the hook. """ - # TODO(walidk): Provide a counter for the number of completed sweeps. self._num_rows = num_rows self._num_cols = num_cols self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops self._cache_init_ops = cache_init_ops self._is_row_sweep_var = is_row_sweep_var + self._completed_sweeps_var = completed_sweeps_var # Boolean variable that determines whether the cache_init_ops have been run. self._is_initialized = False # Boolean variable that is set to True when a sweep is completed. # Used to run the prep_ops at the beginning of a sweep, in before_run(). self._is_sweep_done = False # Ops to run jointly with train_op, responsible for updating - # _is_row_sweep_var and incrementing the global_step counter. They have - # control_dependencies on train_op. + # _is_row_sweep_var and incrementing the global_step and completed_sweeps + # counters. They have control_dependencies on train_op. self._fetches = self._create_switch_ops(processed_row_indices, processed_col_indices, train_op) @@ -93,7 +97,7 @@ class _SweepHook(session_run_hook.SessionRunHook): processed_row_indices, processed_col_indices, train_op): - """Creates ops to update is_row_sweep_var and to increment global_step. + """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. Creates two boolean tensors processed_rows and processed_cols, which keep track of which rows/cols have been processed during the current sweep. @@ -102,8 +106,9 @@ class _SweepHook(session_run_hook.SessionRunHook): processed_rows[processed_row_indices] to True. - When is_row_sweep_var is False, it sets processed_cols[processed_col_indices] to True . - When all rows or all cols have been processed, negates is_row_sweep_var and - resets processed_rows and processed_cols to False. + When all rows or all cols have been processed, negates is_row_sweep_var, + increments the completed_sweeps counter, and resets processed_rows and + processed_cols to False. All of the ops created by this function have control_dependencies on train_op. @@ -121,9 +126,10 @@ class _SweepHook(session_run_hook.SessionRunHook): sweep) have been processed. switch_ops: An op that updates is_row_sweep_var when is_sweep_done is True. Has control_dependencies on train_op. - global_step_incr_op: An op that increments the global_step counter. Has - control_dependenciens on switch_ops. + incr_ops: An op that increments the global_step and completed_sweeps + counters. Has control_dependenciens on switch_ops. """ + processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( @@ -184,9 +190,16 @@ class _SweepHook(session_run_hook.SessionRunHook): switch_ops = control_flow_ops.group(switch_op, reset_op, name="sweep_hook_switch_ops") - # Op to increment the global step - global_step = framework_variables.get_global_step() with ops.control_dependencies([switch_ops]): + # Op to increment the completed_sweeps counter. + completed_sweeps_incr_op = control_flow_ops.cond( + is_sweep_done, + lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op, + control_flow_ops.no_op, + name="completed_sweeps_incr") + + # Op to increment the global_step counter. + global_step = framework_variables.get_global_step() if global_step is not None: global_step_incr_op = state_ops.assign_add( global_step, 1, name="global_step_incr").op @@ -194,7 +207,11 @@ class _SweepHook(session_run_hook.SessionRunHook): global_step_incr_op = control_flow_ops.no_op( name="global_step_incr") - return [is_sweep_done, switch_ops, global_step_incr_op] + incr_ops = control_flow_ops.group( + completed_sweeps_incr_op, global_step_incr_op, + name="counter_incr_ops") + + return [is_sweep_done, switch_ops, incr_ops] def begin(self): pass @@ -217,7 +234,7 @@ class _SweepHook(session_run_hook.SessionRunHook): self._is_initialized = True - # Request running the switch_ops and the global_step_incr_op + # Request running the switch_ops and the incr_ops logging.info("Partial fit starting.") return session_run_hook.SessionRunArgs(fetches=self._fetches) @@ -226,6 +243,37 @@ class _SweepHook(session_run_hook.SessionRunHook): logging.info("Partial fit done.") +class _StopAtSweepHook(session_run_hook.SessionRunHook): + """Hook that requests stop at a given sweep.""" + + def __init__(self, last_sweep): + """Initializes a `StopAtSweepHook`. + + This hook requests stop at a given sweep. Relies on the tensor named + COMPLETED_SWEEPS in the default graph. + + Args: + last_sweep: Integer, number of the last sweep to run. + """ + self._last_sweep = last_sweep + + def begin(self): + try: + self._completed_sweeps_var = ops.get_default_graph().get_tensor_by_name( + WALSMatrixFactorization.COMPLETED_SWEEPS+":0") + except KeyError: + raise RuntimeError(WALSMatrixFactorization.COMPLETED_SWEEPS + + " counter should be created to use StopAtSweepHook.") + + def before_run(self, run_context): + return session_run_hook.SessionRunArgs(self._completed_sweeps_var) + + def after_run(self, run_context, run_values): + completed_sweeps = run_values.results + if completed_sweeps >= self._last_sweep: + run_context.request_stop() + + def _wals_factorization_model_function(features, labels, mode, params): """Model function for the WALSFactorization estimator. @@ -246,6 +294,7 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache = ( params["use_gramian_cache_for_training"] and mode == model_fn.ModeKeys.TRAIN) + max_sweeps = params["max_sweeps"] model = factorization_ops.WALSModel( params["num_rows"], params["num_cols"], @@ -282,8 +331,16 @@ def _wals_factorization_model_function(features, labels, mode, params): # update_col_factors_op is_row_sweep_var = variable_scope.variable( - True, "is_row_sweep", + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The row sweep is determined by is_row_sweep_var (controlled by the # sweep_hook) in TRAIN mode, and manually in EVAL mode. is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] @@ -312,7 +369,14 @@ def _wals_factorization_model_function(features, labels, mode, params): row_prep_ops, col_prep_ops, cache_init_ops, + completed_sweeps_var, ) + training_hooks = [sweep_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) + + summary.scalar("loss", loss) + summary.scalar("completed_sweeps", completed_sweeps_var) # Prediction ops (only return predictions in INFER mode) predictions = {} @@ -341,7 +405,7 @@ def _wals_factorization_model_function(features, labels, mode, params): loss=loss, eval_metric_ops={}, train_op=train_op, - training_hooks=[sweep_hook]) + training_hooks=training_hooks) class WALSMatrixFactorization(estimator.Estimator): @@ -417,6 +481,8 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_WEIGHTS = "projection_weights" # Predictions key PROJECTION_RESULT = "projection" + # Name of the completed_sweeps variable + COMPLETED_SWEEPS = "completed_sweeps" def __init__(self, num_rows, @@ -432,6 +498,7 @@ class WALSMatrixFactorization(estimator.Estimator): col_weights=1, use_factors_weights_cache_for_training=True, use_gramian_cache_for_training=True, + max_sweeps=None, model_dir=None, config=None): """Creates a model for matrix factorization using the WALS method. @@ -471,6 +538,11 @@ class WALSMatrixFactorization(estimator.Estimator): use_gramian_cache_for_training: Boolean, whether the Gramians will be cached on the workers before the updates start, during training. Defaults to True. Note that caching is disabled during prediction. + max_sweeps: integer, optional. Specifies the number of sweeps for which + to train the model, where a sweep is defined as a full update of all the + row factors (resp. column factors). + If `steps` or `max_steps` is also specified in model.fit(), training + stops when either of the steps condition or sweeps condition is met. model_dir: The directory to save the model results and log files. config: A Configuration object. See Estimator. @@ -495,6 +567,7 @@ class WALSMatrixFactorization(estimator.Estimator): "num_col_shards": num_col_shards, "row_weights": row_weights, "col_weights": col_weights, + "max_sweeps": max_sweeps, "use_factors_weights_cache_for_training": use_factors_weights_cache_for_training, "use_gramian_cache_for_training": use_gramian_cache_for_training diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 25e3ee8bae..b5c1bb1151 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -197,6 +197,10 @@ class WALSMatrixFactorizationTest(test.TestCase): def use_cache(self): return False + @property + def max_sweeps(self): + return None + def setUp(self): self._num_rows = 5 self._num_cols = 7 @@ -245,6 +249,7 @@ class WALSMatrixFactorizationTest(test.TestCase): num_col_shards=self._num_col_shards, row_weights=self._row_weights, col_weights=self._col_weights, + max_sweeps=self.max_sweeps, use_factors_weights_cache_for_training=self.use_cache, use_gramian_cache_for_training=self.use_cache) @@ -356,6 +361,19 @@ class WALSMatrixFactorizationTest(test.TestCase): loss = {}.""".format(loss, true_loss)) +class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest): + + @property + def max_sweeps(self): + return 2 + + # We set the column steps to None so that we rely only on max_sweeps to stop + # training. + @property + def col_steps(self): + return None + + class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest): @property @@ -421,6 +439,7 @@ class SweepHookTest(test.TestCase): with self.test_session() as sess: is_row_sweep_var = variables.Variable(True) + completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, self._train_op, @@ -430,7 +449,8 @@ class SweepHookTest(test.TestCase): self._input_col_indices_ph, self._row_prep_ops, self._col_prep_ops, - self._init_ops) + self._init_ops, + completed_sweeps_var) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) @@ -447,6 +467,8 @@ class SweepHookTest(test.TestCase): mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) self.assertFalse(sess.run(is_row_sweep_var), msg='Row sweep is complete but is_row_sweep is True.') + self.assertTrue(sess.run(completed_sweeps_var) == 1, + msg='Completed sweeps should be equal to 1.') self.assertTrue(sweep_hook._is_sweep_done, msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. @@ -464,6 +486,28 @@ class SweepHookTest(test.TestCase): msg='Col sweep is complete but is_row_sweep is False') self.assertTrue(sweep_hook._is_sweep_done, msg='Sweep is complete but is_sweep_done is False.') + self.assertTrue(sess.run(completed_sweeps_var) == 2, + msg='Completed sweeps should be equal to 2.') + + +class StopAtSweepHookTest(test.TestCase): + + def test_stop(self): + hook = wals_lib._StopAtSweepHook(last_sweep=10) + completed_sweeps = variables.Variable( + 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS) + train_op = state_ops.assign_add(completed_sweeps, 1) + hook.begin() + + with self.test_session() as sess: + sess.run([variables.global_variables_initializer()]) + mon_sess = monitored_session._HookedSession(sess, [hook]) + mon_sess.run(train_op) + # completed_sweeps is 9 after running train_op. + self.assertFalse(mon_sess.should_stop()) + mon_sess.run(train_op) + # completed_sweeps is 10 after running train_op. + self.assertTrue(mon_sess.should_stop()) if __name__ == '__main__': -- cgit v1.2.3