diff options
author | Tiezhen WANG <wangtz@google.com> | 2018-08-13 09:49:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-13 09:53:54 -0700 |
commit | a983448448d3030674b2a13c5723a4e9db756003 (patch) | |
tree | ca9e97f05c4c0a27db4ef7d65d69e2470c3433f4 /tensorflow/contrib/training | |
parent | 9299a96a2bd07680cce655867e77b18977b7ed78 (diff) |
TF train: allow passing in run_metadata to unblock tf profiler.
PiperOrigin-RevId: 208495685
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r-- | tensorflow/contrib/training/python/training/training.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py index f72e0a3f83..c272a2ac14 100644 --- a/tensorflow/contrib/training/python/training/training.py +++ b/tensorflow/contrib/training/python/training/training.py @@ -484,7 +484,8 @@ def train(train_op, save_checkpoint_secs=600, save_summaries_steps=100, config=None, - max_wait_secs=7200): + max_wait_secs=7200, + run_metadata=None): """Runs the training loop. Args: @@ -511,6 +512,7 @@ def train(train_op, become available. This should be kept relatively short to help detect incorrect code, but sometimes may need to be increased if the chief takes a while to start up. + run_metadata: A [`RunMetadata`] protocol buffer. Returns: the value of the loss function after training. @@ -541,5 +543,5 @@ def train(train_op, max_wait_secs=max_wait_secs) as session: loss = None while not session.should_stop(): - loss = session.run(train_op) + loss = session.run(train_op, run_metadata=run_metadata) return loss |