diff options
author | 2018-09-13 11:30:45 -0700 | |
---|---|---|
committer | 2018-09-13 11:34:41 -0700 | |
commit | ee72b6a204232532e64221f1b9db7843ee13c312 (patch) | |
tree | e5124841e8d65375ac2f4c66637f268115c63308 /tensorflow/python/estimator | |
parent | 2f886d17f1990da418366bd093a09fb01fe5e777 (diff) |
Automated rollback of commit 56d4fc8ff67f48294ae5cb0a7f9ff3d954463aa3
PiperOrigin-RevId: 212847619
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 93 |
1 files changed, 21 insertions, 72 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 728de65559..439cc2e3a4 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -33,7 +33,6 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest -from tensorflow.python.util.collections import tf_namedtuple from tensorflow.python.util.tf_export import estimator_export @@ -63,65 +62,14 @@ EXPORT_TAG_MAP = { ModeKeys.EVAL: [tag_constants.EVAL], } -# pylint: disable=line-too-long - -_EstimatorSpecNamedTuple = tf_namedtuple('EstimatorSpec', [ # pylint: disable=invalid-name - ('mode', - 'A `ModeKeys`. Specifies if this is training, evaluation or prediction.' - ), - ('predictions', 'Predictions `Tensor` or dict of `Tensor`.'), - ('loss', - 'Training loss `Tensor`. Must be either scalar, or with shape `[1]`.'), - ('train_op', 'Op to run one training step.'), - ('eval_metric_ops', - """Dict of metric results keyed by name. - - The values of the dict are the results of calling a metric function, - namely a `(metric_tensor, update_op)` tuple. - - `metric_tensor` should be evaluated without any impact on state - (typically is a pure computation results based on variables.). - For example, it should not trigger the `update_op` or requires any - input fetching.""" - ), - ('export_outputs', - """Describes the output signatures to be exported to `SavedModel`. - - A dict `{name: output}` where: - - * `name` is An arbitrary name for this output. - * `output` is an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - - Single-headed models only need to specify one entry in this dictionary. - Multi-headed models should specify one entry for each head, one of - which must be named using - `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry is - provided, a default `PredictOutput` mapping to `predictions` will be - created.""" - ), - ('training_chief_hooks', - 'Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training.' - ), - ('training_hooks', - 'Iterable of `tf.train.SessionRunHook` objects to run on all workers during training.' - ), - ('scaffold', - 'A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training.' - ), - ('evaluation_hooks', - 'Iterable of `tf.train.SessionRunHook` objects to run during evaluation.' - ), - ('prediction_hooks', - 'Iterable of `tf.train.SessionRunHook` objects to run during predictions.' - ), -]) - -# pylint: enable=line-too-long - @estimator_export('estimator.EstimatorSpec') -class EstimatorSpec(_EstimatorSpecNamedTuple): +class EstimatorSpec( + collections.namedtuple('EstimatorSpec', [ + 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', + 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', + 'evaluation_hooks', 'prediction_hooks' + ])): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. `EstimatorSpec` fully defines the model to be run by an `Estimator`. @@ -208,22 +156,23 @@ class EstimatorSpec(_EstimatorSpecNamedTuple): A dict `{name: output}` where: * name: An arbitrary name for this output. * output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. Single-headed models only need - to specify one entry in this dictionary. Multi-headed models should - specify one entry for each head, one of which must be named using - `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry - is provided, a default `PredictOutput` mapping to `predictions` will - be created. - training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run - on the chief worker during training. - training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on - all workers during training. + `RegressionOutput`, or `PredictOutput`. + Single-headed models only need to specify one entry in this dictionary. + Multi-headed models should specify one entry for each head, one of + which must be named using + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. + If no entry is provided, a default `PredictOutput` mapping to + `predictions` will be created. + training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to + run on the chief worker during training. + training_hooks: Iterable of `tf.train.SessionRunHook` objects to run + on all workers during training. scaffold: A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training. - evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run - during evaluation. - prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run - during predictions. + evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to + run during evaluation. + prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to + run during predictions. Returns: A validated `EstimatorSpec` object. |