aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/fifo_queue_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tests/fifo_queue_test.py')
-rw-r--r--tensorflow/compiler/tests/fifo_queue_test.py201
1 files changed, 201 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/fifo_queue_test.py b/tensorflow/compiler/tests/fifo_queue_test.py
new file mode 100644
index 0000000000..0f64cc87cd
--- /dev/null
+++ b/tensorflow/compiler/tests/fifo_queue_test.py
@@ -0,0 +1,201 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import test
+
+
+class FIFOQueueTest(xla_test.XLATestCase):
+
+ def testEnqueue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ enqueue_op.run()
+
+ def testEnqueueWithShape(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
+ enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
+ enqueue_correct_op.run()
+ with self.assertRaises(ValueError):
+ q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
+ self.assertEqual(1, q.size().eval())
+
+ def testMultipleDequeues(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue([1]))
+ self.evaluate(q.enqueue([2]))
+ self.evaluate(q.enqueue([3]))
+ a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
+ self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))
+
+ def testQueuesDontShare(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q.enqueue(1))
+ q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
+ self.evaluate(q2.enqueue(2))
+ self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
+ self.assertAllEqual(self.evaluate(q.dequeue()), 1)
+
+ def testEnqueueDictWithoutNames(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ with self.assertRaisesRegexp(ValueError, "must have names"):
+ q.enqueue({"a": 12.0})
+
+ def testParallelEnqueue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Run one producer thread for each element in elems.
+ def enqueue(enqueue_op):
+ sess.run(enqueue_op)
+
+ threads = [
+ self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
+ ]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ # Dequeue every element using a single thread.
+ results = []
+ for _ in xrange(len(elems)):
+ results.append(dequeued_t.eval())
+ self.assertItemsEqual(elems, results)
+
+ def testParallelDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ # Enqueue every element using a single thread.
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ # Run one consumer thread for each element in elems.
+ results = []
+
+ def dequeue():
+ results.append(sess.run(dequeued_t))
+
+ threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ self.assertItemsEqual(elems, results)
+
+ def testDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ vals = dequeued_t.eval()
+ self.assertEqual([elems[i]], vals)
+
+ def testEnqueueAndBlockingDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
+ elems = [10.0, 20.0, 30.0]
+ enqueue_ops = [q.enqueue((x,)) for x in elems]
+ dequeued_t = q.dequeue()
+
+ def enqueue():
+ # The enqueue_ops should run after the dequeue op has blocked.
+ # TODO(mrry): Figure out how to do this without sleeping.
+ time.sleep(0.1)
+ for enqueue_op in enqueue_ops:
+ sess.run(enqueue_op)
+
+ results = []
+
+ def dequeue():
+ for _ in xrange(len(elems)):
+ results.append(sess.run(dequeued_t))
+
+ enqueue_thread = self.checkedThread(target=enqueue)
+ dequeue_thread = self.checkedThread(target=dequeue)
+ enqueue_thread.start()
+ dequeue_thread.start()
+ enqueue_thread.join()
+ dequeue_thread.join()
+
+ for elem, result in zip(elems, results):
+ self.assertEqual([elem], result)
+
+ def testMultiEnqueueAndDequeue(self):
+ with self.test_session() as sess, self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
+ elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
+ enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
+ dequeued_t = q.dequeue()
+
+ for enqueue_op in enqueue_ops:
+ enqueue_op.run()
+
+ for i in xrange(len(elems)):
+ x_val, y_val = sess.run(dequeued_t)
+ x, y = elems[i]
+ self.assertEqual([x], x_val)
+ self.assertEqual([y], y_val)
+
+ def testQueueSizeEmpty(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ self.assertEqual([0], q.size().eval())
+
+ def testQueueSizeAfterEnqueueAndDequeue(self):
+ with self.test_session(), self.test_scope():
+ q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
+ enqueue_op = q.enqueue((10.0,))
+ dequeued_t = q.dequeue()
+ size = q.size()
+ self.assertEqual([], size.get_shape())
+
+ enqueue_op.run()
+ self.assertEqual(1, size.eval())
+ dequeued_t.op.run()
+ self.assertEqual(0, size.eval())
+
+
+if __name__ == "__main__":
+ test.main()