diff options
author | 2018-06-07 18:21:25 -0700 | |
---|---|---|
committer | 2018-06-07 18:25:53 -0700 | |
commit | 2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (patch) | |
tree | e08c7194092649b7a092c7efaaa04e7df8fd4942 /tensorflow/contrib/batching | |
parent | 7b9c723c8f5f732f014ba181daf0b96747f291a9 (diff) |
Changes the batch_function decorator implementation to use the newly added BatchFunction op.
o Renames the previous version to batch_function_v1.
PiperOrigin-RevId: 199729701
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r-- | tensorflow/contrib/batching/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/batching/python/ops/batch_ops.py | 69 | ||||
-rw-r--r-- | tensorflow/contrib/batching/python/ops/batch_ops_test.py | 50 |
3 files changed, 120 insertions, 0 deletions
diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a7..1e503a097a 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4..012a51f711 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -102,6 +103,74 @@ def batch_function(num_batch_threads, Returns: The decorated function will return the unbatched computation output Tensors. """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + for inp in computation.captured_inputs: + print("inp: %s" % inp) + for op in inp.consumers(): + print("op: %s" % op) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ def decorator(f): # pylint: disable=missing-docstring def decorated(*args): with ops.name_scope("batch") as name: diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index ea8339334f..7846814546 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -188,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] |