diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-06-12 15:57:06 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-06-12 16:11:14 +0800 |
commit | a5840964e0eb422fbe73dd3738c8d14c1147276f (patch) | |
tree | 1eb723dc968ffa07cd987000884b451f9f24a140 /tensorflow/python/keras/backend.py | |
parent | ffc5c4e845d8bfb36e6c56d904cba3bd8e1de94e (diff) |
ENH: support run options for function
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r-- | tensorflow/python/keras/backend.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 2a4a1c861c..8abdd5238a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -2759,7 +2759,8 @@ class Function(object): outputs: Output tensors to fetch. updates: Additional update ops to be run at function call. name: A name to help users identify what this function does. - session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`. + session_kwargs: Arguments to `tf.Session.run()`: + `fetches`, `feed_dict`, `options`. """ def __init__(self, inputs, outputs, updates=None, name=None, @@ -2793,6 +2794,7 @@ class Function(object): self.fetches = session_kwargs.pop('fetches', []) if not isinstance(self.fetches, list): self.fetches = [self.fetches] + self.run_options = session_kwargs.pop('options', None) # The main use case of `fetches` being passed to a model is the ability # to run custom updates (since the outputs of fetches are never returned). # This requires us to wrap fetches in `identity` ops. @@ -2844,6 +2846,9 @@ class Function(object): callable_opts.fetch.append(x.name) # Handle updates. callable_opts.target.append(self.updates_op.name) + # Handle run_options. + if self.run_options: + callable_opts.run_options.CopyFrom(self.run_options) # Create callable. callable_fn = session._make_callable_from_options(callable_opts) # Cache parameters corresponding to the generated callable, so that |