aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 06:48:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 08:02:54 -0700
commit44c0237eafaea616707bb3e8f9ec27402b4b33ef (patch)
tree4bfd77154f76607b24b4067706efcb97953a9d25
parenta58c688e1c5b0c8d9887dd5db17fb78e9f238dd7 (diff)
Added config property to BaseEstimator to allow retrieval of RunConfig.
Change: 131054755
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/tests/experiment_test.py16
3 files changed, 22 insertions, 10 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index b60c4e91e6..c6c894542c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
+import copy
import inspect
import itertools
import os
@@ -224,6 +225,11 @@ class BaseEstimator(
self._graph = None
+ @property
+ def config(self):
+ # TODO(wicke): make RunConfig immutable, and then return it without a copy.
+ return copy.deepcopy(self._config)
+
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
# pylint: disable=g-doc-args,g-doc-return-or-yield
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 0f96b70fae..f26a194bb4 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -89,6 +89,10 @@ class Experiment(object):
self._eval_delay_secs = eval_delay_secs
self._continuous_eval_throttle_secs = continuous_eval_throttle_secs
+ @property
+ def estimator(self):
+ return self._estimator
+
def train(self, delay_secs=None):
"""Fit the estimator using the training data.
@@ -102,10 +106,8 @@ class Experiment(object):
The trained estimator.
"""
if delay_secs is None:
- task_id = 0
- if hasattr(FLAGS, "task"):
- task_id = FLAGS.task
- delay_secs = min(60, task_id*5)
+ task_id = self._estimator.config.task or 0
+ delay_secs = min(60, task_id * 5)
if delay_secs:
logging.info("Waiting %d secs before starting training.", delay_secs)
diff --git a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py
index 11feab7c48..7980881ea6 100644
--- a/tensorflow/contrib/learn/python/learn/tests/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/tests/experiment_test.py
@@ -19,16 +19,20 @@ from __future__ import print_function
import time
import tensorflow as tf
-# importing to get flags.
-from tensorflow.contrib.learn.python.learn import learn_runner # pylint: disable=unused-import
+from tensorflow.contrib.learn.python.learn import run_config
class TestEstimator(tf.contrib.learn.Evaluable, tf.contrib.learn.Trainable):
- def __init__(self):
+ def __init__(self, config=None):
self.eval_count = 0
self.fit_count = 0
self.monitors = []
+ self._config = config or run_config.RunConfig()
+
+ @property
+ def config(self):
+ return self._config
def evaluate(self, **kwargs):
tf.logging.info('evaluate called with args: %s' % kwargs)
@@ -72,14 +76,14 @@ class ExperimentTest(tf.test.TestCase):
self.assertAlmostEqual(duration, delay, delta=0.5)
def test_train_default_delay(self):
- est = TestEstimator()
+ config = run_config.RunConfig()
+ est = TestEstimator(config)
ex = tf.contrib.learn.Experiment(est,
train_input_fn='train_input',
eval_input_fn='eval_input')
- tf.flags.DEFINE_integer('task', 0, 'task')
for task in [0, 1, 3]:
start = time.time()
- tf.flags.FLAGS.task = task
+ config.task = task
ex.train()
duration = time.time() - start
self.assertAlmostEqual(duration, task*5, delta=0.5)