diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-14 14:44:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-14 14:57:41 -0800 |
commit | a5a192865e4c1732b414d6a503d07775f7163a5c (patch) | |
tree | 16179cd6b477eb09882b5fa61132427eb5280f5f /tensorflow/contrib/factorization | |
parent | b8054c19b7d72cfb7eb07552a8d0385ffd8810d7 (diff) |
Refactors the WALS estimator so that part of the control flow logic happens in the SweepHook. This fixes a bug that causes both input batches (rows and columns) to be fetched during any given sweep.
PiperOrigin-RevId: 175738242
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals.py | 449 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals_test.py | 112 |
2 files changed, 274 insertions, 287 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3976395d78..b2f22eb2fc 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.contrib.framework.python.ops import variables as framework_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.framework import dtypes @@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, - input_row_indices, input_col_indices, row_prep_ops, - col_prep_ops, init_op, completed_sweeps_var): + def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, + switch_op): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_ops: A list of ops. The ops created by this hook will have - control dependencies on `train_ops`. - num_rows: int, the total number of rows to be processed. - num_cols: int, the total number of columns to be processed. - input_row_indices: A Tensor of type int64. The indices of the input rows - that are processed during the current sweep. All elements of - `input_row_indices` must be in [0, num_rows). - input_col_indices: A Tensor of type int64. The indices of the input - columns that are processed during the current sweep. All elements of - `input_col_indices` must be in [0, num_cols). - row_prep_ops: list of ops, to be run before the beginning of each row - sweep, in the given order. - col_prep_ops: list of ops, to be run before the beginning of each column - sweep, in the given order. + is_sweep_done_var: A Boolean tf.Variable, determines whether we are + starting a new sweep (this is used to determine when to run the prep ops + below). init_op: op to be run once before training. This is typically a local initialization op (such as cache initialization). - completed_sweeps_var: An integer tf.Variable, indicates the number of - completed sweeps. It is updated by the hook. + row_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each row sweep (and during initialization), in the given order. + col_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each column sweep (and during initialization), in the given order. + row_train_op: A TensorFlow op to be run during row sweeps. + col_train_op: A TensorFlow op to be run during column sweeps. + switch_op: A TensorFlow op to be run before each sweep. """ - self._num_rows = num_rows - self._num_cols = num_cols + self._is_row_sweep_var = is_row_sweep_var + self._is_sweep_done_var = is_sweep_done_var + self._init_op = init_op self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._init_op = init_op - self._is_row_sweep_var = is_row_sweep_var - self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the init_ops have been run. + self._row_train_op = row_train_op + self._col_train_op = col_train_op + self._switch_op = switch_op + # Boolean variable that determines whether the init_op has been run. self._is_initialized = False - # Ops to run jointly with train_ops, responsible for updating - # `is_row_sweep_var` and incrementing the `global_step` and - # `completed_sweeps` counters. - self._update_op, self._is_sweep_done_var, self._switch_op = ( - self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) - - def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): - """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - - Creates two boolean tensors `processed_rows` and `processed_cols`, which - keep track of which rows/cols have been processed during the current sweep. - Returns ops that should be run after each row / col update. - - When `self._is_row_sweep_var` is True, it sets - processed_rows[input_row_indices] to True. - - When `self._is_row_sweep_var` is False, it sets - processed_cols[input_col_indices] to True. - - Args: - input_row_indices: A Tensor. The indices of the input rows that are - processed during the current sweep. - input_col_indices: A Tensor. The indices of the input columns that - are processed during the current sweep. - train_ops: A list of ops. The ops created by this function have control - dependencies on `train_ops`. - - Returns: - A tuple consisting of: - update_op: An op to be run jointly with training. It updates the state - and increments counters (global step and completed sweeps). - is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is - done, i.e. all rows (during a row sweep) or all columns (during a - column sweep) have been processed. - switch_op: An op to be run in `self.before_run` when the sweep is done. - """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) - with ops.colocate_with(processed_rows_init): - processed_rows = variable_scope.variable( - processed_rows_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_rows") - processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) - with ops.colocate_with(processed_cols_init): - processed_cols = variable_scope.variable( - processed_cols_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_cols") - switch_ops = control_flow_ops.group( - state_ops.assign( - self._is_row_sweep_var, - math_ops.logical_not(self._is_row_sweep_var)), - state_ops.assign(processed_rows, processed_rows_init), - state_ops.assign(processed_cols, processed_cols_init)) - is_sweep_done_var = variable_scope.variable( - False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="is_sweep_done") - - # After running the `train_ops`, updates `processed_rows` or - # `processed_cols` tensors, depending on whether this is a row or col sweep. - with ops.control_dependencies(train_ops): - with ops.colocate_with(processed_rows): - update_processed_rows = state_ops.scatter_update( - processed_rows, - input_row_indices, - math_ops.logical_and( - self._is_row_sweep_var, - array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) - with ops.colocate_with(processed_cols): - update_processed_cols = state_ops.scatter_update( - processed_cols, - input_col_indices, - math_ops.logical_and( - math_ops.logical_not(self._is_row_sweep_var), - array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) - update_processed_op = control_flow_ops.group( - update_processed_rows, update_processed_cols) - - with ops.control_dependencies([update_processed_op]): - is_sweep_done = math_ops.logical_or( - math_ops.reduce_all(processed_rows), - math_ops.reduce_all(processed_cols)) - # Increments global step. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - # Increments completed sweeps. - completed_sweeps_incr_op = state_ops.assign_add( - self._completed_sweeps_var, - math_ops.cast(is_sweep_done, dtypes.int32), - use_locking=True).op - update_ops = control_flow_ops.group( - global_step_incr_op, - completed_sweeps_incr_op, - state_ops.assign(is_sweep_done_var, is_sweep_done)) - - return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init op.") + logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: sess.run(self._switch_op) + is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: - logging.info("SweepHook running sweep prep ops.") - row_sweep = sess.run(self._is_row_sweep_var) - prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + logging.info("SweepHook running prep ops for the {} sweep.".format( + "row" if is_row_sweep else "col")) + prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) - self._is_initialized = True - - # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") - return session_run_hook.SessionRunArgs(fetches=[self._update_op]) - - def after_run(self, run_context, run_values): - logging.info("Fit step done.") + return session_run_hook.SessionRunArgs( + fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params): Returns: A ModelFnOps object. + + Raises: + ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] @@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on - # the value of row_sweep, which is maintained using a session hook + # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] - input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) - input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) - - # Train ops, controlled using the SweepHook - # We need to run the following ops: - # Before a row sweep: - # row_update_prep_gramian_op - # initialize_row_update_op - # During a row sweep: - # update_row_factors_op - # Before a col sweep: - # col_update_prep_gramian_op - # initialize_col_update_op - # During a col sweep: - # update_col_factors_op - - is_row_sweep_var = variable_scope.variable( - True, - trainable=False, - name="is_row_sweep", - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - completed_sweeps_var = variable_scope.variable( - 0, - trainable=False, - name=WALSMatrixFactorization.COMPLETED_SWEEPS, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - - # The row sweep is determined by is_row_sweep_var (controlled by the - # sweep_hook) in TRAIN mode, and manually in EVAL mode. - is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] - if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) - - def update_row_factors(): - return model.update_row_factors(sp_input=input_rows, transpose_input=False) - - def update_col_factors(): - return model.update_col_factors(sp_input=input_cols, transpose_input=True) - - (_, train_op, - unregularized_loss, regularization, sum_weights) = control_flow_ops.cond( - is_row_sweep, update_row_factors, update_col_factors) - loss = unregularized_loss + regularization - root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights) - - row_prep_ops = [ - model.row_update_prep_gramian_op, model.initialize_row_update_op - ] - col_prep_ops = [ - model.col_update_prep_gramian_op, model.initialize_col_update_op - ] - init_ops = [model.worker_init] - - sweep_hook = _SweepHook( - is_row_sweep_var, - [train_op, loss], - params["num_rows"], - params["num_cols"], - input_row_indices, - input_col_indices, - row_prep_ops, - col_prep_ops, - init_ops, - completed_sweeps_var) - training_hooks = [sweep_hook] - if max_sweeps is not None: - training_hooks.append(_StopAtSweepHook(max_sweeps)) - - # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) - summary.scalar("loss", loss) # the estimated total training loss - summary.scalar("root_weighted_squared_error", root_weighted_squared_error) - summary.scalar("completed_sweeps", completed_sweeps_var) - - # Prediction ops (only return predictions in INFER mode) - predictions = {} - if mode == model_fn.ModeKeys.INFER: - project_row = features[WALSMatrixFactorization.PROJECT_ROW] + + # TRAIN mode: + if mode == model_fn.ModeKeys.TRAIN: + # Training consists of the folowing ops (controlled using a SweepHook). + # Before a row sweep: + # row_update_prep_gramian_op + # initialize_row_update_op + # During a row sweep: + # update_row_factors_op + # Before a col sweep: + # col_update_prep_gramian_op + # initialize_col_update_op + # During a col sweep: + # update_col_factors_op + + is_row_sweep_var = variable_scope.variable( + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_sweep_done_var = variable_scope.variable( + False, + trainable=False, + name="is_sweep_done", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + loss_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.LOSS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The root weighted squared error = + # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + rwse_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.RWSE, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + + summary.scalar("loss", loss_var) + summary.scalar("root_weighted_squared_error", rwse_var) + summary.scalar("completed_sweeps", completed_sweeps_var) + + # Increments global step. + global_step = training_util.get_global_step() + if global_step: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + + def create_axis_ops(sp_input, num_items, update_fn, axis_name): + """Creates book-keeping and training ops for a given axis. + + Args: + sp_input: A SparseTensor corresponding to the row or column batch. + num_items: An integer, the total number of items of this axis. + update_fn: A function that takes one argument (`sp_input`), and that + returns a tuple of + * new_factors: A flot Tensor of the factor values after update. + * update_op: a TensorFlow op which updates the factors. + * loss: A float Tensor, the unregularized loss. + * reg_loss: A float Tensor, the regularization loss. + * sum_weights: A float Tensor, the sum of factor weights. + axis_name: A string that specifies the name of the axis. + + Returns: + A tuple consisting of: + * reset_processed_items_op: A TensorFlow op, to be run before the + beginning of any sweep. It marks all items as not-processed. + * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. + """ + processed_items_init = array_ops.fill(dims=[num_items], value=False) + with ops.colocate_with(processed_items_init): + processed_items = variable_scope.variable( + processed_items_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="processed_" + axis_name) + reset_processed_items_op = state_ops.assign( + processed_items, processed_items_init, + name="reset_processed_" + axis_name) + _, update_op, loss, reg, sum_weights = update_fn(sp_input) + input_indices = sp_input.indices[:, 0] + with ops.control_dependencies([ + update_op, + state_ops.assign(loss_var, loss + reg), + state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): + with ops.colocate_with(processed_items): + update_processed_items = state_ops.scatter_update( + processed_items, + input_indices, + array_ops.ones_like(input_indices, dtype=dtypes.bool), + name="update_processed_{}_indices".format(axis_name)) + with ops.control_dependencies([update_processed_items]): + is_sweep_done = math_ops.reduce_all(processed_items) + axis_train_op = control_flow_ops.group( + global_step_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done), + state_ops.assign_add( + completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32)), + name="{}_sweep_train_op".format(axis_name)) + return reset_processed_items_op, axis_train_op + + reset_processed_rows_op, row_train_op = create_axis_ops( + input_rows, + params["num_rows"], + lambda x: model.update_row_factors(sp_input=x, transpose_input=False), + "rows") + reset_processed_cols_op, col_train_op = create_axis_ops( + input_cols, + params["num_cols"], + lambda x: model.update_col_factors(sp_input=x, transpose_input=True), + "cols") + switch_op = control_flow_ops.group( + state_ops.assign( + is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), + reset_processed_rows_op, + reset_processed_cols_op, + name="sweep_switch_op") + row_prep_ops = [ + model.row_update_prep_gramian_op, model.initialize_row_update_op] + col_prep_ops = [ + model.col_update_prep_gramian_op, model.initialize_col_update_op] + init_op = model.worker_init + sweep_hook = _SweepHook( + is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) + training_hooks = [sweep_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) + + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.TRAIN, + predictions={}, + loss=loss_var, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=training_hooks) + + # INFER mode + elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) @@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params): projection_weights=projection_weights, transpose_input=True) - predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( - control_flow_ops.cond(project_row, get_row_projection, - get_col_projection)) + predictions = { + WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_projection, + get_col_projection) + } - return model_fn.ModelFnOps( - mode=mode, - predictions=predictions, - loss=loss, - eval_metric_ops={}, - train_op=train_op, - training_hooks=training_hooks) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.INFER, + predictions=predictions, + loss=None, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + # EVAL mode + elif mode == model_fn.ModeKeys.EVAL: + def get_row_loss(): + _, _, loss, reg, _ = model.update_row_factors( + sp_input=input_rows, transpose_input=False) + return loss + reg + def get_col_loss(): + _, _, loss, reg, _ = model.update_col_factors( + sp_input=input_cols, transpose_input=True) + return loss + reg + loss = control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_loss, + get_col_loss) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.EVAL, + predictions={}, + loss=loss, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + else: + raise ValueError("mode=%s is not recognized." % str(mode)) class WALSMatrixFactorization(estimator.Estimator): @@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_RESULT = "projection" # Name of the completed_sweeps variable COMPLETED_SWEEPS = "completed_sweeps" + # Name of the loss variable + LOSS = "WALS_loss" + # Name of the Root Weighted Squared Error variable + RWSE = "WALS_RWSE" def __init__(self, num_rows, diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 8bd72b7025..36b483c6d7 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): - def setUp(self): - self._num_rows = 5 - self._num_cols = 7 - self._train_op = control_flow_ops.no_op() - self._row_prep_done = variables.Variable(False) - self._col_prep_done = variables.Variable(False) - self._init_done = variables.Variable(False) - self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] - self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] - self._init_ops = [state_ops.assign(self._init_done, True)] - self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) - self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) - def test_sweeps(self): - def ind_feed(row_indices, col_indices): - return { - self._input_row_indices_ph: row_indices, - self._input_col_indices_ph: col_indices - } + is_row_sweep_var = variables.Variable(True) + is_sweep_done_var = variables.Variable(False) + init_done = variables.Variable(False) + row_prep_done = variables.Variable(False) + col_prep_done = variables.Variable(False) + row_train_done = variables.Variable(False) + col_train_done = variables.Variable(False) + + init_op = state_ops.assign(init_done, True) + row_prep_op = state_ops.assign(row_prep_done, True) + col_prep_op = state_ops.assign(col_prep_done, True) + row_train_op = state_ops.assign(row_train_done, True) + col_train_op = state_ops.assign(col_train_done, True) + train_op = control_flow_ops.no_op() + switch_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, False), + state_ops.assign(is_row_sweep_var, + math_ops.logical_not(is_row_sweep_var))) + mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: - is_row_sweep_var = variables.Variable(True) - completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - [self._train_op], - self._num_rows, - self._num_cols, - self._input_row_indices_ph, - self._input_col_indices_ph, - self._row_prep_ops, - self._col_prep_ops, - self._init_ops, - completed_sweeps_var) + is_sweep_done_var, + init_op, + [row_prep_op], + [col_prep_op], + row_train_op, + col_train_op, + switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) - # Init ops should run before the first run. Row sweep not completed. - mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) - self.assertTrue(sess.run(self._init_done), - msg='init ops not run by the sweep_hook') - self.assertTrue(sess.run(self._row_prep_done), - msg='row_prep not run by the sweep_hook') - self.assertTrue(sess.run(is_row_sweep_var), - msg='Row sweep is not complete but is_row_sweep is ' - 'False.') - # Row sweep completed. - mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertTrue(sess.run(completed_sweeps_var) == 1, - msg='Completed sweeps should be equal to 1.') - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - # Col init ops should run. Col sweep not completed. - mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) - self.assertTrue(sess.run(self._col_prep_done), - msg='col_prep not run by the sweep_hook') - self.assertFalse(sess.run(is_row_sweep_var), - msg='Col sweep is not complete but is_row_sweep is ' - 'True.') - self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is not complete but is_sweep_done is True.') - # Col sweep completed. - mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - self.assertTrue(sess.run(completed_sweeps_var) == 2, - msg='Completed sweeps should be equal to 2.') + # Row sweep. + mon_sess.run(train_op) + self.assertTrue(sess.run(init_done), + msg='init op not run by the Sweephook') + self.assertTrue(sess.run(row_prep_done), + msg='row_prep_op not run by the SweepHook') + self.assertTrue(sess.run(row_train_done), + msg='row_train_op not run by the SweepHook') + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep_var is False.') + # Col sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue(sess.run(col_prep_done), + msg='col_prep_op not run by the SweepHook') + self.assertTrue(sess.run(col_train_done), + msg='col_train_op not run by the SweepHook') + self.assertFalse( + sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep_var is True.') + # Row sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep_var is False.') class StopAtSweepHookTest(test.TestCase): |