diff options
Diffstat (limited to 'tensorflow/contrib/tpu/python/tpu/tpu_estimator.py')
-rw-r--r-- | tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 347 |
1 files changed, 244 insertions, 103 deletions
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 5210139336..7c7c97638e 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -22,9 +22,9 @@ import collections import copy import os import signal +import sys import threading import time -import traceback import numpy as np import six @@ -32,6 +32,7 @@ from six.moves import queue as Queue # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import error_handling from tensorflow.contrib.tpu.python.tpu import session_support from tensorflow.contrib.tpu.python.tpu import tpu from tensorflow.contrib.tpu.python.tpu import tpu_config @@ -81,12 +82,17 @@ _TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CTX_KEY = 'context' +_USE_TPU_KEY = 'use_tpu' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' _REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' +# Ideally _USE_TPU_KEY should be reserved as well. However there are already +# models that make use of this key, thus it can not be reserved now to prevent +# breakage. In the long run, we would like to mitigate this by migrating models +# off of using _USE_TPU_KEY. _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -211,8 +217,8 @@ class _SIGNAL(object): class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - See `EstimatorSpec` for `mode`, 'predictions, 'loss', 'train_op', and - 'export_outputs`. + See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and + `export_outputs`. For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where `metric_fn` runs on CPU to generate metrics and `tensors` represents the @@ -226,7 +232,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote size is the first dimension. Once all tensors are available at CPU host from all shards, they are concatenated (on CPU) and passed as positional arguments to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - dict. `metric_fn` takes the `tensors` and returns a dict from metric string + a dict. `metric_fn` takes the `tensors` and returns a dict from metric string name to the result of calling a metric function, namely a `(metric_tensor, update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the `eval_metrics`. @@ -360,17 +366,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): ctx, enqueue_ops, dequeue_ops, - run_infeed_loop_on_coordinator=True): + run_infeed_loop_on_coordinator=True, + rendezvous=None): self._master_job = ctx.master_job self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops + self._rendezvous = rendezvous self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator self._initial_infeed_sleep_secs = ( ctx.config.tpu_config.initial_infeed_sleep_secs) - self._session_cancel_timer = None - self._feed_error = None self._finished = False @@ -387,62 +393,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for op in summary_writer_init_ops: self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - def _log_error(self, session, error): - """Log an infeed or outfeed error. - - This logs a short error message immediately, and schedules a timer to - emit the full stack trace and error message after a short period of time. - If the main session has terminated by the time the timer triggers, we - assume the real source of the error was from the main session and avoid - emitting a stack trace for the infeed. - - Args: - session: `tf.Session`, session to be terminated error: exception that - triggered logging. - error: the Exception to log. - """ - logging.warning( - '\n\n' - 'Error occurred during infeed/outfeed. This may be due to a compile ' - 'error in the main session. Waiting for a short time for the main ' - 'session to come back.\n\n%s', error) - - self._feed_error = traceback.format_exc() - - # If we've already encountered a feed error, don't schedule another - # cancellation op. - if self._session_cancel_timer: - return - - def _cancel_session(): - """Close the session to avoid the main thread from hanging. - - If input pipeline triggers any error, the infeed thread dies but the main - thread for TPU computation waits for the infeed enqueue forever. Close the - Session to cancel the main thread Session.run execution. - - We sleep for a few seconds before closing to give some time for the TPU - compilation error, if any, propagating, from TPU to CPU host. Compilation - errors should be reported by the main thread so that the program can be - interrupted and users can take action. Due to a race condition, the - infeed thread might see an error first. Closing the session here - immediately would result in a session cancellation exception in the main - thread, instead of the expected compile error. User code that depends on - having the proper exception type will therefore be confused. - """ - time.sleep(5) - - # If the main session is still running, the infeed/outfeed errors are - # legitimate, and should be logged. - if not self._finished and self._feed_error: - logging.error('Feed error: %s', self._feed_error) - logging.error('Closing session. A RuntimeError should follow.') - session.close() - - self._session_cancel_timer = threading.Thread(target=_cancel_session) - self._session_cancel_timer.daemon = True - self._session_cancel_timer.start() - def _run_infeed(self, queue_ctx, session): logging.info('Starting infeed thread controller.') if self._initial_infeed_sleep_secs: @@ -451,7 +401,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): time.sleep(self._initial_infeed_sleep_secs) logging.info('%s thread starting after sleep', self._name) - try: + with self._rendezvous.catch_errors(source='infeed', session=session): if self._run_infeed_loop_on_coordinator: for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): @@ -461,19 +411,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): for _ in queue_ctx.read_iteration_counts(): session.run(self._enqueue_ops) logging.info('Infeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _run_outfeed(self, queue_ctx, session): logging.info('Starting outfeed thread controller.') - try: + with self._rendezvous.catch_errors(source='outfeed', session=session): for count, steps in enumerate(queue_ctx.read_iteration_counts()): for i in xrange(steps): logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) session.run(self._dequeue_ops) logging.info('Outfeed thread finished, shutting down.') - except Exception as e: # pylint: disable=broad-except - self._log_error(session, e) def _create_infeed_controller(self, name, target, args): return _OpQueueContext(name=name, target=target, args=args) @@ -492,11 +438,6 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): def before_run(self, run_context): self._feed_error = None - # Wait for the cancellation timer to complete before continuing. - if self._session_cancel_timer: - self._session_cancel_timer.join() - self._session_cancel_timer = None - iterations = run_context.session.run(self._iterations_per_loop_var) logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) @@ -507,16 +448,14 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._outfeed_controller.send_next_batch_signal(iterations) def end(self, session): - if self._session_cancel_timer: - logging.warning('Feed error occurred; waiting for message.') - self._session_cancel_timer.join() - self._finished = True logging.info('Stop infeed thread controller') self._infeed_controller.join() + self._rendezvous.record_done('infeed') logging.info('Stop output thread controller') self._outfeed_controller.join() + self._rendezvous.record_done('outfeed') logging.info('Shutdown TPU system.') session.run(self._finalize_ops) @@ -524,9 +463,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - def __init__(self, ctx, enqueue_ops, dequeue_ops): + def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False) + ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, + rendezvous=rendezvous) def _create_infeed_controller(self, name, target, args): return _OpSignalOnceQueueContext(name=name, target=target, args=args) @@ -696,8 +636,6 @@ def generate_per_core_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) @@ -832,8 +770,6 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( infeed_queue = tpu_feed.InfeedQueue( number_of_tuple_elements=len(per_host_sharded_inputs[0])) captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_configuration_from_sharded_input_tensors( - per_host_sharded_inputs) per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) @@ -842,6 +778,66 @@ def generate_per_host_v2_enqueue_ops_fn_for_host( return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset +def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, + num_hosts): + """Generates infeed enqueue ops for one input_fn on all the hosts.""" + captured_infeed_queue = _CapturedObject() + hooks = [] + device_0 = ctx.tpu_host_placement_function(host_id=0) + with ops.device(device_0): + user_context = tpu_context.TPUContext( + internal_ctx=ctx, input_device=device_0, invocation_index=0) + inputs = _Inputs.from_input_fn(input_fn(user_context)) + + is_dataset = inputs.is_dataset + if ctx.mode == model_fn_lib.ModeKeys.PREDICT: + raise TypeError('Mode PREDICT not yet supported in BROADCAST mode.') + + if is_dataset: + hooks.append(inputs.dataset_initializer_hook()) + num_replicas_per_host = ctx.num_of_replicas_per_host + + def tpu_ordinal_function_impl(replica_id): + if ctx.device_assignment: + return ctx.device_assignment.tpu_ordinal(replica=replica_id) + else: + return replica_id % num_replicas_per_host + + def device_function_impl(replica_id): + return ctx.tpu_host_placement_function(replica_id=replica_id) + + def enqueue_ops_fn(): + """Generates enqueue ops for all the hosts.""" + broadcasted_inputs = [] + flattened_inputs = None # Cache result from input_fn. + for host_id in xrange(num_hosts): + with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): + for _ in xrange(ctx.num_of_replicas_per_host): + # Note: input_fn is only called once at host 0 for the first replica. + # The features and labels returned from that invocation are + # broadcasted to other replicas(including the replicas on other + # hosts). + if flattened_inputs is None: + features, labels = inputs.features_and_labels() # Calls get_next() + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + broadcasted_inputs.append(flattened_inputs) + + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(broadcasted_inputs[0])) + captured_infeed_queue.capture(infeed_queue) + enqueue_ops = infeed_queue.generate_enqueue_ops( + broadcasted_inputs, + tpu_ordinal_function=tpu_ordinal_function_impl, + placement_function=device_function_impl) + return enqueue_ops + + return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset + + class _InputPipeline(object): """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. @@ -1074,6 +1070,22 @@ class _InputPipeline(object): # Infeed_queue_getter must be called after enqueue_ops_fn is called. infeed_queues.append(captured_infeed_queue.get()) + elif self._ctx.is_input_broadcast_with_iterators(): + # Only calls input_fn in host 0. + host_device = tpu_host_placement_fn(host_id=0) + enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( + generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, + self._inputs_structure_recorder, + num_hosts)) + all_hooks.extend(hooks) + if is_dataset: + run_infeed_loop_on_coordinator = False + enqueue_ops.append( + _wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + infeed_queues.append(captured_infeed_queue.get()) else: for host_id in range(num_hosts): host_device = tpu_host_placement_fn(host_id=host_id) @@ -1260,7 +1272,8 @@ class _ModelFnWrapper(object): loss = tpu_estimator_spec.loss captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) to_record = {} - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics + if tpu_estimator_spec.eval_metrics: + to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics if tpu_estimator_spec.host_call is not None: # We assume that evaluate won't update global step, so we don't wrap # this host_call. @@ -1414,8 +1427,16 @@ class _ModelFnWrapper(object): if batch_size_for_model_fn is not None: _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) + running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) + _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) + + if not running_on_cpu: + user_context = tpu_context.TPUContext( + internal_ctx=self._ctx, call_from_input_fn=False) + _add_item_to_params(params, _CTX_KEY, user_context) + estimator_spec = self._model_fn(features=features, **kwargs) - if (self._ctx.is_running_on_cpu(is_export_mode) and + if (running_on_cpu and isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access # The estimator_spec will be passed to `Estimator` directly, which expects # type `EstimatorSpec`. @@ -1555,7 +1576,7 @@ class _OutfeedHostCall(object): RuntimeError: If outfeed tensor is scalar. """ if not self._names: - return [] + return {} ret = {} # For each i, dequeue_ops[i] is a list containing the tensors from all @@ -1574,11 +1595,13 @@ class _OutfeedHostCall(object): # Outfeed ops execute on each replica's first logical core. Note: we must # constraint it such that we have at most one outfeed dequeue and enqueue # per replica. - tpu_device_placement_fn = self._ctx.tpu_device_placement_function for i in xrange(self._ctx.num_replicas): - with ops.device(tpu_device_placement_fn(i)): + host_device, ordinal_id = self._ctx.device_for_replica(i) + with ops.device(host_device): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, shapes=tensor_shapes) + dtypes=tensor_dtypes, + shapes=tensor_shapes, + device_ordinal=ordinal_id) for j, item in enumerate(outfeed_tensors): dequeue_ops[j].append(item) @@ -1593,7 +1616,7 @@ class _OutfeedHostCall(object): # place all ops on tpu host if possible. # # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): + with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): for name in self._names: dequeue_ops = dequeue_ops_by_name[name] for i, item in enumerate(dequeue_ops): @@ -1702,6 +1725,9 @@ class InstallSignalHandlerHook(session_run_hook.SessionRunHook): class TPUEstimator(estimator_lib.Estimator): """Estimator with TPU support. + TPUEstimator also supports training on CPU and GPU. You don't need to define + a separate `tf.estimator.Estimator`. + TPUEstimator handles many of the details of running on TPU devices, such as replicating inputs and models for each core, and returning to host periodically to run hooks. @@ -1739,7 +1765,8 @@ class TPUEstimator(estimator_lib.Estimator): Current limitations: -------------------- - 1. TPU evaluation only works on a single host (one TPU worker). + 1. TPU evaluation only works on a single host (one TPU worker) except + BROADCAST mode. 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all @@ -1978,7 +2005,7 @@ class TPUEstimator(estimator_lib.Estimator): if (config.tpu_config.per_host_input_for_training is tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.computation_shape): + config.tpu_config.num_cores_per_replica): raise ValueError( 'Model parallelism only supports per host input for training. ' 'Please adjust TPURunconfig.per_host_input_for_training.') @@ -2025,6 +2052,7 @@ class TPUEstimator(estimator_lib.Estimator): self._export_to_tpu = export_to_tpu self._is_input_fn_invoked = None + self._rendezvous = {} def _add_meta_graph_for_mode(self, builder, @@ -2268,6 +2296,65 @@ class TPUEstimator(estimator_lib.Estimator): """ pass + def train(self, + input_fn, + hooks=None, + steps=None, + max_steps=None, + saving_listeners=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous + try: + return super(TPUEstimator, self).train( + input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, + saving_listeners=saving_listeners + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('training_loop', sys.exc_info()) + finally: + rendezvous.record_done('training_loop') + rendezvous.raise_errors() + + def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, + name=None): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous + try: + return super(TPUEstimator, self).evaluate( + input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, + name=name + ) + except Exception: # pylint: disable=broad-except + rendezvous.record_error('evaluation_loop', sys.exc_info()) + finally: + rendezvous.record_done('evaluation_loop') + rendezvous.raise_errors() + + def predict(self, + input_fn, + predict_keys=None, + hooks=None, + checkpoint_path=None, + yield_single_examples=True): + rendezvous = error_handling.ErrorRendezvous(num_sources=3) + self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous + try: + for result in super(TPUEstimator, self).predict( + input_fn=input_fn, + predict_keys=predict_keys, + hooks=hooks, + checkpoint_path=checkpoint_path, + yield_single_examples=yield_single_examples): + yield result + except Exception: # pylint: disable=broad-except + rendezvous.record_error('prediction_loop', sys.exc_info()) + finally: + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + + rendezvous.record_done('prediction_loop') + rendezvous.raise_errors() + def _augment_model_fn(self, model_fn, batch_axis): """Returns a new model_fn, which wraps the TPU support.""" @@ -2290,10 +2377,20 @@ class TPUEstimator(estimator_lib.Estimator): # Clear the bit. self._is_input_fn_invoked = None + # examples_hook is added to training_hooks for both CPU and TPU + # execution. + examples_hook = ExamplesPerSecondHook( + ctx.global_batch_size, + output_dir=self.model_dir, + every_n_steps=self._log_every_n_steps) + if ctx.is_running_on_cpu(is_export_mode=is_export_mode): logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu( + estimator_spec = model_fn_wrapper.call_without_tpu( features, labels, is_export_mode=is_export_mode) + estimator_spec = estimator_spec._replace( + training_hooks=estimator_spec.training_hooks + (examples_hook,)) + return estimator_spec assert labels is None, '`labels` passed to `model_fn` must be `None`.' # TPUEstimator._call_input_fn passes `input_fn` as features to here. @@ -2352,7 +2449,9 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode], + ), InstallSignalHandlerHook(), training.LoggingTensorHook( { @@ -2361,10 +2460,6 @@ class TPUEstimator(estimator_lib.Estimator): }, every_n_iter=logging_hook_frequency) ]) - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) examples_hook._set_steps_per_run( # pylint: disable=protected-access self._config.tpu_config.iterations_per_loop) hooks.append(examples_hook) @@ -2424,7 +2519,8 @@ class TPUEstimator(estimator_lib.Estimator): host_call_ret = host_calls.create_tpu_hostcall() eval_metric_ops = {} eval_update_ops = [] - for k, v in host_call_ret['eval_metrics'].items(): + + for k, v in host_call_ret.get('eval_metrics', {}).items(): eval_metric_ops[k] = (v[0], dummy_update_op) eval_update_ops.append(v[1]) @@ -2438,7 +2534,8 @@ class TPUEstimator(estimator_lib.Estimator): enqueue_ops, eval_update_ops + host_ops, run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator)), + run_infeed_loop_on_coordinator), + rendezvous=self._rendezvous[mode]), ] + input_hooks return model_fn_lib.EstimatorSpec( @@ -2504,8 +2601,8 @@ class TPUEstimator(estimator_lib.Estimator): hooks = [ _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction(ctx, enqueue_ops, - host_ops), + TPUInfeedOutfeedSessionHookForPrediction( + ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), ] + input_hooks return model_fn_lib.EstimatorSpec( @@ -3155,3 +3252,47 @@ def _add_item_to_params(params, key, value): else: # Now params is Python dict. params[key] = value + + +def export_estimator_savedmodel(estimator, + export_dir_base, + serving_input_receiver_fn, + assets_extra=None, + as_text=False, + checkpoint_path=None, + strip_default_attrs=False): + """Export `Estimator` trained model for TPU inference. + + Args: + estimator: `Estimator` with which model has been trained. + export_dir_base: A string containing a directory in which to create + timestamped subdirectories containing exported SavedModels. + serving_input_receiver_fn: A function that takes no argument and + returns a `ServingInputReceiver` or `TensorServingInputReceiver`. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel, or `None` if no extra assets are needed. + as_text: whether to write the SavedModel proto in text format. + checkpoint_path: The checkpoint path to export. If `None` (the default), + the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. + + Returns: + The string path to the exported directory. + """ + # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use + # `estimator.config`. + config = tpu_config.RunConfig(model_dir=estimator.model_dir) + est = TPUEstimator( + estimator._model_fn, # pylint: disable=protected-access + config=config, + params=estimator.params, + use_tpu=True, + train_batch_size=2048, # Does not matter. + eval_batch_size=2048, # Does not matter. + ) + return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, + assets_extra, + as_text, + checkpoint_path, + strip_default_attrs) |