aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-17 21:35:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-17 21:41:13 -0800
commit091291b70b567a37d33bf03b71bee9715e7a80bb (patch)
tree0fb136a802c69bb7d0506493063a7139af14ca5b /tensorflow/contrib/factorization
parent9fd424e4871e5a60b6e1985d70c31960a2df80d8 (diff)
In WALSMatrixFactorization: moves the op that updates the global_step to a session run hook.
PiperOrigin-RevId: 176200549
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py34
2 files changed, 20 insertions, 15 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 29a0a4221a..fe86a20ab1 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -270,7 +270,6 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
- "notsan", # b/69374301
],
)
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index b2f22eb2fc..2bde3e0dd7 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -77,6 +77,7 @@ class _SweepHook(session_run_hook.SessionRunHook):
logging.info("SweepHook running init op.")
sess.run(self._init_op)
if is_sweep_done:
+ logging.info("SweepHook starting the next sweep.")
sess.run(self._switch_op)
is_row_sweep = sess.run(self._is_row_sweep_var)
if is_sweep_done or not self._is_initialized:
@@ -91,6 +92,22 @@ class _SweepHook(session_run_hook.SessionRunHook):
fetches=[self._row_train_op if is_row_sweep else self._col_train_op])
+class _IncrementGlobalStepHook(session_run_hook.SessionRunHook):
+ """Hook that increments the global step."""
+
+ def __init__(self):
+ global_step = training_util.get_global_step()
+ if global_step:
+ self._global_step_incr_op = state_ops.assign_add(
+ global_step, 1, name="global_step_incr").op
+ else:
+ self._global_step_incr_op = None
+
+ def before_run(self, run_context):
+ if self._global_step_incr_op:
+ run_context.session.run(self._global_step_incr_op)
+
+
class _StopAtSweepHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a given sweep."""
@@ -210,14 +227,6 @@ def _wals_factorization_model_function(features, labels, mode, params):
summary.scalar("root_weighted_squared_error", rwse_var)
summary.scalar("completed_sweeps", completed_sweeps_var)
- # Increments global step.
- global_step = training_util.get_global_step()
- if global_step:
- global_step_incr_op = state_ops.assign_add(
- global_step, 1, name="global_step_incr").op
- else:
- global_step_incr_op = control_flow_ops.no_op()
-
def create_axis_ops(sp_input, num_items, update_fn, axis_name):
"""Creates book-keeping and training ops for a given axis.
@@ -246,9 +255,6 @@ def _wals_factorization_model_function(features, labels, mode, params):
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="processed_" + axis_name)
- reset_processed_items_op = state_ops.assign(
- processed_items, processed_items_init,
- name="reset_processed_" + axis_name)
_, update_op, loss, reg, sum_weights = update_fn(sp_input)
input_indices = sp_input.indices[:, 0]
with ops.control_dependencies([
@@ -264,13 +270,12 @@ def _wals_factorization_model_function(features, labels, mode, params):
with ops.control_dependencies([update_processed_items]):
is_sweep_done = math_ops.reduce_all(processed_items)
axis_train_op = control_flow_ops.group(
- global_step_incr_op,
state_ops.assign(is_sweep_done_var, is_sweep_done),
state_ops.assign_add(
completed_sweeps_var,
math_ops.cast(is_sweep_done, dtypes.int32)),
name="{}_sweep_train_op".format(axis_name))
- return reset_processed_items_op, axis_train_op
+ return processed_items.initializer, axis_train_op
reset_processed_rows_op, row_train_op = create_axis_ops(
input_rows,
@@ -296,7 +301,8 @@ def _wals_factorization_model_function(features, labels, mode, params):
sweep_hook = _SweepHook(
is_row_sweep_var, is_sweep_done_var, init_op,
row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
- training_hooks = [sweep_hook]
+ global_step_hook = _IncrementGlobalStepHook()
+ training_hooks = [sweep_hook, global_step_hook]
if max_sweeps is not None:
training_hooks.append(_StopAtSweepHook(max_sweeps))