aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-14 14:44:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 14:57:41 -0800
commita5a192865e4c1732b414d6a503d07775f7163a5c (patch)
tree16179cd6b477eb09882b5fa61132427eb5280f5f /tensorflow/contrib/factorization
parentb8054c19b7d72cfb7eb07552a8d0385ffd8810d7 (diff)
Refactors the WALS estimator so that part of the control flow logic happens in the SweepHook. This fixes a bug that causes both input batches (rows and columns) to be fetched during any given sweep.
PiperOrigin-RevId: 175738242
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py449
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py112
2 files changed, 274 insertions, 287 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index 3976395d78..b2f22eb2fc 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.factorization.python.ops import factorization_ops
-from tensorflow.contrib.framework.python.ops import variables as framework_variables
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.framework import dtypes
@@ -32,175 +31,64 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import training_util
class _SweepHook(session_run_hook.SessionRunHook):
"""Keeps track of row/col sweeps, and runs prep ops before each sweep."""
- def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols,
- input_row_indices, input_col_indices, row_prep_ops,
- col_prep_ops, init_op, completed_sweeps_var):
+ def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op,
+ row_prep_ops, col_prep_ops, row_train_op, col_train_op,
+ switch_op):
"""Initializes SweepHook.
Args:
is_row_sweep_var: A Boolean tf.Variable, determines whether we are
currently doing a row or column sweep. It is updated by the hook.
- train_ops: A list of ops. The ops created by this hook will have
- control dependencies on `train_ops`.
- num_rows: int, the total number of rows to be processed.
- num_cols: int, the total number of columns to be processed.
- input_row_indices: A Tensor of type int64. The indices of the input rows
- that are processed during the current sweep. All elements of
- `input_row_indices` must be in [0, num_rows).
- input_col_indices: A Tensor of type int64. The indices of the input
- columns that are processed during the current sweep. All elements of
- `input_col_indices` must be in [0, num_cols).
- row_prep_ops: list of ops, to be run before the beginning of each row
- sweep, in the given order.
- col_prep_ops: list of ops, to be run before the beginning of each column
- sweep, in the given order.
+ is_sweep_done_var: A Boolean tf.Variable, determines whether we are
+ starting a new sweep (this is used to determine when to run the prep ops
+ below).
init_op: op to be run once before training. This is typically a local
initialization op (such as cache initialization).
- completed_sweeps_var: An integer tf.Variable, indicates the number of
- completed sweeps. It is updated by the hook.
+ row_prep_ops: A list of TensorFlow ops, to be run before the beginning of
+ each row sweep (and during initialization), in the given order.
+ col_prep_ops: A list of TensorFlow ops, to be run before the beginning of
+ each column sweep (and during initialization), in the given order.
+ row_train_op: A TensorFlow op to be run during row sweeps.
+ col_train_op: A TensorFlow op to be run during column sweeps.
+ switch_op: A TensorFlow op to be run before each sweep.
"""
- self._num_rows = num_rows
- self._num_cols = num_cols
+ self._is_row_sweep_var = is_row_sweep_var
+ self._is_sweep_done_var = is_sweep_done_var
+ self._init_op = init_op
self._row_prep_ops = row_prep_ops
self._col_prep_ops = col_prep_ops
- self._init_op = init_op
- self._is_row_sweep_var = is_row_sweep_var
- self._completed_sweeps_var = completed_sweeps_var
- # Boolean variable that determines whether the init_ops have been run.
+ self._row_train_op = row_train_op
+ self._col_train_op = col_train_op
+ self._switch_op = switch_op
+ # Boolean variable that determines whether the init_op has been run.
self._is_initialized = False
- # Ops to run jointly with train_ops, responsible for updating
- # `is_row_sweep_var` and incrementing the `global_step` and
- # `completed_sweeps` counters.
- self._update_op, self._is_sweep_done_var, self._switch_op = (
- self._create_hook_ops(input_row_indices, input_col_indices, train_ops))
-
- def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops):
- """Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
-
- Creates two boolean tensors `processed_rows` and `processed_cols`, which
- keep track of which rows/cols have been processed during the current sweep.
- Returns ops that should be run after each row / col update.
- - When `self._is_row_sweep_var` is True, it sets
- processed_rows[input_row_indices] to True.
- - When `self._is_row_sweep_var` is False, it sets
- processed_cols[input_col_indices] to True.
-
- Args:
- input_row_indices: A Tensor. The indices of the input rows that are
- processed during the current sweep.
- input_col_indices: A Tensor. The indices of the input columns that
- are processed during the current sweep.
- train_ops: A list of ops. The ops created by this function have control
- dependencies on `train_ops`.
-
- Returns:
- A tuple consisting of:
- update_op: An op to be run jointly with training. It updates the state
- and increments counters (global step and completed sweeps).
- is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
- done, i.e. all rows (during a row sweep) or all columns (during a
- column sweep) have been processed.
- switch_op: An op to be run in `self.before_run` when the sweep is done.
- """
- processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
- with ops.colocate_with(processed_rows_init):
- processed_rows = variable_scope.variable(
- processed_rows_init,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES],
- trainable=False,
- name="sweep_hook_processed_rows")
- processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False)
- with ops.colocate_with(processed_cols_init):
- processed_cols = variable_scope.variable(
- processed_cols_init,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES],
- trainable=False,
- name="sweep_hook_processed_cols")
- switch_ops = control_flow_ops.group(
- state_ops.assign(
- self._is_row_sweep_var,
- math_ops.logical_not(self._is_row_sweep_var)),
- state_ops.assign(processed_rows, processed_rows_init),
- state_ops.assign(processed_cols, processed_cols_init))
- is_sweep_done_var = variable_scope.variable(
- False,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES],
- trainable=False,
- name="is_sweep_done")
-
- # After running the `train_ops`, updates `processed_rows` or
- # `processed_cols` tensors, depending on whether this is a row or col sweep.
- with ops.control_dependencies(train_ops):
- with ops.colocate_with(processed_rows):
- update_processed_rows = state_ops.scatter_update(
- processed_rows,
- input_row_indices,
- math_ops.logical_and(
- self._is_row_sweep_var,
- array_ops.ones_like(input_row_indices, dtype=dtypes.bool)))
- with ops.colocate_with(processed_cols):
- update_processed_cols = state_ops.scatter_update(
- processed_cols,
- input_col_indices,
- math_ops.logical_and(
- math_ops.logical_not(self._is_row_sweep_var),
- array_ops.ones_like(input_col_indices, dtype=dtypes.bool)))
- update_processed_op = control_flow_ops.group(
- update_processed_rows, update_processed_cols)
-
- with ops.control_dependencies([update_processed_op]):
- is_sweep_done = math_ops.logical_or(
- math_ops.reduce_all(processed_rows),
- math_ops.reduce_all(processed_cols))
- # Increments global step.
- global_step = framework_variables.get_global_step()
- if global_step is not None:
- global_step_incr_op = state_ops.assign_add(
- global_step, 1, name="global_step_incr").op
- else:
- global_step_incr_op = control_flow_ops.no_op()
- # Increments completed sweeps.
- completed_sweeps_incr_op = state_ops.assign_add(
- self._completed_sweeps_var,
- math_ops.cast(is_sweep_done, dtypes.int32),
- use_locking=True).op
- update_ops = control_flow_ops.group(
- global_step_incr_op,
- completed_sweeps_incr_op,
- state_ops.assign(is_sweep_done_var, is_sweep_done))
-
- return update_ops, is_sweep_done_var, switch_ops
def before_run(self, run_context):
"""Runs the appropriate prep ops, and requests running update ops."""
- # Runs the appropriate init ops and prep ops.
sess = run_context.session
is_sweep_done = sess.run(self._is_sweep_done_var)
if not self._is_initialized:
- logging.info("SweepHook running cache init op.")
+ logging.info("SweepHook running init op.")
sess.run(self._init_op)
if is_sweep_done:
sess.run(self._switch_op)
+ is_row_sweep = sess.run(self._is_row_sweep_var)
if is_sweep_done or not self._is_initialized:
- logging.info("SweepHook running sweep prep ops.")
- row_sweep = sess.run(self._is_row_sweep_var)
- prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
+ logging.info("SweepHook running prep ops for the {} sweep.".format(
+ "row" if is_row_sweep else "col"))
+ prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops
for prep_op in prep_ops:
sess.run(prep_op)
-
self._is_initialized = True
-
- # Requests running `self._update_op` jointly with the training op.
logging.info("Next fit step starting.")
- return session_run_hook.SessionRunArgs(fetches=[self._update_op])
-
- def after_run(self, run_context, run_values):
- logging.info("Fit step done.")
+ return session_run_hook.SessionRunArgs(
+ fetches=[self._row_train_op if is_row_sweep else self._col_train_op])
class _StopAtSweepHook(session_run_hook.SessionRunHook):
@@ -246,6 +134,9 @@ def _wals_factorization_model_function(features, labels, mode, params):
Returns:
A ModelFnOps object.
+
+ Raises:
+ ValueError: If `mode` is not recognized.
"""
assert labels is None
use_factors_weights_cache = (params["use_factors_weights_cache_for_training"]
@@ -269,86 +160,156 @@ def _wals_factorization_model_function(features, labels, mode, params):
use_gramian_cache=use_gramian_cache)
# Get input rows and cols. We either update rows or columns depending on
- # the value of row_sweep, which is maintained using a session hook
+ # the value of row_sweep, which is maintained using a session hook.
input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
input_cols = features[WALSMatrixFactorization.INPUT_COLS]
- input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0])
- input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0])
-
- # Train ops, controlled using the SweepHook
- # We need to run the following ops:
- # Before a row sweep:
- # row_update_prep_gramian_op
- # initialize_row_update_op
- # During a row sweep:
- # update_row_factors_op
- # Before a col sweep:
- # col_update_prep_gramian_op
- # initialize_col_update_op
- # During a col sweep:
- # update_col_factors_op
-
- is_row_sweep_var = variable_scope.variable(
- True,
- trainable=False,
- name="is_row_sweep",
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
- completed_sweeps_var = variable_scope.variable(
- 0,
- trainable=False,
- name=WALSMatrixFactorization.COMPLETED_SWEEPS,
- collections=[ops.GraphKeys.GLOBAL_VARIABLES])
-
- # The row sweep is determined by is_row_sweep_var (controlled by the
- # sweep_hook) in TRAIN mode, and manually in EVAL mode.
- is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
- if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var)
-
- def update_row_factors():
- return model.update_row_factors(sp_input=input_rows, transpose_input=False)
-
- def update_col_factors():
- return model.update_col_factors(sp_input=input_cols, transpose_input=True)
-
- (_, train_op,
- unregularized_loss, regularization, sum_weights) = control_flow_ops.cond(
- is_row_sweep, update_row_factors, update_col_factors)
- loss = unregularized_loss + regularization
- root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights)
-
- row_prep_ops = [
- model.row_update_prep_gramian_op, model.initialize_row_update_op
- ]
- col_prep_ops = [
- model.col_update_prep_gramian_op, model.initialize_col_update_op
- ]
- init_ops = [model.worker_init]
-
- sweep_hook = _SweepHook(
- is_row_sweep_var,
- [train_op, loss],
- params["num_rows"],
- params["num_cols"],
- input_row_indices,
- input_col_indices,
- row_prep_ops,
- col_prep_ops,
- init_ops,
- completed_sweeps_var)
- training_hooks = [sweep_hook]
- if max_sweeps is not None:
- training_hooks.append(_StopAtSweepHook(max_sweeps))
-
- # The root weighted squared error =
- # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
- summary.scalar("loss", loss) # the estimated total training loss
- summary.scalar("root_weighted_squared_error", root_weighted_squared_error)
- summary.scalar("completed_sweeps", completed_sweeps_var)
-
- # Prediction ops (only return predictions in INFER mode)
- predictions = {}
- if mode == model_fn.ModeKeys.INFER:
- project_row = features[WALSMatrixFactorization.PROJECT_ROW]
+
+ # TRAIN mode:
+ if mode == model_fn.ModeKeys.TRAIN:
+ # Training consists of the folowing ops (controlled using a SweepHook).
+ # Before a row sweep:
+ # row_update_prep_gramian_op
+ # initialize_row_update_op
+ # During a row sweep:
+ # update_row_factors_op
+ # Before a col sweep:
+ # col_update_prep_gramian_op
+ # initialize_col_update_op
+ # During a col sweep:
+ # update_col_factors_op
+
+ is_row_sweep_var = variable_scope.variable(
+ True,
+ trainable=False,
+ name="is_row_sweep",
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ is_sweep_done_var = variable_scope.variable(
+ False,
+ trainable=False,
+ name="is_sweep_done",
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ completed_sweeps_var = variable_scope.variable(
+ 0,
+ trainable=False,
+ name=WALSMatrixFactorization.COMPLETED_SWEEPS,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ loss_var = variable_scope.variable(
+ 0.,
+ trainable=False,
+ name=WALSMatrixFactorization.LOSS,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ # The root weighted squared error =
+ # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )
+ rwse_var = variable_scope.variable(
+ 0.,
+ trainable=False,
+ name=WALSMatrixFactorization.RWSE,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+
+ summary.scalar("loss", loss_var)
+ summary.scalar("root_weighted_squared_error", rwse_var)
+ summary.scalar("completed_sweeps", completed_sweeps_var)
+
+ # Increments global step.
+ global_step = training_util.get_global_step()
+ if global_step:
+ global_step_incr_op = state_ops.assign_add(
+ global_step, 1, name="global_step_incr").op
+ else:
+ global_step_incr_op = control_flow_ops.no_op()
+
+ def create_axis_ops(sp_input, num_items, update_fn, axis_name):
+ """Creates book-keeping and training ops for a given axis.
+
+ Args:
+ sp_input: A SparseTensor corresponding to the row or column batch.
+ num_items: An integer, the total number of items of this axis.
+ update_fn: A function that takes one argument (`sp_input`), and that
+ returns a tuple of
+ * new_factors: A flot Tensor of the factor values after update.
+ * update_op: a TensorFlow op which updates the factors.
+ * loss: A float Tensor, the unregularized loss.
+ * reg_loss: A float Tensor, the regularization loss.
+ * sum_weights: A float Tensor, the sum of factor weights.
+ axis_name: A string that specifies the name of the axis.
+
+ Returns:
+ A tuple consisting of:
+ * reset_processed_items_op: A TensorFlow op, to be run before the
+ beginning of any sweep. It marks all items as not-processed.
+ * axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
+ """
+ processed_items_init = array_ops.fill(dims=[num_items], value=False)
+ with ops.colocate_with(processed_items_init):
+ processed_items = variable_scope.variable(
+ processed_items_init,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False,
+ name="processed_" + axis_name)
+ reset_processed_items_op = state_ops.assign(
+ processed_items, processed_items_init,
+ name="reset_processed_" + axis_name)
+ _, update_op, loss, reg, sum_weights = update_fn(sp_input)
+ input_indices = sp_input.indices[:, 0]
+ with ops.control_dependencies([
+ update_op,
+ state_ops.assign(loss_var, loss + reg),
+ state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
+ with ops.colocate_with(processed_items):
+ update_processed_items = state_ops.scatter_update(
+ processed_items,
+ input_indices,
+ array_ops.ones_like(input_indices, dtype=dtypes.bool),
+ name="update_processed_{}_indices".format(axis_name))
+ with ops.control_dependencies([update_processed_items]):
+ is_sweep_done = math_ops.reduce_all(processed_items)
+ axis_train_op = control_flow_ops.group(
+ global_step_incr_op,
+ state_ops.assign(is_sweep_done_var, is_sweep_done),
+ state_ops.assign_add(
+ completed_sweeps_var,
+ math_ops.cast(is_sweep_done, dtypes.int32)),
+ name="{}_sweep_train_op".format(axis_name))
+ return reset_processed_items_op, axis_train_op
+
+ reset_processed_rows_op, row_train_op = create_axis_ops(
+ input_rows,
+ params["num_rows"],
+ lambda x: model.update_row_factors(sp_input=x, transpose_input=False),
+ "rows")
+ reset_processed_cols_op, col_train_op = create_axis_ops(
+ input_cols,
+ params["num_cols"],
+ lambda x: model.update_col_factors(sp_input=x, transpose_input=True),
+ "cols")
+ switch_op = control_flow_ops.group(
+ state_ops.assign(
+ is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
+ reset_processed_rows_op,
+ reset_processed_cols_op,
+ name="sweep_switch_op")
+ row_prep_ops = [
+ model.row_update_prep_gramian_op, model.initialize_row_update_op]
+ col_prep_ops = [
+ model.col_update_prep_gramian_op, model.initialize_col_update_op]
+ init_op = model.worker_init
+ sweep_hook = _SweepHook(
+ is_row_sweep_var, is_sweep_done_var, init_op,
+ row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
+ training_hooks = [sweep_hook]
+ if max_sweeps is not None:
+ training_hooks.append(_StopAtSweepHook(max_sweeps))
+
+ return model_fn.ModelFnOps(
+ mode=model_fn.ModeKeys.TRAIN,
+ predictions={},
+ loss=loss_var,
+ eval_metric_ops={},
+ train_op=control_flow_ops.no_op(),
+ training_hooks=training_hooks)
+
+ # INFER mode
+ elif mode == model_fn.ModeKeys.INFER:
projection_weights = features.get(
WALSMatrixFactorization.PROJECTION_WEIGHTS)
@@ -364,17 +325,45 @@ def _wals_factorization_model_function(features, labels, mode, params):
projection_weights=projection_weights,
transpose_input=True)
- predictions[WALSMatrixFactorization.PROJECTION_RESULT] = (
- control_flow_ops.cond(project_row, get_row_projection,
- get_col_projection))
+ predictions = {
+ WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond(
+ features[WALSMatrixFactorization.PROJECT_ROW],
+ get_row_projection,
+ get_col_projection)
+ }
- return model_fn.ModelFnOps(
- mode=mode,
- predictions=predictions,
- loss=loss,
- eval_metric_ops={},
- train_op=train_op,
- training_hooks=training_hooks)
+ return model_fn.ModelFnOps(
+ mode=model_fn.ModeKeys.INFER,
+ predictions=predictions,
+ loss=None,
+ eval_metric_ops={},
+ train_op=control_flow_ops.no_op(),
+ training_hooks=[])
+
+ # EVAL mode
+ elif mode == model_fn.ModeKeys.EVAL:
+ def get_row_loss():
+ _, _, loss, reg, _ = model.update_row_factors(
+ sp_input=input_rows, transpose_input=False)
+ return loss + reg
+ def get_col_loss():
+ _, _, loss, reg, _ = model.update_col_factors(
+ sp_input=input_cols, transpose_input=True)
+ return loss + reg
+ loss = control_flow_ops.cond(
+ features[WALSMatrixFactorization.PROJECT_ROW],
+ get_row_loss,
+ get_col_loss)
+ return model_fn.ModelFnOps(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={},
+ loss=loss,
+ eval_metric_ops={},
+ train_op=control_flow_ops.no_op(),
+ training_hooks=[])
+
+ else:
+ raise ValueError("mode=%s is not recognized." % str(mode))
class WALSMatrixFactorization(estimator.Estimator):
@@ -452,6 +441,10 @@ class WALSMatrixFactorization(estimator.Estimator):
PROJECTION_RESULT = "projection"
# Name of the completed_sweeps variable
COMPLETED_SWEEPS = "completed_sweeps"
+ # Name of the loss variable
+ LOSS = "WALS_loss"
+ # Name of the Root Weighted Squared Error variable
+ RWSE = "WALS_RWSE"
def __init__(self,
num_rows,
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 8bd72b7025..36b483c6d7 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
class SweepHookTest(test.TestCase):
- def setUp(self):
- self._num_rows = 5
- self._num_cols = 7
- self._train_op = control_flow_ops.no_op()
- self._row_prep_done = variables.Variable(False)
- self._col_prep_done = variables.Variable(False)
- self._init_done = variables.Variable(False)
- self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)]
- self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)]
- self._init_ops = [state_ops.assign(self._init_done, True)]
- self._input_row_indices_ph = array_ops.placeholder(dtypes.int64)
- self._input_col_indices_ph = array_ops.placeholder(dtypes.int64)
-
def test_sweeps(self):
- def ind_feed(row_indices, col_indices):
- return {
- self._input_row_indices_ph: row_indices,
- self._input_col_indices_ph: col_indices
- }
+ is_row_sweep_var = variables.Variable(True)
+ is_sweep_done_var = variables.Variable(False)
+ init_done = variables.Variable(False)
+ row_prep_done = variables.Variable(False)
+ col_prep_done = variables.Variable(False)
+ row_train_done = variables.Variable(False)
+ col_train_done = variables.Variable(False)
+
+ init_op = state_ops.assign(init_done, True)
+ row_prep_op = state_ops.assign(row_prep_done, True)
+ col_prep_op = state_ops.assign(col_prep_done, True)
+ row_train_op = state_ops.assign(row_train_done, True)
+ col_train_op = state_ops.assign(col_train_done, True)
+ train_op = control_flow_ops.no_op()
+ switch_op = control_flow_ops.group(
+ state_ops.assign(is_sweep_done_var, False),
+ state_ops.assign(is_row_sweep_var,
+ math_ops.logical_not(is_row_sweep_var)))
+ mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
with self.test_session() as sess:
- is_row_sweep_var = variables.Variable(True)
- completed_sweeps_var = variables.Variable(0)
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
- [self._train_op],
- self._num_rows,
- self._num_cols,
- self._input_row_indices_ph,
- self._input_col_indices_ph,
- self._row_prep_ops,
- self._col_prep_ops,
- self._init_ops,
- completed_sweeps_var)
+ is_sweep_done_var,
+ init_op,
+ [row_prep_op],
+ [col_prep_op],
+ row_train_op,
+ col_train_op,
+ switch_op)
mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
sess.run([variables.global_variables_initializer()])
- # Init ops should run before the first run. Row sweep not completed.
- mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
- self.assertTrue(sess.run(self._init_done),
- msg='init ops not run by the sweep_hook')
- self.assertTrue(sess.run(self._row_prep_done),
- msg='row_prep not run by the sweep_hook')
- self.assertTrue(sess.run(is_row_sweep_var),
- msg='Row sweep is not complete but is_row_sweep is '
- 'False.')
- # Row sweep completed.
- mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
- self.assertTrue(sess.run(completed_sweeps_var) == 1,
- msg='Completed sweeps should be equal to 1.')
- self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
- msg='Sweep is complete but is_sweep_done is False.')
- # Col init ops should run. Col sweep not completed.
- mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
- self.assertTrue(sess.run(self._col_prep_done),
- msg='col_prep not run by the sweep_hook')
- self.assertFalse(sess.run(is_row_sweep_var),
- msg='Col sweep is not complete but is_row_sweep is '
- 'True.')
- self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
- msg='Sweep is not complete but is_sweep_done is True.')
- # Col sweep completed.
- mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
- self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
- msg='Sweep is complete but is_sweep_done is False.')
- self.assertTrue(sess.run(completed_sweeps_var) == 2,
- msg='Completed sweeps should be equal to 2.')
+ # Row sweep.
+ mon_sess.run(train_op)
+ self.assertTrue(sess.run(init_done),
+ msg='init op not run by the Sweephook')
+ self.assertTrue(sess.run(row_prep_done),
+ msg='row_prep_op not run by the SweepHook')
+ self.assertTrue(sess.run(row_train_done),
+ msg='row_train_op not run by the SweepHook')
+ self.assertTrue(
+ sess.run(is_row_sweep_var),
+ msg='Row sweep is not complete but is_row_sweep_var is False.')
+ # Col sweep.
+ mon_sess.run(mark_sweep_done)
+ mon_sess.run(train_op)
+ self.assertTrue(sess.run(col_prep_done),
+ msg='col_prep_op not run by the SweepHook')
+ self.assertTrue(sess.run(col_train_done),
+ msg='col_train_op not run by the SweepHook')
+ self.assertFalse(
+ sess.run(is_row_sweep_var),
+ msg='Col sweep is not complete but is_row_sweep_var is True.')
+ # Row sweep.
+ mon_sess.run(mark_sweep_done)
+ mon_sess.run(train_op)
+ self.assertTrue(
+ sess.run(is_row_sweep_var),
+ msg='Col sweep is complete but is_row_sweep_var is False.')
class StopAtSweepHookTest(test.TestCase):