# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Some common SessionRunHook classes. @@LoggingTensorHook @@StopAtStepHook @@CheckpointSaverHook @@StepCounterHook @@NanLossDuringTrainingError @@NanTensorHook @@SummarySaverHook @@GlobalStepWaiterHook @@ProfilerHook """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time import numpy as np import six from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.core.protobuf import config_pb2 from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.client import timeline from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache from tensorflow.python.util.tf_export import tf_export class _HookTimer(object): """Base timer for determining when Hooks should trigger. Should not be instantiated directly. """ def __init__(self): pass def reset(self): """Resets the timer.""" pass def should_trigger_for_step(self, step): """Return true if the timer should trigger for the specified step.""" raise NotImplementedError def update_last_triggered_step(self, step): """Update the last triggered time and step number. Args: step: The current step. Returns: A pair `(elapsed_time, elapsed_steps)`, where `elapsed_time` is the number of seconds between the current trigger and the last one (a float), and `elapsed_steps` is the number of steps between the current trigger and the last one. Both values will be set to `None` on the first trigger. """ raise NotImplementedError def last_triggered_step(self): """Returns the last triggered time step or None if never triggered.""" raise NotImplementedError @tf_export("train.SecondOrStepTimer") class SecondOrStepTimer(_HookTimer): """Timer that triggers at most once every N seconds or once every N steps. """ def __init__(self, every_secs=None, every_steps=None): self.reset() self._every_secs = every_secs self._every_steps = every_steps if self._every_secs is None and self._every_steps is None: raise ValueError("Either every_secs or every_steps should be provided.") if (self._every_secs is not None) and (self._every_steps is not None): raise ValueError("Can not provide both every_secs and every_steps.") super(SecondOrStepTimer, self).__init__() def reset(self): self._last_triggered_step = None self._last_triggered_time = None def should_trigger_for_step(self, step): """Return true if the timer should trigger for the specified step. Args: step: Training step to trigger on. Returns: True if the difference between the current time and the time of the last trigger exceeds `every_secs`, or if the difference between the current step and the last triggered step exceeds `every_steps`. False otherwise. """ if self._last_triggered_step is None: return True if self._last_triggered_step == step: return False if self._every_secs is not None: if time.time() >= self._last_triggered_time + self._every_secs: return True if self._every_steps is not None: if step >= self._last_triggered_step + self._every_steps: return True return False def update_last_triggered_step(self, step): current_time = time.time() if self._last_triggered_time is None: elapsed_secs = None elapsed_steps = None else: elapsed_secs = current_time - self._last_triggered_time elapsed_steps = step - self._last_triggered_step self._last_triggered_time = current_time self._last_triggered_step = step return (elapsed_secs, elapsed_steps) def last_triggered_step(self): return self._last_triggered_step class NeverTriggerTimer(_HookTimer): """Timer that never triggers.""" def should_trigger_for_step(self, step): _ = step return False def update_last_triggered_step(self, step): _ = step return (None, None) def last_triggered_step(self): return None @tf_export("train.LoggingTensorHook") class LoggingTensorHook(session_run_hook.SessionRunHook): """Prints the given tensors every N local steps, every N seconds, or at end. The tensors will be printed to the log, with `INFO` severity. If you are not seeing the logs, you might want to add the following line after your imports: ```python tf.logging.set_verbosity(tf.logging.INFO) ``` Note that if `at_end` is True, `tensors` should not include any tensor whose evaluation produces a side effect such as consuming additional inputs. """ def __init__(self, tensors, every_n_iter=None, every_n_secs=None, at_end=False, formatter=None): """Initializes a `LoggingTensorHook`. Args: tensors: `dict` that maps string-valued tags to tensors/tensor names, or `iterable` of tensors/tensor names. every_n_iter: `int`, print the values of `tensors` once every N local steps taken on the current worker. every_n_secs: `int` or `float`, print the values of `tensors` once every N seconds. Exactly one of `every_n_iter` and `every_n_secs` should be provided. at_end: `bool` specifying whether to print the values of `tensors` at the end of the run. formatter: function, takes dict of `tag`->`Tensor` and returns a string. If `None` uses default printing all tensors. Raises: ValueError: if `every_n_iter` is non-positive. """ only_log_at_end = ( at_end and (every_n_iter is None) and (every_n_secs is None)) if (not only_log_at_end and (every_n_iter is None) == (every_n_secs is None)): raise ValueError( "either at_end and/or exactly one of every_n_iter and every_n_secs " "must be provided.") if every_n_iter is not None and every_n_iter <= 0: raise ValueError("invalid every_n_iter=%s." % every_n_iter) if not isinstance(tensors, dict): self._tag_order = tensors tensors = {item: item for item in tensors} else: self._tag_order = tensors.keys() self._tensors = tensors self._formatter = formatter self._timer = ( NeverTriggerTimer() if only_log_at_end else SecondOrStepTimer(every_secs=every_n_secs, every_steps=every_n_iter)) self._log_at_end = at_end def begin(self): self._timer.reset() self._iter_count = 0 # Convert names to tensors if given self._current_tensors = {tag: _as_graph_element(tensor) for (tag, tensor) in self._tensors.items()} def before_run(self, run_context): # pylint: disable=unused-argument self._should_trigger = self._timer.should_trigger_for_step(self._iter_count) if self._should_trigger: return SessionRunArgs(self._current_tensors) else: return None def _log_tensors(self, tensor_values): original = np.get_printoptions() np.set_printoptions(suppress=True) elapsed_secs, _ = self._timer.update_last_triggered_step(self._iter_count) if self._formatter: logging.info(self._formatter(tensor_values)) else: stats = [] for tag in self._tag_order: stats.append("%s = %s" % (tag, tensor_values[tag])) if elapsed_secs is not None: logging.info("%s (%.3f sec)", ", ".join(stats), elapsed_secs) else: logging.info("%s", ", ".join(stats)) np.set_printoptions(**original) def after_run(self, run_context, run_values): _ = run_context if self._should_trigger: self._log_tensors(run_values.results) self._iter_count += 1 def end(self, session): if self._log_at_end: values = session.run(self._current_tensors) self._log_tensors(values) @tf_export("train.StopAtStepHook") class StopAtStepHook(session_run_hook.SessionRunHook): """Hook that requests stop at a specified step.""" def __init__(self, num_steps=None, last_step=None): """Initializes a `StopAtStepHook`. This hook requests stop after either a number of steps have been executed or a last step has been reached. Only one of the two options can be specified. if `num_steps` is specified, it indicates the number of steps to execute after `begin()` is called. If instead `last_step` is specified, it indicates the last step we want to execute, as passed to the `after_run()` call. Args: num_steps: Number of steps to execute. last_step: Step after which to stop. Raises: ValueError: If one of the arguments is invalid. """ if num_steps is None and last_step is None: raise ValueError("One of num_steps or last_step must be specified.") if num_steps is not None and last_step is not None: raise ValueError("Only one of num_steps or last_step can be specified.") self._num_steps = num_steps self._last_step = last_step def begin(self): self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use StopAtStepHook.") def after_create_session(self, session, coord): if self._last_step is None: global_step = session.run(self._global_step_tensor) self._last_step = global_step + self._num_steps def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def after_run(self, run_context, run_values): global_step = run_values.results + 1 if global_step >= self._last_step: # Check latest global step to ensure that the targeted last step is # reached. global_step read tensor is the value of global step # before running the operation. We're not sure whether current session.run # incremented the global_step or not. Here we're checking it. step = run_context.session.run(self._global_step_tensor) if step >= self._last_step: run_context.request_stop() @tf_export("train.CheckpointSaverListener") class CheckpointSaverListener(object): """Interface for listeners that take action before or after checkpoint save. `CheckpointSaverListener` triggers only in steps when `CheckpointSaverHook` is triggered, and provides callbacks at the following points: - before using the session - before each call to `Saver.save()` - after each call to `Saver.save()` - at the end of session To use a listener, implement a class and pass the listener to a `CheckpointSaverHook`, as in this example: ```python class ExampleCheckpointSaverListerner(CheckpointSaverListener): def begin(self): # You can add ops to the graph here. print('Starting the session.') self.your_tensor = ... def before_save(self, session, global_step_value): print('About to write a checkpoint') def after_save(self, session, global_step_value): print('Done writing checkpoint.') def end(self, session, global_step_value): print('Done with the session.') ... listener = ExampleCheckpointSaverListerner() saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir, listeners=[listener]) with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): ... ``` A `CheckpointSaverListener` may simply take some action after every checkpoint save. It is also possible for the listener to use its own schedule to act less frequently, e.g. based on global_step_value. In this case, implementors should implement the `end()` method to handle actions related to the last checkpoint save. But the listener should not act twice if `after_save()` already handled this last checkpoint save. """ def begin(self): pass def before_save(self, session, global_step_value): pass def after_save(self, session, global_step_value): pass def end(self, session, global_step_value): pass @tf_export("train.CheckpointSaverHook") class CheckpointSaverHook(session_run_hook.SessionRunHook): """Saves checkpoints every N steps or seconds.""" def __init__(self, checkpoint_dir, save_secs=None, save_steps=None, saver=None, checkpoint_basename="model.ckpt", scaffold=None, listeners=None): """Initializes a `CheckpointSaverHook`. Args: checkpoint_dir: `str`, base directory for the checkpoint files. save_secs: `int`, save every N secs. save_steps: `int`, save every N steps. saver: `Saver` object, used for saving. checkpoint_basename: `str`, base name for the checkpoint files. scaffold: `Scaffold`, use to get saver object. listeners: List of `CheckpointSaverListener` subclass instances. Used for callbacks that run immediately before or after this hook saves the checkpoint. Raises: ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of saver or scaffold should be set. """ logging.info("Create CheckpointSaverHook.") if saver is not None and scaffold is not None: raise ValueError("You cannot provide both saver and scaffold.") self._saver = saver self._checkpoint_dir = checkpoint_dir self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) self._scaffold = scaffold self._timer = SecondOrStepTimer(every_secs=save_secs, every_steps=save_steps) self._listeners = listeners or [] def begin(self): self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use CheckpointSaverHook.") for l in self._listeners: l.begin() def before_run(self, run_context): # pylint: disable=unused-argument if self._timer.last_triggered_step() is None: # We do write graph and saver_def at the first call of before_run. # We cannot do this in begin, since we let other hooks to change graph and # add variables in begin. Graph is finalized after all begin calls. training_util.write_graph( ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir, "graph.pbtxt") saver_def = self._get_saver().saver_def if self._get_saver() else None graph = ops.get_default_graph() meta_graph_def = meta_graph.create_meta_graph_def( graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) self._summary_writer.add_graph(graph) self._summary_writer.add_meta_graph(meta_graph_def) return SessionRunArgs(self._global_step_tensor) def after_run(self, run_context, run_values): stale_global_step = run_values.results if self._timer.should_trigger_for_step(stale_global_step+1): # get the real value after train op. global_step = run_context.session.run(self._global_step_tensor) if self._timer.should_trigger_for_step(global_step): self._timer.update_last_triggered_step(global_step) self._save(run_context.session, global_step) def end(self, session): last_step = session.run(self._global_step_tensor) if last_step != self._timer.last_triggered_step(): self._save(session, last_step) for l in self._listeners: l.end(session, last_step) def _save(self, session, step): """Saves the latest checkpoint.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: l.before_save(session, step) self._get_saver().save(session, self._save_path, global_step=step) self._summary_writer.add_session_log( SessionLog( status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path), step) for l in self._listeners: l.after_save(session, step) def _get_saver(self): if self._saver is not None: return self._saver elif self._scaffold is not None: return self._scaffold.saver # Get saver from the SAVERS collection if present. collection_key = ops.GraphKeys.SAVERS savers = ops.get_collection(collection_key) if not savers: raise RuntimeError( "No items in collection {}. Please add a saver to the collection " "or provide a saver or scaffold.".format(collection_key)) elif len(savers) > 1: raise RuntimeError( "More than one item in collection {}. " "Please indicate which one to use by passing it to the constructor.". format(collection_key)) self._saver = savers[0] return savers[0] @tf_export("train.StepCounterHook") class StepCounterHook(session_run_hook.SessionRunHook): """Hook that counts steps per second.""" def __init__(self, every_n_steps=100, every_n_secs=None, output_dir=None, summary_writer=None): if (every_n_steps is None) == (every_n_secs is None): raise ValueError( "exactly one of every_n_steps and every_n_secs should be provided.") self._timer = SecondOrStepTimer(every_steps=every_n_steps, every_secs=every_n_secs) self._summary_writer = summary_writer self._output_dir = output_dir self._last_global_step = None self._global_step_check_count = 0 def begin(self): if self._summary_writer is None and self._output_dir: self._summary_writer = SummaryWriterCache.get(self._output_dir) self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use StepCounterHook.") self._summary_tag = training_util.get_global_step().op.name + "/sec" def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._global_step_tensor) def _log_and_record(self, elapsed_steps, elapsed_time, global_step): steps_per_sec = elapsed_steps / elapsed_time if self._summary_writer is not None: summary = Summary(value=[Summary.Value( tag=self._summary_tag, simple_value=steps_per_sec)]) self._summary_writer.add_summary(summary, global_step) logging.info("%s: %g", self._summary_tag, steps_per_sec) def after_run(self, run_context, run_values): _ = run_context stale_global_step = run_values.results if self._timer.should_trigger_for_step(stale_global_step+1): # get the real value after train op. global_step = run_context.session.run(self._global_step_tensor) if self._timer.should_trigger_for_step(global_step): elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( global_step) if elapsed_time is not None: self._log_and_record(elapsed_steps, elapsed_time, global_step) # Check whether the global step has been increased. Here, we do not use the # timer.last_triggered_step as the timer might record a different global # step value such that the comparison could be unreliable. For simplicity, # we just compare the stale_global_step with previously recorded version. if stale_global_step == self._last_global_step: # Here, we use a counter to count how many times we have observed that the # global step has not been increased. For some Optimizers, the global step # is not increased each time by design. For example, SyncReplicaOptimizer # doesn't increase the global step in worker's main train step. self._global_step_check_count += 1 if self._global_step_check_count % 20 == 0: self._global_step_check_count = 0 logging.warning( "It seems that global step (tf.train.get_global_step) has not " "been increased. Current value (could be stable): %s vs previous " "value: %s. You could increase the global step by passing " "tf.train.get_global_step() to Optimizer.apply_gradients or " "Optimizer.minimize.", stale_global_step, self._last_global_step) else: # Whenever we observe the increment, reset the counter. self._global_step_check_count = 0 self._last_global_step = stale_global_step @tf_export("train.NanLossDuringTrainingError") class NanLossDuringTrainingError(RuntimeError): def __str__(self): return "NaN loss during training." @tf_export("train.NanTensorHook") class NanTensorHook(session_run_hook.SessionRunHook): """Monitors the loss tensor and stops training if loss is NaN. Can either fail with exception or just stop training. """ def __init__(self, loss_tensor, fail_on_nan_loss=True): """Initializes a `NanTensorHook`. Args: loss_tensor: `Tensor`, the loss tensor. fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. """ self._loss_tensor = loss_tensor self._fail_on_nan_loss = fail_on_nan_loss def before_run(self, run_context): # pylint: disable=unused-argument return SessionRunArgs(self._loss_tensor) def after_run(self, run_context, run_values): if np.isnan(run_values.results): failure_message = "Model diverged with loss = NaN." if self._fail_on_nan_loss: logging.error(failure_message) raise NanLossDuringTrainingError else: logging.warning(failure_message) # We don't raise an error but we request stop without an exception. run_context.request_stop() @tf_export("train.SummarySaverHook") class SummarySaverHook(session_run_hook.SessionRunHook): """Saves summaries every N steps.""" def __init__(self, save_steps=None, save_secs=None, output_dir=None, summary_writer=None, scaffold=None, summary_op=None): """Initializes a `SummarySaverHook`. Args: save_steps: `int`, save summaries every N steps. Exactly one of `save_secs` and `save_steps` should be set. save_secs: `int`, save summaries every N seconds. output_dir: `string`, the directory to save the summaries to. Only used if no `summary_writer` is supplied. summary_writer: `SummaryWriter`. If `None` and an `output_dir` was passed, one will be created accordingly. scaffold: `Scaffold` to get summary_op if it's not provided. summary_op: `Tensor` of type `string` containing the serialized `Summary` protocol buffer or a list of `Tensor`. They are most likely an output by TF summary methods like `tf.summary.scalar` or `tf.summary.merge_all`. It can be passed in as one tensor; if more than one, they must be passed in as a list. Raises: ValueError: Exactly one of scaffold or summary_op should be set. """ if ((scaffold is None and summary_op is None) or (scaffold is not None and summary_op is not None)): raise ValueError( "Exactly one of scaffold or summary_op must be provided.") self._summary_op = summary_op self._summary_writer = summary_writer self._output_dir = output_dir self._scaffold = scaffold self._timer = SecondOrStepTimer(every_secs=save_secs, every_steps=save_steps) # TODO(mdan): Throw an error if output_dir and summary_writer are None. def begin(self): if self._summary_writer is None and self._output_dir: self._summary_writer = SummaryWriterCache.get(self._output_dir) self._next_step = None self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use SummarySaverHook.") def before_run(self, run_context): # pylint: disable=unused-argument self._request_summary = ( self._next_step is None or self._timer.should_trigger_for_step(self._next_step)) requests = {"global_step": self._global_step_tensor} if self._request_summary: if self._get_summary_op() is not None: requests["summary"] = self._get_summary_op() return SessionRunArgs(requests) def after_run(self, run_context, run_values): _ = run_context if not self._summary_writer: return stale_global_step = run_values.results["global_step"] global_step = stale_global_step + 1 if self._next_step is None or self._request_summary: global_step = run_context.session.run(self._global_step_tensor) if self._next_step is None: self._summary_writer.add_session_log( SessionLog(status=SessionLog.START), global_step) if self._request_summary: self._timer.update_last_triggered_step(global_step) if "summary" in run_values.results: for summary in run_values.results["summary"]: self._summary_writer.add_summary(summary, global_step) self._next_step = global_step + 1 def end(self, session=None): if self._summary_writer: self._summary_writer.flush() def _get_summary_op(self): """Fetches the summary op either from self._summary_op or self._scaffold. Returns: Returns a list of summary `Tensor`. """ summary_op = None if self._summary_op is not None: summary_op = self._summary_op elif self._scaffold.summary_op is not None: summary_op = self._scaffold.summary_op if summary_op is None: return None if not isinstance(summary_op, list): return [summary_op] return summary_op @tf_export("train.GlobalStepWaiterHook") class GlobalStepWaiterHook(session_run_hook.SessionRunHook): """Delays execution until global step reaches `wait_until_step`. This hook delays execution until global step reaches to `wait_until_step`. It is used to gradually start workers in distributed settings. One example usage would be setting `wait_until_step=int(K*log(task_id+1))` assuming that task_id=0 is the chief. """ def __init__(self, wait_until_step): """Initializes a `GlobalStepWaiterHook`. Args: wait_until_step: an `int` shows until which global step should we wait. """ self._wait_until_step = wait_until_step def begin(self): self._worker_is_started = False self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError( "Global step should be created to use _GlobalStepWaiterHook.") def before_run(self, run_context): if self._worker_is_started: return None if self._wait_until_step <= 0: self._worker_is_started = True return None logging.info("Waiting for global step %d before starting training.", self._wait_until_step) last_logged_step = 0 while True: current_step = run_context.session.run(self._global_step_tensor) if current_step >= self._wait_until_step: self._worker_is_started = True return None if current_step - last_logged_step > 1000: logging.info("Waiting for global step %d before starting training. " "Current step is %d.", self._wait_until_step, current_step) last_logged_step = current_step time.sleep(0.5) @tf_export("train.FinalOpsHook") class FinalOpsHook(session_run_hook.SessionRunHook): """A hook which evaluates `Tensors` at the end of a session.""" def __init__(self, final_ops, final_ops_feed_dict=None): """Initializes `FinalOpHook` with ops to run at the end of the session. Args: final_ops: A single `Tensor`, a list of `Tensors` or a dictionary of names to `Tensors`. final_ops_feed_dict: A feed dictionary to use when running `final_ops_dict`. """ self._final_ops = final_ops self._final_ops_feed_dict = final_ops_feed_dict self._final_ops_values = None @property def final_ops_values(self): return self._final_ops_values def end(self, session): if self._final_ops is not None: self._final_ops_values = session.run(self._final_ops, feed_dict=self._final_ops_feed_dict) @tf_export("train.FeedFnHook") class FeedFnHook(session_run_hook.SessionRunHook): """Runs `feed_fn` and sets the `feed_dict` accordingly.""" def __init__(self, feed_fn): """Initializes a `FeedFnHook`. Args: feed_fn: function that takes no arguments and returns `dict` of `Tensor` to feed. """ self.feed_fn = feed_fn def before_run(self, run_context): # pylint: disable=unused-argument return session_run_hook.SessionRunArgs( fetches=None, feed_dict=self.feed_fn()) @tf_export("train.ProfilerHook") class ProfilerHook(session_run_hook.SessionRunHook): """Captures CPU/GPU profiling information every N steps or seconds. This produces files called "timeline-.json", which are in Chrome Trace format. For more information see: https://github.com/catapult-project/catapult/blob/master/tracing/README.md """ def __init__(self, save_steps=None, save_secs=None, output_dir="", show_dataflow=True, show_memory=False): """Initializes a hook that takes periodic profiling snapshots. `options.run_metadata` argument of `tf.Session.Run` is used to collect metadata about execution. This hook sets the metadata and dumps it in Chrome Trace format. Args: save_steps: `int`, save profile traces every N steps. Exactly one of `save_secs` and `save_steps` should be set. save_secs: `int` or `float`, save profile traces every N seconds. output_dir: `string`, the directory to save the profile traces to. Defaults to the current directory. show_dataflow: `bool`, if True, add flow events to the trace connecting producers and consumers of tensors. show_memory: `bool`, if True, add object snapshot events to the trace showing the sizes and lifetimes of tensors. """ self._output_file = os.path.join(output_dir, "timeline-{}.json") self._show_dataflow = show_dataflow self._show_memory = show_memory self._timer = SecondOrStepTimer( every_secs=save_secs, every_steps=save_steps) def begin(self): self._next_step = None self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use ProfilerHook.") def before_run(self, run_context): self._request_summary = ( self._next_step is None or self._timer.should_trigger_for_step(self._next_step)) requests = {"global_step": self._global_step_tensor} opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) if self._request_summary else None) return SessionRunArgs(requests, options=opts) def after_run(self, run_context, run_values): stale_global_step = run_values.results["global_step"] global_step = stale_global_step + 1 if self._request_summary: global_step = run_context.session.run(self._global_step_tensor) self._timer.update_last_triggered_step(global_step) self._save(global_step, self._output_file.format(global_step), run_values.run_metadata.step_stats) self._next_step = global_step + 1 def _save(self, step, save_path, step_stats): logging.info("Saving timeline for %d into '%s'.", step, save_path) with gfile.Open(save_path, "w") as f: trace = timeline.Timeline(step_stats) f.write( trace.generate_chrome_trace_format( show_dataflow=self._show_dataflow, show_memory=self._show_memory)) def _as_graph_element(obj): """Retrieves Graph element.""" graph = ops.get_default_graph() if not isinstance(obj, six.string_types): if not hasattr(obj, "graph") or obj.graph != graph: raise ValueError("Passed %s should have graph attribute that is equal " "to current graph %s." % (obj, graph)) return obj if ":" in obj: element = graph.as_graph_element(obj) else: element = graph.as_graph_element(obj + ":0") # Check that there is no :1 (e.g. it's single output). try: graph.as_graph_element(obj + ":1") except (KeyError, ValueError): pass else: raise ValueError("Name %s is ambiguous, " "as this `Operation` has multiple outputs " "(at least 2)." % obj) return element