aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-29 17:12:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-29 18:33:33 -0700
commit79d5e63091418d20b6ee01b5cc7e598a0606d691 (patch)
treeb6ef49d8071110c711a7aa3cc7aca93d0fd504ee /tensorflow
parent1f6c19e3a9131db4fa0353982b7bf44d61fe2ce8 (diff)
Add a SessionRunHook to manage synchronization of the row and column sweeps in the WALS solver.
Change: 151652662
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/factorization/BUILD21
-rw-r--r--tensorflow/contrib/factorization/__init__.py1
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py223
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py129
4 files changed, 374 insertions, 0 deletions
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 338e67e953..aa26c6060f 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -23,6 +23,7 @@ tf_custom_op_py_library(
"python/ops/factorization_ops.py",
"python/ops/gmm.py",
"python/ops/gmm_ops.py",
+ "python/ops/wals.py",
],
dso = [
":python/ops/_clustering_ops.so",
@@ -176,6 +177,26 @@ tf_py_test(
],
)
+# Estimators tests
+tf_py_test(
+ name = "wals_test",
+ size = "medium",
+ srcs = ["python/ops/wals_test.py"],
+ additional_deps = [
+ ":factorization_py",
+ ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_benchmark",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
# Kernel tests
tf_py_test(
name = "wals_solver_ops_test",
diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py
index 63243a7fdb..f0ca879259 100644
--- a/tensorflow/contrib/factorization/__init__.py
+++ b/tensorflow/contrib/factorization/__init__.py
@@ -23,4 +23,5 @@ from tensorflow.contrib.factorization.python.ops.clustering_ops import *
from tensorflow.contrib.factorization.python.ops.factorization_ops import *
from tensorflow.contrib.factorization.python.ops.gmm import *
from tensorflow.contrib.factorization.python.ops.gmm_ops import *
+from tensorflow.contrib.factorization.python.ops.wals import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
new file mode 100644
index 0000000000..3fd2cbbec2
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -0,0 +1,223 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Weighted Alternating Least Squares (WALS) on the tf.learn API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.framework.python.ops import variables as framework_variables
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import session_run_hook
+
+
+class _SweepHook(session_run_hook.SessionRunHook):
+ """Keeps track of row/col sweeps, and runs prep ops before each sweep."""
+
+ def __init__(self,
+ is_row_sweep_var,
+ train_op,
+ num_rows,
+ num_cols,
+ processed_row_indices,
+ processed_col_indices,
+ row_prep_ops,
+ col_prep_ops,
+ cache_init_ops):
+ """Initializes SweepHook.
+
+ Args:
+ is_row_sweep_var: A Boolean tf.Variable, determines whether we are
+ currently doing a row or column sweep. It is updated by the hook.
+ train_op: An op. All the ops created by the hook will have
+ control_dependencies on train_op.
+ num_rows: int, the total number of rows to be processed.
+ num_cols: int, the total number of columns to be processed.
+ processed_row_indices: A Tensor of type int64. The indices of the input
+ rows that are processed during the current sweep. All elements of
+ processed_row_indices must be in [0, num_rows).
+ processed_col_indices: A Tensor of type int64. The indices of the input
+ columns that are processed during the current sweep. All elements of
+ processed_col_indices must be in [0, num_cols).
+ row_prep_ops: list of ops, to be run before the beginning of each row
+ sweep, in the given order.
+ col_prep_ops: list of ops, to be run before the beginning of each column
+ sweep, in the given order.
+ cache_init_ops: list of ops, to be run once before training, in the given
+ order. These are typically local initialization ops (such as cache
+ initialization).
+ """
+ self._num_rows = num_rows
+ self._num_cols = num_cols
+ self._row_prep_ops = row_prep_ops
+ self._col_prep_ops = col_prep_ops
+ self._cache_init_ops = cache_init_ops
+ self._is_row_sweep_var = is_row_sweep_var
+ # Boolean variable that determines whether the cache_init_ops have been run.
+ self._is_initialized = False
+ # Boolean variable that is set to True when a sweep is completed.
+ # Used to run the prep_ops at the beginning of a sweep, in before_run().
+ self._is_sweep_done = False
+ # Ops to run jointly with train_op, responsible for updating
+ # _is_row_sweep_var and incrementing the global_step counter. They have
+ # control_dependencies on train_op.
+ self._fetches = self._create_switch_ops(processed_row_indices,
+ processed_col_indices,
+ train_op)
+
+ def _create_switch_ops(self,
+ processed_row_indices,
+ processed_col_indices,
+ train_op):
+ """Creates ops to update is_row_sweep_var and to increment global_step.
+
+ Creates two boolean tensors processed_rows and processed_cols, which keep
+ track of which rows/cols have been processed during the current sweep.
+ Returns ops that should be run after each row / col update.
+ - When is_row_sweep_var is True, it sets
+ processed_rows[processed_row_indices] to True.
+ - When is_row_sweep_var is False, it sets
+ processed_cols[processed_col_indices] to True .
+ When all rows or all cols have been processed, negates is_row_sweep_var and
+ resets processed_rows and processed_cols to False.
+ All of the ops created by this function have control_dependencies on
+ train_op.
+
+ Args:
+ processed_row_indices: A Tensor. The indices of the input rows that are
+ processed during the current sweep.
+ processed_col_indices: A Tensor. The indices of the input columns that
+ are processed during the current sweep.
+ train_op: An op. All the ops created by this function have
+ control_dependencies on train_op.
+ Returns:
+ A list consisting of:
+ is_sweep_done: A Boolean tensor, determines whether the sweep is done,
+ i.e. all rows (during a row sweep) or all columns (during a column
+ sweep) have been processed.
+ switch_ops: An op that updates is_row_sweep_var when is_sweep_done is
+ True. Has control_dependencies on train_op.
+ global_step_incr_op: An op that increments the global_step counter. Has
+ control_dependenciens on switch_ops.
+ """
+ processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False)
+ with ops.colocate_with(processed_rows_init):
+ processed_rows = variables.Variable(
+ processed_rows_init,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False,
+ name="sweep_hook_processed_rows")
+ processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False)
+ with ops.colocate_with(processed_cols_init):
+ processed_cols = variables.Variable(
+ processed_cols_init,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES],
+ trainable=False,
+ name="sweep_hook_processed_cols")
+ # After running the train_op, update processed_rows or processed_cols
+ # tensors, depending on whether we are currently doing a row or a col sweep
+ with ops.control_dependencies([train_op]):
+ def get_row_update_op():
+ with ops.colocate_with(processed_rows):
+ return state_ops.scatter_update(
+ processed_rows, processed_row_indices,
+ array_ops.ones_like(processed_row_indices, dtype=dtypes.bool))
+
+ def get_col_update_op():
+ with ops.colocate_with(processed_cols):
+ return state_ops.scatter_update(
+ processed_cols, processed_col_indices,
+ array_ops.ones_like(processed_col_indices, dtype=dtypes.bool))
+
+ update_processed_op = control_flow_ops.cond(
+ self._is_row_sweep_var, get_row_update_op, get_col_update_op)
+
+ # After update_processed_op, check whether we have completed a sweep.
+ # If this is the case, flip the is_row_sweep_var and reset processed_rows
+ # and processed_cols tensors.
+ with ops.control_dependencies([update_processed_op]):
+ def get_switch_op():
+ return state_ops.assign(
+ self._is_row_sweep_var,
+ gen_math_ops.logical_not(self._is_row_sweep_var)).op
+
+ def get_reset_op():
+ return control_flow_ops.group(
+ state_ops.assign(processed_rows, processed_rows_init).op,
+ state_ops.assign(processed_cols, processed_cols_init).op)
+
+ is_sweep_done = control_flow_ops.cond(
+ self._is_row_sweep_var,
+ lambda: math_ops.reduce_all(processed_rows),
+ lambda: math_ops.reduce_all(processed_cols),
+ name="sweep_hook_is_sweep_done")
+ switch_op = control_flow_ops.cond(
+ is_sweep_done, get_switch_op, control_flow_ops.no_op,
+ name="sweep_hook_switch_op")
+ reset_op = control_flow_ops.cond(
+ is_sweep_done, get_reset_op, control_flow_ops.no_op,
+ name="sweep_hook_reset_op")
+ switch_ops = control_flow_ops.group(switch_op, reset_op,
+ name="sweep_hook_switch_ops")
+
+ # Op to increment the global step
+ global_step = framework_variables.get_global_step()
+ with ops.control_dependencies([switch_ops]):
+ if global_step is not None:
+ global_step_incr_op = state_ops.assign_add(
+ global_step, 1, name="global_step_incr").op
+ else:
+ global_step_incr_op = control_flow_ops.no_op(
+ name="global_step_incr")
+
+ return [is_sweep_done, switch_ops, global_step_incr_op]
+
+ def begin(self):
+ pass
+
+ def before_run(self, run_context):
+ """Runs the appropriate prep ops, and requests running update ops."""
+ # Run the appropriate cache_init and prep ops
+ sess = run_context.session
+ if not self._is_initialized:
+ logging.info("SweepHook running cache init ops.")
+ for init_op in self._cache_init_ops:
+ sess.run(init_op)
+
+ if self._is_sweep_done or not self._is_initialized:
+ logging.info("SweepHook running sweep prep ops.")
+ row_sweep = sess.run(self._is_row_sweep_var)
+ prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops
+ for prep_op in prep_ops:
+ sess.run(prep_op)
+
+ self._is_initialized = True
+
+ # Request running the switch_ops and the global_step_incr_op
+ logging.info("Partial fit starting.")
+ return session_run_hook.SessionRunArgs(fetches=self._fetches)
+
+ def after_run(self, run_context, run_values):
+ self._is_sweep_done = run_values.results[0]
+ logging.info("Partial fit done.")
+
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
new file mode 100644
index 0000000000..2ae2d3ab05
--- /dev/null
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for WALSMatrixFactorization."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.factorization.python.ops import wals as wals_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import session_run_hook
+
+
+class SweepHookTest(test.TestCase):
+
+ def setUp(self):
+ self._num_rows = 5
+ self._num_cols = 7
+ self._train_op = control_flow_ops.no_op()
+ self._row_prep_done = variables.Variable(False)
+ self._col_prep_done = variables.Variable(False)
+ self._init_done = variables.Variable(False)
+ self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)]
+ self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)]
+ self._init_ops = [state_ops.assign(self._init_done, True)]
+ self._input_row_indices_ph = array_ops.placeholder(dtypes.int64)
+ self._input_col_indices_ph = array_ops.placeholder(dtypes.int64)
+
+ def run_hook_with_indices(self, sweep_hook, row_indices, col_indices):
+ with self.test_session() as sess:
+ # Before run
+ run_context = session_run_hook.SessionRunContext(
+ original_args=None, session=sess)
+ sess_run_args = sweep_hook.before_run(run_context)
+ feed_dict = {
+ self._input_row_indices_ph: row_indices,
+ self._input_col_indices_ph: col_indices
+ }
+ # Run
+ run_results = sess.run(sess_run_args.fetches, feed_dict=feed_dict)
+ run_values = session_run_hook.SessionRunValues(
+ results=run_results, options=None, run_metadata=None)
+ # After run
+ sweep_hook.after_run(run_context, run_values)
+
+ def test_row_sweep(self):
+ with self.test_session() as sess:
+ is_row_sweep_var = variables.Variable(True)
+ sweep_hook = wals_lib._SweepHook(
+ is_row_sweep_var,
+ self._train_op,
+ self._num_rows,
+ self._num_cols,
+ self._input_row_indices_ph,
+ self._input_col_indices_ph,
+ self._row_prep_ops,
+ self._col_prep_ops,
+ self._init_ops)
+
+ # Initialize variables
+ sess.run([variables.global_variables_initializer()])
+ # Row sweep
+ self.run_hook_with_indices(sweep_hook, [], [])
+ self.assertTrue(sess.run(self._init_done),
+ msg='init ops not run by the sweep_hook')
+ self.assertTrue(sess.run(self._row_prep_done),
+ msg='row_prep not run by the sweep_hook')
+ self.run_hook_with_indices(sweep_hook, [0, 1, 2], [])
+ self.assertTrue(sess.run(is_row_sweep_var),
+ msg='Row sweep is not complete but is_row_sweep is '
+ 'False.')
+ self.run_hook_with_indices(sweep_hook, [3, 4], [0, 1, 2, 3, 4, 5, 6])
+ self.assertFalse(sess.run(is_row_sweep_var),
+ msg='Row sweep is complete but is_row_sweep is True.')
+ self.assertTrue(sweep_hook._is_sweep_done,
+ msg='Sweep is complete but is_sweep_done is False.')
+
+ def test_col_sweep(self):
+ with self.test_session() as sess:
+ is_row_sweep_var = variables.Variable(False)
+ sweep_hook = wals_lib._SweepHook(
+ is_row_sweep_var,
+ self._train_op,
+ self._num_rows,
+ self._num_cols,
+ self._input_row_indices_ph,
+ self._input_col_indices_ph,
+ self._row_prep_ops,
+ self._col_prep_ops,
+ self._init_ops)
+
+ # Initialize variables
+ sess.run([variables.global_variables_initializer()])
+ # Col sweep
+ self.run_hook_with_indices(sweep_hook, [], [])
+ self.assertTrue(sess.run(self._col_prep_done),
+ msg='col_prep not run by the sweep_hook')
+ self.run_hook_with_indices(sweep_hook, [], [0, 1, 2, 3, 4])
+ self.assertFalse(sess.run(is_row_sweep_var),
+ msg='Col sweep is not complete but is_row_sweep is '
+ 'True.')
+ self.assertFalse(sweep_hook._is_sweep_done,
+ msg='Sweep is not complete but is_sweep_done is True.')
+ self.run_hook_with_indices(sweep_hook, [], [4, 5, 6])
+ self.assertTrue(sess.run(is_row_sweep_var),
+ msg='Col sweep is complete but is_row_sweep is False')
+ self.assertTrue(sweep_hook._is_sweep_done,
+ msg='Sweep is complete but is_sweep_done is False.')
+
+
+if __name__ == '__main__':
+ test.main()