aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-06 10:19:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 10:23:32 -0700
commitdc81a24204fd93f2dcc90e11c28566b9567767f2 (patch)
treecb6dd6e9bf8c10dc6610ad3cc8ddb6df1ed670c9 /tensorflow/contrib/factorization
parent74220616c3b6739b89c352f25fb6dccf483f7fab (diff)
Updates to the WALSMatrixFactorization estimator:
- Add a completed_sweeps variable to keep track of sweeps that have been completed during training. - Add a StopAtSweepHook, which can request a stop after completing a specified number of sweeps. PiperOrigin-RevId: 158156347
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py103
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py46
2 files changed, 133 insertions, 16 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index 2f7bf48041..0bc0ef39ec 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
@@ -46,7 +47,8 @@ class _SweepHook(session_run_hook.SessionRunHook):
processed_col_indices,
row_prep_ops,
col_prep_ops,
- cache_init_ops):
+ cache_init_ops,
+ completed_sweeps_var):
"""Initializes SweepHook.
Args:
@@ -69,22 +71,24 @@ class _SweepHook(session_run_hook.SessionRunHook):
cache_init_ops: list of ops, to be run once before training, in the given
order. These are typically local initialization ops (such as cache
initialization).
+ completed_sweeps_var: An integer tf.Variable, indicates the number of
+ completed sweeps. It is updated by the hook.
"""
- # TODO(walidk): Provide a counter for the number of completed sweeps.
self._num_rows = num_rows
self._num_cols = num_cols
self._row_prep_ops = row_prep_ops
self._col_prep_ops = col_prep_ops
self._cache_init_ops = cache_init_ops
self._is_row_sweep_var = is_row_sweep_var
+ self._completed_sweeps_var = completed_sweeps_var
# Boolean variable that determines whether the cache_init_ops have been run.
self._is_initialized = False
# Boolean variable that is set to True when a sweep is completed.
# Used to run the prep_ops at the beginning of a sweep, in before_run().
self._is_sweep_done = False
# Ops to run jointly with train_op, responsible for updating
- # _is_row_sweep_var and incrementing the global_step counter. They have
- # control_dependencies on train_op.
+ # _is_row_sweep_var and incrementing the global_step and completed_sweeps
+ # counters. They have control_dependencies on train_op.
self._fetches = self._create_switch_ops(processed_row_indices,
processed_col_indices,
train_op)
@@ -93,7 +97,7 @@ class _SweepHook(session_run_hook.SessionRunHook):
processed_row_indices,
processed_col_indices,
train_op):
- """Creates ops to update is_row_sweep_var and to increment global_step.
+ """Creates ops to update is_row_sweep_var, global_step and completed_sweeps.
Creates two boolean tensors processed_rows and processed_cols, which keep
track of which rows/cols have been processed during the current sweep.
@@ -102,8 +106,9 @@ class _SweepHook(session_run_hook.SessionRunHook):
processed_rows[processed_row_indices] to True.
- When is_row_sweep_var is False, it sets
processed_cols[processed_col_indices] to True .
- When all rows or all cols have been processed, negates is_row_sweep_var and
- resets processed_rows and processed_cols to False.
+ When all rows or all cols have been processed, negates is_row_sweep_var,
+ increments the completed_sweeps counter, and resets processed_rows and
+ processed_cols to False.
All of the ops created by this function have control_dependencies on
train_op.
@@ -121,9 +126,10 @@ class _SweepHook(session_run_hook.SessionRunHook):
sweep) have been processed.
switch_ops: An op that updates is_row_sweep_var when is_sweep_done is
True. Has control_dependencies on train_op.
- global_step_incr_op: An op that increments the global_step counter. Has
- control_dependenciens on switch_ops.
+ incr_ops: An op that increments the global_step and completed_sweeps
+ counters. Has control_dependenciens on switch_ops.
"""
+
processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
with ops.colocate_with(processed_rows_init):
processed_rows = variable_scope.variable(
@@ -184,9 +190,16 @@ class _SweepHook(session_run_hook.SessionRunHook):
switch_ops = control_flow_ops.group(switch_op, reset_op,
name="sweep_hook_switch_ops")
- # Op to increment the global step
- global_step = framework_variables.get_global_step()
with ops.control_dependencies([switch_ops]):
+ # Op to increment the completed_sweeps counter.
+ completed_sweeps_incr_op = control_flow_ops.cond(
+ is_sweep_done,
+ lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op,
+ control_flow_ops.no_op,
+ name="completed_sweeps_incr")
+
+ # Op to increment the global_step counter.
+ global_step = framework_variables.get_global_step()
if global_step is not None:
global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
@@ -194,7 +207,11 @@ class _SweepHook(session_run_hook.SessionRunHook):
global_step_incr_op = control_flow_ops.no_op(
name="global_step_incr")
- return [is_sweep_done, switch_ops, global_step_incr_op]
+ incr_ops = control_flow_ops.group(
+ completed_sweeps_incr_op, global_step_incr_op,
+ name="counter_incr_ops")
+
+ return [is_sweep_done, switch_ops, incr_ops]
def begin(self):
pass
@@ -217,7 +234,7 @@ class _SweepHook(session_run_hook.SessionRunHook):
self._is_initialized = True
- # Request running the switch_ops and the global_step_incr_op
+ # Request running the switch_ops and the incr_ops
logging.info("Partial fit starting.")
return session_run_hook.SessionRunArgs(fetches=self._fetches)
@@ -226,6 +243,37 @@ class _SweepHook(session_run_hook.SessionRunHook):
logging.info("Partial fit done.")
+class _StopAtSweepHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop at a given sweep."""
+
+ def __init__(self, last_sweep):
+ """Initializes a `StopAtSweepHook`.
+
+ This hook requests stop at a given sweep. Relies on the tensor named
+ COMPLETED_SWEEPS in the default graph.
+
+ Args:
+ last_sweep: Integer, number of the last sweep to run.
+ """
+ self._last_sweep = last_sweep
+
+ def begin(self):
+ try:
+ self._completed_sweeps_var = ops.get_default_graph().get_tensor_by_name(
+ WALSMatrixFactorization.COMPLETED_SWEEPS+":0")
+ except KeyError:
+ raise RuntimeError(WALSMatrixFactorization.COMPLETED_SWEEPS +
+ " counter should be created to use StopAtSweepHook.")
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs(self._completed_sweeps_var)
+
+ def after_run(self, run_context, run_values):
+ completed_sweeps = run_values.results
+ if completed_sweeps >= self._last_sweep:
+ run_context.request_stop()
+
+
def _wals_factorization_model_function(features, labels, mode, params):
"""Model function for the WALSFactorization estimator.
@@ -246,6 +294,7 @@ def _wals_factorization_model_function(features, labels, mode, params):
use_gramian_cache = (
params["use_gramian_cache_for_training"]
and mode == model_fn.ModeKeys.TRAIN)
+ max_sweeps = params["max_sweeps"]
model = factorization_ops.WALSModel(
params["num_rows"],
params["num_cols"],
@@ -282,8 +331,16 @@ def _wals_factorization_model_function(features, labels, mode, params):
# update_col_factors_op
is_row_sweep_var = variable_scope.variable(
- True, "is_row_sweep",
+ True,
+ trainable=False,
+ name="is_row_sweep",
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ completed_sweeps_var = variable_scope.variable(
+ 0,
+ trainable=False,
+ name=WALSMatrixFactorization.COMPLETED_SWEEPS,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+
# The row sweep is determined by is_row_sweep_var (controlled by the
# sweep_hook) in TRAIN mode, and manually in EVAL mode.
is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
@@ -312,7 +369,14 @@ def _wals_factorization_model_function(features, labels, mode, params):
row_prep_ops,
col_prep_ops,
cache_init_ops,
+ completed_sweeps_var,
)
+ training_hooks = [sweep_hook]
+ if max_sweeps is not None:
+ training_hooks.append(_StopAtSweepHook(max_sweeps))
+
+ summary.scalar("loss", loss)
+ summary.scalar("completed_sweeps", completed_sweeps_var)
# Prediction ops (only return predictions in INFER mode)
predictions = {}
@@ -341,7 +405,7 @@ def _wals_factorization_model_function(features, labels, mode, params):
loss=loss,
eval_metric_ops={},
train_op=train_op,
- training_hooks=[sweep_hook])
+ training_hooks=training_hooks)
class WALSMatrixFactorization(estimator.Estimator):
@@ -417,6 +481,8 @@ class WALSMatrixFactorization(estimator.Estimator):
PROJECTION_WEIGHTS = "projection_weights"
# Predictions key
PROJECTION_RESULT = "projection"
+ # Name of the completed_sweeps variable
+ COMPLETED_SWEEPS = "completed_sweeps"
def __init__(self,
num_rows,
@@ -432,6 +498,7 @@ class WALSMatrixFactorization(estimator.Estimator):
col_weights=1,
use_factors_weights_cache_for_training=True,
use_gramian_cache_for_training=True,
+ max_sweeps=None,
model_dir=None,
config=None):
"""Creates a model for matrix factorization using the WALS method.
@@ -471,6 +538,11 @@ class WALSMatrixFactorization(estimator.Estimator):
use_gramian_cache_for_training: Boolean, whether the Gramians will be
cached on the workers before the updates start, during training.
Defaults to True. Note that caching is disabled during prediction.
+ max_sweeps: integer, optional. Specifies the number of sweeps for which
+ to train the model, where a sweep is defined as a full update of all the
+ row factors (resp. column factors).
+ If `steps` or `max_steps` is also specified in model.fit(), training
+ stops when either of the steps condition or sweeps condition is met.
model_dir: The directory to save the model results and log files.
config: A Configuration object. See Estimator.
@@ -495,6 +567,7 @@ class WALSMatrixFactorization(estimator.Estimator):
"num_col_shards": num_col_shards,
"row_weights": row_weights,
"col_weights": col_weights,
+ "max_sweeps": max_sweeps,
"use_factors_weights_cache_for_training":
use_factors_weights_cache_for_training,
"use_gramian_cache_for_training": use_gramian_cache_for_training
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 25e3ee8bae..b5c1bb1151 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -197,6 +197,10 @@ class WALSMatrixFactorizationTest(test.TestCase):
def use_cache(self):
return False
+ @property
+ def max_sweeps(self):
+ return None
+
def setUp(self):
self._num_rows = 5
self._num_cols = 7
@@ -245,6 +249,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
num_col_shards=self._num_col_shards,
row_weights=self._row_weights,
col_weights=self._col_weights,
+ max_sweeps=self.max_sweeps,
use_factors_weights_cache_for_training=self.use_cache,
use_gramian_cache_for_training=self.use_cache)
@@ -356,6 +361,19 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = {}.""".format(loss, true_loss))
+class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest):
+
+ @property
+ def max_sweeps(self):
+ return 2
+
+ # We set the column steps to None so that we rely only on max_sweeps to stop
+ # training.
+ @property
+ def col_steps(self):
+ return None
+
+
class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest):
@property
@@ -421,6 +439,7 @@ class SweepHookTest(test.TestCase):
with self.test_session() as sess:
is_row_sweep_var = variables.Variable(True)
+ completed_sweeps_var = variables.Variable(0)
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
self._train_op,
@@ -430,7 +449,8 @@ class SweepHookTest(test.TestCase):
self._input_col_indices_ph,
self._row_prep_ops,
self._col_prep_ops,
- self._init_ops)
+ self._init_ops,
+ completed_sweeps_var)
mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
sess.run([variables.global_variables_initializer()])
@@ -447,6 +467,8 @@ class SweepHookTest(test.TestCase):
mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
self.assertFalse(sess.run(is_row_sweep_var),
msg='Row sweep is complete but is_row_sweep is True.')
+ self.assertTrue(sess.run(completed_sweeps_var) == 1,
+ msg='Completed sweeps should be equal to 1.')
self.assertTrue(sweep_hook._is_sweep_done,
msg='Sweep is complete but is_sweep_done is False.')
# Col init ops should run. Col sweep not completed.
@@ -464,6 +486,28 @@ class SweepHookTest(test.TestCase):
msg='Col sweep is complete but is_row_sweep is False')
self.assertTrue(sweep_hook._is_sweep_done,
msg='Sweep is complete but is_sweep_done is False.')
+ self.assertTrue(sess.run(completed_sweeps_var) == 2,
+ msg='Completed sweeps should be equal to 2.')
+
+
+class StopAtSweepHookTest(test.TestCase):
+
+ def test_stop(self):
+ hook = wals_lib._StopAtSweepHook(last_sweep=10)
+ completed_sweeps = variables.Variable(
+ 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS)
+ train_op = state_ops.assign_add(completed_sweeps, 1)
+ hook.begin()
+
+ with self.test_session() as sess:
+ sess.run([variables.global_variables_initializer()])
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ mon_sess.run(train_op)
+ # completed_sweeps is 9 after running train_op.
+ self.assertFalse(mon_sess.should_stop())
+ mon_sess.run(train_op)
+ # completed_sweeps is 10 after running train_op.
+ self.assertTrue(mon_sess.should_stop())
if __name__ == '__main__':