diff options
Diffstat (limited to 'tensorflow/python/debug/wrappers/framework.py')
-rw-r--r-- | tensorflow/python/debug/wrappers/framework.py | 87 |
1 files changed, 68 insertions, 19 deletions
diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index c530204bbf..b9524ce649 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -392,6 +392,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._default_session_context_manager = None + # A cache for callables created from CallableOptions. + self._cached_callables_from_options = dict() + @property def graph(self): return self._sess.graph @@ -414,7 +417,8 @@ class BaseDebugWrapperSession(session.SessionInterface): options=None, run_metadata=None, callable_runner=None, - callable_runner_args=None): + callable_runner_args=None, + callable_options=None): """Wrapper around Session.run() that inserts tensor watch options. Args: @@ -424,7 +428,12 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata: Same as the `run_metadata` arg to regular `Session.run()`. callable_runner: A `callable` returned by `Session.make_callable()`. If not `None`, `fetches` and `feed_dict` must both be `None`. - callable_runner_args: An optional list of arguments to `callable_runner`. + Mutually exclusive with `callable_options`. + callable_runner_args: An optional list of arguments to `callable_runner` + or for `callable_options`. + callable_options: An instance of `config_pb2.CallableOptions`, to be + used with `Session._make_callable_from_options()`. Mutually exclusive + with `callable_runner`. Returns: Simply forwards the output of the wrapped `Session.run()` call. @@ -433,13 +442,17 @@ class BaseDebugWrapperSession(session.SessionInterface): ValueError: On invalid `OnRunStartAction` value. Or if `callable_runner` is not `None` and either or both of `fetches` and `feed_dict` is `None`. """ - if not callable_runner: + if callable_runner and callable_options: + raise ValueError( + "callable_runner and callable_options are mutually exclusive, but " + "are both specified in this call to BaseDebugWrapperSession.run().") + + if not (callable_runner or callable_options): self.increment_run_call_count() - else: - if fetches or feed_dict: - raise ValueError( - "callable_runner and fetches/feed_dict are mutually exclusive, but " - "are used simultaneously.") + elif callable_runner and (fetches or feed_dict): + raise ValueError( + "callable_runner and fetches/feed_dict are mutually exclusive, " + "but are used simultaneously.") empty_fetches = not nest.flatten(fetches) if empty_fetches: @@ -449,6 +462,11 @@ class BaseDebugWrapperSession(session.SessionInterface): if self._is_disabled_thread() or empty_fetches: if callable_runner: return callable_runner(*callable_runner_args) + elif callable_options: + # pylint:disable=protected-access + return self._sess._make_callable_from_options( + callable_options)(*callable_runner_args) + # pylint:enable=protected-access else: return self._sess.run(fetches, feed_dict=feed_dict, @@ -464,19 +482,30 @@ class BaseDebugWrapperSession(session.SessionInterface): if run_start_resp.action == OnRunStartAction.DEBUG_RUN: # Decorate RunOption to fill in debugger tensor watch specifications. - decorated_run_options = options or config_pb2.RunOptions() + decorated_run_options = None + if callable_options: + callable_options_id = id(callable_options) + if callable_options_id not in self._cached_callables_from_options: + # Make a copy of callable_options to avoid mutating it. + new_callable_options = config_pb2.CallableOptions() + new_callable_options.CopyFrom(callable_options) + decorated_run_options = new_callable_options.run_options + else: + decorated_run_options = options or config_pb2.RunOptions() + run_metadata = run_metadata or config_pb2.RunMetadata() - self._decorate_run_options_for_debug( - decorated_run_options, - run_start_resp.debug_urls, - debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), - tolerate_debug_op_creation_failures=( - run_start_resp.tolerate_debug_op_creation_failures)) + if decorated_run_options: + self._decorate_run_options_for_debug( + decorated_run_options, + run_start_resp.debug_urls, + debug_ops=run_start_resp.debug_ops, + node_name_regex_whitelist=run_start_resp.node_name_regex_whitelist, + op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, + tensor_dtype_regex_whitelist=( + run_start_resp.tensor_dtype_regex_whitelist), + tolerate_debug_op_creation_failures=( + run_start_resp.tolerate_debug_op_creation_failures)) # Invoke the run() method of the wrapped Session. Catch any TensorFlow # runtime errors. @@ -486,6 +515,19 @@ class BaseDebugWrapperSession(session.SessionInterface): retvals = callable_runner(*callable_runner_args, options=decorated_run_options, run_metadata=run_metadata) + elif callable_options: + # pylint:disable=protected-access + if callable_options_id in self._cached_callables_from_options: + callable_object = self._cached_callables_from_options[ + callable_options_id] + else: + callable_object = self._sess._make_callable_from_options( + new_callable_options) + self._cached_callables_from_options[ + callable_options_id] = callable_object + # pylint:enable=protected-access + retvals = callable_object( + *callable_runner_args, run_metadata=run_metadata) else: retvals = self._sess.run(fetches, feed_dict=feed_dict, @@ -590,7 +632,14 @@ class BaseDebugWrapperSession(session.SessionInterface): run_metadata=kwargs.get("run_metadata", None), callable_runner=runner, callable_runner_args=runner_args) + return wrapped_runner + def _make_callable_from_options(self, callable_options): + def wrapped_runner(*feed_values, **kwargs): + return self.run(None, + run_metadata=kwargs.get("run_metadata", None), + callable_options=callable_options, + callable_runner_args=feed_values) return wrapped_runner @property |