aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-13 11:30:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 11:34:41 -0700
commitee72b6a204232532e64221f1b9db7843ee13c312 (patch)
treee5124841e8d65375ac2f4c66637f268115c63308 /tensorflow/python/estimator
parent2f886d17f1990da418366bd093a09fb01fe5e777 (diff)
Automated rollback of commit 56d4fc8ff67f48294ae5cb0a7f9ff3d954463aa3
PiperOrigin-RevId: 212847619
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/model_fn.py93
1 files changed, 21 insertions, 72 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 728de65559..439cc2e3a4 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -33,7 +33,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
-from tensorflow.python.util.collections import tf_namedtuple
from tensorflow.python.util.tf_export import estimator_export
@@ -63,65 +62,14 @@ EXPORT_TAG_MAP = {
ModeKeys.EVAL: [tag_constants.EVAL],
}
-# pylint: disable=line-too-long
-
-_EstimatorSpecNamedTuple = tf_namedtuple('EstimatorSpec', [ # pylint: disable=invalid-name
- ('mode',
- 'A `ModeKeys`. Specifies if this is training, evaluation or prediction.'
- ),
- ('predictions', 'Predictions `Tensor` or dict of `Tensor`.'),
- ('loss',
- 'Training loss `Tensor`. Must be either scalar, or with shape `[1]`.'),
- ('train_op', 'Op to run one training step.'),
- ('eval_metric_ops',
- """Dict of metric results keyed by name.
-
- The values of the dict are the results of calling a metric function,
- namely a `(metric_tensor, update_op)` tuple.
-
- `metric_tensor` should be evaluated without any impact on state
- (typically is a pure computation results based on variables.).
- For example, it should not trigger the `update_op` or requires any
- input fetching."""
- ),
- ('export_outputs',
- """Describes the output signatures to be exported to `SavedModel`.
-
- A dict `{name: output}` where:
-
- * `name` is An arbitrary name for this output.
- * `output` is an `ExportOutput` object such as `ClassificationOutput`,
- `RegressionOutput`, or `PredictOutput`.
-
- Single-headed models only need to specify one entry in this dictionary.
- Multi-headed models should specify one entry for each head, one of
- which must be named using
- `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry is
- provided, a default `PredictOutput` mapping to `predictions` will be
- created."""
- ),
- ('training_chief_hooks',
- 'Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training.'
- ),
- ('training_hooks',
- 'Iterable of `tf.train.SessionRunHook` objects to run on all workers during training.'
- ),
- ('scaffold',
- 'A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training.'
- ),
- ('evaluation_hooks',
- 'Iterable of `tf.train.SessionRunHook` objects to run during evaluation.'
- ),
- ('prediction_hooks',
- 'Iterable of `tf.train.SessionRunHook` objects to run during predictions.'
- ),
-])
-
-# pylint: enable=line-too-long
-
@estimator_export('estimator.EstimatorSpec')
-class EstimatorSpec(_EstimatorSpecNamedTuple):
+class EstimatorSpec(
+ collections.namedtuple('EstimatorSpec', [
+ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
+ 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold',
+ 'evaluation_hooks', 'prediction_hooks'
+ ])):
"""Ops and objects returned from a `model_fn` and passed to an `Estimator`.
`EstimatorSpec` fully defines the model to be run by an `Estimator`.
@@ -208,22 +156,23 @@ class EstimatorSpec(_EstimatorSpecNamedTuple):
A dict `{name: output}` where:
* name: An arbitrary name for this output.
* output: an `ExportOutput` object such as `ClassificationOutput`,
- `RegressionOutput`, or `PredictOutput`. Single-headed models only need
- to specify one entry in this dictionary. Multi-headed models should
- specify one entry for each head, one of which must be named using
- `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry
- is provided, a default `PredictOutput` mapping to `predictions` will
- be created.
- training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run
- on the chief worker during training.
- training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on
- all workers during training.
+ `RegressionOutput`, or `PredictOutput`.
+ Single-headed models only need to specify one entry in this dictionary.
+ Multi-headed models should specify one entry for each head, one of
+ which must be named using
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
+ If no entry is provided, a default `PredictOutput` mapping to
+ `predictions` will be created.
+ training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to
+ run on the chief worker during training.
+ training_hooks: Iterable of `tf.train.SessionRunHook` objects to run
+ on all workers during training.
scaffold: A `tf.train.Scaffold` object that can be used to set
initialization, saver, and more to be used in training.
- evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run
- during evaluation.
- prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run
- during predictions.
+ evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to
+ run during evaluation.
+ prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to
+ run during predictions.
Returns:
A validated `EstimatorSpec` object.