aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/model_fn.py
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2017-09-29 12:21:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-29 12:33:00 -0700
commit0fb83965a209eb03c1c090e3e540fd7c2c7d1025 (patch)
treec0a7de3131acbcfa1a2dd34531cea34d5f059cb5 /tensorflow/python/estimator/model_fn.py
parent76db7553ab2998116a62d6c242aa39373a362993 (diff)
Users can call EstimatorSpec._replace since it's a namedtuple. Calling _replace does not run validations. Here we provide a new 'replace' which does the validations.
PiperOrigin-RevId: 170516477
Diffstat (limited to 'tensorflow/python/estimator/model_fn.py')
-rw-r--r--tensorflow/python/estimator/model_fn.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index cfa4be5c7d..d58e03f6ef 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -54,9 +54,9 @@ AVERAGE_LOSS_METRIC_KEY = 'average_loss'
class EstimatorSpec(
collections.namedtuple('EstimatorSpec', [
- 'predictions', 'loss', 'train_op', 'eval_metric_ops',
- 'export_outputs', 'training_chief_hooks', 'training_hooks',
- 'scaffold', 'evaluation_hooks'
+ 'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',
+ 'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold',
+ 'evaluation_hooks'
])):
"""Ops and objects returned from a `model_fn` and passed to an `Estimator`.
@@ -295,6 +295,7 @@ class EstimatorSpec(
return super(EstimatorSpec, cls).__new__(
cls,
+ mode=mode,
predictions=predictions,
loss=loss,
train_op=train_op,
@@ -305,6 +306,14 @@ class EstimatorSpec(
scaffold=scaffold,
evaluation_hooks=evaluation_hooks)
+ def _replace(self, **kwds):
+ """Return a new EstimatorSpec replacing specified fields with new values."""
+ if 'mode' in kwds:
+ if self.mode != kwds['mode']:
+ raise ValueError('mode of EstimatorSpec cannot be changed.')
+ new_fields = map(kwds.pop, self._fields, list(self))
+ return EstimatorSpec(*new_fields)
+
def _check_is_tensor_or_operation(x, name):
if not (isinstance(x, ops.Operation) or isinstance(x, ops.Tensor)):