diff options
author | 2016-05-10 15:04:34 -0800 | |
---|---|---|
committer | 2016-05-10 16:12:01 -0700 | |
commit | c4520616b388dbb97427a936ee35e26d006248ec (patch) | |
tree | 9fcb0376926d704eab979707e502cb1497d0c4f9 /tensorflow/python/training/coordinator.py | |
parent | 39e7b4bb6b34c949ea2df825dfb18a694d8b41f4 (diff) |
Add optional kwargs argument to LooperThread and Supervisor.loop().
This is more in line with the Python threading.Thread() paradigm that allows
passing both 'args' and 'kwargs'.
Add tests for LooperThread.
Change: 121998795
Diffstat (limited to 'tensorflow/python/training/coordinator.py')
-rw-r--r-- | tensorflow/python/training/coordinator.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index b8a83af985..4b8e4b638f 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -343,7 +343,8 @@ class LooperThread(threading.Thread): You typically pass looper threads to the supervisor `Join()` method. """ - def __init__(self, coord, timer_interval_secs, target=None, args=None): + def __init__(self, coord, timer_interval_secs, target=None, args=None, + kwargs=None): """Create a LooperThread. Args: @@ -352,6 +353,7 @@ class LooperThread(threading.Thread): if it should be called back to back. target: Optional callable object that will be executed in the thread. args: Optional arguments to pass to `target` when calling it. + kwargs: Optional keyword arguments to pass to `target` when calling it. Raises: ValueError: If one of the arguments is invalid. @@ -364,15 +366,14 @@ class LooperThread(threading.Thread): self._timer_interval_secs = timer_interval_secs self._target = target if self._target: - if args is None: - self._args = () - else: - self._args = args - elif args: - raise ValueError("'args' argument require that you also pass 'target'") + self._args = args or () + self._kwargs = kwargs or {} + elif args or kwargs: + raise ValueError("'args' and 'kwargs' argument require that you also " + "pass 'target'") @staticmethod - def loop(coord, timer_interval_secs, target, args=None): + def loop(coord, timer_interval_secs, target, args=None, kwargs=None): """Start a LooperThread that calls a function periodically. If `timer_interval_secs` is None the thread calls `target(args)` @@ -385,11 +386,13 @@ class LooperThread(threading.Thread): timer_interval_secs: Number. Time boundaries at which to call `target`. target: A callable object. args: Optional arguments to pass to `target` when calling it. + kwargs: Optional keyword arguments to pass to `target` when calling it. Returns: The started thread. """ - looper = LooperThread(coord, timer_interval_secs, target=target, args=args) + looper = LooperThread(coord, timer_interval_secs, target=target, args=args, + kwargs=kwargs) looper.start() return looper @@ -419,4 +422,4 @@ class LooperThread(threading.Thread): def run_loop(self): """Called at 'timer_interval_secs' boundaries.""" if self._target: - self._target(*self._args) + self._target(*self._args, **self._kwargs) |