aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Mark Daoust <markdaoust@google.com>2018-09-13 09:47:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 09:51:36 -0700
commit56d4fc8ff67f48294ae5cb0a7f9ff3d954463aa3 (patch)
tree70ef23b6614992758aa3ab525b74e162bfd3f7e5 /tensorflow/python/estimator
parent5ae1c93473ae690d4a7b9389b1219179cb2504a3 (diff)
Add a `namedtuple` factory that accepts doc-strings.
PiperOrigin-RevId: 212828094
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/model_fn.py93
1 files changed, 72 insertions, 21 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 439cc2e3a4..728de65559 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -33,6 +33,7 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import monitored_session
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import nest
+from tensorflow.python.util.collections import tf_namedtuple
from tensorflow.python.util.tf_export import estimator_export
@@ -62,14 +63,65 @@ EXPORT_TAG_MAP = {
ModeKeys.EVAL: [tag_constants.EVAL],
}
+# pylint: disable=line-too-long
+
+_EstimatorSpecNamedTuple = tf_namedtuple('EstimatorSpec', [ # pylint: disable=invalid-name
+ ('mode',
+ 'A `ModeKeys`. Specifies if this is training, evaluation or prediction.'
+ ),
+ ('predictions', 'Predictions `Tensor` or dict of `Tensor`.'),
+ ('loss',
+ 'Training loss `Tensor`. Must be either scalar, or with shape `[1]`.'),
+ ('train_op', 'Op to run one training step.'),
+ ('eval_metric_ops',
+ """Dict of metric results keyed by name.
+
+ The values of the dict are the results of calling a metric function,
+ namely a `(metric_tensor, update_op)` tuple.
+
+ `metric_tensor` should be evaluated without any impact on state
+ (typically is a pure computation results based on variables.).
+ For example, it should not trigger the `update_op` or requires any
+ input fetching."""
+ ),
+ ('export_outputs',
+ """Describes the output signatures to be exported to `SavedModel`.
+
+ A dict `{name: output}` where:
+
+ * `name` is An arbitrary name for this output.
+ * `output` is an `ExportOutput` object such as `ClassificationOutput`,
+ `RegressionOutput`, or `PredictOutput`.
+
+ Single-headed models only need to specify one entry in this dictionary.
+ Multi-headed models should specify one entry for each head, one of
+ which must be named using
+ `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry is
+ provided, a default `PredictOutput` mapping to `predictions` will be
+ created."""
+ ),
+ ('training_chief_hooks',
+ 'Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training.'
+ ),
+ ('training_hooks',
+ 'Iterable of `tf.train.SessionRunHook` objects to run on all workers during training.'
+ ),
+ ('scaffold',
+ 'A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training.'
+ ),
+ ('evaluation_hooks',
+ 'Iterable of `tf.train.SessionRunHook` objects to run during evaluation.'
+ ),
+ ('prediction_hooks',
+ 'Iterable of `tf.train.SessionRunHook` objects to run during predictions.'
+ ),
+])
+
+# pylint: enable=line-too-long
+
@estimator_export('estimator.EstimatorSpec')
-class EstimatorSpec(
- collections.namedtuple('EstimatorSpec', [
- 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
- 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold',
- 'evaluation_hooks', 'prediction_hooks'
- ])):
+class EstimatorSpec(_EstimatorSpecNamedTuple):
"""Ops and objects returned from a `model_fn` and passed to an `Estimator`.
`EstimatorSpec` fully defines the model to be run by an `Estimator`.
@@ -156,23 +208,22 @@ class EstimatorSpec(
A dict `{name: output}` where:
* name: An arbitrary name for this output.
* output: an `ExportOutput` object such as `ClassificationOutput`,
- `RegressionOutput`, or `PredictOutput`.
- Single-headed models only need to specify one entry in this dictionary.
- Multi-headed models should specify one entry for each head, one of
- which must be named using
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY.
- If no entry is provided, a default `PredictOutput` mapping to
- `predictions` will be created.
- training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to
- run on the chief worker during training.
- training_hooks: Iterable of `tf.train.SessionRunHook` objects to run
- on all workers during training.
+ `RegressionOutput`, or `PredictOutput`. Single-headed models only need
+ to specify one entry in this dictionary. Multi-headed models should
+ specify one entry for each head, one of which must be named using
+ `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry
+ is provided, a default `PredictOutput` mapping to `predictions` will
+ be created.
+ training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run
+ on the chief worker during training.
+ training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on
+ all workers during training.
scaffold: A `tf.train.Scaffold` object that can be used to set
initialization, saver, and more to be used in training.
- evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to
- run during evaluation.
- prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to
- run during predictions.
+ evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run
+ during evaluation.
+ prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run
+ during predictions.
Returns:
A validated `EstimatorSpec` object.