diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-23 06:48:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-23 08:02:54 -0700 |
commit | 44c0237eafaea616707bb3e8f9ec27402b4b33ef (patch) | |
tree | 4bfd77154f76607b24b4067706efcb97953a9d25 | |
parent | a58c688e1c5b0c8d9887dd5db17fb78e9f238dd7 (diff) |
Added config property to BaseEstimator to allow retrieval of RunConfig.
Change: 131054755
3 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index b60c4e91e6..c6c894542c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import abc +import copy import inspect import itertools import os @@ -224,6 +225,11 @@ class BaseEstimator( self._graph = None + @property + def config(self): + # TODO(wicke): make RunConfig immutable, and then return it without a copy. + return copy.deepcopy(self._config) + def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): # pylint: disable=g-doc-args,g-doc-return-or-yield diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 0f96b70fae..f26a194bb4 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -89,6 +89,10 @@ class Experiment(object): self._eval_delay_secs = eval_delay_secs self._continuous_eval_throttle_secs = continuous_eval_throttle_secs + @property + def estimator(self): + return self._estimator + def train(self, delay_secs=None): """Fit the estimator using the training data. @@ -102,10 +106,8 @@ class Experiment(object): The trained estimator. """ if delay_secs is None: - task_id = 0 - if hasattr(FLAGS, "task"): - task_id = FLAGS.task - delay_secs = min(60, task_id*5) + task_id = self._estimator.config.task or 0 + delay_secs = min(60, task_id * 5) if delay_secs: logging.info("Waiting %d secs before starting training.", delay_secs) diff --git a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py index 11feab7c48..7980881ea6 100644 --- a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py @@ -19,16 +19,20 @@ from __future__ import print_function import time import tensorflow as tf -# importing to get flags. -from tensorflow.contrib.learn.python.learn import learn_runner # pylint: disable=unused-import +from tensorflow.contrib.learn.python.learn import run_config class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable): - def __init__(self): + def __init__(self, config=None): self.eval_count = 0 self.fit_count = 0 self.monitors = [] + self._config = config or run_config.RunConfig() + + @property + def config(self): + return self._config def evaluate(self, **kwargs): tf.logging.info('evaluate called with args: %s' % kwargs) @@ -72,14 +76,14 @@ class ExperimentTest(tf.test.TestCase): self.assertAlmostEqual(duration, delay, delta=0.5) def test_train_default_delay(self): - est = TestEstimator() + config = run_config.RunConfig() + est = TestEstimator(config) ex = tf.contrib.learn.Experiment(est, train_input_fn='train_input', eval_input_fn='eval_input') - tf.flags.DEFINE_integer('task', 0, 'task') for task in [0, 1, 3]: start = time.time() - tf.flags.FLAGS.task = task + config.task = task ex.train() duration = time.time() - start self.assertAlmostEqual(duration, task*5, delta=0.5) |