aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-12-08 11:08:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-08 11:24:01 -0800
commitb59082b8afa5130d455e97b017cc4548201cb658 (patch)
tree3075c8afac302767ad3f314f8a78e26ed28b1d52
parentf15b389de5decd9c2a709d7a128e6fcd49ff808c (diff)
Pass estimator config to the model-fn.
Change: 141460215
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py21
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py15
2 files changed, 28 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index a7af33eb44..91d900395b 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -970,9 +970,12 @@ class Estimator(BaseEstimator):
`labels=None`.
* `mode` specifies if this training, evaluation or
prediction. See `ModeKeys`.
- * `params` is a `dict` of hyperparameters. Will receive what
+ * `params` is a `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
+ * `config` is a Configuration object. Will receive what is passed to
+ Estimator in `config` parameter. This allows updating things in
+ your model_fn based on configuration such as num_ps_replicas.
* Returns:
`ModelFnOps`
@@ -990,6 +993,8 @@ class Estimator(BaseEstimator):
* `(features, labels) -> (predictions, loss, train_op)`
* `(features, labels, mode) -> (predictions, loss, train_op)`
* `(features, labels, mode, params) -> (predictions, loss, train_op)`
+ * `(features, labels, mode, params, config) ->
+ (predictions, loss, train_op)`
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator to
@@ -1040,14 +1045,14 @@ class Estimator(BaseEstimator):
"""
features, labels = self._feature_engineering_fn(features, labels)
model_fn_args = _get_arguments(self._model_fn)
+ kwargs = {}
if 'mode' in model_fn_args:
- if 'params' in model_fn_args:
- model_fn_results = self._model_fn(features, labels, mode=mode,
- params=self.params)
- else:
- model_fn_results = self._model_fn(features, labels, mode=mode)
- else:
- model_fn_results = self._model_fn(features, labels)
+ kwargs['mode'] = mode
+ if 'params' in model_fn_args:
+ kwargs['params'] = self.params
+ if 'config' in model_fn_args:
+ kwargs['config'] = self.config
+ model_fn_results = self._model_fn(features, labels, **kwargs)
if isinstance(model_fn_results, model_fn_lib.ModelFnOps):
return model_fn_results
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 9060657bbf..5ebc299b57 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -218,6 +218,21 @@ class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
class EstimatorTest(tf.test.TestCase):
+ def testModelFnArgs(self):
+ expected_param = {'some_param': 'some_value'}
+ expected_config = tf.contrib.learn.RunConfig()
+ expected_config.i_am_test = True
+ def _argument_checker(features, labels, mode, params, config):
+ _, _ = features, labels
+ self.assertEqual(tf.contrib.learn.ModeKeys.TRAIN, mode)
+ self.assertEqual(expected_param, params)
+ self.assertTrue(config.i_am_test)
+ return tf.constant(0.), tf.constant(0.), tf.constant(0.)
+ est = tf.contrib.learn.Estimator(model_fn=_argument_checker,
+ params=expected_param,
+ config=expected_config)
+ est.fit(input_fn=boston_input_fn, steps=1)
+
def testInvalidModelFn_no_train_op(self):
def _invalid_model_fn(features, labels):
# pylint: disable=unused-argument