diff options
author | 2017-09-29 12:21:37 -0700 | |
---|---|---|
committer | 2017-09-29 12:33:00 -0700 | |
commit | 0fb83965a209eb03c1c090e3e540fd7c2c7d1025 (patch) | |
tree | c0a7de3131acbcfa1a2dd34531cea34d5f059cb5 /tensorflow/python/estimator/model_fn.py | |
parent | 76db7553ab2998116a62d6c242aa39373a362993 (diff) |
Users can call EstimatorSpec._replace since it's a namedtuple. Calling _replace does not run validations. Here we provide a new 'replace' which does the validations.
PiperOrigin-RevId: 170516477
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r-- | tensorflow/python/estimator/model_fn.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py index cfa4be5c7d..d58e03f6ef 100644 --- a/tensorflow/python/estimator/model_fn.py +++ b/tensorflow/python/estimator/model_fn.py @@ -54,9 +54,9 @@ AVERAGE_LOSS_METRIC_KEY = 'average_loss' class EstimatorSpec( collections.namedtuple('EstimatorSpec', [ - 'predictions', 'loss', 'train_op', 'eval_metric_ops', - 'export_outputs', 'training_chief_hooks', 'training_hooks', - 'scaffold', 'evaluation_hooks' + 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops', + 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold', + 'evaluation_hooks' ])): """Ops and objects returned from a `model_fn` and passed to an `Estimator`. @@ -295,6 +295,7 @@ class EstimatorSpec( return super(EstimatorSpec, cls).__new__( cls, + mode=mode, predictions=predictions, loss=loss, train_op=train_op, @@ -305,6 +306,14 @@ class EstimatorSpec( scaffold=scaffold, evaluation_hooks=evaluation_hooks) + def _replace(self, **kwds): + """Return a new EstimatorSpec replacing specified fields with new values.""" + if 'mode' in kwds: + if self.mode != kwds['mode']: + raise ValueError('mode of EstimatorSpec cannot be changed.') + new_fields = map(kwds.pop, self._fields, list(self)) + return EstimatorSpec(*new_fields) + def _check_is_tensor_or_operation(x, name): if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)): |