aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python/ops/wals_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-09 17:31:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-09 17:38:43 -0700
commit07d78ddeafe41bc0363ac92efd7ca8ea60478989 (patch)
tree31b41c3b2acc121e570a948e03967a8f94a528d9 /tensorflow/contrib/factorization/python/ops/wals_test.py
parent485cb179ea84c8de26263628510f930d07a98c4a (diff)
Removes the use of tf.cond in the SweepHook used in the WALSMatrixFactorization estimator, to prevent a rare but possible race condition.
PiperOrigin-RevId: 171612114
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/wals_test.py')
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py14
1 files changed, 5 insertions, 9 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index b5c1bb1151..8bd72b7025 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
self.assertNear(
loss, true_loss, err=.001,
- msg="""After row update, eval loss = {}, does not match the true
+ msg="""After col update, eval loss = {}, does not match the true
loss = {}.""".format(loss, true_loss))
@@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase):
completed_sweeps_var = variables.Variable(0)
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
- self._train_op,
+ [self._train_op],
self._num_rows,
self._num_cols,
self._input_row_indices_ph,
@@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase):
'False.')
# Row sweep completed.
mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
- self.assertFalse(sess.run(is_row_sweep_var),
- msg='Row sweep is complete but is_row_sweep is True.')
self.assertTrue(sess.run(completed_sweeps_var) == 1,
msg='Completed sweeps should be equal to 1.')
- self.assertTrue(sweep_hook._is_sweep_done,
+ self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
# Col init ops should run. Col sweep not completed.
mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
@@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase):
self.assertFalse(sess.run(is_row_sweep_var),
msg='Col sweep is not complete but is_row_sweep is '
'True.')
- self.assertFalse(sweep_hook._is_sweep_done,
+ self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is not complete but is_sweep_done is True.')
# Col sweep completed.
mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
- self.assertTrue(sess.run(is_row_sweep_var),
- msg='Col sweep is complete but is_row_sweep is False')
- self.assertTrue(sweep_hook._is_sweep_done,
+ self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
msg='Sweep is complete but is_sweep_done is False.')
self.assertTrue(sess.run(completed_sweeps_var) == 2,
msg='Completed sweeps should be equal to 2.')