aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py347
1 files changed, 244 insertions, 103 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 5210139336..7c7c97638e 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -22,9 +22,9 @@ import collections
import copy
import os
import signal
+import sys
import threading
import time
-import traceback
import numpy as np
import six
@@ -32,6 +32,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import error_handling
from tensorflow.contrib.tpu.python.tpu import session_support
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_config
@@ -81,12 +82,17 @@ _TPU_ESTIMATOR = 'tpu_estimator'
_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
_BATCH_SIZE_KEY = 'batch_size'
_CTX_KEY = 'context'
+_USE_TPU_KEY = 'use_tpu'
_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
_ONE_GIGABYTE = 1024 * 1024 * 1024
_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
_TPU_TRAIN_OP = '_tpu_train_op'
_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
+# Ideally _USE_TPU_KEY should be reserved as well. However there are already
+# models that make use of this key, thus it can not be reserved now to prevent
+# breakage. In the long run, we would like to mitigate this by migrating models
+# off of using _USE_TPU_KEY.
_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
@@ -211,8 +217,8 @@ class _SIGNAL(object):
class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access
"""Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
- See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and
- 'export_outputs`.
+ See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
+ `export_outputs`.
For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
@@ -226,7 +232,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
size is the first dimension. Once all tensors are available at CPU host from
all shards, they are concatenated (on CPU) and passed as positional arguments
to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
- dict. `metric_fn` takes the `tensors` and returns a dict from metric string
+ a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
name to the result of calling a metric function, namely a `(metric_tensor,
update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
`eval_metrics`.
@@ -360,17 +366,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
ctx,
enqueue_ops,
dequeue_ops,
- run_infeed_loop_on_coordinator=True):
+ run_infeed_loop_on_coordinator=True,
+ rendezvous=None):
self._master_job = ctx.master_job
self._enqueue_ops = enqueue_ops
self._dequeue_ops = dequeue_ops
+ self._rendezvous = rendezvous
self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
self._initial_infeed_sleep_secs = (
ctx.config.tpu_config.initial_infeed_sleep_secs)
- self._session_cancel_timer = None
-
self._feed_error = None
self._finished = False
@@ -387,62 +393,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
for op in summary_writer_init_ops:
self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
- def _log_error(self, session, error):
- """Log an infeed or outfeed error.
-
- This logs a short error message immediately, and schedules a timer to
- emit the full stack trace and error message after a short period of time.
- If the main session has terminated by the time the timer triggers, we
- assume the real source of the error was from the main session and avoid
- emitting a stack trace for the infeed.
-
- Args:
- session: `tf.Session`, session to be terminated error: exception that
- triggered logging.
- error: the Exception to log.
- """
- logging.warning(
- '\n\n'
- 'Error occurred during infeed/outfeed. This may be due to a compile '
- 'error in the main session. Waiting for a short time for the main '
- 'session to come back.\n\n%s', error)
-
- self._feed_error = traceback.format_exc()
-
- # If we've already encountered a feed error, don't schedule another
- # cancellation op.
- if self._session_cancel_timer:
- return
-
- def _cancel_session():
- """Close the session to avoid the main thread from hanging.
-
- If input pipeline triggers any error, the infeed thread dies but the main
- thread for TPU computation waits for the infeed enqueue forever. Close the
- Session to cancel the main thread Session.run execution.
-
- We sleep for a few seconds before closing to give some time for the TPU
- compilation error, if any, propagating, from TPU to CPU host. Compilation
- errors should be reported by the main thread so that the program can be
- interrupted and users can take action. Due to a race condition, the
- infeed thread might see an error first. Closing the session here
- immediately would result in a session cancellation exception in the main
- thread, instead of the expected compile error. User code that depends on
- having the proper exception type will therefore be confused.
- """
- time.sleep(5)
-
- # If the main session is still running, the infeed/outfeed errors are
- # legitimate, and should be logged.
- if not self._finished and self._feed_error:
- logging.error('Feed error: %s', self._feed_error)
- logging.error('Closing session. A RuntimeError should follow.')
- session.close()
-
- self._session_cancel_timer = threading.Thread(target=_cancel_session)
- self._session_cancel_timer.daemon = True
- self._session_cancel_timer.start()
-
def _run_infeed(self, queue_ctx, session):
logging.info('Starting infeed thread controller.')
if self._initial_infeed_sleep_secs:
@@ -451,7 +401,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
time.sleep(self._initial_infeed_sleep_secs)
logging.info('%s thread starting after sleep', self._name)
- try:
+ with self._rendezvous.catch_errors(source='infeed', session=session):
if self._run_infeed_loop_on_coordinator:
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
@@ -461,19 +411,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
for _ in queue_ctx.read_iteration_counts():
session.run(self._enqueue_ops)
logging.info('Infeed thread finished, shutting down.')
- except Exception as e: # pylint: disable=broad-except
- self._log_error(session, e)
def _run_outfeed(self, queue_ctx, session):
logging.info('Starting outfeed thread controller.')
- try:
+ with self._rendezvous.catch_errors(source='outfeed', session=session):
for count, steps in enumerate(queue_ctx.read_iteration_counts()):
for i in xrange(steps):
logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
session.run(self._dequeue_ops)
logging.info('Outfeed thread finished, shutting down.')
- except Exception as e: # pylint: disable=broad-except
- self._log_error(session, e)
def _create_infeed_controller(self, name, target, args):
return _OpQueueContext(name=name, target=target, args=args)
@@ -492,11 +438,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
self._feed_error = None
- # Wait for the cancellation timer to complete before continuing.
- if self._session_cancel_timer:
- self._session_cancel_timer.join()
- self._session_cancel_timer = None
-
iterations = run_context.session.run(self._iterations_per_loop_var)
logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
@@ -507,16 +448,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._outfeed_controller.send_next_batch_signal(iterations)
def end(self, session):
- if self._session_cancel_timer:
- logging.warning('Feed error occurred; waiting for message.')
- self._session_cancel_timer.join()
-
self._finished = True
logging.info('Stop infeed thread controller')
self._infeed_controller.join()
+ self._rendezvous.record_done('infeed')
logging.info('Stop output thread controller')
self._outfeed_controller.join()
+ self._rendezvous.record_done('outfeed')
logging.info('Shutdown TPU system.')
session.run(self._finalize_ops)
@@ -524,9 +463,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):
- def __init__(self, ctx, enqueue_ops, dequeue_ops):
+ def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None):
super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
- ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False)
+ ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False,
+ rendezvous=rendezvous)
def _create_infeed_controller(self, name, target, args):
return _OpSignalOnceQueueContext(name=name, target=target, args=args)
@@ -696,8 +636,6 @@ def generate_per_core_enqueue_ops_fn_for_host(
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
- infeed_queue.set_configuration_from_sharded_input_tensors(
- per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
@@ -832,8 +770,6 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
infeed_queue = tpu_feed.InfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]))
captured_infeed_queue.capture(infeed_queue)
- infeed_queue.set_configuration_from_sharded_input_tensors(
- per_host_sharded_inputs)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
@@ -842,6 +778,66 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
+ num_hosts):
+ """Generates infeed enqueue ops for one input_fn on all the hosts."""
+ captured_infeed_queue = _CapturedObject()
+ hooks = []
+ device_0 = ctx.tpu_host_placement_function(host_id=0)
+ with ops.device(device_0):
+ user_context = tpu_context.TPUContext(
+ internal_ctx=ctx, input_device=device_0, invocation_index=0)
+ inputs = _Inputs.from_input_fn(input_fn(user_context))
+
+ is_dataset = inputs.is_dataset
+ if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
+ raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.')
+
+ if is_dataset:
+ hooks.append(inputs.dataset_initializer_hook())
+ num_replicas_per_host = ctx.num_of_replicas_per_host
+
+ def tpu_ordinal_function_impl(replica_id):
+ if ctx.device_assignment:
+ return ctx.device_assignment.tpu_ordinal(replica=replica_id)
+ else:
+ return replica_id % num_replicas_per_host
+
+ def device_function_impl(replica_id):
+ return ctx.tpu_host_placement_function(replica_id=replica_id)
+
+ def enqueue_ops_fn():
+ """Generates enqueue ops for all the hosts."""
+ broadcasted_inputs = []
+ flattened_inputs = None # Cache result from input_fn.
+ for host_id in xrange(num_hosts):
+ with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
+ for _ in xrange(ctx.num_of_replicas_per_host):
+ # Note: input_fn is only called once at host 0 for the first replica.
+ # The features and labels returned from that invocation are
+ # broadcasted to other replicas(including the replicas on other
+ # hosts).
+ if flattened_inputs is None:
+ features, labels = inputs.features_and_labels() # Calls get_next()
+ inputs_structure_recorder.validate_and_record_structure(
+ features, labels)
+ flattened_inputs = (
+ inputs_structure_recorder.flatten_features_and_labels(
+ features, labels))
+ broadcasted_inputs.append(flattened_inputs)
+
+ infeed_queue = tpu_feed.InfeedQueue(
+ number_of_tuple_elements=len(broadcasted_inputs[0]))
+ captured_infeed_queue.capture(infeed_queue)
+ enqueue_ops = infeed_queue.generate_enqueue_ops(
+ broadcasted_inputs,
+ tpu_ordinal_function=tpu_ordinal_function_impl,
+ placement_function=device_function_impl)
+ return enqueue_ops
+
+ return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset
+
+
class _InputPipeline(object):
"""`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
@@ -1074,6 +1070,22 @@ class _InputPipeline(object):
# Infeed_queue_getter must be called after enqueue_ops_fn is called.
infeed_queues.append(captured_infeed_queue.get())
+ elif self._ctx.is_input_broadcast_with_iterators():
+ # Only calls input_fn in host 0.
+ host_device = tpu_host_placement_fn(host_id=0)
+ enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
+ generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
+ self._inputs_structure_recorder,
+ num_hosts))
+ all_hooks.extend(hooks)
+ if is_dataset:
+ run_infeed_loop_on_coordinator = False
+ enqueue_ops.append(
+ _wrap_computation_in_while_loop(
+ device=host_device, op_fn=enqueue_ops_fn))
+ else:
+ enqueue_ops.append(enqueue_ops_fn())
+ infeed_queues.append(captured_infeed_queue.get())
else:
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
@@ -1260,7 +1272,8 @@ class _ModelFnWrapper(object):
loss = tpu_estimator_spec.loss
captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
to_record = {}
- to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
+ if tpu_estimator_spec.eval_metrics:
+ to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
if tpu_estimator_spec.host_call is not None:
# We assume that evaluate won't update global step, so we don't wrap
# this host_call.
@@ -1414,8 +1427,16 @@ class _ModelFnWrapper(object):
if batch_size_for_model_fn is not None:
_add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
+ running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
+ _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
+
+ if not running_on_cpu:
+ user_context = tpu_context.TPUContext(
+ internal_ctx=self._ctx, call_from_input_fn=False)
+ _add_item_to_params(params, _CTX_KEY, user_context)
+
estimator_spec = self._model_fn(features=features, **kwargs)
- if (self._ctx.is_running_on_cpu(is_export_mode) and
+ if (running_on_cpu and
isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access
# The estimator_spec will be passed to `Estimator` directly, which expects
# type `EstimatorSpec`.
@@ -1555,7 +1576,7 @@ class _OutfeedHostCall(object):
RuntimeError: If outfeed tensor is scalar.
"""
if not self._names:
- return []
+ return {}
ret = {}
# For each i, dequeue_ops[i] is a list containing the tensors from all
@@ -1574,11 +1595,13 @@ class _OutfeedHostCall(object):
# Outfeed ops execute on each replica's first logical core. Note: we must
# constraint it such that we have at most one outfeed dequeue and enqueue
# per replica.
- tpu_device_placement_fn = self._ctx.tpu_device_placement_function
for i in xrange(self._ctx.num_replicas):
- with ops.device(tpu_device_placement_fn(i)):
+ host_device, ordinal_id = self._ctx.device_for_replica(i)
+ with ops.device(host_device):
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
- dtypes=tensor_dtypes, shapes=tensor_shapes)
+ dtypes=tensor_dtypes,
+ shapes=tensor_shapes,
+ device_ordinal=ordinal_id)
for j, item in enumerate(outfeed_tensors):
dequeue_ops[j].append(item)
@@ -1593,7 +1616,7 @@ class _OutfeedHostCall(object):
# place all ops on tpu host if possible.
#
# TODO(jhseu): Evaluate whether this is right for summaries.
- with ops.device(self._ctx.tpu_host_placement_function(core_id=0)):
+ with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
for name in self._names:
dequeue_ops = dequeue_ops_by_name[name]
for i, item in enumerate(dequeue_ops):
@@ -1702,6 +1725,9 @@ class InstallSignalHandlerHook(session_run_hook.SessionRunHook):
class TPUEstimator(estimator_lib.Estimator):
"""Estimator with TPU support.
+ TPUEstimator also supports training on CPU and GPU. You don't need to define
+ a separate `tf.estimator.Estimator`.
+
TPUEstimator handles many of the details of running on TPU devices, such as
replicating inputs and models for each core, and returning to host
periodically to run hooks.
@@ -1739,7 +1765,8 @@ class TPUEstimator(estimator_lib.Estimator):
Current limitations:
--------------------
- 1. TPU evaluation only works on a single host (one TPU worker).
+ 1. TPU evaluation only works on a single host (one TPU worker) except
+ BROADCAST mode.
2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
(`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
@@ -1978,7 +2005,7 @@ class TPUEstimator(estimator_lib.Estimator):
if (config.tpu_config.per_host_input_for_training is
tpu_config.InputPipelineConfig.PER_SHARD_V1 and
- config.tpu_config.computation_shape):
+ config.tpu_config.num_cores_per_replica):
raise ValueError(
'Model parallelism only supports per host input for training. '
'Please adjust TPURunconfig.per_host_input_for_training.')
@@ -2025,6 +2052,7 @@ class TPUEstimator(estimator_lib.Estimator):
self._export_to_tpu = export_to_tpu
self._is_input_fn_invoked = None
+ self._rendezvous = {}
def _add_meta_graph_for_mode(self,
builder,
@@ -2268,6 +2296,65 @@ class TPUEstimator(estimator_lib.Estimator):
"""
pass
+ def train(self,
+ input_fn,
+ hooks=None,
+ steps=None,
+ max_steps=None,
+ saving_listeners=None):
+ rendezvous = error_handling.ErrorRendezvous(num_sources=3)
+ self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
+ try:
+ return super(TPUEstimator, self).train(
+ input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps,
+ saving_listeners=saving_listeners
+ )
+ except Exception: # pylint: disable=broad-except
+ rendezvous.record_error('training_loop', sys.exc_info())
+ finally:
+ rendezvous.record_done('training_loop')
+ rendezvous.raise_errors()
+
+ def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
+ name=None):
+ rendezvous = error_handling.ErrorRendezvous(num_sources=3)
+ self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
+ try:
+ return super(TPUEstimator, self).evaluate(
+ input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path,
+ name=name
+ )
+ except Exception: # pylint: disable=broad-except
+ rendezvous.record_error('evaluation_loop', sys.exc_info())
+ finally:
+ rendezvous.record_done('evaluation_loop')
+ rendezvous.raise_errors()
+
+ def predict(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None,
+ yield_single_examples=True):
+ rendezvous = error_handling.ErrorRendezvous(num_sources=3)
+ self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
+ try:
+ for result in super(TPUEstimator, self).predict(
+ input_fn=input_fn,
+ predict_keys=predict_keys,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ yield_single_examples=yield_single_examples):
+ yield result
+ except Exception: # pylint: disable=broad-except
+ rendezvous.record_error('prediction_loop', sys.exc_info())
+ finally:
+ rendezvous.record_done('prediction_loop')
+ rendezvous.raise_errors()
+
+ rendezvous.record_done('prediction_loop')
+ rendezvous.raise_errors()
+
def _augment_model_fn(self, model_fn, batch_axis):
"""Returns a new model_fn, which wraps the TPU support."""
@@ -2290,10 +2377,20 @@ class TPUEstimator(estimator_lib.Estimator):
# Clear the bit.
self._is_input_fn_invoked = None
+ # examples_hook is added to training_hooks for both CPU and TPU
+ # execution.
+ examples_hook = ExamplesPerSecondHook(
+ ctx.global_batch_size,
+ output_dir=self.model_dir,
+ every_n_steps=self._log_every_n_steps)
+
if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
logging.info('Running %s on CPU', mode)
- return model_fn_wrapper.call_without_tpu(
+ estimator_spec = model_fn_wrapper.call_without_tpu(
features, labels, is_export_mode=is_export_mode)
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks + (examples_hook,))
+ return estimator_spec
assert labels is None, '`labels` passed to `model_fn` must be `None`.'
# TPUEstimator._call_input_fn passes `input_fn` as features to here.
@@ -2352,7 +2449,9 @@ class TPUEstimator(estimator_lib.Estimator):
enqueue_ops,
host_ops,
run_infeed_loop_on_coordinator=(
- run_infeed_loop_on_coordinator)),
+ run_infeed_loop_on_coordinator),
+ rendezvous=self._rendezvous[mode],
+ ),
InstallSignalHandlerHook(),
training.LoggingTensorHook(
{
@@ -2361,10 +2460,6 @@ class TPUEstimator(estimator_lib.Estimator):
},
every_n_iter=logging_hook_frequency)
])
- examples_hook = ExamplesPerSecondHook(
- ctx.global_batch_size,
- output_dir=self.model_dir,
- every_n_steps=self._log_every_n_steps)
examples_hook._set_steps_per_run( # pylint: disable=protected-access
self._config.tpu_config.iterations_per_loop)
hooks.append(examples_hook)
@@ -2424,7 +2519,8 @@ class TPUEstimator(estimator_lib.Estimator):
host_call_ret = host_calls.create_tpu_hostcall()
eval_metric_ops = {}
eval_update_ops = []
- for k, v in host_call_ret['eval_metrics'].items():
+
+ for k, v in host_call_ret.get('eval_metrics', {}).items():
eval_metric_ops[k] = (v[0], dummy_update_op)
eval_update_ops.append(v[1])
@@ -2438,7 +2534,8 @@ class TPUEstimator(estimator_lib.Estimator):
enqueue_ops,
eval_update_ops + host_ops,
run_infeed_loop_on_coordinator=(
- run_infeed_loop_on_coordinator)),
+ run_infeed_loop_on_coordinator),
+ rendezvous=self._rendezvous[mode]),
] + input_hooks
return model_fn_lib.EstimatorSpec(
@@ -2504,8 +2601,8 @@ class TPUEstimator(estimator_lib.Estimator):
hooks = [
_StoppingPredictHook(scalar_stopping_signal),
- TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops,
- host_ops),
+ TPUInfeedOutfeedSessionHookForPrediction(
+ ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]),
] + input_hooks
return model_fn_lib.EstimatorSpec(
@@ -3155,3 +3252,47 @@ def _add_item_to_params(params, key, value):
else:
# Now params is Python dict.
params[key] = value
+
+
+def export_estimator_savedmodel(estimator,
+ export_dir_base,
+ serving_input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None,
+ strip_default_attrs=False):
+ """Export `Estimator` trained model for TPU inference.
+
+ Args:
+ estimator: `Estimator` with which model has been trained.
+ export_dir_base: A string containing a directory in which to create
+ timestamped subdirectories containing exported SavedModels.
+ serving_input_receiver_fn: A function that takes no argument and
+ returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel, or `None` if no extra assets are needed.
+ as_text: whether to write the SavedModel proto in text format.
+ checkpoint_path: The checkpoint path to export. If `None` (the default),
+ the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs.
+
+ Returns:
+ The string path to the exported directory.
+ """
+ # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
+ # `estimator.config`.
+ config = tpu_config.RunConfig(model_dir=estimator.model_dir)
+ est = TPUEstimator(
+ estimator._model_fn, # pylint: disable=protected-access
+ config=config,
+ params=estimator.params,
+ use_tpu=True,
+ train_batch_size=2048, # Does not matter.
+ eval_batch_size=2048, # Does not matter.
+ )
+ return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
+ assets_extra,
+ as_text,
+ checkpoint_path,
+ strip_default_attrs)