aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-09 17:31:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 17:38:43 -0700
commit07d78ddeafe41bc0363ac92efd7ca8ea60478989 (patch)
tree31b41c3b2acc121e570a948e03967a8f94a528d9
parent485cb179ea84c8de26263628510f930d07a98c4a (diff)
Removes the use of tf.cond in the SweepHook used in the WALSMatrixFactorization estimator, to prevent a rare but possible race condition.
PiperOrigin-RevId: 171612114
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py250
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py14
3 files changed, 111 insertions, 154 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index c741815042..44095bd00a 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -246,7 +246,6 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
- "notsan",
],
)
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index 3e3ee5fa57..3976395d78 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@@ -38,31 +37,30 @@ from tensorflow.python.training import session_run_hook
class _SweepHook(session_run_hook.SessionRunHook):
"""Keeps track of row/col sweeps, and runs prep ops before each sweep."""
- def __init__(self, is_row_sweep_var, train_op, num_rows, num_cols,
- processed_row_indices, processed_col_indices, row_prep_ops,
- col_prep_ops, cache_init_ops, completed_sweeps_var):
+ def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols,
+ input_row_indices, input_col_indices, row_prep_ops,
+ col_prep_ops, init_op, completed_sweeps_var):
"""Initializes SweepHook.
Args:
is_row_sweep_var: A Boolean tf.Variable, determines whether we are
currently doing a row or column sweep. It is updated by the hook.
- train_op: An op. All the ops created by the hook will have
- control_dependencies on train_op.
+ train_ops: A list of ops. The ops created by this hook will have
+ control dependencies on `train_ops`.
num_rows: int, the total number of rows to be processed.
num_cols: int, the total number of columns to be processed.
- processed_row_indices: A Tensor of type int64. The indices of the input
- rows that are processed during the current sweep. All elements of
- processed_row_indices must be in [0, num_rows).
- processed_col_indices: A Tensor of type int64. The indices of the input
+ input_row_indices: A Tensor of type int64. The indices of the input rows
+ that are processed during the current sweep. All elements of
+ `input_row_indices` must be in [0, num_rows).
+ input_col_indices: A Tensor of type int64. The indices of the input
columns that are processed during the current sweep. All elements of
- processed_col_indices must be in [0, num_cols).
+ `input_col_indices` must be in [0, num_cols).
row_prep_ops: list of ops, to be run before the beginning of each row
sweep, in the given order.
col_prep_ops: list of ops, to be run before the beginning of each column
sweep, in the given order.
- cache_init_ops: list of ops, to be run once before training, in the given
- order. These are typically local initialization ops (such as cache
- initialization).
+ init_op: op to be run once before training. This is typically a local
+ initialization op (such as cache initialization).
completed_sweeps_var: An integer tf.Variable, indicates the number of
completed sweeps. It is updated by the hook.
"""
@@ -70,55 +68,45 @@ class _SweepHook(session_run_hook.SessionRunHook):
self._num_cols = num_cols
self._row_prep_ops = row_prep_ops
self._col_prep_ops = col_prep_ops
- self._cache_init_ops = cache_init_ops
+ self._init_op = init_op
self._is_row_sweep_var = is_row_sweep_var
self._completed_sweeps_var = completed_sweeps_var
- # Boolean variable that determines whether the cache_init_ops have been run.
+ # Boolean variable that determines whether the init_ops have been run.
self._is_initialized = False
- # Boolean variable that is set to True when a sweep is completed.
- # Used to run the prep_ops at the beginning of a sweep, in before_run().
- self._is_sweep_done = False
- # Ops to run jointly with train_op, responsible for updating
- # _is_row_sweep_var and incrementing the global_step and completed_sweeps
- # counters. They have control_dependencies on train_op.
- self._fetches = self._create_switch_ops(processed_row_indices,
- processed_col_indices, train_op)
-
- def _create_switch_ops(self, processed_row_indices, processed_col_indices,
- train_op):
+ # Ops to run jointly with train_ops, responsible for updating
+ # `is_row_sweep_var` and incrementing the `global_step` and
+ # `completed_sweeps` counters.
+ self._update_op, self._is_sweep_done_var, self._switch_op = (
+ self._create_hook_ops(input_row_indices, input_col_indices, train_ops))
+
+ def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops):
"""Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
- Creates two boolean tensors processed_rows and processed_cols, which keep
- track of which rows/cols have been processed during the current sweep.
+ Creates two boolean tensors `processed_rows` and `processed_cols`, which
+ keep track of which rows/cols have been processed during the current sweep.
Returns ops that should be run after each row / col update.
- - When is_row_sweep_var is True, it sets
- processed_rows[processed_row_indices] to True.
- - When is_row_sweep_var is False, it sets
- processed_cols[processed_col_indices] to True .
- When all rows or all cols have been processed, negates is_row_sweep_var,
- increments the completed_sweeps counter, and resets processed_rows and
- processed_cols to False.
- All of the ops created by this function have control_dependencies on
- train_op.
+ - When `self._is_row_sweep_var` is True, it sets
+ processed_rows[input_row_indices] to True.
+ - When `self._is_row_sweep_var` is False, it sets
+ processed_cols[input_col_indices] to True.
Args:
- processed_row_indices: A Tensor. The indices of the input rows that are
+ input_row_indices: A Tensor. The indices of the input rows that are
processed during the current sweep.
- processed_col_indices: A Tensor. The indices of the input columns that
+ input_col_indices: A Tensor. The indices of the input columns that
are processed during the current sweep.
- train_op: An op. All the ops created by this function have
- control_dependencies on train_op.
+ train_ops: A list of ops. The ops created by this function have control
+ dependencies on `train_ops`.
+
Returns:
- A list consisting of:
- is_sweep_done: A Boolean tensor, determines whether the sweep is done,
- i.e. all rows (during a row sweep) or all columns (during a column
- sweep) have been processed.
- switch_ops: An op that updates is_row_sweep_var when is_sweep_done is
- True. Has control_dependencies on train_op.
- incr_ops: An op that increments the global_step and completed_sweeps
- counters. Has control_dependenciens on switch_ops.
+ A tuple consisting of:
+ update_op: An op to be run jointly with training. It updates the state
+ and increments counters (global step and completed sweeps).
+ is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
+ done, i.e. all rows (during a row sweep) or all columns (during a
+ column sweep) have been processed.
+ switch_op: An op to be run in `self.before_run` when the sweep is done.
"""
-
processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
with ops.colocate_with(processed_rows_init):
processed_rows = variable_scope.variable(
@@ -133,97 +121,72 @@ class _SweepHook(session_run_hook.SessionRunHook):
collections=[ops.GraphKeys.GLOBAL_VARIABLES],
trainable=False,
name="sweep_hook_processed_cols")
- # After running the train_op, update processed_rows or processed_cols
- # tensors, depending on whether we are currently doing a row or a col sweep
- with ops.control_dependencies([train_op]):
-
- def get_row_update_op():
- with ops.colocate_with(processed_rows):
- return state_ops.scatter_update(processed_rows, processed_row_indices,
- array_ops.ones_like(
- processed_row_indices,
- dtype=dtypes.bool))
-
- def get_col_update_op():
- with ops.colocate_with(processed_cols):
- return state_ops.scatter_update(processed_cols, processed_col_indices,
- array_ops.ones_like(
- processed_col_indices,
- dtype=dtypes.bool))
-
- update_processed_op = control_flow_ops.cond(
- self._is_row_sweep_var, get_row_update_op, get_col_update_op)
-
- # After update_processed_op, check whether we have completed a sweep.
- # If this is the case, flip the is_row_sweep_var and reset processed_rows
- # and processed_cols tensors.
- with ops.control_dependencies([update_processed_op]):
-
- def get_switch_op():
- return state_ops.assign(
- self._is_row_sweep_var,
- gen_math_ops.logical_not(self._is_row_sweep_var)).op
-
- def get_reset_op():
- return control_flow_ops.group(
- state_ops.assign(processed_rows, processed_rows_init).op,
- state_ops.assign(processed_cols, processed_cols_init).op)
-
- is_sweep_done = control_flow_ops.cond(
+ switch_ops = control_flow_ops.group(
+ state_ops.assign(
self._is_row_sweep_var,
- lambda: math_ops.reduce_all(processed_rows),
- lambda: math_ops.reduce_all(processed_cols),
- name="sweep_hook_is_sweep_done")
- switch_op = control_flow_ops.cond(
- is_sweep_done,
- get_switch_op,
- control_flow_ops.no_op,
- name="sweep_hook_switch_op")
- reset_op = control_flow_ops.cond(
- is_sweep_done,
- get_reset_op,
- control_flow_ops.no_op,
- name="sweep_hook_reset_op")
- switch_ops = control_flow_ops.group(
- switch_op, reset_op, name="sweep_hook_switch_ops")
-
- with ops.control_dependencies([switch_ops]):
- # Op to increment the completed_sweeps counter.
- completed_sweeps_incr_op = control_flow_ops.cond(
- is_sweep_done,
- lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op,
- control_flow_ops.no_op,
- name="completed_sweeps_incr")
-
- # Op to increment the global_step counter.
- global_step = framework_variables.get_global_step()
- if global_step is not None:
- global_step_incr_op = state_ops.assign_add(
- global_step, 1, name="global_step_incr").op
- else:
- global_step_incr_op = control_flow_ops.no_op(
- name="global_step_incr")
-
- incr_ops = control_flow_ops.group(
- completed_sweeps_incr_op,
- global_step_incr_op,
- name="counter_incr_ops")
-
- return [is_sweep_done, switch_ops, incr_ops]
+ math_ops.logical_not(self._is_row_sweep_var)),
+ state_ops.assign(processed_rows, processed_rows_init),
+ state_ops.assign(processed_cols, processed_cols_init))
+ is_sweep_done_var = variable_scope.variable(
+ False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False,
+ name="is_sweep_done")
+
+ # After running the `train_ops`, updates `processed_rows` or
+ # `processed_cols` tensors, depending on whether this is a row or col sweep.
+ with ops.control_dependencies(train_ops):
+ with ops.colocate_with(processed_rows):
+ update_processed_rows = state_ops.scatter_update(
+ processed_rows,
+ input_row_indices,
+ math_ops.logical_and(
+ self._is_row_sweep_var,
+ array_ops.ones_like(input_row_indices, dtype=dtypes.bool)))
+ with ops.colocate_with(processed_cols):
+ update_processed_cols = state_ops.scatter_update(
+ processed_cols,
+ input_col_indices,
+ math_ops.logical_and(
+ math_ops.logical_not(self._is_row_sweep_var),
+ array_ops.ones_like(input_col_indices, dtype=dtypes.bool)))
+ update_processed_op = control_flow_ops.group(
+ update_processed_rows, update_processed_cols)
- def begin(self):
- pass
+ with ops.control_dependencies([update_processed_op]):
+ is_sweep_done = math_ops.logical_or(
+ math_ops.reduce_all(processed_rows),
+ math_ops.reduce_all(processed_cols))
+ # Increments global step.
+ global_step = framework_variables.get_global_step()
+ if global_step is not None:
+ global_step_incr_op = state_ops.assign_add(
+ global_step, 1, name="global_step_incr").op
+ else:
+ global_step_incr_op = control_flow_ops.no_op()
+ # Increments completed sweeps.
+ completed_sweeps_incr_op = state_ops.assign_add(
+ self._completed_sweeps_var,
+ math_ops.cast(is_sweep_done, dtypes.int32),
+ use_locking=True).op
+ update_ops = control_flow_ops.group(
+ global_step_incr_op,
+ completed_sweeps_incr_op,
+ state_ops.assign(is_sweep_done_var, is_sweep_done))
+
+ return update_ops, is_sweep_done_var, switch_ops
def before_run(self, run_context):
"""Runs the appropriate prep ops, and requests running update ops."""
- # Run the appropriate cache_init and prep ops
+ # Runs the appropriate init ops and prep ops.
sess = run_context.session
+ is_sweep_done = sess.run(self._is_sweep_done_var)
if not self._is_initialized:
- logging.info("SweepHook running cache init ops.")
- for init_op in self._cache_init_ops:
- sess.run(init_op)
-
- if self._is_sweep_done or not self._is_initialized:
+ logging.info("SweepHook running cache init op.")
+ sess.run(self._init_op)
+ if is_sweep_done:
+ sess.run(self._switch_op)
+ if is_sweep_done or not self._is_initialized:
logging.info("SweepHook running sweep prep ops.")
row_sweep = sess.run(self._is_row_sweep_var)
prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
@@ -232,13 +195,12 @@ class _SweepHook(session_run_hook.SessionRunHook):
self._is_initialized = True
- # Request running the switch_ops and the incr_ops
- logging.info("Partial fit starting.")
- return session_run_hook.SessionRunArgs(fetches=self._fetches)
+ # Requests running `self._update_op` jointly with the training op.
+ logging.info("Next fit step starting.")
+ return session_run_hook.SessionRunArgs(fetches=[self._update_op])
def after_run(self, run_context, run_values):
- self._is_sweep_done = run_values.results[0]
- logging.info("Partial fit done.")
+ logging.info("Fit step done.")
class _StopAtSweepHook(session_run_hook.SessionRunHook):
@@ -360,19 +322,19 @@ def _wals_factorization_model_function(features, labels, mode, params):
col_prep_ops = [
model.col_update_prep_gramian_op, model.initialize_col_update_op
]
- cache_init_ops = [model.worker_init]
+ init_ops = [model.worker_init]
sweep_hook = _SweepHook(
is_row_sweep_var,
- train_op,
+ [train_op, loss],
params["num_rows"],
params["num_cols"],
input_row_indices,
input_col_indices,
row_prep_ops,
col_prep_ops,
- cache_init_ops,
- completed_sweeps_var,)
+ init_ops,
+ completed_sweeps_var)
training_hooks = [sweep_hook]
if max_sweeps is not None:
training_hooks.append(_StopAtSweepHook(max_sweeps))
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index b5c1bb1151..8bd72b7025 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
self.assertNear(
loss, true_loss, err=.001,
- msg="""After row update, eval loss = {}, does not match the true
+ msg="""After col update, eval loss = {}, does not match the true
loss = {}.""".format(loss, true_loss))
@@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase):
completed_sweeps_var = variables.Variable(0)
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
- self._train_op,
+ [self._train_op],
self._num_rows,
self._num_cols,
self._input_row_indices_ph,
@@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase):
'False.')
# Row sweep completed.
mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
- self.assertFalse(sess.run(is_row_sweep_var),
- msg='Row sweep is complete but is_row_sweep is True.')
self.assertTrue(sess.run(completed_sweeps_var) == 1,
msg='Completed sweeps should be equal to 1.')
- self.assertTrue(sweep_hook._is_sweep_done,
+ self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
# Col init ops should run. Col sweep not completed.
mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
@@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase):
self.assertFalse(sess.run(is_row_sweep_var),
msg='Col sweep is not complete but is_row_sweep is '
'True.')
- self.assertFalse(sweep_hook._is_sweep_done,
+ self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is not complete but is_sweep_done is True.')
# Col sweep completed.
mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
- self.assertTrue(sess.run(is_row_sweep_var),
- msg='Col sweep is complete but is_row_sweep is False')
- self.assertTrue(sweep_hook._is_sweep_done,
+ self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
self.assertTrue(sess.run(completed_sweeps_var) == 2,
msg='Completed sweeps should be equal to 2.')