aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-01-10 10:50:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-10 11:13:50 -0800
commit7636fd6e21935379f5b3ed45720e781001eda46b (patch)
treebe71347e190a02ed33d78550dab000d9d7cd02c7
parentd73a2668b5e7b65d54576be461e1107f27ea1365 (diff)
Make SecondOrStepTimer public.
Change: 144101397
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py18
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py10
-rw-r--r--tensorflow/python/training/training.py1
3 files changed, 15 insertions, 14 deletions
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index a1742621b3..d7ec4795ec 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -46,7 +46,7 @@ from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
-class _SecondOrStepTimer(object):
+class SecondOrStepTimer(object):
"""Timer that triggers at most once every N seconds or once every N steps.
"""
@@ -145,8 +145,8 @@ class LoggingTensorHook(session_run_hook.SessionRunHook):
if not isinstance(tensors, dict):
tensors = {item: item for item in tensors}
self._tensors = tensors
- self._timer = _SecondOrStepTimer(every_secs=every_n_secs,
- every_steps=every_n_iter)
+ self._timer = SecondOrStepTimer(every_secs=every_n_secs,
+ every_steps=every_n_iter)
def begin(self):
self._iter_count = 0
@@ -317,8 +317,8 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
self._scaffold = scaffold
- self._timer = _SecondOrStepTimer(every_secs=save_secs,
- every_steps=save_steps)
+ self._timer = SecondOrStepTimer(every_secs=save_secs,
+ every_steps=save_steps)
self._listeners = listeners or []
def begin(self):
@@ -397,8 +397,8 @@ class StepCounterHook(session_run_hook.SessionRunHook):
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError(
"exactly one of every_n_steps and every_n_secs should be provided.")
- self._timer = _SecondOrStepTimer(every_steps=every_n_steps,
- every_secs=every_n_secs)
+ self._timer = SecondOrStepTimer(every_steps=every_n_steps,
+ every_secs=every_n_secs)
self._summary_writer = summary_writer
if summary_writer is None and output_dir:
@@ -507,8 +507,8 @@ class SummarySaverHook(session_run_hook.SessionRunHook):
if summary_writer is None and output_dir:
self._summary_writer = SummaryWriterCache.get(output_dir)
self._scaffold = scaffold
- self._timer = _SecondOrStepTimer(every_secs=save_secs,
- every_steps=save_steps)
+ self._timer = SecondOrStepTimer(every_secs=save_secs,
+ every_steps=save_steps)
# TODO(mdan): Throw an error if output_dir and summary_writer are None.
def begin(self):
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index a39fbeb2d0..a1532d2873 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -81,14 +81,14 @@ class SecondOrStepTimerTest(test.TestCase):
def test_raise_in_both_secs_and_steps(self):
with self.assertRaises(ValueError):
- basic_session_run_hooks._SecondOrStepTimer(every_secs=2.0, every_steps=10)
+ basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)
def test_raise_in_none_secs_and_steps(self):
with self.assertRaises(ValueError):
- basic_session_run_hooks._SecondOrStepTimer()
+ basic_session_run_hooks.SecondOrStepTimer()
def test_every_secs(self):
- timer = basic_session_run_hooks._SecondOrStepTimer(every_secs=1.0)
+ timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)
self.assertTrue(timer.should_trigger_for_step(1))
timer.update_last_triggered_step(1)
@@ -100,7 +100,7 @@ class SecondOrStepTimerTest(test.TestCase):
self.assertTrue(timer.should_trigger_for_step(2))
def test_every_steps(self):
- timer = basic_session_run_hooks._SecondOrStepTimer(every_steps=3)
+ timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)
self.assertTrue(timer.should_trigger_for_step(1))
timer.update_last_triggered_step(1)
@@ -110,7 +110,7 @@ class SecondOrStepTimerTest(test.TestCase):
self.assertTrue(timer.should_trigger_for_step(4))
def test_update_last_triggered_step(self):
- timer = basic_session_run_hooks._SecondOrStepTimer(every_steps=1)
+ timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)
elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)
self.assertEqual(None, elapsed_secs)
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index f7afb8d287..3a2415629a 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -181,6 +181,7 @@ from tensorflow.python.training import input as _input
from tensorflow.python.training.input import *
# pylint: enable=wildcard-import
+from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer
from tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook
from tensorflow.python.training.basic_session_run_hooks import StopAtStepHook
from tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook