aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-01 16:32:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-01 16:35:09 -0700
commitcd368924989284864e3df2fcbae72a3892bb7afb (patch)
tree95e4ee75b1d93324c40d68d5ae6f346a92e8f1c0
parentb31498a054d55ce328a2820fd403af764c482500 (diff)
Allow user to opt out of saving metagraph for TPU with TPUEstimator.export_output().
PiperOrigin-RevId: 198944144
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 4465833f88..a155de3844 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -1830,6 +1830,7 @@ class TPUEstimator(estimator_lib.Estimator):
predict_batch_size=None,
batch_axis=None,
eval_on_tpu=True,
+ export_to_tpu=True,
warm_start_from=None):
"""Constructs an `TPUEstimator` instance.
@@ -1872,6 +1873,8 @@ class TPUEstimator(estimator_lib.Estimator):
False or `PER_HOST_V2`, batch_axis is ignored.
eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
+ export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
+ serving on TPU besides the one on CPU.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
@@ -1943,6 +1946,8 @@ class TPUEstimator(estimator_lib.Estimator):
use_tpu,
eval_on_tpu)
+ self._export_to_tpu = export_to_tpu
+
self._is_input_fn_invoked = None
def _add_meta_graph_for_mode(self,
@@ -1965,11 +1970,11 @@ class TPUEstimator(estimator_lib.Estimator):
save_variables,
mode=mode)
- input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
- input_receiver_fn_map[mode]}
- export_tags = [tag_constants.SERVING, tag_constants.TPU]
- mode = _REWRITE_FOR_INFERENCE_MODE
- try:
+ if self._export_to_tpu:
+ input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE:
+ input_receiver_fn_map[mode]}
+ export_tags = [tag_constants.SERVING, tag_constants.TPU]
+ mode = _REWRITE_FOR_INFERENCE_MODE
(super(TPUEstimator, self).
_add_meta_graph_for_mode(builder,
input_receiver_fn_map,
@@ -1978,9 +1983,6 @@ class TPUEstimator(estimator_lib.Estimator):
save_variables=False,
mode=mode,
export_tags=export_tags))
- except Exception as error: # pylint: disable=broad-except
- logging.warning('Saving meta graph for TPU failed: {}.'
- .format(str(error)))
def _call_model_fn(self, features, labels, mode, config):
if mode == _REWRITE_FOR_INFERENCE_MODE: