aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
new file mode 100644
index 0000000000..ad0724931e
--- /dev/null
+++ b/tensorflow/python/kernel_tests/dense_update_ops_no_tsan_test.py
@@ -0,0 +1,60 @@
+"""Tests for state updating ops that may have benign race conditions."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class AssignOpTest(tf.test.TestCase):
+
+ # NOTE(mrry): We exclude thess tests from the TSAN TAP target, because they
+ # contain benign and deliberate data races when multiple threads update
+ # the same parameters without a lock.
+ def testParallelUpdateWithoutLocking(self):
+ with self.test_session() as sess:
+ ones_t = tf.fill([1024, 1024], 1.0)
+ p = tf.Variable(tf.zeros([1024, 1024]))
+ adds = [tf.assign_add(p, ones_t, use_locking=False)
+ for _ in range(20)]
+ tf.initialize_all_variables().run()
+
+ def run_add(add_op):
+ sess.run(add_op)
+ threads = [self.checkedThread(target=run_add, args=(add_op,))
+ for add_op in adds]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+ ones = np.ones((1024, 1024)).astype(np.float32)
+ self.assertTrue((vals >= ones).all())
+ self.assertTrue((vals <= ones * 20).all())
+
+ def testParallelAssignWithoutLocking(self):
+ with self.test_session() as sess:
+ ones_t = tf.fill([1024, 1024], float(1))
+ p = tf.Variable(tf.zeros([1024, 1024]))
+ assigns = [tf.assign(p, tf.mul(ones_t, float(i)), False)
+ for i in range(1, 21)]
+ tf.initialize_all_variables().run()
+
+ def run_assign(assign_op):
+ sess.run(assign_op)
+ threads = [self.checkedThread(target=run_assign, args=(assign_op,))
+ for assign_op in assigns]
+ for t in threads:
+ t.start()
+ for t in threads:
+ t.join()
+
+ vals = p.eval()
+
+ # Assert every element is taken from one of the assignments.
+ self.assertTrue((vals > 0).all())
+ self.assertTrue((vals <= 20).all())
+
+
+if __name__ == "__main__":
+ tf.test.main()