diff options
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/python/estimator/api/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 439 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/metric_keys.py | 5 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 16 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 42 | ||||
-rw-r--r-- | tensorflow/python/estimator/export/export_output.py | 11 | ||||
-rw-r--r-- | tensorflow/python/estimator/export/export_output_test.py | 15 | ||||
-rw-r--r-- | tensorflow/python/estimator/keras.py | 117 | ||||
-rw-r--r-- | tensorflow/python/estimator/keras_test.py | 172 | ||||
-rw-r--r-- | tensorflow/python/estimator/run_config.py | 40 | ||||
-rw-r--r-- | tensorflow/python/estimator/training.py | 3 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 4 |
13 files changed, 637 insertions, 247 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 8ee38d35cc..fd46163050 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -40,9 +40,9 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gc", + ":metric_keys", + ":util", "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:metric_keys", - "//tensorflow/python/estimator:util", ], ) @@ -683,9 +683,9 @@ py_test( ], deps = [ ":keras", + ":numpy_io", + ":run_config", "//tensorflow:tensorflow_py_no_contrib", - "//tensorflow/python/estimator:numpy_io", - "//tensorflow/python/estimator:run_config", "//third_party/py/numpy", ], ) @@ -707,6 +707,14 @@ py_library( ) py_library( + name = "expect_h5py_installed", + # This is a dummy rule used as a numpy dependency in open-source. + # We expect h5py to already be installed on the system, e.g. via + # `pip install h5py' + visibility = ["//visibility:public"], +) + +py_library( name = "expect_six_installed", # This is a dummy rule used as a numpy dependency in open-source. # We expect six to already be installed on the system, e.g. via diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD index ceb9baef4d..a75fa7d0ae 100644 --- a/tensorflow/python/estimator/api/BUILD +++ b/tensorflow/python/estimator/api/BUILD @@ -6,8 +6,8 @@ package( licenses(["notice"]) # Apache 2.0 -load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files") -load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files") +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES") gen_api_init_files( name = "estimator_python_api_gen", diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 3c832c7569..3292e2724d 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools @@ -384,6 +385,249 @@ class _StopAtAttemptsHook(session_run_hook.SessionRunHook): run_context.request_stop() +def _get_max_splits(tree_hparams): + """Calculates the max possible number of splits based on tree params.""" + # maximum number of splits possible in the whole tree =2^(D-1)-1 + max_splits = (1 << tree_hparams.max_depth) - 1 + return max_splits + + +class _EnsembleGrower(object): + """Abstract base class for different types of ensemble growers. + + Use it to receive training ops for growing and centering bias, depending + on the implementation (for example, in memory or accumulator-based + distributed): + grower = ...create subclass grower(tree_ensemble, tree_hparams) + grow_op = grower.grow_tree(stats_summaries_list, feature_ids_list, + last_layer_nodes_range) + training_ops.append(grow_op) + """ + + def __init__(self, tree_ensemble, tree_hparams): + """Initializes a grower object. + + Args: + tree_ensemble: A TreeEnsemble variable. + tree_hparams: TODO. collections.namedtuple for hyper parameters. + """ + self._tree_ensemble = tree_ensemble + self._tree_hparams = tree_hparams + + @abc.abstractmethod + def center_bias(self, center_bias_var, gradients, hessians): + """Centers bias, if ready, based on statistics. + + Args: + center_bias_var: A variable that will be updated when bias centering + finished. + gradients: A rank 2 tensor of gradients. + hessians: A rank 2 tensor of hessians. + + Returns: + An operation for centering bias. + """ + + @abc.abstractmethod + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + """Grows a tree, if ready, based on provided statistics. + + Args: + stats_summaries_list: List of stats summary tensors, representing sums of + gradients and hessians for each feature bucket. + feature_ids_list: a list of lists of feature ids for each bucket size. + last_layer_nodes_range: A tensor representing ids of the nodes in the + current layer, to be split. + + Returns: + An op for growing a tree. + """ + + # ============= Helper methods =========== + + def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians): + """Updates the ensembles and cache (if needed) with logits prior.""" + continue_centering = boosted_trees_ops.center_bias( + self._tree_ensemble.resource_handle, + mean_gradients=mean_gradients, + mean_hessians=mean_hessians, + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2) + return center_bias_var.assign(continue_centering) + + def _grow_tree_from_stats_summaries(self, stats_summaries_list, + feature_ids_list, last_layer_nodes_range): + """Updates ensemble based on the best gains from stats summaries.""" + node_ids_per_feature = [] + gains_list = [] + thresholds_list = [] + left_node_contribs_list = [] + right_node_contribs_list = [] + all_feature_ids = [] + assert len(stats_summaries_list) == len(feature_ids_list) + + max_splits = _get_max_splits(self._tree_hparams) + + for i, feature_ids in enumerate(feature_ids_list): + (numeric_node_ids_per_feature, numeric_gains_list, + numeric_thresholds_list, numeric_left_node_contribs_list, + numeric_right_node_contribs_list) = ( + boosted_trees_ops.calculate_best_gains_per_feature( + node_id_range=last_layer_nodes_range, + stats_summary_list=stats_summaries_list[i], + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2, + tree_complexity=self._tree_hparams.tree_complexity, + min_node_weight=self._tree_hparams.min_node_weight, + max_splits=max_splits)) + + all_feature_ids += feature_ids + node_ids_per_feature += numeric_node_ids_per_feature + gains_list += numeric_gains_list + thresholds_list += numeric_thresholds_list + left_node_contribs_list += numeric_left_node_contribs_list + right_node_contribs_list += numeric_right_node_contribs_list + + grow_op = boosted_trees_ops.update_ensemble( + # Confirm if local_tree_ensemble or tree_ensemble should be used. + self._tree_ensemble.resource_handle, + feature_ids=all_feature_ids, + node_ids=node_ids_per_feature, + gains=gains_list, + thresholds=thresholds_list, + left_node_contribs=left_node_contribs_list, + right_node_contribs=right_node_contribs_list, + learning_rate=self._tree_hparams.learning_rate, + max_depth=self._tree_hparams.max_depth, + pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) + return grow_op + + +class _InMemoryEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams): + + super(_InMemoryEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + + def center_bias(self, center_bias_var, gradients, hessians): + # For in memory, we already have a full batch of gradients and hessians, + # so just take a mean and proceed with centering. + mean_gradients = array_ops.expand_dims( + math_ops.reduce_mean(gradients, 0), 0) + mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0) + return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians) + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For in memory, we already have full data in one batch, so we can grow the + # tree immediately. + return self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + + +class _AccumulatorEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams, stamp_token, + n_batches_per_layer, bucket_size_list, is_chief): + super(_AccumulatorEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + self._stamp_token = stamp_token + self._n_batches_per_layer = n_batches_per_layer + self._bucket_size_list = bucket_size_list + self._is_chief = is_chief + + def center_bias(self, center_bias_var, gradients, hessians): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with centering bias. + + # Create an accumulator. + bias_dependencies = [] + bias_accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians means only. + # TODO(nponomareva): this will change for a multiclass + shape=[2, 1], + shared_name='bias_accumulator') + + grads_and_hess = array_ops.stack([gradients, hessians], axis=0) + grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) + + apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token) + bias_dependencies.append(apply_grad) + + # Center bias if enough batches were processed. + with ops.control_dependencies(bias_dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + def center_bias_from_accumulator(): + accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0) + return self._center_bias_fn(center_bias_var, + array_ops.expand_dims(accumulated[0], 0), + array_ops.expand_dims(accumulated[1], 0)) + + center_bias_op = control_flow_ops.cond( + math_ops.greater_equal(bias_accumulator.num_accumulated(), + self._n_batches_per_layer), + center_bias_from_accumulator, + control_flow_ops.no_op, + name='wait_until_n_batches_for_bias_accumulated') + return center_bias_op + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with building a tree layer. + max_splits = _get_max_splits(self._tree_hparams) + + # Prepare accumulators. + accumulators = [] + dependencies = [] + for i, feature_ids in enumerate(feature_ids_list): + stats_summaries = stats_summaries_list[i] + accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians (the last dimension). + shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2], + shared_name='numeric_stats_summary_accumulator_' + str(i)) + accumulators.append(accumulator) + + apply_grad = accumulator.apply_grad( + array_ops.stack(stats_summaries, axis=0), self._stamp_token) + dependencies.append(apply_grad) + + # Grow the tree if enough batches is accumulated. + with ops.control_dependencies(dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + min_accumulated = math_ops.reduce_min( + array_ops.stack([acc.num_accumulated() for acc in accumulators])) + + def grow_tree_from_accumulated_summaries_fn(): + """Updates tree with the best layer from accumulated summaries.""" + # Take out the accumulated summaries from the accumulator and grow. + stats_summaries_list = [] + stats_summaries_list = [ + array_ops.unstack(accumulator.take_grad(1), axis=0) + for accumulator in accumulators + ] + grow_op = self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + return grow_op + + grow_model = control_flow_ops.cond( + math_ops.greater_equal(min_accumulated, self._n_batches_per_layer), + grow_tree_from_accumulated_summaries_fn, + control_flow_ops.no_op, + name='wait_until_n_batches_accumulated') + return grow_model + + def _bt_model_fn( features, labels, @@ -441,11 +685,6 @@ def _bt_model_fn( raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device - # maximum number of splits possible in the whole tree =2^(D-1)-1 - # TODO(youngheek): perhaps storage could be optimized by storing stats with - # the dimension max_splits_per_layer, instead of max_splits (for the entire - # tree). - max_splits = (1 << tree_hparams.max_depth) - 1 train_op = [] with ops.name_scope(name) as name: # Prepare. @@ -543,6 +782,11 @@ def _bt_model_fn( hessians = gradients_impl.gradients( gradients, logits, name='Hessians')[0] + # TODO(youngheek): perhaps storage could be optimized by storing stats + # with the dimension max_splits_per_layer, instead of max_splits (for the + # entire tree). + max_splits = _get_max_splits(tree_hparams) + stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] @@ -559,173 +803,28 @@ def _bt_model_fn( ] stats_summaries_list.append(summaries) - # ========= Helper methods for both in and not in memory. ============== - def grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list): - """Updates ensemble based on the best gains from stats summaries.""" - node_ids_per_feature = [] - gains_list = [] - thresholds_list = [] - left_node_contribs_list = [] - right_node_contribs_list = [] - all_feature_ids = [] - - assert len(stats_summaries_list) == len(feature_ids_list) - - for i, feature_ids in enumerate(feature_ids_list): - (numeric_node_ids_per_feature, numeric_gains_list, - numeric_thresholds_list, numeric_left_node_contribs_list, - numeric_right_node_contribs_list) = ( - boosted_trees_ops.calculate_best_gains_per_feature( - node_id_range=last_layer_nodes_range, - stats_summary_list=stats_summaries_list[i], - l1=tree_hparams.l1, - l2=tree_hparams.l2, - tree_complexity=tree_hparams.tree_complexity, - min_node_weight=tree_hparams.min_node_weight, - max_splits=max_splits)) - - all_feature_ids += feature_ids - node_ids_per_feature += numeric_node_ids_per_feature - gains_list += numeric_gains_list - thresholds_list += numeric_thresholds_list - left_node_contribs_list += numeric_left_node_contribs_list - right_node_contribs_list += numeric_right_node_contribs_list - - grow_op = boosted_trees_ops.update_ensemble( - # Confirm if local_tree_ensemble or tree_ensemble should be used. - tree_ensemble.resource_handle, - feature_ids=all_feature_ids, - node_ids=node_ids_per_feature, - gains=gains_list, - thresholds=thresholds_list, - left_node_contribs=left_node_contribs_list, - right_node_contribs=right_node_contribs_list, - learning_rate=tree_hparams.learning_rate, - max_depth=tree_hparams.max_depth, - pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) - return grow_op - - def _center_bias_fn(mean_gradients, mean_hessians): - """Updates the ensembles and cache (if needed) with logits prior.""" - continue_centering = boosted_trees_ops.center_bias( - tree_ensemble.resource_handle, - mean_gradients=mean_gradients, - mean_hessians=mean_hessians, - l1=tree_hparams.l1, - l2=tree_hparams.l2 - ) - return center_bias_var.assign(continue_centering) - - # ========= End of helper methods. ============== - if train_in_memory and is_single_machine: - train_op.append(distribute_lib.increment_var(global_step)) - - mean_gradients = array_ops.expand_dims( - math_ops.reduce_mean(gradients, 0), 0) - mean_heassians = array_ops.expand_dims( - math_ops.reduce_mean(hessians, 0), 0) - - train_op.append( - control_flow_ops.cond( - center_bias_var, - lambda: _center_bias_fn(mean_gradients, mean_heassians), - functools.partial(grow_tree_from_stats_summaries, - stats_summaries_list, feature_ids_list))) + grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams) else: - - def center_bias_not_in_mem(): - """Accumulates the data and updates the logits bias, when ready.""" - bias_dependencies = [] - - bias_accumulator = data_flow_ops.ConditionalAccumulator( - dtype=dtypes.float32, - # The stats consist of grads and hessians means only. - # TODO(nponomareva): this will change for a multiclass - shape=[2, 1], - shared_name='bias_accumulator') - - grads_and_hess = array_ops.stack([gradients, hessians], axis=0) - grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) - - apply_grad = bias_accumulator.apply_grad(grads_and_hess, stamp_token) - bias_dependencies.append(apply_grad) - - def center_bias_from_accumulator(): - accumulated = array_ops.unstack( - bias_accumulator.take_grad(1), axis=0) - return _center_bias_fn( - array_ops.expand_dims(accumulated[0], 0), - array_ops.expand_dims(accumulated[1], 0)) - - with ops.control_dependencies(bias_dependencies): - if config.is_chief: - center_bias_op = control_flow_ops.cond( - math_ops.greater_equal(bias_accumulator.num_accumulated(), - n_batches_per_layer), - center_bias_from_accumulator, - control_flow_ops.no_op, - name='wait_until_n_batches_for_bias_accumulated') - - return center_bias_op - else: - return control_flow_ops.no_op() - - def grow_not_in_mem(): - """Accumulates the data and grows a layer when ready.""" - - accumulators = [] - dependencies = [] - for i, feature_ids in enumerate(feature_ids_list): - stats_summaries = stats_summaries_list[i] - accumulator = data_flow_ops.ConditionalAccumulator( - dtype=dtypes.float32, - # The stats consist of grads and hessians (the last dimension). - shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], - shared_name='numeric_stats_summary_accumulator_' + str(i)) - accumulators.append(accumulator) - - apply_grad = accumulator.apply_grad( - array_ops.stack(stats_summaries, axis=0), stamp_token) - dependencies.append(apply_grad) - - def grow_tree_from_accumulated_summaries_fn(): - """Updates tree with the best layer from accumulated summaries.""" - # Take out the accumulated summaries from the accumulator and grow. - stats_summaries_list = [] - - stats_summaries_list = [ - array_ops.unstack(accumulator.take_grad(1), axis=0) - for accumulator in accumulators - ] - - grow_op = grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list) - return grow_op - - with ops.control_dependencies(dependencies): - if config.is_chief: - min_accumulated = math_ops.reduce_min( - array_ops.stack( - [acc.num_accumulated() for acc in accumulators])) - - grow_model = control_flow_ops.cond( - math_ops.greater_equal(min_accumulated, n_batches_per_layer), - grow_tree_from_accumulated_summaries_fn, - control_flow_ops.no_op, - name='wait_until_n_batches_accumulated') - - return grow_model - else: - return control_flow_ops.no_op() - - update_model = control_flow_ops.cond( - center_bias_var, center_bias_not_in_mem, grow_not_in_mem) - train_op.append(update_model) - with ops.control_dependencies([update_model]): - increment_global = distribute_lib.increment_var(global_step) - train_op.append(increment_global) + grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams, + stamp_token, n_batches_per_layer, + bucket_size_list, config.is_chief) + + update_model = control_flow_ops.cond( + center_bias_var, + functools.partial( + grower.center_bias, + center_bias_var, + gradients, + hessians, + ), + functools.partial(grower.grow_tree, stats_summaries_list, + feature_ids_list, last_layer_nodes_range)) + train_op.append(update_model) + + with ops.control_dependencies([update_model]): + increment_global = distribute_lib.increment_var(global_step) + train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py index 4f7c849ba4..9d49240fea 100644 --- a/tensorflow/python/estimator/canned/metric_keys.py +++ b/tensorflow/python/estimator/canned/metric_keys.py @@ -47,3 +47,8 @@ class MetricKeys(object): PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d' AUC_AT_CLASS = 'auc/class%d' AUC_PR_AT_CLASS = 'auc_precision_recall/class%d' + + # The following require a class name applied. + PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s' + AUC_AT_NAME = 'auc/%s' + AUC_PR_AT_NAME = 'auc_precision_recall/%s' diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 350a95eea1..915ceeb98b 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -29,8 +29,6 @@ import six from google.protobuf import message from tensorflow.core.framework import summary_pb2 -from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib @@ -216,11 +214,7 @@ class Estimator(object): logging.info('Using config: %s', str(vars(self._config))) if self._config.session_config is None: - rewrite_opts = rewriter_config_pb2.RewriterConfig( - meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) - graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) - self._session_config = config_pb2.ConfigProto( - allow_soft_placement=True, graph_options=graph_opts) + self._session_config = run_config.get_default_session_config() else: self._session_config = self._config.session_config @@ -573,10 +567,16 @@ class Estimator(object): def _assert_members_are_not_overridden(self): """Asserts members of `Estimator` are not overridden.""" + # TPUEstimator is special cased (owned by TF). + if self.__class__.__name__ == 'TPUEstimator': + return + allowed_overrides = set([ '_call_input_fn', '_create_global_step', '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks', - '_tf_api_names', '_estimator_api_names', '_estimator_api_constants', + '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names', + '_estimator_api_names_v1', '_estimator_api_constants', + '_estimator_api_constants_v1', '_validate_features_in_predict_input', '_call_model_fn', '_add_meta_graph_for_mode' ]) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 2a0e4e7617..8bc410ba0b 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -28,6 +28,7 @@ import six from google.protobuf import text_format +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator @@ -203,6 +204,10 @@ class EstimatorConstructorTest(test.TestCase): est = estimator.Estimator(model_fn=model_fn) self.assertTrue(isinstance(est.config, run_config.RunConfig)) + self.assertTrue(est._session_config.allow_soft_placement) + rewrite_options = est._session_config.graph_options.rewrite_options + self.assertEqual(rewrite_options.meta_optimizer_iterations, + rewriter_config_pb2.RewriterConfig.ONE) def test_default_model_dir(self): @@ -2304,6 +2309,43 @@ class EstimatorExportTest(test.TestCase): with self.assertRaisesRegexp(ValueError, err_regex): est._export_all_saved_models(export_dir_base, input_receiver_fn_map) + def test_export_all_saved_models_metric_operation(self): + """Ensures metrics ops.Operations can be expoerted (b/109740581).""" + + def _model_fn(features, labels, mode): + del features, labels # Unused + metrics = {'metrics': (constant_op.constant([0]), + control_flow_ops.no_op())} + return model_fn_lib.EstimatorSpec( + mode, + predictions=constant_op.constant(10.), + loss=constant_op.constant(1.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + eval_metric_ops=metrics) + + tmpdir = tempfile.mkdtemp() + est = estimator.Estimator(model_fn=_model_fn) + est.train(input_fn=dummy_input_fn, steps=1) + + # Perform the export. + export_dir_base = os.path.join( + compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export')) + + input_receiver_fn_map = { + model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()} + + export_dir = est._export_all_saved_models( + export_dir_base, input_receiver_fn_map) + + # Restore, to validate that the export was well-formed. + with ops.Graph().as_default() as graph: + with session.Session(graph=graph) as sess: + meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir) + sig_outputs = meta_graph.signature_def[ + model_fn_lib.ModeKeys.EVAL].outputs + self.assertEqual( + sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0') + def test_export_savedmodel_with_saveables_proto_roundtrip(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py index 6c26d29985..20382a58d8 100644 --- a/tensorflow/python/estimator/export/export_output.py +++ b/tensorflow/python/estimator/export/export_output.py @@ -23,6 +23,7 @@ import abc import six +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.saved_model import signature_def_utils @@ -338,8 +339,16 @@ class _SupervisedOutput(ExportOutput): raise ValueError( '{} update_op must be a Tensor or Operation; got {}.'.format( key, metric_op)) + + # We must wrap any ops in a Tensor before export, as the SignatureDef + # proto expects tensors only. See b/109740581 + metric_op_tensor = metric_op + if isinstance(metric_op, ops.Operation): + with ops.control_dependencies([metric_op]): + metric_op_tensor = constant_op.constant([], name='metric_op_wrapper') + outputs[val_name] = metric_val - outputs[op_name] = metric_op + outputs[op_name] = metric_op_tensor return outputs diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py index b21ba91b0f..d94c764fd7 100644 --- a/tensorflow/python/estimator/export/export_output_test.py +++ b/tensorflow/python/estimator/export/export_output_test.py @@ -24,8 +24,10 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -335,5 +337,18 @@ class SupervisedOutputTest(test.TestCase): self.assertTrue("predictions/output1" in sig_def.outputs) self.assertTrue("features" in sig_def.inputs) + def test_metric_op_is_operation(self): + """Tests that ops.Operation is wrapped by a tensor for metric_ops.""" + loss = {"my_loss": constant_op.constant([0])} + predictions = {u"output1": constant_op.constant(["foo"])} + metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())} + + outputter = MockSupervisedOutput(loss, predictions, metrics) + self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0]) + self.assertEqual( + outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0") + self.assertTrue( + isinstance(outputter.metrics["metrics/update_op"], ops.Tensor)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py index 076359b503..70517ae278 100644 --- a/tensorflow/python/estimator/keras.py +++ b/tensorflow/python/estimator/keras.py @@ -21,11 +21,14 @@ from __future__ import print_function import os import re +import tempfile + from tensorflow.python.client import session from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import export as export_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.estimator.run_config import RunConfig from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib @@ -39,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_module +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import distribute as distribute_lib @@ -180,7 +184,7 @@ def _in_place_subclassed_model_reset(model): # Replace layers on the model with fresh layers layers_to_names = {value: key for key, value in attributes_cache.items()} original_layers = model._layers[:] - model._layers = [] + model._layers = data_structures.NoDependency([]) for layer in original_layers: # We preserve layer order. config = layer.get_config() # This will not work for nested subclassed models used as layers. @@ -228,7 +232,8 @@ def _in_place_subclassed_model_reset(model): ] for name in attributes_to_cache: attributes_cache[name] = getattr(model, name) - model._original_attributes_cache = attributes_cache + model._original_attributes_cache = data_structures.NoDependency( + attributes_cache) # Reset built state model.built = False model.inputs = None @@ -426,29 +431,34 @@ def _create_keras_model_fn(keras_model, custom_objects=None): return model_fn -def _save_first_checkpoint(keras_model, estimator, custom_objects, - keras_weights): +def _save_first_checkpoint(keras_model, custom_objects, config): """Save first checkpoint for the keras Estimator. Args: keras_model: an instance of compiled keras model. - estimator: keras estimator. custom_objects: Dictionary for custom objects. - keras_weights: A flat list of Numpy arrays for weights of given keras_model. + config: Estimator config. Returns: - The model_fn for a keras Estimator. + The path where keras model checkpoint is saved. """ + # save checkpoint into subdirectory to allow warm start + keras_model_dir = os.path.join(config.model_dir, 'keras') # Load weights and save to checkpoint if there is no checkpoint - latest_path = saver_lib.latest_checkpoint(estimator.model_dir) + latest_path = saver_lib.latest_checkpoint(keras_model_dir) if not latest_path: + keras_weights = None + if _any_weight_initialized(keras_model): + keras_weights = keras_model.get_weights() + if not gfile.IsDirectory(keras_model_dir): + gfile.MakeDirs(keras_model_dir) with ops.Graph().as_default(): - random_seed.set_random_seed(estimator.config.tf_random_seed) + random_seed.set_random_seed(config.tf_random_seed) training_util.create_global_step() model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, custom_objects) # save to checkpoint - with session.Session(config=estimator._session_config) as sess: + with session.Session(config=config.session_config) as sess: if keras_weights: model.set_weights(keras_weights) # Make update ops and initialize all variables. @@ -458,7 +468,46 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, K._initialize_variables(sess) # pylint: enable=protected-access saver = saver_lib.Saver() - saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt')) + latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt') + saver.save(sess, latest_path) + return latest_path + + +def _maybe_overwrite_model_dir_and_session_config(config, model_dir): + """Overwrite estimator config by `model_dir` and `session_config` if needed. + + Args: + config: Original estimator config. + model_dir: Estimator model checkpoint directory. + + Returns: + Overwritten estimator config. + + Raises: + ValueError: Model directory inconsistent between `model_dir` and `config`. + """ + + default_session_config = run_config_lib.get_default_session_config() + if isinstance(config, dict): + config = RunConfig(**config) + elif config is None: + config = RunConfig(session_config=default_session_config) + if config.session_config is None: + config = RunConfig.replace(config, session_config=default_session_config) + + if model_dir is not None: + if (getattr(config, 'model_dir', None) is not None and + config.model_dir != model_dir): + raise ValueError( + "`model_dir` are set both in constructor and `RunConfig`, but with " + "different values. In constructor: '{}', in `RunConfig`: " + "'{}' ".format(model_dir, config.model_dir)) + config = RunConfig.replace(config, model_dir=model_dir) + elif getattr(config, 'model_dir', None) is None: + model_dir = tempfile.mkdtemp() + config = RunConfig.replace(config, model_dir=model_dir) + + return config def model_to_estimator(keras_model=None, @@ -517,45 +566,39 @@ def model_to_estimator(keras_model=None, 'Please compile the model with `model.compile()` ' 'before calling `model_to_estimator()`.') - if isinstance(config, dict): - config = run_config_lib.RunConfig(**config) + config = _maybe_overwrite_model_dir_and_session_config(config, model_dir) keras_model_fn = _create_keras_model_fn(keras_model, custom_objects) - estimator = estimator_lib.Estimator( - keras_model_fn, model_dir=model_dir, config=config) - - # Check if we need to call get_weights: if _any_weight_initialized(keras_model): - keras_weights = keras_model.get_weights() # Warn if config passed to estimator tries to update GPUOptions. If a # session has already been created, the GPUOptions passed to the first # session sticks. - if estimator._session_config.HasField('gpu_options'): + if config.session_config.HasField('gpu_options'): logging.warning( 'The Keras backend session has already been set. ' 'The _session_config passed to model_to_estimator will not be used.') else: # Pass the config into keras backend's default session. - sess = session.Session(config=estimator._session_config) + sess = session.Session(config=config.session_config) K.set_session(sess) - keras_weights = None + warm_start_path = None if keras_model._is_graph_network: - # TODO(yifeif): move checkpoint initialization to scaffold.init_fn - _save_first_checkpoint(keras_model, - estimator, - custom_objects, - keras_weights) + warm_start_path = _save_first_checkpoint(keras_model, custom_objects, + config) elif keras_model.built: - logging.warning('You are creating an Estimator from a Keras model ' - 'manually subclassed from `Model`, that was ' - 'already called on some inputs (and thus already had ' - 'weights). We are currently unable to preserve ' - 'the model\'s state (its weights) ' - 'as part of the estimator ' - 'in this case. Be warned that the estimator ' - 'has been created using ' - 'a freshly initialized version of your model.\n' - 'Note that this doesn\'t affect the state of the ' - 'model instance you passed as `keras_model` argument.') + logging.warning('You are creating an Estimator from a Keras model manually ' + 'subclassed from `Model`, that was already called on some ' + 'inputs (and thus already had weights). We are currently ' + 'unable to preserve the model\'s state (its weights) as ' + 'part of the estimator in this case. Be warned that the ' + 'estimator has been created using a freshly initialized ' + 'version of your model.\n' + 'Note that this doesn\'t affect the state of the model ' + 'instance you passed as `keras_model` argument.') + + estimator = estimator_lib.Estimator(keras_model_fn, + config=config, + warm_start_from=warm_start_path) + return estimator diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index 7a4457f5a4..cf4ec7f4da 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -32,13 +32,14 @@ from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras import testing_utils -from tensorflow.python.keras.applications import mobilenet from tensorflow.python.keras.optimizers import SGD +from tensorflow.python.ops import variable_scope from tensorflow.python.ops.parsing_ops import gen_parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import rmsprop +from tensorflow.python.training import session_run_hook try: @@ -51,6 +52,8 @@ _TRAIN_SIZE = 200 _INPUT_SIZE = (10,) _NUM_CLASS = 2 +_TMP_DIR = '/tmp' + def simple_sequential_model(): model = keras.models.Sequential() @@ -60,9 +63,9 @@ def simple_sequential_model(): return model -def simple_functional_model(): +def simple_functional_model(activation='relu'): a = keras.layers.Input(shape=_INPUT_SIZE) - b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dense(16, activation=activation)(a) b = keras.layers.Dropout(0.1)(b) b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) model = keras.models.Model(inputs=[a], outputs=[b]) @@ -168,6 +171,12 @@ def multi_inputs_multi_outputs_model(): return model +class MyHook(session_run_hook.SessionRunHook): + + def begin(self): + _ = variable_scope.get_variable('temp', [1]) + + class TestKerasEstimator(test_util.TensorFlowTestCase): def setUp(self): @@ -204,6 +213,54 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): writer_cache.FileWriterCache.clear() gfile.DeleteRecursively(self._config.model_dir) + # see b/109935364 + @test_util.run_in_graph_and_eager_modes + def test_train_with_hooks(self): + for model_type in ['sequential', 'functional']: + keras_model, (_, _), ( + _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type=model_type, is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + + my_hook = MyHook() + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + before_eval_results = est_keras.evaluate( + input_fn=eval_input_fn, steps=1) + est_keras.train(input_fn=train_input_fn, hooks=[my_hook], + steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + @test_util.run_in_graph_and_eager_modes + def test_train_with_model_fit_and_hooks(self): + keras_model, (x_train, y_train), _, \ + train_input_fn, eval_input_fn = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(1e-3), + metrics=['mse', keras.metrics.categorical_accuracy]) + my_hook = MyHook() + with self.test_session(): + keras_model.fit(x_train, y_train, epochs=1) + + keras_est = keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + before_eval_results = keras_est.evaluate(input_fn=eval_input_fn) + keras_est.train(input_fn=train_input_fn, hooks=[my_hook], + steps=_TRAIN_SIZE / 16) + after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + @test_util.run_in_graph_and_eager_modes def test_train_with_tf_optimizer(self): for model_type in ['sequential', 'functional']: @@ -474,23 +531,43 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): est_keras.train(input_fn=invald_output_name_input_fn, steps=100) def test_custom_objects(self): - keras_mobile = mobilenet.MobileNet(weights=None) - keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam') + + def relu6(x): + return keras.backend.relu(x, max_value=6) + + keras_model = simple_functional_model(activation=relu6) + keras_model.compile(loss='categorical_crossentropy', optimizer='adam') custom_objects = { - 'relu6': mobilenet.relu6, - 'DepthwiseConv2D': mobilenet.DepthwiseConv2D + 'relu6': relu6 } + + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=(10,), + num_classes=2) + y_train = keras.utils.to_categorical(y_train, 2) + input_name = keras_model.input_names[0] + output_name = keras_model.output_names[0] + train_input_fn = numpy_io.numpy_input_fn( + x=randomize_io_type(x_train, input_name), + y=randomize_io_type(y_train, output_name), + shuffle=False, + num_epochs=None, + batch_size=16) with self.assertRaisesRegexp(ValueError, 'relu6'): with self.test_session(): - keras_lib.model_to_estimator( - keras_model=keras_mobile, + est = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) + est.train(input_fn=train_input_fn, steps=1) with self.test_session(): - keras_lib.model_to_estimator( - keras_model=keras_mobile, + est = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), custom_objects=custom_objects) + est.train(input_fn=train_input_fn, steps=1) def test_tf_config(self): keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() @@ -527,12 +604,73 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) self._config._session_config = sess_config - keras_lib.model_to_estimator( - keras_model=keras_model, config=self._config) - self.assertEqual( - keras.backend.get_session() - ._config.gpu_options.per_process_gpu_memory_fraction, - gpu_options.per_process_gpu_memory_fraction) + with self.test_session(): + keras_lib.model_to_estimator( + keras_model=keras_model, config=self._config) + self.assertEqual( + keras.backend.get_session() + ._config.gpu_options.per_process_gpu_memory_fraction, + gpu_options.per_process_gpu_memory_fraction) + + def test_with_empty_config(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=run_config_lib.RunConfig()) + self.assertEqual(run_config_lib.get_default_session_config(), + est_keras._session_config) + self.assertEqual(est_keras._session_config, + est_keras._config.session_config) + self.assertEqual(self._base_dir, est_keras._config.model_dir) + self.assertEqual(self._base_dir, est_keras._model_dir) + + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=None) + self.assertEqual(run_config_lib.get_default_session_config(), + est_keras._session_config) + self.assertEqual(est_keras._session_config, + est_keras._config.session_config) + self.assertEqual(self._base_dir, est_keras._config.model_dir) + self.assertEqual(self._base_dir, est_keras._model_dir) + + def test_with_empty_config_and_empty_model_dir(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, + config=run_config_lib.RunConfig()) + self.assertEqual(est_keras._model_dir, _TMP_DIR) + + def test_with_conflicting_model_dir_and_config(self): + keras_model, _, _, _, _ = get_resource_for_simple_model( + model_type='sequential', is_evaluate=True) + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + with self.test_session(): + with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in ' + 'constructor and `RunConfig`'): + keras_lib.model_to_estimator( + keras_model=keras_model, model_dir=self._base_dir, + config=run_config_lib.RunConfig(model_dir=_TMP_DIR)) def test_pretrained_weights(self): keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index aa594af2e4..6c1de166a4 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -48,7 +48,8 @@ _DEFAULT_REPLACEABLE_LIST = [ 'keep_checkpoint_every_n_hours', 'log_step_count_steps', 'train_distribute', - 'device_fn' + 'device_fn', + 'protocol' ] _SAVE_CKPT_ERR = ( @@ -288,6 +289,21 @@ def _validate_properties(run_config): message='device_fn must be callable with exactly' ' one argument "op".') + _validate('protocol', + lambda protocol: protocol in (None, "grpc", "grpc+verbs"), + message='protocol should be grpc or grpc+verbs') + + +def get_default_session_config(): + """Returns tf.ConfigProto instance.""" + + rewrite_opts = rewriter_config_pb2.RewriterConfig( + meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE) + graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts) + + return config_pb2.ConfigProto(allow_soft_placement=True, + graph_options=graph_opts) + class TaskType(object): MASTER = 'master' @@ -312,7 +328,8 @@ class RunConfig(object): keep_checkpoint_every_n_hours=10000, log_step_count_steps=100, train_distribute=None, - device_fn=None): + device_fn=None, + protocol=None): """Constructs a RunConfig. All distributed training related properties `cluster_spec`, `is_chief`, @@ -436,7 +453,7 @@ class RunConfig(object): the feature. log_step_count_steps: The frequency, in number of global steps, that the global step/sec and the loss will be logged during training. - train_distribute: an optional instance of + train_distribute: An optional instance of `tf.contrib.distribute.DistributionStrategy`. If specified, then Estimator will distribute the user's model during training, according to the policy specified by that strategy. @@ -444,6 +461,8 @@ class RunConfig(object): `Operation` and returns the device string. If `None`, defaults to the device function returned by `tf.train.replica_device_setter` with round-robin strategy. + protocol: An optional argument which specifies the protocol used when + starting server. None means default to grpc. Raises: ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs` @@ -481,7 +500,8 @@ class RunConfig(object): keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours, log_step_count_steps=log_step_count_steps, train_distribute=train_distribute, - device_fn=device_fn) + device_fn=device_fn, + protocol=protocol) self._init_distributed_setting_from_environment_var(tf_config) @@ -499,9 +519,9 @@ class RunConfig(object): RunConfig._replace( self, allowed_properties_list=_DEFAULT_REPLACEABLE_LIST, - session_config=self._get_default_session_config()) + session_config=self._get_default_session_config_distributed()) - def _get_default_session_config(self): + def _get_default_session_config_distributed(self): """Returns None or tf.ConfigProto instance with default device_filters set. Device filters are set such that chief/master and worker communicates with @@ -754,6 +774,11 @@ class RunConfig(object): """ return self._train_distribute + @property + def protocol(self): + """Returns the optional protocol value.""" + return self._protocol + def replace(self, **kwargs): """Returns a new instance of `RunConfig` replacing specified properties. @@ -769,7 +794,8 @@ class RunConfig(object): - `keep_checkpoint_every_n_hours`, - `log_step_count_steps`, - `train_distribute`, - - `device_fn`. + - `device_fn`, + - `protocol`. In addition, either `save_checkpoints_steps` or `save_checkpoints_secs` can be set (should not be both). diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index f5ac79ced2..a01b2300dd 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -732,7 +732,8 @@ class _TrainingExecutor(object): job_name=config.task_type, task_index=config.task_id, config=session_config, - start=False) + start=False, + protocol=config.protocol) server.start() return server diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index 6bee7cbe83..dc106c7d3b 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -472,6 +472,7 @@ class _TrainingExecutorTrainingTest(object): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -502,6 +503,7 @@ class _TrainingExecutorTrainingTest(object): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -729,6 +731,7 @@ class TrainingExecutorRunMasterTest(test.TestCase): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) @@ -1481,6 +1484,7 @@ class TrainingExecutorRunPsTest(test.TestCase): job_name=mock_est.config.task_type, task_index=mock_est.config.task_id, config=test.mock.ANY, + protocol=None, start=False) self.assertTrue(mock_server_instance.start.called) |