aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Vinu Rajashekhar <vinuraja@google.com>2018-06-01 15:44:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 15:47:13 -0700
commit25486ef05d59265b769684589b738636b3207cc7 (patch)
tree4ea6b51a76542aa43522ebbc798153aa0c369267 /tensorflow/contrib/batching
parent2d71691dad337c4e7a6b5dbf18fd0ab0e6bd7cf6 (diff)
Adds a batch-op implemented using TF functions.
o This has a couple of important advantages over the current implementation: 1. The existing batch-op waits for the batch to be created and then forwards the tensors to the rest of the graph, which causes a lot of batches to be created, because there is no way for the op to know if the other batches are being queued up. A mitigation, which we have seen working in practice, is to actually wait for the graph to finish processing the batch. So there is a sort of flow-control happening, and meanwhile the batches get coalesced, which improves latency and throughput as well. Using functions makes this kind of approach easier. 2. The existing op passes empty tensors around the graph to make the TF executor happy, which has sometimes worked not well with some Ops (like Reshape). Using functions means that we don't need to rely on this mechanism as well. PiperOrigin-RevId: 198937594
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index e22f978dde..68e8a88ca0 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -23,7 +23,9 @@ import time
from tensorflow.contrib.batching.python.ops import batch_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
@@ -205,6 +207,91 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBatchFunctionOp(self):
+ """Tests that the batch_func works."""
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.int32)
+ def computation(in_t):
+ return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = gen_batch_ops.batch_function(
+ [inp],
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000,
+ Tout=[dtypes.int32],
+ f=computation,
+ captured_tensors=computation.captured_inputs)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchFunctionOpWithCapturedInput(self):
+ """Tests that batch_func with timeout."""
+ with self.test_session() as sess:
+ captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
+ captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+
+ @function.Defun(dtypes.int32)
+ def computation(inp):
+ return inp + captured_inp0 - captured_inp1
+
+ result = gen_batch_ops.batch_function(
+ num_batch_threads=1,
+ max_batch_size=10,
+ batch_timeout_micros=100000, # 100ms
+ allowed_batch_sizes=[3, 10],
+ batching_queue="",
+ f=computation,
+ in_tensors=[inp],
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBasicUnbatchDecoratedWithReshape(self):
+ """Tests that the batch_function decorator works."""
+ with self.test_session() as sess:
+
+ @batch_ops.batch_function(1, 10, 100000)
+ def computation(in_t):
+ return array_ops.reshape(in_t, [-1]) + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [[2]]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
def testUnbatchTimeout(self):
"""Tests that the unbatch timeout works."""
with self.test_session() as sess: