diff options
author | Mark Daoust <markdaoust@google.com> | 2018-09-13 09:47:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 09:51:36 -0700 |
commit | 56d4fc8ff67f48294ae5cb0a7f9ff3d954463aa3 (patch) | |
tree | 70ef23b6614992758aa3ab525b74e162bfd3f7e5 /tensorflow/python/estimator | |
parent | 5ae1c93473ae690d4a7b9389b1219179cb2504a3 (diff) |
Add a `namedtuple` factory that accepts doc-strings.
PiperOrigin-RevId: 212828094
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 93 |
1 files changed, 72 insertions, 21 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index 439cc2e3a4..728de65559 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -33,6 +33,7 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import monitored_session from tensorflow.python.training import session_run_hook from tensorflow.python.util import nest +from tensorflow.python.util.collections import tf_namedtuple from tensorflow.python.util.tf_export import estimator_export @@ -62,14 +63,65 @@ EXPORT_TAG_MAP = { ModeKeys.EVAL: [tag_constants.EVAL], } +# pylint: disable=line-too-long + +_EstimatorSpecNamedTuple = tf_namedtuple('EstimatorSpec', [ # pylint: disable=invalid-name + ('mode', + 'A `ModeKeys`. Specifies if this is training, evaluation or prediction.' + ), + ('predictions', 'Predictions `Tensor` or dict of `Tensor`.'), + ('loss', + 'Training loss `Tensor`. Must be either scalar, or with shape `[1]`.'), + ('train_op', 'Op to run one training step.'), + ('eval_metric_ops', + """Dict of metric results keyed by name. + + The values of the dict are the results of calling a metric function, + namely a `(metric_tensor, update_op)` tuple. + + `metric_tensor` should be evaluated without any impact on state + (typically is a pure computation results based on variables.). + For example, it should not trigger the `update_op` or requires any + input fetching.""" + ), + ('export_outputs', + """Describes the output signatures to be exported to `SavedModel`. + + A dict `{name: output}` where: + + * `name` is An arbitrary name for this output. + * `output` is an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + + Single-headed models only need to specify one entry in this dictionary. + Multi-headed models should specify one entry for each head, one of + which must be named using + `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry is + provided, a default `PredictOutput` mapping to `predictions` will be + created.""" + ), + ('training_chief_hooks', + 'Iterable of `tf.train.SessionRunHook` objects to run on the chief worker during training.' + ), + ('training_hooks', + 'Iterable of `tf.train.SessionRunHook` objects to run on all workers during training.' + ), + ('scaffold', + 'A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training.' + ), + ('evaluation_hooks', + 'Iterable of `tf.train.SessionRunHook` objects to run during evaluation.' + ), + ('prediction_hooks', + 'Iterable of `tf.train.SessionRunHook` objects to run during predictions.' + ), +]) + +# pylint: enable=line-too-long + @estimator_export('estimator.EstimatorSpec') -class EstimatorSpec( - collections.namedtuple('EstimatorSpec', [ - 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', - 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', - 'evaluation_hooks', 'prediction_hooks' - ])): +class EstimatorSpec(_EstimatorSpecNamedTuple): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. `EstimatorSpec` fully defines the model to be run by an `Estimator`. @@ -156,23 +208,22 @@ class EstimatorSpec( A dict `{name: output}` where: * name: An arbitrary name for this output. * output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - Single-headed models only need to specify one entry in this dictionary. - Multi-headed models should specify one entry for each head, one of - which must be named using - signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY. - If no entry is provided, a default `PredictOutput` mapping to - `predictions` will be created. - training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to - run on the chief worker during training. - training_hooks: Iterable of `tf.train.SessionRunHook` objects to run - on all workers during training. + `RegressionOutput`, or `PredictOutput`. Single-headed models only need + to specify one entry in this dictionary. Multi-headed models should + specify one entry for each head, one of which must be named using + `signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`. If no entry + is provided, a default `PredictOutput` mapping to `predictions` will + be created. + training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run + on the chief worker during training. + training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on + all workers during training. scaffold: A `tf.train.Scaffold` object that can be used to set initialization, saver, and more to be used in training. - evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to - run during evaluation. - prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to - run during predictions. + evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run + during evaluation. + prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run + during predictions. Returns: A validated `EstimatorSpec` object. |