aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training
diff options
context:
space:
mode:
authorGravatar Tiezhen WANG <wangtz@google.com>2018-08-13 09:49:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-13 09:53:54 -0700
commita983448448d3030674b2a13c5723a4e9db756003 (patch)
treeca9e97f05c4c0a27db4ef7d65d69e2470c3433f4 /tensorflow/contrib/training
parent9299a96a2bd07680cce655867e77b18977b7ed78 (diff)
TF train: allow passing in run_metadata to unblock tf profiler.
PiperOrigin-RevId: 208495685
Diffstat (limited to 'tensorflow/contrib/training')
-rw-r--r--tensorflow/contrib/training/python/training/training.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py
index f72e0a3f83..c272a2ac14 100644
--- a/tensorflow/contrib/training/python/training/training.py
+++ b/tensorflow/contrib/training/python/training/training.py
@@ -484,7 +484,8 @@ def train(train_op,
save_checkpoint_secs=600,
save_summaries_steps=100,
config=None,
- max_wait_secs=7200):
+ max_wait_secs=7200,
+ run_metadata=None):
"""Runs the training loop.
Args:
@@ -511,6 +512,7 @@ def train(train_op,
become available. This should be kept relatively short to help detect
incorrect code, but sometimes may need to be increased if the chief takes
a while to start up.
+ run_metadata: A [`RunMetadata`] protocol buffer.
Returns:
the value of the loss function after training.
@@ -541,5 +543,5 @@ def train(train_op,
max_wait_secs=max_wait_secs) as session:
loss = None
while not session.should_stop():
- loss = session.run(train_op)
+ loss = session.run(train_op, run_metadata=run_metadata)
return loss