diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-17 21:35:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-17 21:41:13 -0800 |
commit | 091291b70b567a37d33bf03b71bee9715e7a80bb (patch) | |
tree | 0fb136a802c69bb7d0506493063a7139af14ca5b /tensorflow/contrib/factorization | |
parent | 9fd424e4871e5a60b6e1985d70c31960a2df80d8 (diff) |
In WALSMatrixFactorization: moves the op that updates the global_step to a session run hook.
PiperOrigin-RevId: 176200549
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r-- | tensorflow/contrib/factorization/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/wals.py | 34 |
2 files changed, 20 insertions, 15 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 29a0a4221a..fe86a20ab1 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -270,7 +270,6 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", # b/69374301 ], ) diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index b2f22eb2fc..2bde3e0dd7 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -77,6 +77,7 @@ class _SweepHook(session_run_hook.SessionRunHook): logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: + logging.info("SweepHook starting the next sweep.") sess.run(self._switch_op) is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: @@ -91,6 +92,22 @@ class _SweepHook(session_run_hook.SessionRunHook): fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) +class _IncrementGlobalStepHook(session_run_hook.SessionRunHook): + """Hook that increments the global step.""" + + def __init__(self): + global_step = training_util.get_global_step() + if global_step: + self._global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + self._global_step_incr_op = None + + def before_run(self, run_context): + if self._global_step_incr_op: + run_context.session.run(self._global_step_incr_op) + + class _StopAtSweepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a given sweep.""" @@ -210,14 +227,6 @@ def _wals_factorization_model_function(features, labels, mode, params): summary.scalar("root_weighted_squared_error", rwse_var) summary.scalar("completed_sweeps", completed_sweeps_var) - # Increments global step. - global_step = training_util.get_global_step() - if global_step: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - def create_axis_ops(sp_input, num_items, update_fn, axis_name): """Creates book-keeping and training ops for a given axis. @@ -246,9 +255,6 @@ def _wals_factorization_model_function(features, labels, mode, params): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="processed_" + axis_name) - reset_processed_items_op = state_ops.assign( - processed_items, processed_items_init, - name="reset_processed_" + axis_name) _, update_op, loss, reg, sum_weights = update_fn(sp_input) input_indices = sp_input.indices[:, 0] with ops.control_dependencies([ @@ -264,13 +270,12 @@ def _wals_factorization_model_function(features, labels, mode, params): with ops.control_dependencies([update_processed_items]): is_sweep_done = math_ops.reduce_all(processed_items) axis_train_op = control_flow_ops.group( - global_step_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done), state_ops.assign_add( completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32)), name="{}_sweep_train_op".format(axis_name)) - return reset_processed_items_op, axis_train_op + return processed_items.initializer, axis_train_op reset_processed_rows_op, row_train_op = create_axis_ops( input_rows, @@ -296,7 +301,8 @@ def _wals_factorization_model_function(features, labels, mode, params): sweep_hook = _SweepHook( is_row_sweep_var, is_sweep_done_var, init_op, row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) - training_hooks = [sweep_hook] + global_step_hook = _IncrementGlobalStepHook() + training_hooks = [sweep_hook, global_step_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) |