aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/coordinator.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-10 15:04:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-10 16:12:01 -0700
commitc4520616b388dbb97427a936ee35e26d006248ec (patch)
tree9fcb0376926d704eab979707e502cb1497d0c4f9 /tensorflow/python/training/coordinator.py
parent39e7b4bb6b34c949ea2df825dfb18a694d8b41f4 (diff)
Add optional kwargs argument to LooperThread and Supervisor.loop().
This is more in line with the Python threading.Thread() paradigm that allows passing both 'args' and 'kwargs'. Add tests for LooperThread. Change: 121998795
Diffstat (limited to 'tensorflow/python/training/coordinator.py')
-rw-r--r--tensorflow/python/training/coordinator.py23
1 files changed, 13 insertions, 10 deletions
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py
index b8a83af985..4b8e4b638f 100644
--- a/tensorflow/python/training/coordinator.py
+++ b/tensorflow/python/training/coordinator.py
@@ -343,7 +343,8 @@ class LooperThread(threading.Thread):
You typically pass looper threads to the supervisor `Join()` method.
"""
- def __init__(self, coord, timer_interval_secs, target=None, args=None):
+ def __init__(self, coord, timer_interval_secs, target=None, args=None,
+ kwargs=None):
"""Create a LooperThread.
Args:
@@ -352,6 +353,7 @@ class LooperThread(threading.Thread):
if it should be called back to back.
target: Optional callable object that will be executed in the thread.
args: Optional arguments to pass to `target` when calling it.
+ kwargs: Optional keyword arguments to pass to `target` when calling it.
Raises:
ValueError: If one of the arguments is invalid.
@@ -364,15 +366,14 @@ class LooperThread(threading.Thread):
self._timer_interval_secs = timer_interval_secs
self._target = target
if self._target:
- if args is None:
- self._args = ()
- else:
- self._args = args
- elif args:
- raise ValueError("'args' argument require that you also pass 'target'")
+ self._args = args or ()
+ self._kwargs = kwargs or {}
+ elif args or kwargs:
+ raise ValueError("'args' and 'kwargs' argument require that you also "
+ "pass 'target'")
@staticmethod
- def loop(coord, timer_interval_secs, target, args=None):
+ def loop(coord, timer_interval_secs, target, args=None, kwargs=None):
"""Start a LooperThread that calls a function periodically.
If `timer_interval_secs` is None the thread calls `target(args)`
@@ -385,11 +386,13 @@ class LooperThread(threading.Thread):
timer_interval_secs: Number. Time boundaries at which to call `target`.
target: A callable object.
args: Optional arguments to pass to `target` when calling it.
+ kwargs: Optional keyword arguments to pass to `target` when calling it.
Returns:
The started thread.
"""
- looper = LooperThread(coord, timer_interval_secs, target=target, args=args)
+ looper = LooperThread(coord, timer_interval_secs, target=target, args=args,
+ kwargs=kwargs)
looper.start()
return looper
@@ -419,4 +422,4 @@ class LooperThread(threading.Thread):
def run_loop(self):
"""Called at 'timer_interval_secs' boundaries."""
if self._target:
- self._target(*self._args)
+ self._target(*self._args, **self._kwargs)