aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-09-07 14:44:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-07 14:49:07 -0700
commitb6215a6681fdd1c3d56dc1c5e6acc33d69f91b4d (patch)
tree1d70a7ff4907493fb0d1f841d5a42881800af470 /tensorflow/contrib/predictor
parenteb75ded6d71e2351be4b8694b60b5e473c05027d (diff)
Provide a public property to access model_fn of an Estimator. An example usage of it for extending model_fn is in tf.contrib.estimator.add_metrics function.
Design decisions: * `params` is bound to the returned function: model_fn is made public to help framework developers to extend/compose estimators. We assumed that framework developer does not know much information about how one estimator/model_fn is constructed. We think `params` is part of model_fn construction. If a framework combines estimators, each estimator may be created with different params. Combining those params into single one may create conflicts. * model_fn property have a fixed signature instead of optional arguments. We made model_fn arguments as optional to simplify model_fn implementations. In this case users are not model_fn developers but model_fn users. If we kept the arguments optional, then framework developers should handle that optional logic for every model_fn usage. Also model_fn is called by Estimator. Estimator have all those arguments. So there is no need to make them optional. PiperOrigin-RevId: 167914400
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py
index 5557ef5101..bd5174aef8 100644
--- a/tensorflow/contrib/predictor/core_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/core_estimator_predictor.py
@@ -32,8 +32,9 @@ def _get_signature_def(
if output_key is None:
output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
# pylint: disable=protected-access
- estimator_spec = estimator._call_model_fn(
- serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT)
+ estimator_spec = estimator.model_fn(
+ serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT,
+ estimator.config)
# pylint: enable=protected-access
export_outputs = estimator_spec.export_outputs
export_output = export_outputs.get(output_key)