diff options
author | Mustafa Ispir <ispir@google.com> | 2017-09-07 14:44:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-07 14:49:07 -0700 |
commit | b6215a6681fdd1c3d56dc1c5e6acc33d69f91b4d (patch) | |
tree | 1d70a7ff4907493fb0d1f841d5a42881800af470 /tensorflow/contrib/predictor | |
parent | eb75ded6d71e2351be4b8694b60b5e473c05027d (diff) |
Provide a public property to access model_fn of an Estimator. An example usage of it for extending model_fn is in tf.contrib.estimator.add_metrics function.
Design decisions:
* `params` is bound to the returned function: model_fn is made public to help framework developers to extend/compose estimators. We assumed that framework developer does not know much information about how one estimator/model_fn is constructed. We think `params` is part of model_fn construction. If a framework combines estimators, each estimator may be created with different params. Combining those params into single one may create conflicts.
* model_fn property have a fixed signature instead of optional arguments. We made model_fn arguments as optional to simplify model_fn implementations. In this case users are not model_fn developers but model_fn users. If we kept the arguments optional, then framework developers should handle that optional logic for every model_fn usage. Also model_fn is called by Estimator. Estimator have all those arguments. So there is no need to make them optional.
PiperOrigin-RevId: 167914400
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r-- | tensorflow/contrib/predictor/core_estimator_predictor.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py index 5557ef5101..bd5174aef8 100644 --- a/tensorflow/contrib/predictor/core_estimator_predictor.py +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -32,8 +32,9 @@ def _get_signature_def( if output_key is None: output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY # pylint: disable=protected-access - estimator_spec = estimator._call_model_fn( - serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT) + estimator_spec = estimator.model_fn( + serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT, + estimator.config) # pylint: enable=protected-access export_outputs = estimator_spec.export_outputs export_output = export_outputs.get(output_key) |