aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/BUILD16
-rw-r--r--tensorflow/python/estimator/api/BUILD4
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py439
-rw-r--r--tensorflow/python/estimator/canned/metric_keys.py5
-rw-r--r--tensorflow/python/estimator/estimator.py16
-rw-r--r--tensorflow/python/estimator/estimator_test.py42
-rw-r--r--tensorflow/python/estimator/export/export_output.py11
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py15
-rw-r--r--tensorflow/python/estimator/keras.py117
-rw-r--r--tensorflow/python/estimator/keras_test.py172
-rw-r--r--tensorflow/python/estimator/run_config.py40
-rw-r--r--tensorflow/python/estimator/training.py3
-rw-r--r--tensorflow/python/estimator/training_test.py4
13 files changed, 637 insertions, 247 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 8ee38d35cc..fd46163050 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -40,9 +40,9 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gc",
+ ":metric_keys",
+ ":util",
"//tensorflow:tensorflow_py_no_contrib",
- "//tensorflow/python/estimator:metric_keys",
- "//tensorflow/python/estimator:util",
],
)
@@ -683,9 +683,9 @@ py_test(
],
deps = [
":keras",
+ ":numpy_io",
+ ":run_config",
"//tensorflow:tensorflow_py_no_contrib",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/estimator:run_config",
"//third_party/py/numpy",
],
)
@@ -707,6 +707,14 @@ py_library(
)
py_library(
+ name = "expect_h5py_installed",
+ # This is a dummy rule used as a numpy dependency in open-source.
+ # We expect h5py to already be installed on the system, e.g. via
+ # `pip install h5py'
+ visibility = ["//visibility:public"],
+)
+
+py_library(
name = "expect_six_installed",
# This is a dummy rule used as a numpy dependency in open-source.
# We expect six to already be installed on the system, e.g. via
diff --git a/tensorflow/python/estimator/api/BUILD b/tensorflow/python/estimator/api/BUILD
index ceb9baef4d..a75fa7d0ae 100644
--- a/tensorflow/python/estimator/api/BUILD
+++ b/tensorflow/python/estimator/api/BUILD
@@ -6,8 +6,8 @@ package(
licenses(["notice"]) # Apache 2.0
-load("//tensorflow/tools/api/generator:api_gen.bzl", "gen_api_init_files")
-load("//tensorflow/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files")
+load("//tensorflow/python/tools/api/generator:api_gen.bzl", "ESTIMATOR_API_INIT_FILES")
gen_api_init_files(
name = "estimator_python_api_gen",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 3c832c7569..3292e2724d 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import functools
@@ -384,6 +385,249 @@ class _StopAtAttemptsHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+def _get_max_splits(tree_hparams):
+ """Calculates the max possible number of splits based on tree params."""
+ # maximum number of splits possible in the whole tree =2^(D-1)-1
+ max_splits = (1 << tree_hparams.max_depth) - 1
+ return max_splits
+
+
+class _EnsembleGrower(object):
+ """Abstract base class for different types of ensemble growers.
+
+ Use it to receive training ops for growing and centering bias, depending
+ on the implementation (for example, in memory or accumulator-based
+ distributed):
+ grower = ...create subclass grower(tree_ensemble, tree_hparams)
+ grow_op = grower.grow_tree(stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range)
+ training_ops.append(grow_op)
+ """
+
+ def __init__(self, tree_ensemble, tree_hparams):
+ """Initializes a grower object.
+
+ Args:
+ tree_ensemble: A TreeEnsemble variable.
+ tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ """
+ self._tree_ensemble = tree_ensemble
+ self._tree_hparams = tree_hparams
+
+ @abc.abstractmethod
+ def center_bias(self, center_bias_var, gradients, hessians):
+ """Centers bias, if ready, based on statistics.
+
+ Args:
+ center_bias_var: A variable that will be updated when bias centering
+ finished.
+ gradients: A rank 2 tensor of gradients.
+ hessians: A rank 2 tensor of hessians.
+
+ Returns:
+ An operation for centering bias.
+ """
+
+ @abc.abstractmethod
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ """Grows a tree, if ready, based on provided statistics.
+
+ Args:
+ stats_summaries_list: List of stats summary tensors, representing sums of
+ gradients and hessians for each feature bucket.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+ last_layer_nodes_range: A tensor representing ids of the nodes in the
+ current layer, to be split.
+
+ Returns:
+ An op for growing a tree.
+ """
+
+ # ============= Helper methods ===========
+
+ def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
+ """Updates the ensembles and cache (if needed) with logits prior."""
+ continue_centering = boosted_trees_ops.center_bias(
+ self._tree_ensemble.resource_handle,
+ mean_gradients=mean_gradients,
+ mean_hessians=mean_hessians,
+ l1=self._tree_hparams.l1,
+ l2=self._tree_hparams.l2)
+ return center_bias_var.assign(continue_centering)
+
+ def _grow_tree_from_stats_summaries(self, stats_summaries_list,
+ feature_ids_list, last_layer_nodes_range):
+ """Updates ensemble based on the best gains from stats summaries."""
+ node_ids_per_feature = []
+ gains_list = []
+ thresholds_list = []
+ left_node_contribs_list = []
+ right_node_contribs_list = []
+ all_feature_ids = []
+ assert len(stats_summaries_list) == len(feature_ids_list)
+
+ max_splits = _get_max_splits(self._tree_hparams)
+
+ for i, feature_ids in enumerate(feature_ids_list):
+ (numeric_node_ids_per_feature, numeric_gains_list,
+ numeric_thresholds_list, numeric_left_node_contribs_list,
+ numeric_right_node_contribs_list) = (
+ boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range=last_layer_nodes_range,
+ stats_summary_list=stats_summaries_list[i],
+ l1=self._tree_hparams.l1,
+ l2=self._tree_hparams.l2,
+ tree_complexity=self._tree_hparams.tree_complexity,
+ min_node_weight=self._tree_hparams.min_node_weight,
+ max_splits=max_splits))
+
+ all_feature_ids += feature_ids
+ node_ids_per_feature += numeric_node_ids_per_feature
+ gains_list += numeric_gains_list
+ thresholds_list += numeric_thresholds_list
+ left_node_contribs_list += numeric_left_node_contribs_list
+ right_node_contribs_list += numeric_right_node_contribs_list
+
+ grow_op = boosted_trees_ops.update_ensemble(
+ # Confirm if local_tree_ensemble or tree_ensemble should be used.
+ self._tree_ensemble.resource_handle,
+ feature_ids=all_feature_ids,
+ node_ids=node_ids_per_feature,
+ gains=gains_list,
+ thresholds=thresholds_list,
+ left_node_contribs=left_node_contribs_list,
+ right_node_contribs=right_node_contribs_list,
+ learning_rate=self._tree_hparams.learning_rate,
+ max_depth=self._tree_hparams.max_depth,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+ return grow_op
+
+
+class _InMemoryEnsembleGrower(_EnsembleGrower):
+ """A base class for ensemble growers."""
+
+ def __init__(self, tree_ensemble, tree_hparams):
+
+ super(_InMemoryEnsembleGrower, self).__init__(
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+
+ def center_bias(self, center_bias_var, gradients, hessians):
+ # For in memory, we already have a full batch of gradients and hessians,
+ # so just take a mean and proceed with centering.
+ mean_gradients = array_ops.expand_dims(
+ math_ops.reduce_mean(gradients, 0), 0)
+ mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
+ return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
+
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ # For in memory, we already have full data in one batch, so we can grow the
+ # tree immediately.
+ return self._grow_tree_from_stats_summaries(
+ stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+
+
+class _AccumulatorEnsembleGrower(_EnsembleGrower):
+ """A base class for ensemble growers."""
+
+ def __init__(self, tree_ensemble, tree_hparams, stamp_token,
+ n_batches_per_layer, bucket_size_list, is_chief):
+ super(_AccumulatorEnsembleGrower, self).__init__(
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ self._stamp_token = stamp_token
+ self._n_batches_per_layer = n_batches_per_layer
+ self._bucket_size_list = bucket_size_list
+ self._is_chief = is_chief
+
+ def center_bias(self, center_bias_var, gradients, hessians):
+ # For not in memory situation, we need to accumulate enough of batches first
+ # before proceeding with centering bias.
+
+ # Create an accumulator.
+ bias_dependencies = []
+ bias_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+
+ grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
+ grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
+
+ apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
+ bias_dependencies.append(apply_grad)
+
+ # Center bias if enough batches were processed.
+ with ops.control_dependencies(bias_dependencies):
+ if not self._is_chief:
+ return control_flow_ops.no_op()
+
+ def center_bias_from_accumulator():
+ accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
+ return self._center_bias_fn(center_bias_var,
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+
+ center_bias_op = control_flow_ops.cond(
+ math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ self._n_batches_per_layer),
+ center_bias_from_accumulator,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_for_bias_accumulated')
+ return center_bias_op
+
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ # For not in memory situation, we need to accumulate enough of batches first
+ # before proceeding with building a tree layer.
+ max_splits = _get_max_splits(self._tree_hparams)
+
+ # Prepare accumulators.
+ accumulators = []
+ dependencies = []
+ for i, feature_ids in enumerate(feature_ids_list):
+ stats_summaries = stats_summaries_list[i]
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ accumulators.append(accumulator)
+
+ apply_grad = accumulator.apply_grad(
+ array_ops.stack(stats_summaries, axis=0), self._stamp_token)
+ dependencies.append(apply_grad)
+
+ # Grow the tree if enough batches is accumulated.
+ with ops.control_dependencies(dependencies):
+ if not self._is_chief:
+ return control_flow_ops.no_op()
+
+ min_accumulated = math_ops.reduce_min(
+ array_ops.stack([acc.num_accumulated() for acc in accumulators]))
+
+ def grow_tree_from_accumulated_summaries_fn():
+ """Updates tree with the best layer from accumulated summaries."""
+ # Take out the accumulated summaries from the accumulator and grow.
+ stats_summaries_list = []
+ stats_summaries_list = [
+ array_ops.unstack(accumulator.take_grad(1), axis=0)
+ for accumulator in accumulators
+ ]
+ grow_op = self._grow_tree_from_stats_summaries(
+ stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ return grow_op
+
+ grow_model = control_flow_ops.cond(
+ math_ops.greater_equal(min_accumulated, self._n_batches_per_layer),
+ grow_tree_from_accumulated_summaries_fn,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_accumulated')
+ return grow_model
+
+
def _bt_model_fn(
features,
labels,
@@ -441,11 +685,6 @@ def _bt_model_fn(
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
- # maximum number of splits possible in the whole tree =2^(D-1)-1
- # TODO(youngheek): perhaps storage could be optimized by storing stats with
- # the dimension max_splits_per_layer, instead of max_splits (for the entire
- # tree).
- max_splits = (1 << tree_hparams.max_depth) - 1
train_op = []
with ops.name_scope(name) as name:
# Prepare.
@@ -543,6 +782,11 @@ def _bt_model_fn(
hessians = gradients_impl.gradients(
gradients, logits, name='Hessians')[0]
+ # TODO(youngheek): perhaps storage could be optimized by storing stats
+ # with the dimension max_splits_per_layer, instead of max_splits (for the
+ # entire tree).
+ max_splits = _get_max_splits(tree_hparams)
+
stats_summaries_list = []
for i, feature_ids in enumerate(feature_ids_list):
num_buckets = bucket_size_list[i]
@@ -559,173 +803,28 @@ def _bt_model_fn(
]
stats_summaries_list.append(summaries)
- # ========= Helper methods for both in and not in memory. ==============
- def grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list):
- """Updates ensemble based on the best gains from stats summaries."""
- node_ids_per_feature = []
- gains_list = []
- thresholds_list = []
- left_node_contribs_list = []
- right_node_contribs_list = []
- all_feature_ids = []
-
- assert len(stats_summaries_list) == len(feature_ids_list)
-
- for i, feature_ids in enumerate(feature_ids_list):
- (numeric_node_ids_per_feature, numeric_gains_list,
- numeric_thresholds_list, numeric_left_node_contribs_list,
- numeric_right_node_contribs_list) = (
- boosted_trees_ops.calculate_best_gains_per_feature(
- node_id_range=last_layer_nodes_range,
- stats_summary_list=stats_summaries_list[i],
- l1=tree_hparams.l1,
- l2=tree_hparams.l2,
- tree_complexity=tree_hparams.tree_complexity,
- min_node_weight=tree_hparams.min_node_weight,
- max_splits=max_splits))
-
- all_feature_ids += feature_ids
- node_ids_per_feature += numeric_node_ids_per_feature
- gains_list += numeric_gains_list
- thresholds_list += numeric_thresholds_list
- left_node_contribs_list += numeric_left_node_contribs_list
- right_node_contribs_list += numeric_right_node_contribs_list
-
- grow_op = boosted_trees_ops.update_ensemble(
- # Confirm if local_tree_ensemble or tree_ensemble should be used.
- tree_ensemble.resource_handle,
- feature_ids=all_feature_ids,
- node_ids=node_ids_per_feature,
- gains=gains_list,
- thresholds=thresholds_list,
- left_node_contribs=left_node_contribs_list,
- right_node_contribs=right_node_contribs_list,
- learning_rate=tree_hparams.learning_rate,
- max_depth=tree_hparams.max_depth,
- pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
- return grow_op
-
- def _center_bias_fn(mean_gradients, mean_hessians):
- """Updates the ensembles and cache (if needed) with logits prior."""
- continue_centering = boosted_trees_ops.center_bias(
- tree_ensemble.resource_handle,
- mean_gradients=mean_gradients,
- mean_hessians=mean_hessians,
- l1=tree_hparams.l1,
- l2=tree_hparams.l2
- )
- return center_bias_var.assign(continue_centering)
-
- # ========= End of helper methods. ==============
-
if train_in_memory and is_single_machine:
- train_op.append(distribute_lib.increment_var(global_step))
-
- mean_gradients = array_ops.expand_dims(
- math_ops.reduce_mean(gradients, 0), 0)
- mean_heassians = array_ops.expand_dims(
- math_ops.reduce_mean(hessians, 0), 0)
-
- train_op.append(
- control_flow_ops.cond(
- center_bias_var,
- lambda: _center_bias_fn(mean_gradients, mean_heassians),
- functools.partial(grow_tree_from_stats_summaries,
- stats_summaries_list, feature_ids_list)))
+ grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
else:
-
- def center_bias_not_in_mem():
- """Accumulates the data and updates the logits bias, when ready."""
- bias_dependencies = []
-
- bias_accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians means only.
- # TODO(nponomareva): this will change for a multiclass
- shape=[2, 1],
- shared_name='bias_accumulator')
-
- grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
- grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
-
- apply_grad = bias_accumulator.apply_grad(grads_and_hess, stamp_token)
- bias_dependencies.append(apply_grad)
-
- def center_bias_from_accumulator():
- accumulated = array_ops.unstack(
- bias_accumulator.take_grad(1), axis=0)
- return _center_bias_fn(
- array_ops.expand_dims(accumulated[0], 0),
- array_ops.expand_dims(accumulated[1], 0))
-
- with ops.control_dependencies(bias_dependencies):
- if config.is_chief:
- center_bias_op = control_flow_ops.cond(
- math_ops.greater_equal(bias_accumulator.num_accumulated(),
- n_batches_per_layer),
- center_bias_from_accumulator,
- control_flow_ops.no_op,
- name='wait_until_n_batches_for_bias_accumulated')
-
- return center_bias_op
- else:
- return control_flow_ops.no_op()
-
- def grow_not_in_mem():
- """Accumulates the data and grows a layer when ready."""
-
- accumulators = []
- dependencies = []
- for i, feature_ids in enumerate(feature_ids_list):
- stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
- array_ops.stack(stats_summaries, axis=0), stamp_token)
- dependencies.append(apply_grad)
-
- def grow_tree_from_accumulated_summaries_fn():
- """Updates tree with the best layer from accumulated summaries."""
- # Take out the accumulated summaries from the accumulator and grow.
- stats_summaries_list = []
-
- stats_summaries_list = [
- array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
- ]
-
- grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list)
- return grow_op
-
- with ops.control_dependencies(dependencies):
- if config.is_chief:
- min_accumulated = math_ops.reduce_min(
- array_ops.stack(
- [acc.num_accumulated() for acc in accumulators]))
-
- grow_model = control_flow_ops.cond(
- math_ops.greater_equal(min_accumulated, n_batches_per_layer),
- grow_tree_from_accumulated_summaries_fn,
- control_flow_ops.no_op,
- name='wait_until_n_batches_accumulated')
-
- return grow_model
- else:
- return control_flow_ops.no_op()
-
- update_model = control_flow_ops.cond(
- center_bias_var, center_bias_not_in_mem, grow_not_in_mem)
- train_op.append(update_model)
- with ops.control_dependencies([update_model]):
- increment_global = distribute_lib.increment_var(global_step)
- train_op.append(increment_global)
+ grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
+ stamp_token, n_batches_per_layer,
+ bucket_size_list, config.is_chief)
+
+ update_model = control_flow_ops.cond(
+ center_bias_var,
+ functools.partial(
+ grower.center_bias,
+ center_bias_var,
+ gradients,
+ hessians,
+ ),
+ functools.partial(grower.grow_tree, stats_summaries_list,
+ feature_ids_list, last_layer_nodes_range))
+ train_op.append(update_model)
+
+ with ops.control_dependencies([update_model]):
+ increment_global = distribute_lib.increment_var(global_step)
+ train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
diff --git a/tensorflow/python/estimator/canned/metric_keys.py b/tensorflow/python/estimator/canned/metric_keys.py
index 4f7c849ba4..9d49240fea 100644
--- a/tensorflow/python/estimator/canned/metric_keys.py
+++ b/tensorflow/python/estimator/canned/metric_keys.py
@@ -47,3 +47,8 @@ class MetricKeys(object):
PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d'
AUC_AT_CLASS = 'auc/class%d'
AUC_PR_AT_CLASS = 'auc_precision_recall/class%d'
+
+ # The following require a class name applied.
+ PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s'
+ AUC_AT_NAME = 'auc/%s'
+ AUC_PR_AT_NAME = 'auc_precision_recall/%s'
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 350a95eea1..915ceeb98b 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -29,8 +29,6 @@ import six
from google.protobuf import message
from tensorflow.core.framework import summary_pb2
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
@@ -216,11 +214,7 @@ class Estimator(object):
logging.info('Using config: %s', str(vars(self._config)))
if self._config.session_config is None:
- rewrite_opts = rewriter_config_pb2.RewriterConfig(
- meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
- graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
- self._session_config = config_pb2.ConfigProto(
- allow_soft_placement=True, graph_options=graph_opts)
+ self._session_config = run_config.get_default_session_config()
else:
self._session_config = self._config.session_config
@@ -573,10 +567,16 @@ class Estimator(object):
def _assert_members_are_not_overridden(self):
"""Asserts members of `Estimator` are not overridden."""
+ # TPUEstimator is special cased (owned by TF).
+ if self.__class__.__name__ == 'TPUEstimator':
+ return
+
allowed_overrides = set([
'_call_input_fn', '_create_global_step',
'_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_tf_api_names', '_estimator_api_names', '_estimator_api_constants',
+ '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
+ '_estimator_api_names_v1', '_estimator_api_constants',
+ '_estimator_api_constants_v1',
'_validate_features_in_predict_input',
'_call_model_fn', '_add_meta_graph_for_mode'
])
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 2a0e4e7617..8bc410ba0b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -28,6 +28,7 @@ import six
from google.protobuf import text_format
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
@@ -203,6 +204,10 @@ class EstimatorConstructorTest(test.TestCase):
est = estimator.Estimator(model_fn=model_fn)
self.assertTrue(isinstance(est.config, run_config.RunConfig))
+ self.assertTrue(est._session_config.allow_soft_placement)
+ rewrite_options = est._session_config.graph_options.rewrite_options
+ self.assertEqual(rewrite_options.meta_optimizer_iterations,
+ rewriter_config_pb2.RewriterConfig.ONE)
def test_default_model_dir(self):
@@ -2304,6 +2309,43 @@ class EstimatorExportTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, err_regex):
est._export_all_saved_models(export_dir_base, input_receiver_fn_map)
+ def test_export_all_saved_models_metric_operation(self):
+ """Ensures metrics ops.Operations can be expoerted (b/109740581)."""
+
+ def _model_fn(features, labels, mode):
+ del features, labels # Unused
+ metrics = {'metrics': (constant_op.constant([0]),
+ control_flow_ops.no_op())}
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=constant_op.constant(10.),
+ loss=constant_op.constant(1.),
+ train_op=state_ops.assign_add(training.get_global_step(), 1),
+ eval_metric_ops=metrics)
+
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(input_fn=dummy_input_fn, steps=1)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('metric_operation_export'))
+
+ input_receiver_fn_map = {
+ model_fn_lib.ModeKeys.EVAL: _get_supervised_input_receiver_fn()}
+
+ export_dir = est._export_all_saved_models(
+ export_dir_base, input_receiver_fn_map)
+
+ # Restore, to validate that the export was well-formed.
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
+ sig_outputs = meta_graph.signature_def[
+ model_fn_lib.ModeKeys.EVAL].outputs
+ self.assertEqual(
+ sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 6c26d29985..20382a58d8 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -23,6 +23,7 @@ import abc
import six
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_def_utils
@@ -338,8 +339,16 @@ class _SupervisedOutput(ExportOutput):
raise ValueError(
'{} update_op must be a Tensor or Operation; got {}.'.format(
key, metric_op))
+
+ # We must wrap any ops in a Tensor before export, as the SignatureDef
+ # proto expects tensors only. See b/109740581
+ metric_op_tensor = metric_op
+ if isinstance(metric_op, ops.Operation):
+ with ops.control_dependencies([metric_op]):
+ metric_op_tensor = constant_op.constant([], name='metric_op_wrapper')
+
outputs[val_name] = metric_val
- outputs[op_name] = metric_op
+ outputs[op_name] = metric_op_tensor
return outputs
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index b21ba91b0f..d94c764fd7 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -24,8 +24,10 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -335,5 +337,18 @@ class SupervisedOutputTest(test.TestCase):
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
+ def test_metric_op_is_operation(self):
+ """Tests that ops.Operation is wrapped by a tensor for metric_ops."""
+ loss = {"my_loss": constant_op.constant([0])}
+ predictions = {u"output1": constant_op.constant(["foo"])}
+ metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())}
+
+ outputter = MockSupervisedOutput(loss, predictions, metrics)
+ self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(
+ outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0")
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics/update_op"], ops.Tensor))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 076359b503..70517ae278 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,11 +21,14 @@ from __future__ import print_function
import os
import re
+import tempfile
+
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import export as export_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.estimator.run_config import RunConfig
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
@@ -39,6 +42,7 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import distribute as distribute_lib
@@ -180,7 +184,7 @@ def _in_place_subclassed_model_reset(model):
# Replace layers on the model with fresh layers
layers_to_names = {value: key for key, value in attributes_cache.items()}
original_layers = model._layers[:]
- model._layers = []
+ model._layers = data_structures.NoDependency([])
for layer in original_layers: # We preserve layer order.
config = layer.get_config()
# This will not work for nested subclassed models used as layers.
@@ -228,7 +232,8 @@ def _in_place_subclassed_model_reset(model):
]
for name in attributes_to_cache:
attributes_cache[name] = getattr(model, name)
- model._original_attributes_cache = attributes_cache
+ model._original_attributes_cache = data_structures.NoDependency(
+ attributes_cache)
# Reset built state
model.built = False
model.inputs = None
@@ -426,29 +431,34 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
return model_fn
-def _save_first_checkpoint(keras_model, estimator, custom_objects,
- keras_weights):
+def _save_first_checkpoint(keras_model, custom_objects, config):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
- estimator: keras estimator.
custom_objects: Dictionary for custom objects.
- keras_weights: A flat list of Numpy arrays for weights of given keras_model.
+ config: Estimator config.
Returns:
- The model_fn for a keras Estimator.
+ The path where keras model checkpoint is saved.
"""
+ # save checkpoint into subdirectory to allow warm start
+ keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
- latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
+ latest_path = saver_lib.latest_checkpoint(keras_model_dir)
if not latest_path:
+ keras_weights = None
+ if _any_weight_initialized(keras_model):
+ keras_weights = keras_model.get_weights()
+ if not gfile.IsDirectory(keras_model_dir):
+ gfile.MakeDirs(keras_model_dir)
with ops.Graph().as_default():
- random_seed.set_random_seed(estimator.config.tf_random_seed)
+ random_seed.set_random_seed(config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
# save to checkpoint
- with session.Session(config=estimator._session_config) as sess:
+ with session.Session(config=config.session_config) as sess:
if keras_weights:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
@@ -458,7 +468,46 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects,
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
- saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
+ latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
+ saver.save(sess, latest_path)
+ return latest_path
+
+
+def _maybe_overwrite_model_dir_and_session_config(config, model_dir):
+ """Overwrite estimator config by `model_dir` and `session_config` if needed.
+
+ Args:
+ config: Original estimator config.
+ model_dir: Estimator model checkpoint directory.
+
+ Returns:
+ Overwritten estimator config.
+
+ Raises:
+ ValueError: Model directory inconsistent between `model_dir` and `config`.
+ """
+
+ default_session_config = run_config_lib.get_default_session_config()
+ if isinstance(config, dict):
+ config = RunConfig(**config)
+ elif config is None:
+ config = RunConfig(session_config=default_session_config)
+ if config.session_config is None:
+ config = RunConfig.replace(config, session_config=default_session_config)
+
+ if model_dir is not None:
+ if (getattr(config, 'model_dir', None) is not None and
+ config.model_dir != model_dir):
+ raise ValueError(
+ "`model_dir` are set both in constructor and `RunConfig`, but with "
+ "different values. In constructor: '{}', in `RunConfig`: "
+ "'{}' ".format(model_dir, config.model_dir))
+ config = RunConfig.replace(config, model_dir=model_dir)
+ elif getattr(config, 'model_dir', None) is None:
+ model_dir = tempfile.mkdtemp()
+ config = RunConfig.replace(config, model_dir=model_dir)
+
+ return config
def model_to_estimator(keras_model=None,
@@ -517,45 +566,39 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- if isinstance(config, dict):
- config = run_config_lib.RunConfig(**config)
+ config = _maybe_overwrite_model_dir_and_session_config(config, model_dir)
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
- estimator = estimator_lib.Estimator(
- keras_model_fn, model_dir=model_dir, config=config)
-
- # Check if we need to call get_weights:
if _any_weight_initialized(keras_model):
- keras_weights = keras_model.get_weights()
# Warn if config passed to estimator tries to update GPUOptions. If a
# session has already been created, the GPUOptions passed to the first
# session sticks.
- if estimator._session_config.HasField('gpu_options'):
+ if config.session_config.HasField('gpu_options'):
logging.warning(
'The Keras backend session has already been set. '
'The _session_config passed to model_to_estimator will not be used.')
else:
# Pass the config into keras backend's default session.
- sess = session.Session(config=estimator._session_config)
+ sess = session.Session(config=config.session_config)
K.set_session(sess)
- keras_weights = None
+ warm_start_path = None
if keras_model._is_graph_network:
- # TODO(yifeif): move checkpoint initialization to scaffold.init_fn
- _save_first_checkpoint(keras_model,
- estimator,
- custom_objects,
- keras_weights)
+ warm_start_path = _save_first_checkpoint(keras_model, custom_objects,
+ config)
elif keras_model.built:
- logging.warning('You are creating an Estimator from a Keras model '
- 'manually subclassed from `Model`, that was '
- 'already called on some inputs (and thus already had '
- 'weights). We are currently unable to preserve '
- 'the model\'s state (its weights) '
- 'as part of the estimator '
- 'in this case. Be warned that the estimator '
- 'has been created using '
- 'a freshly initialized version of your model.\n'
- 'Note that this doesn\'t affect the state of the '
- 'model instance you passed as `keras_model` argument.')
+ logging.warning('You are creating an Estimator from a Keras model manually '
+ 'subclassed from `Model`, that was already called on some '
+ 'inputs (and thus already had weights). We are currently '
+ 'unable to preserve the model\'s state (its weights) as '
+ 'part of the estimator in this case. Be warned that the '
+ 'estimator has been created using a freshly initialized '
+ 'version of your model.\n'
+ 'Note that this doesn\'t affect the state of the model '
+ 'instance you passed as `keras_model` argument.')
+
+ estimator = estimator_lib.Estimator(keras_model_fn,
+ config=config,
+ warm_start_from=warm_start_path)
+
return estimator
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 7a4457f5a4..cf4ec7f4da 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -32,13 +32,14 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
-from tensorflow.python.keras.applications import mobilenet
from tensorflow.python.keras.optimizers import SGD
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import rmsprop
+from tensorflow.python.training import session_run_hook
try:
@@ -51,6 +52,8 @@ _TRAIN_SIZE = 200
_INPUT_SIZE = (10,)
_NUM_CLASS = 2
+_TMP_DIR = '/tmp'
+
def simple_sequential_model():
model = keras.models.Sequential()
@@ -60,9 +63,9 @@ def simple_sequential_model():
return model
-def simple_functional_model():
+def simple_functional_model(activation='relu'):
a = keras.layers.Input(shape=_INPUT_SIZE)
- b = keras.layers.Dense(16, activation='relu')(a)
+ b = keras.layers.Dense(16, activation=activation)(a)
b = keras.layers.Dropout(0.1)(b)
b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)
model = keras.models.Model(inputs=[a], outputs=[b])
@@ -168,6 +171,12 @@ def multi_inputs_multi_outputs_model():
return model
+class MyHook(session_run_hook.SessionRunHook):
+
+ def begin(self):
+ _ = variable_scope.get_variable('temp', [1])
+
+
class TestKerasEstimator(test_util.TensorFlowTestCase):
def setUp(self):
@@ -204,6 +213,54 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
writer_cache.FileWriterCache.clear()
gfile.DeleteRecursively(self._config.model_dir)
+ # see b/109935364
+ @test_util.run_in_graph_and_eager_modes
+ def test_train_with_hooks(self):
+ for model_type in ['sequential', 'functional']:
+ keras_model, (_, _), (
+ _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type=model_type, is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ my_hook = MyHook()
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ before_eval_results = est_keras.evaluate(
+ input_fn=eval_input_fn, steps=1)
+ est_keras.train(input_fn=train_input_fn, hooks=[my_hook],
+ steps=_TRAIN_SIZE / 16)
+ after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
+ writer_cache.FileWriterCache.clear()
+ gfile.DeleteRecursively(self._config.model_dir)
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_train_with_model_fit_and_hooks(self):
+ keras_model, (x_train, y_train), _, \
+ train_input_fn, eval_input_fn = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer=rmsprop.RMSPropOptimizer(1e-3),
+ metrics=['mse', keras.metrics.categorical_accuracy])
+ my_hook = MyHook()
+ with self.test_session():
+ keras_model.fit(x_train, y_train, epochs=1)
+
+ keras_est = keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ before_eval_results = keras_est.evaluate(input_fn=eval_input_fn)
+ keras_est.train(input_fn=train_input_fn, hooks=[my_hook],
+ steps=_TRAIN_SIZE / 16)
+ after_eval_results = keras_est.evaluate(input_fn=eval_input_fn, steps=1)
+ self.assertLess(after_eval_results['loss'], before_eval_results['loss'])
+
@test_util.run_in_graph_and_eager_modes
def test_train_with_tf_optimizer(self):
for model_type in ['sequential', 'functional']:
@@ -474,23 +531,43 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
- keras_mobile = mobilenet.MobileNet(weights=None)
- keras_mobile.compile(loss='categorical_crossentropy', optimizer='adam')
+
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+
+ keras_model = simple_functional_model(activation=relu6)
+ keras_model.compile(loss='categorical_crossentropy', optimizer='adam')
custom_objects = {
- 'relu6': mobilenet.relu6,
- 'DepthwiseConv2D': mobilenet.DepthwiseConv2D
+ 'relu6': relu6
}
+
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=_TRAIN_SIZE,
+ test_samples=50,
+ input_shape=(10,),
+ num_classes=2)
+ y_train = keras.utils.to_categorical(y_train, 2)
+ input_name = keras_model.input_names[0]
+ output_name = keras_model.output_names[0]
+ train_input_fn = numpy_io.numpy_input_fn(
+ x=randomize_io_type(x_train, input_name),
+ y=randomize_io_type(y_train, output_name),
+ shuffle=False,
+ num_epochs=None,
+ batch_size=16)
with self.assertRaisesRegexp(ValueError, 'relu6'):
with self.test_session():
- keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ est = keras_lib.model_to_estimator(
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir))
+ est.train(input_fn=train_input_fn, steps=1)
with self.test_session():
- keras_lib.model_to_estimator(
- keras_model=keras_mobile,
+ est = keras_lib.model_to_estimator(
+ keras_model=keras_model,
model_dir=tempfile.mkdtemp(dir=self._base_dir),
custom_objects=custom_objects)
+ est.train(input_fn=train_input_fn, steps=1)
def test_tf_config(self):
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
@@ -527,12 +604,73 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
self._config._session_config = sess_config
- keras_lib.model_to_estimator(
- keras_model=keras_model, config=self._config)
- self.assertEqual(
- keras.backend.get_session()
- ._config.gpu_options.per_process_gpu_memory_fraction,
- gpu_options.per_process_gpu_memory_fraction)
+ with self.test_session():
+ keras_lib.model_to_estimator(
+ keras_model=keras_model, config=self._config)
+ self.assertEqual(
+ keras.backend.get_session()
+ ._config.gpu_options.per_process_gpu_memory_fraction,
+ gpu_options.per_process_gpu_memory_fraction)
+
+ def test_with_empty_config(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=run_config_lib.RunConfig())
+ self.assertEqual(run_config_lib.get_default_session_config(),
+ est_keras._session_config)
+ self.assertEqual(est_keras._session_config,
+ est_keras._config.session_config)
+ self.assertEqual(self._base_dir, est_keras._config.model_dir)
+ self.assertEqual(self._base_dir, est_keras._model_dir)
+
+ with self.test_session():
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=None)
+ self.assertEqual(run_config_lib.get_default_session_config(),
+ est_keras._session_config)
+ self.assertEqual(est_keras._session_config,
+ est_keras._config.session_config)
+ self.assertEqual(self._base_dir, est_keras._config.model_dir)
+ self.assertEqual(self._base_dir, est_keras._model_dir)
+
+ def test_with_empty_config_and_empty_model_dir(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
+ est_keras = keras_lib.model_to_estimator(
+ keras_model=keras_model,
+ config=run_config_lib.RunConfig())
+ self.assertEqual(est_keras._model_dir, _TMP_DIR)
+
+ def test_with_conflicting_model_dir_and_config(self):
+ keras_model, _, _, _, _ = get_resource_for_simple_model(
+ model_type='sequential', is_evaluate=True)
+ keras_model.compile(
+ loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=['mse', keras.metrics.categorical_accuracy])
+
+ with self.test_session():
+ with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
+ 'constructor and `RunConfig`'):
+ keras_lib.model_to_estimator(
+ keras_model=keras_model, model_dir=self._base_dir,
+ config=run_config_lib.RunConfig(model_dir=_TMP_DIR))
def test_pretrained_weights(self):
keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index aa594af2e4..6c1de166a4 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -48,7 +48,8 @@ _DEFAULT_REPLACEABLE_LIST = [
'keep_checkpoint_every_n_hours',
'log_step_count_steps',
'train_distribute',
- 'device_fn'
+ 'device_fn',
+ 'protocol'
]
_SAVE_CKPT_ERR = (
@@ -288,6 +289,21 @@ def _validate_properties(run_config):
message='device_fn must be callable with exactly'
' one argument "op".')
+ _validate('protocol',
+ lambda protocol: protocol in (None, "grpc", "grpc+verbs"),
+ message='protocol should be grpc or grpc+verbs')
+
+
+def get_default_session_config():
+ """Returns tf.ConfigProto instance."""
+
+ rewrite_opts = rewriter_config_pb2.RewriterConfig(
+ meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)
+ graph_opts = config_pb2.GraphOptions(rewrite_options=rewrite_opts)
+
+ return config_pb2.ConfigProto(allow_soft_placement=True,
+ graph_options=graph_opts)
+
class TaskType(object):
MASTER = 'master'
@@ -312,7 +328,8 @@ class RunConfig(object):
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
train_distribute=None,
- device_fn=None):
+ device_fn=None,
+ protocol=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -436,7 +453,7 @@ class RunConfig(object):
the feature.
log_step_count_steps: The frequency, in number of global steps, that the
global step/sec and the loss will be logged during training.
- train_distribute: an optional instance of
+ train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
according to the policy specified by that strategy.
@@ -444,6 +461,8 @@ class RunConfig(object):
`Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter`
with round-robin strategy.
+ protocol: An optional argument which specifies the protocol used when
+ starting server. None means default to grpc.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -481,7 +500,8 @@ class RunConfig(object):
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
log_step_count_steps=log_step_count_steps,
train_distribute=train_distribute,
- device_fn=device_fn)
+ device_fn=device_fn,
+ protocol=protocol)
self._init_distributed_setting_from_environment_var(tf_config)
@@ -499,9 +519,9 @@ class RunConfig(object):
RunConfig._replace(
self,
allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,
- session_config=self._get_default_session_config())
+ session_config=self._get_default_session_config_distributed())
- def _get_default_session_config(self):
+ def _get_default_session_config_distributed(self):
"""Returns None or tf.ConfigProto instance with default device_filters set.
Device filters are set such that chief/master and worker communicates with
@@ -754,6 +774,11 @@ class RunConfig(object):
"""
return self._train_distribute
+ @property
+ def protocol(self):
+ """Returns the optional protocol value."""
+ return self._protocol
+
def replace(self, **kwargs):
"""Returns a new instance of `RunConfig` replacing specified properties.
@@ -769,7 +794,8 @@ class RunConfig(object):
- `keep_checkpoint_every_n_hours`,
- `log_step_count_steps`,
- `train_distribute`,
- - `device_fn`.
+ - `device_fn`,
+ - `protocol`.
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index f5ac79ced2..a01b2300dd 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -732,7 +732,8 @@ class _TrainingExecutor(object):
job_name=config.task_type,
task_index=config.task_id,
config=session_config,
- start=False)
+ start=False,
+ protocol=config.protocol)
server.start()
return server
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 6bee7cbe83..dc106c7d3b 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -472,6 +472,7 @@ class _TrainingExecutorTrainingTest(object):
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=test.mock.ANY,
+ protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
@@ -502,6 +503,7 @@ class _TrainingExecutorTrainingTest(object):
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=test.mock.ANY,
+ protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
@@ -729,6 +731,7 @@ class TrainingExecutorRunMasterTest(test.TestCase):
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=test.mock.ANY,
+ protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
@@ -1481,6 +1484,7 @@ class TrainingExecutorRunPsTest(test.TestCase):
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=test.mock.ANY,
+ protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)