aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Vinu Rajashekhar <vinuraja@google.com>2018-06-07 18:21:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 18:25:53 -0700
commit2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (patch)
treee08c7194092649b7a092c7efaaa04e7df8fd4942 /tensorflow/contrib/batching
parent7b9c723c8f5f732f014ba181daf0b96747f291a9 (diff)
Changes the batch_function decorator implementation to use the newly added BatchFunction op.
o Renames the previous version to batch_function_v1. PiperOrigin-RevId: 199729701
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/__init__.py1
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py69
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops_test.py50
3 files changed, 120 insertions, 0 deletions
diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py
index 44fa5f42a7..1e503a097a 100644
--- a/tensorflow/contrib/batching/__init__.py
+++ b/tensorflow/contrib/batching/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Ops and modules related to batch.
+@@batch_function_v1
@@batch_function
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index 921d6917a4..012a51f711 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
@@ -102,6 +103,74 @@ def batch_function(num_batch_threads,
Returns:
The decorated function will return the unbatched computation output Tensors.
"""
+
+ def decorator(fn): # pylint: disable=missing-docstring
+
+ def decorated(*args): # pylint: disable=missing-docstring
+ types = [arg.dtype for arg in args]
+
+ @function.Defun(*types)
+ def computation(*computation_args):
+ return fn(*computation_args)
+
+ with ops.name_scope("batch") as name:
+ for a in args:
+ if not isinstance(a, ops.Tensor):
+ raise ValueError("All arguments to functions decorated with "
+ "`batch_function` are supposed to be Tensors; "
+ "found %s" % repr(a))
+ for inp in computation.captured_inputs:
+ print("inp: %s" % inp)
+ for op in inp.consumers():
+ print("op: %s" % op)
+ return gen_batch_ops.batch_function(
+ num_batch_threads=num_batch_threads,
+ max_batch_size=max_batch_size,
+ batch_timeout_micros=batch_timeout_micros,
+ allowed_batch_sizes=allowed_batch_sizes,
+ max_enqueued_batches=max_enqueued_batches,
+ shared_name=name,
+ f=computation,
+ in_tensors=list(args),
+ captured_tensors=computation.captured_inputs,
+ Tout=[o.type for o in computation.definition.signature.output_arg])
+
+ return decorated
+
+ return decorator
+
+
+def batch_function_v1(num_batch_threads,
+ max_batch_size,
+ batch_timeout_micros,
+ allowed_batch_sizes=None,
+ grad_timeout_micros=60 * 1000 * 1000,
+ unbatch_timeout_micros=60 * 1000 * 1000,
+ max_enqueued_batches=10):
+ """Batches the computation done by the decorated function.
+
+ This is the older version of batch_function(). Please use the former instead
+ of this.
+
+ Args:
+ num_batch_threads: Number of scheduling threads for processing batches
+ of work. Determines the number of batches processed in parallel.
+ max_batch_size: Batch sizes will never be bigger than this.
+ batch_timeout_micros: Maximum number of microseconds to wait before
+ outputting an incomplete batch.
+ allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
+ does nothing. Otherwise, supplies a list of batch sizes, causing the op
+ to pad batches up to one of those sizes. The entries must increase
+ monotonically, and the final entry must equal max_batch_size.
+ grad_timeout_micros: The timeout to use for the gradient. See the
+ documentation of the unbatch op for more details. Defaults to 60s.
+ unbatch_timeout_micros: The timeout to use for unbatching. See the
+ documentation of the unbatch op for more details. Defaults to 60s.
+ max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
+
+ Returns:
+ The decorated function will return the unbatched computation output Tensors.
+ """
def decorator(f): # pylint: disable=missing-docstring
def decorated(*args):
with ops.name_scope("batch") as name:
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
index ea8339334f..7846814546 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py
@@ -188,12 +188,62 @@ class BatchOpsTest(test.TestCase):
self.assertEqual(thread_results[0], [2])
self.assertEqual(main_results[0], [3])
+ def testBasicUnbatchV1Decorated(self):
+ """Tests that the batch_function_v1 decorator works."""
+ with self.test_session() as sess:
+ @batch_ops.batch_function_v1(1, 10, 100000)
+ def computation(in_t):
+ return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
def testBasicUnbatchDecorated(self):
"""Tests that the batch_function decorator works."""
with self.test_session() as sess:
+ # TODO(apassos): Removing this line causes test flakiness! Ideally should
+ # be investigated.
+ default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable
+
@batch_ops.batch_function(1, 10, 100000)
def computation(in_t):
return in_t + 1
+
+ inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
+ result = computation(inp)
+ thread_results = []
+
+ def worker():
+ thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
+
+ worker_thread = threading.Thread(target=worker)
+ worker_thread.start()
+ main_results = sess.run([result], feed_dict={inp: [2]})
+ worker_thread.join()
+ self.assertEqual(thread_results[0], [2])
+ self.assertEqual(main_results[0], [3])
+
+ def testBatchDecoratedWithCapturedInput(self):
+ """Tests that the batch_function decorator works."""
+ with self.test_session() as sess:
+ captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
+ captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
+
+ @batch_ops.batch_function(1, 10, 100000)
+ def computation(in_t):
+ return in_t + captured_inp0 - captured_inp1
+
inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
result = computation(inp)
thread_results = []