diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-08 11:57:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-08 12:07:57 -0800 |
commit | 543454b282bbcffd63d1348204662dbfed82fb86 (patch) | |
tree | 530ff37d76b596f2cbe288c80a705cec2a9ccd9c /tensorflow/contrib/learn | |
parent | 6e3a43f4b7a1288c878b5daff274f1229256fbe8 (diff) |
Expose a version of model_fn for contrib Estimators. Make the body of get_timestamped_export_dir an Estimator util.
PiperOrigin-RevId: 188366199
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 5262e04e16..d8ccb1e7dc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -470,6 +470,20 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) + @property + def model_fn(self): + """Returns the model_fn which is bound to self.params. + + Returns: + The model_fn with the following signature: + `def model_fn(features, labels, mode, metrics)` + """ + + def public_model_fn(features, labels, mode, config): + return self._call_model_fn(features, labels, mode, config=config) + + return public_model_fn + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), ('y', None), ('batch_size', None)) def fit(self, @@ -1179,7 +1193,7 @@ class Estimator(BaseEstimator): self._feature_engineering_fn = ( feature_engineering_fn or _identity_feature_engineering_fn) - def _call_model_fn(self, features, labels, mode, metrics=None): + def _call_model_fn(self, features, labels, mode, metrics=None, config=None): """Calls model function with support of 2, 3 or 4 arguments. Args: @@ -1187,6 +1201,7 @@ class Estimator(BaseEstimator): labels: labels dict. mode: ModeKeys metrics: Dict of metrics. + config: RunConfig. Returns: A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a @@ -1203,7 +1218,10 @@ class Estimator(BaseEstimator): if 'params' in model_fn_args: kwargs['params'] = self.params if 'config' in model_fn_args: - kwargs['config'] = self.config + if config: + kwargs['config'] = config + else: + kwargs['config'] = self.config if 'model_dir' in model_fn_args: kwargs['model_dir'] = self.model_dir model_fn_results = self._model_fn(features, labels, **kwargs) |