diff options
author | Mustafa Ispir <ispir@google.com> | 2016-12-08 11:08:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-08 11:24:01 -0800 |
commit | b59082b8afa5130d455e97b017cc4548201cb658 (patch) | |
tree | 3075c8afac302767ad3f314f8a78e26ed28b1d52 | |
parent | f15b389de5decd9c2a709d7a128e6fcd49ff808c (diff) |
Pass estimator config to the model-fn.
Change: 141460215
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 21 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator_test.py | 15 |
2 files changed, 28 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index a7af33eb44..91d900395b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -970,9 +970,12 @@ class Estimator(BaseEstimator): `labels=None`. * `mode` specifies if this training, evaluation or prediction. See `ModeKeys`. - * `params` is a `dict` of hyperparameters. Will receive what + * `params` is a `dict` of hyperparameters. Will receive what is passed to Estimator in `params` parameter. This allows to configure Estimators from hyper parameter tuning. + * `config` is a Configuration object. Will receive what is passed to + Estimator in `config` parameter. This allows updating things in + your model_fn based on configuration such as num_ps_replicas. * Returns: `ModelFnOps` @@ -990,6 +993,8 @@ class Estimator(BaseEstimator): * `(features, labels) -> (predictions, loss, train_op)` * `(features, labels, mode) -> (predictions, loss, train_op)` * `(features, labels, mode, params) -> (predictions, loss, train_op)` + * `(features, labels, mode, params, config) -> + (predictions, loss, train_op)` model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to @@ -1040,14 +1045,14 @@ class Estimator(BaseEstimator): """ features, labels = self._feature_engineering_fn(features, labels) model_fn_args = _get_arguments(self._model_fn) + kwargs = {} if 'mode' in model_fn_args: - if 'params' in model_fn_args: - model_fn_results = self._model_fn(features, labels, mode=mode, - params=self.params) - else: - model_fn_results = self._model_fn(features, labels, mode=mode) - else: - model_fn_results = self._model_fn(features, labels) + kwargs['mode'] = mode + if 'params' in model_fn_args: + kwargs['params'] = self.params + if 'config' in model_fn_args: + kwargs['config'] = self.config + model_fn_results = self._model_fn(features, labels, **kwargs) if isinstance(model_fn_results, model_fn_lib.ModelFnOps): return model_fn_results diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 9060657bbf..5ebc299b57 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -218,6 +218,21 @@ class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor): class EstimatorTest(tf.test.TestCase): + def testModelFnArgs(self): + expected_param = {'some_param': 'some_value'} + expected_config = tf.contrib.learn.RunConfig() + expected_config.i_am_test = True + def _argument_checker(features, labels, mode, params, config): + _, _ = features, labels + self.assertEqual(tf.contrib.learn.ModeKeys.TRAIN, mode) + self.assertEqual(expected_param, params) + self.assertTrue(config.i_am_test) + return tf.constant(0.), tf.constant(0.), tf.constant(0.) + est = tf.contrib.learn.Estimator(model_fn=_argument_checker, + params=expected_param, + config=expected_config) + est.fit(input_fn=boston_input_fn, steps=1) + def testInvalidModelFn_no_train_op(self): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument |