diff options
-rw-r--r-- | tensorflow/contrib/factorization/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals.py | 250 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals_test.py | 14 |
3 files changed, 111 insertions, 154 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index c741815042..44095bd00a 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -246,7 +246,6 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", ], ) diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3e3ee5fa57..3976395d78 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -38,31 +37,30 @@ from tensorflow.python.training import session_run_hook class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_op, num_rows, num_cols, - processed_row_indices, processed_col_indices, row_prep_ops, - col_prep_ops, cache_init_ops, completed_sweeps_var): + def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, + input_row_indices, input_col_indices, row_prep_ops, + col_prep_ops, init_op, completed_sweeps_var): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_op: An op. All the ops created by the hook will have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this hook will have + control dependencies on `train_ops`. num_rows: int, the total number of rows to be processed. num_cols: int, the total number of columns to be processed. - processed_row_indices: A Tensor of type int64. The indices of the input - rows that are processed during the current sweep. All elements of - processed_row_indices must be in [0, num_rows). - processed_col_indices: A Tensor of type int64. The indices of the input + input_row_indices: A Tensor of type int64. The indices of the input rows + that are processed during the current sweep. All elements of + `input_row_indices` must be in [0, num_rows). + input_col_indices: A Tensor of type int64. The indices of the input columns that are processed during the current sweep. All elements of - processed_col_indices must be in [0, num_cols). + `input_col_indices` must be in [0, num_cols). row_prep_ops: list of ops, to be run before the beginning of each row sweep, in the given order. col_prep_ops: list of ops, to be run before the beginning of each column sweep, in the given order. - cache_init_ops: list of ops, to be run once before training, in the given - order. These are typically local initialization ops (such as cache - initialization). + init_op: op to be run once before training. This is typically a local + initialization op (such as cache initialization). completed_sweeps_var: An integer tf.Variable, indicates the number of completed sweeps. It is updated by the hook. """ @@ -70,55 +68,45 @@ class _SweepHook(session_run_hook.SessionRunHook): self._num_cols = num_cols self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._cache_init_ops = cache_init_ops + self._init_op = init_op self._is_row_sweep_var = is_row_sweep_var self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the cache_init_ops have been run. + # Boolean variable that determines whether the init_ops have been run. self._is_initialized = False - # Boolean variable that is set to True when a sweep is completed. - # Used to run the prep_ops at the beginning of a sweep, in before_run(). - self._is_sweep_done = False - # Ops to run jointly with train_op, responsible for updating - # _is_row_sweep_var and incrementing the global_step and completed_sweeps - # counters. They have control_dependencies on train_op. - self._fetches = self._create_switch_ops(processed_row_indices, - processed_col_indices, train_op) - - def _create_switch_ops(self, processed_row_indices, processed_col_indices, - train_op): + # Ops to run jointly with train_ops, responsible for updating + # `is_row_sweep_var` and incrementing the `global_step` and + # `completed_sweeps` counters. + self._update_op, self._is_sweep_done_var, self._switch_op = ( + self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) + + def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - Creates two boolean tensors processed_rows and processed_cols, which keep - track of which rows/cols have been processed during the current sweep. + Creates two boolean tensors `processed_rows` and `processed_cols`, which + keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - - When is_row_sweep_var is True, it sets - processed_rows[processed_row_indices] to True. - - When is_row_sweep_var is False, it sets - processed_cols[processed_col_indices] to True . - When all rows or all cols have been processed, negates is_row_sweep_var, - increments the completed_sweeps counter, and resets processed_rows and - processed_cols to False. - All of the ops created by this function have control_dependencies on - train_op. + - When `self._is_row_sweep_var` is True, it sets + processed_rows[input_row_indices] to True. + - When `self._is_row_sweep_var` is False, it sets + processed_cols[input_col_indices] to True. Args: - processed_row_indices: A Tensor. The indices of the input rows that are + input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. - processed_col_indices: A Tensor. The indices of the input columns that + input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. - train_op: An op. All the ops created by this function have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this function have control + dependencies on `train_ops`. + Returns: - A list consisting of: - is_sweep_done: A Boolean tensor, determines whether the sweep is done, - i.e. all rows (during a row sweep) or all columns (during a column - sweep) have been processed. - switch_ops: An op that updates is_row_sweep_var when is_sweep_done is - True. Has control_dependencies on train_op. - incr_ops: An op that increments the global_step and completed_sweeps - counters. Has control_dependenciens on switch_ops. + A tuple consisting of: + update_op: An op to be run jointly with training. It updates the state + and increments counters (global step and completed sweeps). + is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is + done, i.e. all rows (during a row sweep) or all columns (during a + column sweep) have been processed. + switch_op: An op to be run in `self.before_run` when the sweep is done. """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( @@ -133,97 +121,72 @@ class _SweepHook(session_run_hook.SessionRunHook): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") - # After running the train_op, update processed_rows or processed_cols - # tensors, depending on whether we are currently doing a row or a col sweep - with ops.control_dependencies([train_op]): - - def get_row_update_op(): - with ops.colocate_with(processed_rows): - return state_ops.scatter_update(processed_rows, processed_row_indices, - array_ops.ones_like( - processed_row_indices, - dtype=dtypes.bool)) - - def get_col_update_op(): - with ops.colocate_with(processed_cols): - return state_ops.scatter_update(processed_cols, processed_col_indices, - array_ops.ones_like( - processed_col_indices, - dtype=dtypes.bool)) - - update_processed_op = control_flow_ops.cond( - self._is_row_sweep_var, get_row_update_op, get_col_update_op) - - # After update_processed_op, check whether we have completed a sweep. - # If this is the case, flip the is_row_sweep_var and reset processed_rows - # and processed_cols tensors. - with ops.control_dependencies([update_processed_op]): - - def get_switch_op(): - return state_ops.assign( - self._is_row_sweep_var, - gen_math_ops.logical_not(self._is_row_sweep_var)).op - - def get_reset_op(): - return control_flow_ops.group( - state_ops.assign(processed_rows, processed_rows_init).op, - state_ops.assign(processed_cols, processed_cols_init).op) - - is_sweep_done = control_flow_ops.cond( + switch_ops = control_flow_ops.group( + state_ops.assign( self._is_row_sweep_var, - lambda: math_ops.reduce_all(processed_rows), - lambda: math_ops.reduce_all(processed_cols), - name="sweep_hook_is_sweep_done") - switch_op = control_flow_ops.cond( - is_sweep_done, - get_switch_op, - control_flow_ops.no_op, - name="sweep_hook_switch_op") - reset_op = control_flow_ops.cond( - is_sweep_done, - get_reset_op, - control_flow_ops.no_op, - name="sweep_hook_reset_op") - switch_ops = control_flow_ops.group( - switch_op, reset_op, name="sweep_hook_switch_ops") - - with ops.control_dependencies([switch_ops]): - # Op to increment the completed_sweeps counter. - completed_sweeps_incr_op = control_flow_ops.cond( - is_sweep_done, - lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op, - control_flow_ops.no_op, - name="completed_sweeps_incr") - - # Op to increment the global_step counter. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op( - name="global_step_incr") - - incr_ops = control_flow_ops.group( - completed_sweeps_incr_op, - global_step_incr_op, - name="counter_incr_ops") - - return [is_sweep_done, switch_ops, incr_ops] + math_ops.logical_not(self._is_row_sweep_var)), + state_ops.assign(processed_rows, processed_rows_init), + state_ops.assign(processed_cols, processed_cols_init)) + is_sweep_done_var = variable_scope.variable( + False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="is_sweep_done") + + # After running the `train_ops`, updates `processed_rows` or + # `processed_cols` tensors, depending on whether this is a row or col sweep. + with ops.control_dependencies(train_ops): + with ops.colocate_with(processed_rows): + update_processed_rows = state_ops.scatter_update( + processed_rows, + input_row_indices, + math_ops.logical_and( + self._is_row_sweep_var, + array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) + with ops.colocate_with(processed_cols): + update_processed_cols = state_ops.scatter_update( + processed_cols, + input_col_indices, + math_ops.logical_and( + math_ops.logical_not(self._is_row_sweep_var), + array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) + update_processed_op = control_flow_ops.group( + update_processed_rows, update_processed_cols) - def begin(self): - pass + with ops.control_dependencies([update_processed_op]): + is_sweep_done = math_ops.logical_or( + math_ops.reduce_all(processed_rows), + math_ops.reduce_all(processed_cols)) + # Increments global step. + global_step = framework_variables.get_global_step() + if global_step is not None: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + # Increments completed sweeps. + completed_sweeps_incr_op = state_ops.assign_add( + self._completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32), + use_locking=True).op + update_ops = control_flow_ops.group( + global_step_incr_op, + completed_sweeps_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done)) + + return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Run the appropriate cache_init and prep ops + # Runs the appropriate init ops and prep ops. sess = run_context.session + is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init ops.") - for init_op in self._cache_init_ops: - sess.run(init_op) - - if self._is_sweep_done or not self._is_initialized: + logging.info("SweepHook running cache init op.") + sess.run(self._init_op) + if is_sweep_done: + sess.run(self._switch_op) + if is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops @@ -232,13 +195,12 @@ class _SweepHook(session_run_hook.SessionRunHook): self._is_initialized = True - # Request running the switch_ops and the incr_ops - logging.info("Partial fit starting.") - return session_run_hook.SessionRunArgs(fetches=self._fetches) + # Requests running `self._update_op` jointly with the training op. + logging.info("Next fit step starting.") + return session_run_hook.SessionRunArgs(fetches=[self._update_op]) def after_run(self, run_context, run_values): - self._is_sweep_done = run_values.results[0] - logging.info("Partial fit done.") + logging.info("Fit step done.") class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -360,19 +322,19 @@ def _wals_factorization_model_function(features, labels, mode, params): col_prep_ops = [ model.col_update_prep_gramian_op, model.initialize_col_update_op ] - cache_init_ops = [model.worker_init] + init_ops = [model.worker_init] sweep_hook = _SweepHook( is_row_sweep_var, - train_op, + [train_op, loss], params["num_rows"], params["num_cols"], input_row_indices, input_col_indices, row_prep_ops, col_prep_ops, - cache_init_ops, - completed_sweeps_var,) + init_ops, + completed_sweeps_var) training_hooks = [sweep_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index b5c1bb1151..8bd72b7025 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase): self.assertNear( loss, true_loss, err=.001, - msg="""After row update, eval loss = {}, does not match the true + msg="""After col update, eval loss = {}, does not match the true loss = {}.""".format(loss, true_loss)) @@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase): completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - self._train_op, + [self._train_op], self._num_rows, self._num_cols, self._input_row_indices_ph, @@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase): 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertFalse(sess.run(is_row_sweep_var), - msg='Row sweep is complete but is_row_sweep is True.') self.assertTrue(sess.run(completed_sweeps_var) == 1, msg='Completed sweeps should be equal to 1.') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) @@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase): self.assertFalse(sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') - self.assertFalse(sweep_hook._is_sweep_done, + self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(is_row_sweep_var), - msg='Col sweep is complete but is_row_sweep is False') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') self.assertTrue(sess.run(completed_sweeps_var) == 2, msg='Completed sweeps should be equal to 2.') |