aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug/wrappers/framework.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/debug/wrappers/framework.py')
-rw-r--r--tensorflow/python/debug/wrappers/framework.py87
1 files changed, 68 insertions, 19 deletions
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py
index c530204bbf..b9524ce649 100644
--- a/tensorflow/python/debug/wrappers/framework.py
+++ b/tensorflow/python/debug/wrappers/framework.py
@@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface):
self._default_session_context_manager = None
+ # A cache for callables created from CallableOptions.
+ self._cached_callables_from_options = dict()
+
@property
def graph(self):
return self._sess.graph
@@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface):
options=None,
run_metadata=None,
callable_runner=None,
- callable_runner_args=None):
+ callable_runner_args=None,
+ callable_options=None):
"""Wrapper around Session.run() that inserts tensor watch options.
Args:
@@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata: Same as the `run_metadata` arg to regular `Session.run()`.
callable_runner: A `callable` returned by `Session.make_callable()`.
If not `None`, `fetches` and `feed_dict` must both be `None`.
- callable_runner_args: An optional list of arguments to `callable_runner`.
+ Mutually exclusive with `callable_options`.
+ callable_runner_args: An optional list of arguments to `callable_runner`
+ or for `callable_options`.
+ callable_options: An instance of `config_pb2.CallableOptions`, to be
+ used with `Session._make_callable_from_options()`. Mutually exclusive
+ with `callable_runner`.
Returns:
Simply forwards the output of the wrapped `Session.run()` call.
@@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface):
ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner`
is not `None` and either or both of `fetches` and `feed_dict` is `None`.
"""
- if not callable_runner:
+ if callable_runner and callable_options:
+ raise ValueError(
+ "callable_runner and callable_options are mutually exclusive, but "
+ "are both specified in this call to BaseDebugWrapperSession.run().")
+
+ if not (callable_runner or callable_options):
self.increment_run_call_count()
- else:
- if fetches or feed_dict:
- raise ValueError(
- "callable_runner and fetches/feed_dict are mutually exclusive, but "
- "are used simultaneously.")
+ elif callable_runner and (fetches or feed_dict):
+ raise ValueError(
+ "callable_runner and fetches/feed_dict are mutually exclusive, "
+ "but are used simultaneously.")
empty_fetches = not nest.flatten(fetches)
if empty_fetches:
@@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface):
if self._is_disabled_thread() or empty_fetches:
if callable_runner:
return callable_runner(*callable_runner_args)
+ elif callable_options:
+ # pylint:disable=protected-access
+ return self._sess._make_callable_from_options(
+ callable_options)(*callable_runner_args)
+ # pylint:enable=protected-access
else:
return self._sess.run(fetches,
feed_dict=feed_dict,
@@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface):
if run_start_resp.action == OnRunStartAction.DEBUG_RUN:
# Decorate RunOption to fill in debugger tensor watch specifications.
- decorated_run_options = options or config_pb2.RunOptions()
+ decorated_run_options = None
+ if callable_options:
+ callable_options_id = id(callable_options)
+ if callable_options_id not in self._cached_callables_from_options:
+ # Make a copy of callable_options to avoid mutating it.
+ new_callable_options = config_pb2.CallableOptions()
+ new_callable_options.CopyFrom(callable_options)
+ decorated_run_options = new_callable_options.run_options
+ else:
+ decorated_run_options = options or config_pb2.RunOptions()
+
run_metadata = run_metadata or config_pb2.RunMetadata()
- self._decorate_run_options_for_debug(
- decorated_run_options,
- run_start_resp.debug_urls,
- debug_ops=run_start_resp.debug_ops,
- node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
- op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
- tensor_dtype_regex_whitelist=(
- run_start_resp.tensor_dtype_regex_whitelist),
- tolerate_debug_op_creation_failures=(
- run_start_resp.tolerate_debug_op_creation_failures))
+ if decorated_run_options:
+ self._decorate_run_options_for_debug(
+ decorated_run_options,
+ run_start_resp.debug_urls,
+ debug_ops=run_start_resp.debug_ops,
+ node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist,
+ op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist,
+ tensor_dtype_regex_whitelist=(
+ run_start_resp.tensor_dtype_regex_whitelist),
+ tolerate_debug_op_creation_failures=(
+ run_start_resp.tolerate_debug_op_creation_failures))
# Invoke the run() method of the wrapped Session. Catch any TensorFlow
# runtime errors.
@@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface):
retvals = callable_runner(*callable_runner_args,
options=decorated_run_options,
run_metadata=run_metadata)
+ elif callable_options:
+ # pylint:disable=protected-access
+ if callable_options_id in self._cached_callables_from_options:
+ callable_object = self._cached_callables_from_options[
+ callable_options_id]
+ else:
+ callable_object = self._sess._make_callable_from_options(
+ new_callable_options)
+ self._cached_callables_from_options[
+ callable_options_id] = callable_object
+ # pylint:enable=protected-access
+ retvals = callable_object(
+ *callable_runner_args, run_metadata=run_metadata)
else:
retvals = self._sess.run(fetches,
feed_dict=feed_dict,
@@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface):
run_metadata=kwargs.get("run_metadata", None),
callable_runner=runner,
callable_runner_args=runner_args)
+ return wrapped_runner
+ def _make_callable_from_options(self, callable_options):
+ def wrapped_runner(*feed_values, **kwargs):
+ return self.run(None,
+ run_metadata=kwargs.get("run_metadata", None),
+ callable_options=callable_options,
+ callable_runner_args=feed_values)
return wrapped_runner
@property