aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 11:57:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 12:07:57 -0800
commit543454b282bbcffd63d1348204662dbfed82fb86 (patch)
tree530ff37d76b596f2cbe288c80a705cec2a9ccd9c /tensorflow/contrib/learn
parent6e3a43f4b7a1288c878b5daff274f1229256fbe8 (diff)
Expose a version of model_fn for contrib Estimators. Make the body of get_timestamped_export_dir an Estimator util.
PiperOrigin-RevId: 188366199
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py22
1 files changed, 20 insertions, 2 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 5262e04e16..d8ccb1e7dc 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -470,6 +470,20 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# TODO(wicke): make RunConfig immutable, and then return it without a copy.
return copy.deepcopy(self._config)
+ @property
+ def model_fn(self):
+ """Returns the model_fn which is bound to self.params.
+
+ Returns:
+ The model_fn with the following signature:
+ `def model_fn(features, labels, mode, metrics)`
+ """
+
+ def public_model_fn(features, labels, mode, config):
+ return self._call_model_fn(features, labels, mode, config=config)
+
+ return public_model_fn
+
@deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
('x', None), ('y', None), ('batch_size', None))
def fit(self,
@@ -1179,7 +1193,7 @@ class Estimator(BaseEstimator):
self._feature_engineering_fn = (
feature_engineering_fn or _identity_feature_engineering_fn)
- def _call_model_fn(self, features, labels, mode, metrics=None):
+ def _call_model_fn(self, features, labels, mode, metrics=None, config=None):
"""Calls model function with support of 2, 3 or 4 arguments.
Args:
@@ -1187,6 +1201,7 @@ class Estimator(BaseEstimator):
labels: labels dict.
mode: ModeKeys
metrics: Dict of metrics.
+ config: RunConfig.
Returns:
A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a
@@ -1203,7 +1218,10 @@ class Estimator(BaseEstimator):
if 'params' in model_fn_args:
kwargs['params'] = self.params
if 'config' in model_fn_args:
- kwargs['config'] = self.config
+ if config:
+ kwargs['config'] = config
+ else:
+ kwargs['config'] = self.config
if 'model_dir' in model_fn_args:
kwargs['model_dir'] = self.model_dir
model_fn_results = self._model_fn(features, labels, **kwargs)