aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/backend.py
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-06-12 15:57:06 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-06-12 16:11:14 +0800
commita5840964e0eb422fbe73dd3738c8d14c1147276f (patch)
tree1eb723dc968ffa07cd987000884b451f9f24a140 /tensorflow/python/keras/backend.py
parentffc5c4e845d8bfb36e6c56d904cba3bd8e1de94e (diff)
ENH: support run options for function
Diffstat (limited to 'tensorflow/python/keras/backend.py')
-rw-r--r--tensorflow/python/keras/backend.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 2a4a1c861c..8abdd5238a 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -2759,7 +2759,8 @@ class Function(object):
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: A name to help users identify what this function does.
- session_kwargs: Arguments to `tf.Session.run()`: `fetches`, `feed_dict`.
+ session_kwargs: Arguments to `tf.Session.run()`:
+ `fetches`, `feed_dict`, `options`.
"""
def __init__(self, inputs, outputs, updates=None, name=None,
@@ -2793,6 +2794,7 @@ class Function(object):
self.fetches = session_kwargs.pop('fetches', [])
if not isinstance(self.fetches, list):
self.fetches = [self.fetches]
+ self.run_options = session_kwargs.pop('options', None)
# The main use case of `fetches` being passed to a model is the ability
# to run custom updates (since the outputs of fetches are never returned).
# This requires us to wrap fetches in `identity` ops.
@@ -2844,6 +2846,9 @@ class Function(object):
callable_opts.fetch.append(x.name)
# Handle updates.
callable_opts.target.append(self.updates_op.name)
+ # Handle run_options.
+ if self.run_options:
+ callable_opts.run_options.CopyFrom(self.run_options)
# Create callable.
callable_fn = session._make_callable_from_options(callable_opts)
# Cache parameters corresponding to the generated callable, so that