aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-31 17:13:04 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-31 17:13:04 +0800
commitb3114e5b1e930c4dd1a1fdfaac721a219677d611 (patch)
tree7bfb4c4e8a9a158e86752a73117559df3d0386c1 /tensorflow/python/estimator
parentf8ee9799e6a72d4fe24f9fad76d6e6b1b3a01af1 (diff)
parent9357b2558adc13c479c8edb66c5002c5c6ec3664 (diff)
Merge remote-tracking branch 'upstream/master' into ENH/feature_importances_for_boosted_tree
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/BUILD3
-rw-r--r--tensorflow/python/estimator/canned/baseline_test.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py329
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py20
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py17
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py6
-rw-r--r--tensorflow/python/estimator/canned/head.py4
-rw-r--r--tensorflow/python/estimator/canned/linear.py4
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py10
-rw-r--r--tensorflow/python/estimator/canned/prediction_keys.py1
-rw-r--r--tensorflow/python/estimator/estimator.py879
-rw-r--r--tensorflow/python/estimator/estimator_test.py209
-rw-r--r--tensorflow/python/estimator/export/export.py50
-rw-r--r--tensorflow/python/estimator/export/export_output.py24
-rw-r--r--tensorflow/python/estimator/export/export_output_test.py89
-rw-r--r--tensorflow/python/estimator/export/export_test.py45
-rw-r--r--tensorflow/python/estimator/exporter_test.py37
-rw-r--r--tensorflow/python/estimator/gc.py8
-rw-r--r--tensorflow/python/estimator/gc_test.py11
-rw-r--r--tensorflow/python/estimator/inputs/numpy_io_test.py162
-rw-r--r--tensorflow/python/estimator/keras.py319
-rw-r--r--tensorflow/python/estimator/keras_test.py10
-rw-r--r--tensorflow/python/estimator/model_fn.py72
-rw-r--r--tensorflow/python/estimator/model_fn_test.py155
-rw-r--r--tensorflow/python/estimator/run_config.py29
-rw-r--r--tensorflow/python/estimator/training.py41
-rw-r--r--tensorflow/python/estimator/training_test.py33
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/estimator/util_test.py4
29 files changed, 1636 insertions, 953 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 817c8e6848..9fce172bee 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -211,6 +211,9 @@ py_test(
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
+ "manual",
+ "no_oss",
+ "notap",
"optonly",
],
deps = [
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
index e46a3a156d..1df7216ba6 100644
--- a/tensorflow/python/estimator/canned/baseline_test.py
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -42,13 +42,13 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import queue_runner
@@ -490,7 +490,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -498,7 +498,7 @@ class BaselineRegressorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -693,13 +693,13 @@ class BaselineClassifierTrainingTest(test.TestCase):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer.Optimizer,
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 66784fad0c..1c7e2189c2 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -42,7 +42,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -445,18 +444,21 @@ class _EnsembleGrower(object):
training_ops.append(grow_op)
"""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
"""Initializes a grower object.
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+
Raises:
ValueError: when pruning mode is invalid or pruning is used and no tree
complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
+ self._feature_ids_list = feature_ids_list
# pylint: disable=protected-access
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
@@ -481,14 +483,12 @@ class _EnsembleGrower(object):
"""
@abc.abstractmethod
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
"""Grows a tree, if ready, based on provided statistics.
Args:
stats_summaries_list: List of stats summary tensors, representing sums of
gradients and hessians for each feature bucket.
- feature_ids_list: a list of lists of feature ids for each bucket size.
last_layer_nodes_range: A tensor representing ids of the nodes in the
current layer, to be split.
@@ -496,6 +496,10 @@ class _EnsembleGrower(object):
An op for growing a tree.
"""
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.no_op()
+
# ============= Helper methods ===========
def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
@@ -509,7 +513,7 @@ class _EnsembleGrower(object):
return center_bias_var.assign(continue_centering)
def _grow_tree_from_stats_summaries(self, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range):
+ last_layer_nodes_range):
"""Updates ensemble based on the best gains from stats summaries."""
node_ids_per_feature = []
gains_list = []
@@ -517,11 +521,11 @@ class _EnsembleGrower(object):
left_node_contribs_list = []
right_node_contribs_list = []
all_feature_ids = []
- assert len(stats_summaries_list) == len(feature_ids_list)
+ assert len(stats_summaries_list) == len(self._feature_ids_list)
max_splits = _get_max_splits(self._tree_hparams)
- for i, feature_ids in enumerate(feature_ids_list):
+ for i, feature_ids in enumerate(self._feature_ids_list):
(numeric_node_ids_per_feature, numeric_gains_list,
numeric_thresholds_list, numeric_left_node_contribs_list,
numeric_right_node_contribs_list) = (
@@ -557,12 +561,13 @@ class _EnsembleGrower(object):
class _InMemoryEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An in-memory ensemble grower."""
- def __init__(self, tree_ensemble, tree_hparams):
+ def __init__(self, tree_ensemble, tree_hparams, feature_ids_list):
super(_InMemoryEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
def center_bias(self, center_bias_var, gradients, hessians):
# For in memory, we already have a full batch of gradients and hessians,
@@ -572,83 +577,98 @@ class _InMemoryEnsembleGrower(_EnsembleGrower):
mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
# For in memory, we already have full data in one batch, so we can grow the
# tree immediately.
return self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range)
class _AccumulatorEnsembleGrower(_EnsembleGrower):
- """A base class for ensemble growers."""
+ """An accumulator based ensemble grower."""
def __init__(self, tree_ensemble, tree_hparams, stamp_token,
- n_batches_per_layer, bucket_size_list, is_chief):
+ n_batches_per_layer, bucket_size_list, is_chief, center_bias,
+ feature_ids_list):
super(_AccumulatorEnsembleGrower, self).__init__(
- tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams,
+ feature_ids_list=feature_ids_list)
self._stamp_token = stamp_token
self._n_batches_per_layer = n_batches_per_layer
self._bucket_size_list = bucket_size_list
self._is_chief = is_chief
+ self._growing_accumulators = []
+ self._chief_init_ops = []
+ max_splits = _get_max_splits(self._tree_hparams)
+ for i, feature_ids in enumerate(self._feature_ids_list):
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ self._chief_init_ops.append(
+ accumulator.set_global_step(self._stamp_token))
+ self._growing_accumulators.append(accumulator)
+ self._center_bias = center_bias
+ if center_bias:
+ self._bias_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+ self._chief_init_ops.append(
+ self._bias_accumulator.set_global_step(self._stamp_token))
def center_bias(self, center_bias_var, gradients, hessians):
# For not in memory situation, we need to accumulate enough of batches first
# before proceeding with centering bias.
# Create an accumulator.
+ if not self._center_bias:
+ raise RuntimeError('center_bias called but bias centering is disabled.')
bias_dependencies = []
- bias_accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians means only.
- # TODO(nponomareva): this will change for a multiclass
- shape=[2, 1],
- shared_name='bias_accumulator')
-
grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
- apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
+ apply_grad = self._bias_accumulator.apply_grad(
+ grads_and_hess, self._stamp_token)
bias_dependencies.append(apply_grad)
# Center bias if enough batches were processed.
with ops.control_dependencies(bias_dependencies):
if not self._is_chief:
return control_flow_ops.no_op()
+ def _set_accumulators_stamp():
+ return control_flow_ops.group(
+ [acc.set_global_step(self._stamp_token + 1) for acc in
+ self._growing_accumulators])
def center_bias_from_accumulator():
- accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
- return self._center_bias_fn(center_bias_var,
- array_ops.expand_dims(accumulated[0], 0),
- array_ops.expand_dims(accumulated[1], 0))
+ accumulated = array_ops.unstack(self._bias_accumulator.take_grad(1),
+ axis=0)
+ center_bias_op = self._center_bias_fn(
+ center_bias_var,
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+ with ops.control_dependencies([center_bias_op]):
+ return control_flow_ops.cond(center_bias_var,
+ control_flow_ops.no_op,
+ _set_accumulators_stamp)
center_bias_op = control_flow_ops.cond(
- math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ math_ops.greater_equal(self._bias_accumulator.num_accumulated(),
self._n_batches_per_layer),
center_bias_from_accumulator,
control_flow_ops.no_op,
name='wait_until_n_batches_for_bias_accumulated')
return center_bias_op
- def grow_tree(self, stats_summaries_list, feature_ids_list,
- last_layer_nodes_range):
- # For not in memory situation, we need to accumulate enough of batches first
- # before proceeding with building a tree layer.
- max_splits = _get_max_splits(self._tree_hparams)
-
- # Prepare accumulators.
- accumulators = []
+ def grow_tree(self, stats_summaries_list, last_layer_nodes_range):
dependencies = []
- for i, feature_ids in enumerate(feature_ids_list):
+ for i in range(len(self._feature_ids_list)):
stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
+ apply_grad = self._growing_accumulators[i].apply_grad(
array_ops.stack(stats_summaries, axis=0), self._stamp_token)
dependencies.append(apply_grad)
@@ -658,7 +678,8 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
return control_flow_ops.no_op()
min_accumulated = math_ops.reduce_min(
- array_ops.stack([acc.num_accumulated() for acc in accumulators]))
+ array_ops.stack([acc.num_accumulated() for acc in
+ self._growing_accumulators]))
def grow_tree_from_accumulated_summaries_fn():
"""Updates tree with the best layer from accumulated summaries."""
@@ -666,10 +687,11 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
stats_summaries_list = []
stats_summaries_list = [
array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
+ for accumulator in self._growing_accumulators
]
grow_op = self._grow_tree_from_stats_summaries(
- stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ stats_summaries_list, last_layer_nodes_range
+ )
return grow_op
grow_model = control_flow_ops.cond(
@@ -679,6 +701,10 @@ class _AccumulatorEnsembleGrower(_EnsembleGrower):
name='wait_until_n_batches_accumulated')
return grow_model
+ def chief_init_op(self):
+ """Ops that chief needs to run to initialize the state."""
+ return control_flow_ops.group(self._chief_init_ops)
+
def _bt_model_fn(
features,
@@ -724,29 +750,50 @@ def _bt_model_fn(
Raises:
ValueError: mode or params are invalid, or features has the wrong type.
"""
- is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
- center_bias = tree_hparams.center_bias
-
- if train_in_memory:
- assert n_batches_per_layer == 1, (
- 'When train_in_memory is enabled, input_fn should return the entire '
- 'dataset as a single batch, and n_batches_per_layer should be set as '
- '1.')
- if (not config.is_chief or config.num_worker_replicas > 1 or
- config.num_ps_replicas > 0):
- raise ValueError('train_in_memory is supported only for '
- 'non-distributed training.')
- worker_device = control_flow_ops.no_op().device
- train_op = []
with ops.name_scope(name) as name:
# Prepare.
global_step = training_util.get_or_create_global_step()
bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
sorted_feature_columns)
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ # Create logits.
+ if mode != model_fn.ModeKeys.TRAIN:
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ return head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ # ============== Training graph ==============
+ center_bias = tree_hparams.center_bias
+ is_single_machine = (config.num_worker_replicas <= 1)
+
+ if train_in_memory:
+ assert n_batches_per_layer == 1, (
+ 'When train_in_memory is enabled, input_fn should return the entire '
+ 'dataset as a single batch, and n_batches_per_layer should be set as '
+ '1.')
+ if (not config.is_chief or config.num_worker_replicas > 1 or
+ config.num_ps_replicas > 0):
+ raise ValueError('train_in_memory is supported only for '
+ 'non-distributed training.')
+ worker_device = control_flow_ops.no_op().device
+ train_op = []
# Extract input features and set up cache for training.
training_state_cache = None
- if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
+ if train_in_memory:
# cache transformed features as well for in-memory training.
batch_size = array_ops.shape(labels)[0]
input_feature_list, input_cache_op = (
@@ -758,65 +805,62 @@ def _bt_model_fn(
else:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
- if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
+ if example_id_column_name:
example_ids = features[example_id_column_name]
training_state_cache = _CacheTrainingStatesUsingHashTable(
example_ids, head.logits_dimension)
+ if training_state_cache:
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ training_state_cache.lookup())
+ else:
+ # Always start from the beginning when no cache is set up.
+ batch_size = array_ops.shape(labels)[0]
+ cached_tree_ids, cached_node_ids, cached_logits = (
+ array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
+ array_ops.zeros(
+ [batch_size, head.logits_dimension], dtype=dtypes.float32))
- # Create Ensemble resources.
- tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
- # Variable that determines whether bias centering is needed.
- center_bias_var = variable_scope.variable(
- initial_value=center_bias, name='center_bias_needed', trainable=False)
- # Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
- logits = boosted_trees_ops.predict(
- # For non-TRAIN mode, ensemble doesn't change after initialization,
- # so no local copy is needed; using tree_ensemble directly.
- tree_ensemble_handle=tree_ensemble.resource_handle,
+ if is_single_machine:
+ local_tree_ensemble = tree_ensemble
+ ensemble_reload = control_flow_ops.no_op()
+ else:
+ # Have a local copy of ensemble for the distributed setting.
+ with ops.device(worker_device):
+ local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
+ name=name + '_local', is_local=True)
+ # TODO(soroush): Do partial updates if this becomes a bottleneck.
+ ensemble_reload = local_tree_ensemble.deserialize(
+ *tree_ensemble.serialize())
+ with ops.control_dependencies([ensemble_reload]):
+ (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
+ last_layer_nodes_range) = local_tree_ensemble.get_states()
+ partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
+ tree_ensemble_handle=local_tree_ensemble.resource_handle,
+ cached_tree_ids=cached_tree_ids,
+ cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
+ logits = cached_logits + partial_logits
+
+ if train_in_memory:
+ grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams,
+ feature_ids_list=feature_ids_list)
else:
- if is_single_machine:
- local_tree_ensemble = tree_ensemble
- ensemble_reload = control_flow_ops.no_op()
- else:
- # Have a local copy of ensemble for the distributed setting.
- with ops.device(worker_device):
- local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
- name=name + '_local', is_local=True)
- # TODO(soroush): Do partial updates if this becomes a bottleneck.
- ensemble_reload = local_tree_ensemble.deserialize(
- *tree_ensemble.serialize())
+ grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
+ stamp_token, n_batches_per_layer,
+ bucket_size_list, config.is_chief,
+ center_bias=center_bias,
+ feature_ids_list=feature_ids_list)
- if training_state_cache:
- cached_tree_ids, cached_node_ids, cached_logits = (
- training_state_cache.lookup())
- else:
- # Always start from the beginning when no cache is set up.
- batch_size = array_ops.shape(labels)[0]
- cached_tree_ids, cached_node_ids, cached_logits = (
- array_ops.zeros([batch_size], dtype=dtypes.int32),
- _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
- array_ops.zeros(
- [batch_size, head.logits_dimension], dtype=dtypes.float32))
-
- with ops.control_dependencies([ensemble_reload]):
- (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
- last_layer_nodes_range) = local_tree_ensemble.get_states()
- summary.scalar('ensemble/num_trees', num_trees)
- summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
- summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
-
- partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
- tree_ensemble_handle=local_tree_ensemble.resource_handle,
- cached_tree_ids=cached_tree_ids,
- cached_node_ids=cached_node_ids,
- bucketized_features=input_feature_list,
- logits_dimension=head.logits_dimension)
-
- logits = cached_logits + partial_logits
+ summary.scalar('ensemble/num_trees', num_trees)
+ summary.scalar('ensemble/num_finalized_trees', num_finalized_trees)
+ summary.scalar('ensemble/num_attempted_layers', num_attempted_layers)
+ # Variable that determines whether bias centering is needed.
+ center_bias_var = variable_scope.variable(
+ initial_value=center_bias, name='center_bias_needed', trainable=False,
+ use_resource=True)
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
@@ -855,28 +899,24 @@ def _bt_model_fn(
axis=0) for f in feature_ids
]
stats_summaries_list.append(summaries)
-
- if train_in_memory and is_single_machine:
- grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
+ if center_bias:
+ update_model = control_flow_ops.cond(
+ center_bias_var,
+ functools.partial(
+ grower.center_bias,
+ center_bias_var,
+ gradients,
+ hessians,
+ ),
+ functools.partial(grower.grow_tree, stats_summaries_list,
+ last_layer_nodes_range))
else:
- grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
- stamp_token, n_batches_per_layer,
- bucket_size_list, config.is_chief)
-
- update_model = control_flow_ops.cond(
- center_bias_var,
- functools.partial(
- grower.center_bias,
- center_bias_var,
- gradients,
- hessians,
- ),
- functools.partial(grower.grow_tree, stats_summaries_list,
- feature_ids_list, last_layer_nodes_range))
+ update_model = grower.grow_tree(stats_summaries_list,
+ last_layer_nodes_range)
train_op.append(update_model)
with ops.control_dependencies([update_model]):
- increment_global = distribute_lib.increment_var(global_step)
+ increment_global = state_ops.assign_add(global_step, 1).op
train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
@@ -887,15 +927,26 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
- if mode == model_fn.ModeKeys.TRAIN:
- # Add an early stop hook.
- estimator_spec = estimator_spec._replace(
- training_hooks=estimator_spec.training_hooks +
- (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
- tree_hparams.n_trees, tree_hparams.max_depth),))
+ # Add an early stop hook.
+ estimator_spec = estimator_spec._replace(
+ training_hooks=estimator_spec.training_hooks +
+ (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
+ tree_hparams.n_trees, tree_hparams.max_depth),),
+ training_chief_hooks=[GrowerInitializationHook(grower.chief_init_op())] +
+ list(estimator_spec.training_chief_hooks))
return estimator_spec
+class GrowerInitializationHook(session_run_hook.SessionRunHook):
+ """A SessionRunHook handles initialization of `_EnsembleGrower`."""
+
+ def __init__(self, init_op):
+ self._init_op = init_op
+
+ def after_create_session(self, session, coord):
+ session.run(self._init_op)
+
+
def _create_classification_head(n_classes,
weight_column=None,
label_vocabulary=None):
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 14c05e024d..a176b4941f 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -179,6 +179,26 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
+ def testTrainTwiceAndEvaluateBinaryClassifier(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=5,
+ max_depth=10)
+
+ num_steps = 2
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ est.train(input_fn, steps=num_steps)
+
+ self._assert_checkpoint(
+ est.model_dir, global_step=num_steps * 2,
+ finalized_trees=0, attempted_layers=4)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+
def testInferBinaryClassifier(self):
train_input_fn = _make_train_input_fn(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index efa7812452..9799cf9e98 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -31,10 +31,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -161,8 +161,8 @@ def _dnn_linear_combined_model_fn(features,
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
- partitioner=dnn_partitioner):
-
+ partitioner=dnn_partitioner) as scope:
+ dnn_absolute_scope = scope.name
dnn_logit_fn = dnn._dnn_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
hidden_units=dnn_hidden_units,
@@ -186,6 +186,7 @@ def _dnn_linear_combined_model_fn(features,
linear_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=input_layer_partitioner) as scope:
+ linear_absolute_scope = scope.name
logit_fn = linear._linear_logit_fn_builder( # pylint: disable=protected-access
units=head.logits_dimension,
feature_columns=linear_feature_columns,
@@ -211,18 +212,18 @@ def _dnn_linear_combined_model_fn(features,
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=dnn_parent_scope)))
+ scope=dnn_absolute_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
- scope=linear_parent_scope)))
+ scope=linear_absolute_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return head.create_estimator_spec(
features=features,
@@ -388,7 +389,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
if a categorical column is multivalent. One of "mean", "sqrtn", and
"sum" -- these are effectively different ways to do example-level
normalization, which can be useful for bag-of-words features. For more
- details, see @{tf.feature_column.linear_model$linear_model}.
+ details, see `tf.feature_column.linear_model`.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@@ -586,7 +587,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
if a categorical column is multivalent. One of "mean", "sqrtn", and
"sum" -- these are effectively different ways to do example-level
normalization, which can be useful for bag-of-words features. For more
- details, see @{tf.feature_column.linear_model$linear_model}.
+ details, see `tf.feature_column.linear_model`.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index de226ed0ef..11f1e93630 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -44,13 +44,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.summary import summary as summary_lib
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
from tensorflow.python.training import optimizer as optimizer_lib
@@ -222,7 +222,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
testcase.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -230,7 +230,7 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
optimizer_mock = test.mock.NonCallableMagicMock(
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index da9a64c2bc..06593f9520 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -335,8 +335,8 @@ def _check_dense_labels_match_logits_and_reshape(
'Expected labels dimension=%s. Received %s. '
'Suggested Fix:'
'If your classifier expects one-hot encoding label,'
- 'check your n_classes argument to the estimator'
- 'and/or the shape of your label.'
+ 'check your n_classes argument to the estimator '
+ 'and/or the shape of your label. '
'Otherwise, check the shape of your label.' %
(expected_labels_dimension, dim1))
expected_labels_shape = array_ops.concat(
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 58a7160348..115dd18518 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -306,7 +306,7 @@ class LinearClassifier(estimator.Estimator):
is multivalent. One of "mean", "sqrtn", and "sum" -- these are
effectively different ways to do example-level normalization, which can
be useful for bag-of-words features. for more details, see
- @{tf.feature_column.linear_model$linear_model}.
+ `tf.feature_column.linear_model`.
Returns:
A `LinearClassifier` estimator.
@@ -472,7 +472,7 @@ class LinearRegressor(estimator.Estimator):
is multivalent. One of "mean", "sqrtn", and "sum" -- these are
effectively different ways to do example-level normalization, which can
be useful for bag-of-words features. for more details, see
- @{tf.feature_column.linear_model$linear_model}.
+ `tf.feature_column.linear_model`.
"""
head = head_lib._regression_head( # pylint: disable=protected-access
label_dimension=label_dimension, weight_column=weight_column,
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index c3934c7a80..65cdd50061 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -48,13 +48,13 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import optimizer as optimizer_lib
@@ -756,7 +756,7 @@ class BaseLinearRegressorTrainingTest(object):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -764,7 +764,7 @@ class BaseLinearRegressorTrainingTest(object):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
@@ -979,13 +979,13 @@ class BaseLinearClassifierTrainingTest(object):
# Verify loss. We can't check the value directly, so we add an assert op.
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
loss,
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
mock_optimizer = test.mock.NonCallableMock(
spec=optimizer_lib.Optimizer,
diff --git a/tensorflow/python/estimator/canned/prediction_keys.py b/tensorflow/python/estimator/canned/prediction_keys.py
index 16890ec09a..daa275b46b 100644
--- a/tensorflow/python/estimator/canned/prediction_keys.py
+++ b/tensorflow/python/estimator/canned/prediction_keys.py
@@ -32,3 +32,4 @@ class PredictionKeys(object):
LOGITS = 'logits'
PREDICTIONS = 'predictions'
PROBABILITIES = 'probabilities'
+ TOP_K = 'top_k'
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 2fe44bc6ce..44a60495d8 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -42,6 +42,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import metrics as metrics_lib
@@ -50,9 +51,10 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
-from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import distribute as distribute_lib
@@ -85,14 +87,15 @@ class Estimator(object):
subdirectory thereof. If `model_dir` is not set, a temporary directory is
used.
- The `config` argument can be passed `RunConfig` object containing information
- about the execution environment. It is passed on to the `model_fn`, if the
- `model_fn` has a parameter named "config" (and input functions in the same
- manner). If the `config` parameter is not passed, it is instantiated by the
- `Estimator`. Not passing config means that defaults useful for local execution
- are used. `Estimator` makes config available to the model (for instance, to
- allow specialization based on the number of workers available), and also uses
- some of its fields to control internals, especially regarding checkpointing.
+ The `config` argument can be passed `tf.estimator.RunConfig` object containing
+ information about the execution environment. It is passed on to the
+ `model_fn`, if the `model_fn` has a parameter named "config" (and input
+ functions in the same manner). If the `config` parameter is not passed, it is
+ instantiated by the `Estimator`. Not passing config means that defaults useful
+ for local execution are used. `Estimator` makes config available to the model
+ (for instance, to allow specialization based on the number of workers
+ available), and also uses some of its fields to control internals, especially
+ regarding checkpointing.
The `params` argument contains hyperparameters. It is passed to the
`model_fn`, if the `model_fn` has a parameter named "params", and to the input
@@ -118,7 +121,10 @@ class Estimator(object):
warm_start_from=None):
"""Constructs an `Estimator` instance.
- See @{$estimators} for more information. To warm-start an `Estimator`:
+ See [estimators](https://tensorflow.org/guide/estimators) for more
+ information.
+
+ To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
@@ -128,7 +134,7 @@ class Estimator(object):
```
For more details on warm-start configuration, see
- @{tf.estimator.WarmStartSettings$WarmStartSettings}.
+ `tf.estimator.WarmStartSettings`.
Args:
model_fn: Model function. Follows the signature:
@@ -137,41 +143,43 @@ class Estimator(object):
* `features`: This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
- single `Tensor` or `dict` of same.
+ single `tf.Tensor` or `dict` of same.
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
- single `Tensor` or `dict` of same (for multi-head models). If
- mode is `ModeKeys.PREDICT`, `labels=None` will be passed. If
- the `model_fn`'s signature does not accept `mode`, the
- `model_fn` must still be able to handle `labels=None`.
+ single `tf.Tensor` or `dict` of same (for multi-head models).
+ If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ be passed. If the `model_fn`'s signature does not accept
+ `mode`, the `model_fn` must still be able to handle
+ `labels=None`.
* `mode`: Optional. Specifies if this training, evaluation or
- prediction. See `ModeKeys`.
+ prediction. See `tf.estimator.ModeKeys`.
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
- * `config`: Optional configuration object. Will receive what is passed
- to Estimator in `config` parameter, or the default `config`.
- Allows updating things in your `model_fn` based on
+ * `config`: Optional `estimator.RunConfig` object. Will receive what
+ is passed to Estimator as its `config` parameter, or a default
+ value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
- `EstimatorSpec`
+ `tf.estimator.EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator to
+ also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If `PathLike` object, the
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
temporary directory will be used.
- config: Configuration object.
+ config: `estimator.RunConfig` configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
- filepath is provided instead of a `WarmStartSettings`,
- then all variables are warm-started, and it is assumed
- that vocabularies and Tensor names are unchanged.
+ filepath is provided instead of a
+ `tf.estimator.WarmStartSettings`, then all variables are
+ warm-started, and it is assumed that vocabularies
+ and `tf.Tensor` names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
@@ -180,8 +188,8 @@ class Estimator(object):
"""
Estimator._assert_members_are_not_overridden(self)
- config = maybe_overwrite_model_dir_and_session_config(config, model_dir)
- self._config = config
+ self._config = maybe_overwrite_model_dir_and_session_config(config,
+ model_dir)
# The distribute field contains an instance of DistributionStrategy.
self._train_distribution = self._config.train_distribute
@@ -219,10 +227,10 @@ class Estimator(object):
@property
def model_fn(self):
- """Returns the model_fn which is bound to self.params.
+ """Returns the `model_fn` which is bound to `self.params`.
Returns:
- The model_fn with following signature:
+ The `model_fn` with following signature:
`def model_fn(features, labels, mode, config)`
"""
@@ -242,7 +250,7 @@ class Estimator(object):
Numpy array - value of the tensor.
Raises:
- ValueError: If the Estimator has not produced a checkpoint yet.
+ ValueError: If the `Estimator` has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
with context.graph_mode():
@@ -255,14 +263,14 @@ class Estimator(object):
List of names.
Raises:
- ValueError: If the Estimator has not produced a checkpoint yet.
+ ValueError: If the `Estimator` has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
with context.graph_mode():
return [name for name, _ in training.list_variables(self.model_dir)]
def latest_checkpoint(self):
- """Finds the filename of latest saved checkpoint file in `model_dir`.
+ """Finds the filename of the latest saved checkpoint file in `model_dir`.
Returns:
The full path to the latest checkpoint or `None` if no checkpoint was
@@ -277,40 +285,38 @@ class Estimator(object):
steps=None,
max_steps=None,
saving_listeners=None):
- """Trains a model given training data input_fn.
+ """Trains a model given training data `input_fn`.
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
- the following:
-
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
- dictionary of string feature name to `Tensor` and `labels` is a
- `Tensor` or a dictionary of string label name to `Tensor`. Both
- `features` and `labels` are consumed by `model_fn`. They should
- satisfy the expectation of `model_fn` from inputs.
-
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- steps: Number of steps for which to train model. If `None`, train forever
- or train until input_fn generates the `OutOfRange` error or
- `StopIteration` exception. 'steps' works incrementally. If you call two
- times train(steps=10) then training occurs in total 20 steps. If
- `OutOfRange` or `StopIteration` occurs in the middle, training stops
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A
+ `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
+ `(features, labels)` with same constraints as below. * A tuple
+ `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
+ of string feature name to `Tensor` and `labels` is a `Tensor` or a
+ dictionary of string label name to `Tensor`. Both `features` and
+ `labels` are consumed by `model_fn`. They should satisfy the expectation
+ of `model_fn` from inputs.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ steps: Number of steps for which to train the model. If `None`, train
+ forever or train until `input_fn` generates the `tf.errors.OutOfRange`
+ error or `StopIteration` exception. `steps` works incrementally. If you
+ call two times `train(steps=10)` then training occurs in total 20 steps.
+ If `OutOfRange` or `StopIteration` occurs in the middle, training stops
before 20 steps. If you don't want to have incremental behavior please
set `max_steps` instead. If set, `max_steps` must be `None`.
max_steps: Number of total steps for which to train model. If `None`,
- train forever or train until input_fn generates the `OutOfRange` error
- or `StopIteration` exception. If set, `steps` must be `None`. If
- `OutOfRange` or `StopIteration` occurs in the middle, training stops
- before `max_steps` steps.
- Two calls to `train(steps=100)` means 200 training
- iterations. On the other hand, two calls to `train(max_steps=100)` means
- that the second call will not do any iteration since first call did
- all 100 steps.
+ train forever or train until `input_fn` generates the
+ `tf.errors.OutOfRange` error or `StopIteration` exception. If set,
+ `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the
+ middle, training stops before `max_steps` steps. Two calls to
+ `train(steps=100)` means 200 training iterations. On the other hand, two
+ calls to `train(max_steps=100)` means that the second call will not do
+ any iteration since first call did all 100 steps.
saving_listeners: list of `CheckpointSaverListener` objects. Used for
callbacks that run immediately before or after checkpoint savings.
@@ -319,8 +325,16 @@ class Estimator(object):
Raises:
ValueError: If both `steps` and `max_steps` are not `None`.
- ValueError: If either `steps` or `max_steps` is <= 0.
+ ValueError: If either `steps` or `max_steps <= 0`.
"""
+ if self.config.task_type in (run_config.TaskType.EVALUATOR,
+ run_config.TaskType.PS):
+ raise ValueError(
+ 'Train has been called wrong configuration. Please use '
+ 'tf.estimator.train_and_evaluate which calls propper API according '
+ 'to given configuration. Current configuration: {}.'.format(
+ self.config))
+
with context.graph_mode():
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
@@ -345,13 +359,29 @@ class Estimator(object):
return self
def _convert_train_steps_to_hooks(self, steps, max_steps):
+ """Create hooks to run correct number of steps in training.
+
+ Args:
+ steps: number of steps to run during training.
+ max_steps: maximum number of steps to be run during training. It'll be
+ the maximum number of steps the model will train to after restoring
+ from checkpoint even across multiple estimator.train calls.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is not None or max_steps is not None:
+ if self._train_distribution:
+ steps_per_run = getattr(self._train_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [basic_session_run_hooks._MultiStepStopAtStepHook( # pylint: disable=protected-access
+ steps, max_steps, steps_per_run)]
return [training.StopAtStepHook(steps, max_steps)]
else:
return []
def eval_dir(self, name=None):
- """Shows directory name where evaluation metrics are dumped.
+ """Shows the directory name where evaluation metrics are dumped.
Args:
name: Name of the evaluation if user needs to run multiple evaluations on
@@ -367,36 +397,36 @@ class Estimator(object):
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
name=None):
- """Evaluates the model given evaluation data input_fn.
+ """Evaluates the model given evaluation data `input_fn`.
For each step, calls `input_fn`, which returns one batch of data.
Evaluates until:
- `steps` batches are processed, or
- - `input_fn` raises an end-of-input exception (`OutOfRangeError` or
+ - `input_fn` raises an end-of-input exception (`tf.errors.OutOfRangeError`
+ or
`StopIteration`).
Args:
- input_fn: A function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
- the following:
-
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
- dictionary of string feature name to `Tensor` and `labels` is a
- `Tensor` or a dictionary of string label name to `Tensor`. Both
- `features` and `labels` are consumed by `model_fn`. They should
- satisfy the expectation of `model_fn` from inputs.
-
+ input_fn: A function that constructs the input data for evaluation. See
+ [Premade Estimators](
+ https://tensorflow.org/guide/premade#create_input_functions)
+ for more information. The
+ function should construct and return one of the following: * A
+ `tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
+ `(features, labels)` with same constraints as below. * A tuple
+ `(features, labels)`: Where `features` is a `tf.Tensor` or a dictionary
+ of string feature name to `Tensor` and `labels` is a `Tensor` or a
+ dictionary of string label name to `Tensor`. Both `features` and
+ `labels` are consumed by `model_fn`. They should satisfy the expectation
+ of `model_fn` from inputs.
steps: Number of steps for which to evaluate model. If `None`, evaluates
until `input_fn` raises an end-of-input exception.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the evaluation call.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the evaluation call.
checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the
latest checkpoint in `model_dir` is used. If there are no checkpoints
in `model_dir`, evaluation is run with newly initialized `Variables`
- instead of restored from checkpoint.
+ instead of ones restored from checkpoint.
name: Name of the evaluation if user needs to run multiple evaluations on
different data sets, such as on training data vs test data. Metrics for
different evaluations are saved in separate folders, and appear
@@ -405,7 +435,11 @@ class Estimator(object):
Returns:
A dict containing the evaluation metrics specified in `model_fn` keyed by
name, as well as an entry `global_step` which contains the value of the
- global step for which this evaluation was performed.
+ global step for which this evaluation was performed. For canned
+ estimators, the dict contains the `loss` (mean loss per mini-batch) and
+ the `average_loss` (mean loss per sample). Canned classifiers also return
+ the `accuracy`. Canned regressors also return the `label/mean` and the
+ `prediction/mean`.
Raises:
ValueError: If `steps <= 0`.
@@ -436,9 +470,7 @@ class Estimator(object):
output_dir=self.eval_dir(name))
with ops.Graph().as_default():
- # TODO(priyag): Support distributed eval on TPUs.
- if (self._eval_distribution
- and self._eval_distribution.__class__.__name__ != 'TPUStrategy'):
+ if self._eval_distribution:
with self._eval_distribution.scope():
return _evaluate()
else:
@@ -462,33 +494,34 @@ class Estimator(object):
Args:
input_fn: A function that constructs the features. Prediction continues
- until `input_fn` raises an end-of-input exception (`OutOfRangeError` or
- `StopIteration`).
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ until `input_fn` raises an end-of-input exception
+ (`tf.errors.OutOfRangeError` or `StopIteration`).
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
- * A 'tf.data.Dataset' object: Outputs of `Dataset` object must have
+ * A `tf.data.Dataset` object: Outputs of `Dataset` object must have
same constraints as below.
- * features: A `Tensor` or a dictionary of string feature name to
+ * features: A `tf.Tensor` or a dictionary of string feature name to
`Tensor`. features are consumed by `model_fn`. They should satisfy
the expectation of `model_fn` from inputs.
* A tuple, in which case the first item is extracted as features.
predict_keys: list of `str`, name of the keys to predict. It is used if
- the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
- then rest of the predictions will be filtered from the dictionary. If
- `None`, returns all.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the prediction call.
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary. If `None`, returns all.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used. If there are no checkpoints
in `model_dir`, prediction is run with newly initialized `Variables`
- instead of restored from checkpoint.
- yield_single_examples: If False, yield the whole batch as returned by the
- `model_fn` instead of decomposing the batch into individual elements.
- This is useful if `model_fn` returns some tensors whose first dimension
- is not equal to the batch size.
+ instead of ones restored from checkpoint.
+ yield_single_examples: If `False`, yields the whole batch as returned by
+ the `model_fn` instead of decomposing the batch into individual
+ elements. This is useful if `model_fn` returns some tensors whose first
+ dimension is not equal to the batch size.
Yields:
Evaluated values of `predictions` tensors.
@@ -496,10 +529,10 @@ class Estimator(object):
Raises:
ValueError: Could not find a trained model in `model_dir`.
ValueError: If batch length of predictions is not the same and
- `yield_single_examples` is True.
+ `yield_single_examples` is `True`.
ValueError: If there is a conflict between `predict_keys` and
`predictions`. For example if `predict_keys` is not `None` but
- `EstimatorSpec.predictions` is not a `dict`.
+ `tf.estimator.EstimatorSpec.predictions` is not a `dict`.
"""
with context.graph_mode():
hooks = _check_hooks_type(hooks)
@@ -554,14 +587,10 @@ class Estimator(object):
return
allowed_overrides = set([
- '_call_input_fn', '_call_model_fn',
- '_convert_train_steps_to_hooks', '_convert_eval_steps_to_hooks',
- '_create_global_step', '_create_and_assert_global_step',
+ '_create_and_assert_global_step',
'_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',
'_estimator_api_names_v1', '_estimator_api_constants',
'_estimator_api_constants_v1',
- '_validate_features_in_predict_input',
- '_add_meta_graph_for_mode'
])
estimator_members = set([m for m in Estimator.__dict__.keys()
if not m.startswith('__')])
@@ -581,31 +610,66 @@ class Estimator(object):
as_text=False,
checkpoint_path=None,
strip_default_attrs=False):
+ # pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
+ """Exports inference graph as a `SavedModel` into the given dir.
+
+ Note that `export_to_savedmodel` will be renamed to `export_to_saved_model`
+ in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
+ additional underscore will be available only through tf.compat.v1.
+
+ Please see `tf.estimator.Estimator.export_saved_model` for more information.
+
+ There is one additional arg versus the new method:
+ strip_default_attrs: This parameter is going away in TF 2.0, and
+ the new behavior will automatically strip all default attributes.
+ Boolean. If `True`, default-valued attributes will be
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued Attributes](
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ """
+ # pylint: enable=line-too-long,g-doc-args,g-doc-return-or-yield
+ return self._export_saved_model_for_mode(
+ export_dir_base,
+ serving_input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=model_fn_lib.ModeKeys.PREDICT)
+
+ def export_saved_model(
+ self, export_dir_base, serving_input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None):
# pylint: disable=line-too-long
- """Exports inference graph as a SavedModel into given dir.
+ """Exports inference graph as a `SavedModel` into the given dir.
For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ [Using SavedModel with Estimators](https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
This method builds a new graph by first calling the
- serving_input_receiver_fn to obtain feature `Tensor`s, and then calling
- this `Estimator`'s model_fn to generate the model graph based on those
+ `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling
+ this `Estimator`'s `model_fn` to generate the model graph based on those
features. It restores the given checkpoint (or, lacking that, the most
recent checkpoint) into this graph in a fresh session. Finally it creates
- a timestamped export directory below the given export_dir_base, and writes
- a `SavedModel` into it containing a single `MetaGraphDef` saved from this
+ a timestamped export directory below the given `export_dir_base`, and writes
+ a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this
session.
The exported `MetaGraphDef` will provide one `SignatureDef` for each
- element of the export_outputs dict returned from the model_fn, named using
+ element of the `export_outputs` dict returned from the `model_fn`, named
+ using
the same keys. One of these keys is always
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
+ indicating which
signature will be served when a serving request does not specify one.
For each signature, the outputs are provided by the corresponding
- `ExportOutput`s, and the inputs are always the input receivers provided by
- the serving_input_receiver_fn.
+ `tf.estimator.export.ExportOutput`s, and the inputs are always the input
+ receivers provided by
+ the `serving_input_receiver_fn`.
- Extra assets may be written into the SavedModel via the assets_extra
+ Extra assets may be written into the `SavedModel` via the `assets_extra`
argument. This should be a dict, where each key gives a destination path
(including the filename) relative to the assets.extra directory. The
corresponding value gives the full path of the source file to be copied.
@@ -614,34 +678,35 @@ class Estimator(object):
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- serving_input_receiver_fn: A function that takes no argument and
- returns a `ServingInputReceiver` or `TensorServingInputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ serving_input_receiver_fn: A function that takes no argument and returns a
+ `tf.estimator.export.ServingInputReceiver` or
+ `tf.estimator.export.TensorServingInputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
- strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
Raises:
- ValueError: if no serving_input_receiver_fn is provided, no export_outputs
- are provided, or no checkpoint can be found.
+ ValueError: if no `serving_input_receiver_fn` is provided, no
+ `export_outputs` are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
- return self._export_saved_model_for_mode(
+ # TODO(b/111442174): `export_to_savedmodel` will be renamed to
+ # `export_to_saved_model` in TensorFlow 2.0. This function is a wrapper
+ # while staging the new version; do not add any logic here.
+ return self.export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path,
- strip_default_attrs=strip_default_attrs,
- mode=model_fn_lib.ModeKeys.PREDICT)
+ strip_default_attrs=True)
def _export_saved_model_for_mode(
self, export_dir_base, input_receiver_fn,
@@ -651,35 +716,37 @@ class Estimator(object):
strip_default_attrs=False,
mode=model_fn_lib.ModeKeys.PREDICT):
# pylint: disable=line-too-long
- """Exports a single train/eval/predict graph as a SavedModel.
+ """Exports a single train/eval/predict graph as a `SavedModel`.
- This method is a wrapper for _export_all_saved_models, and wraps a raw
- input_receiver_fn in a dictionary to pass in to that function.
- See _export_all_saved_models for full docs.
+ This method is a wrapper for `_export_all_saved_models`, and wraps a raw
+ `input_receiver_fn` in a dictionary to pass in to that function.
+ See `_export_all_saved_models` for full docs.
- See tf.contrib.estimator.export_saved_model_for_mode for the currently
+ See `tf.contrib.estimator.export_saved_model_for_mode` for the currently
exposed version of this function.
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- input_receiver_fn: a function that takes no argument and
- returns the appropriate subclass of `InputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ input_receiver_fn: a function that takes no argument and returns the
+ appropriate subclass of `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
- mode: tf.estimator.ModeKeys value indicating with mode will be exported.
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ mode: `tf.estimator.ModeKeys` value indicating with mode will be exported.
Returns:
The string path to the exported directory.
Raises:
- ValueError: if input_receiver_fn is None, no export_outputs
+ ValueError: if `input_receiver_fn` is `None`, no `export_outputs`
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
@@ -703,40 +770,46 @@ class Estimator(object):
checkpoint_path=None,
strip_default_attrs=False):
# pylint: disable=line-too-long
- """Exports a SavedModel containing MetaGraphDefs for each requested mode.
+ """Exports a `SavedModel` containing `tf.MetaGraphDefs` for each requested mode.
- See tf.contrib.estimator.export_all_saved_models for the currently
+ See `tf.contrib.estimator.export_all_saved_models` for the currently
exposed version of this function.
- For each mode passed in via the input_receiver_fn_map,
- this method builds a new graph by calling the input_receiver_fn to obtain
+ For each mode passed in via the `input_receiver_fn_map`,
+ this method builds a new graph by calling the `input_receiver_fn` to obtain
feature and label `Tensor`s. Next, this method calls the `Estimator`'s
- model_fn in the passed mode to generate the model graph based on
+ `model_fn` in the passed mode to generate the model graph based on
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
- Only one of the modes is used for saving variables to the SavedModel
- (order of preference: TRAIN, EVAL, then PREDICT), such that up to three
- MetaGraphDefs are saved with a single set of variables in a single
- SavedModel directory.
-
- For the variables and MetaGraphDefs, a timestamped export directory below
- export_dir_base, and writes a `SavedModel` into it containing
- the `MetaGraphDef` for the given mode and its associated signatures.
+ Only one of the modes is used for saving variables to the `SavedModel`
+ (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
+ @{tf.estimator.ModeKeys#EVAL$EVAL}, then
+ @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ `tf.MetaGraphDefs` are saved with a single set of variables in a single
+ `SavedModel` directory.
+
+ For the variables and `tf.MetaGraphDefs`, a timestamped export directory
+ below
+ `export_dir_base`, and writes a `SavedModel` into it containing
+ the `tf.MetaGraphDef` for the given mode and its associated signatures.
For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`
- for each element of the export_outputs dict returned from the model_fn,
+ for each element of the `export_outputs` dict returned from the `model_fn`,
named using the same keys. One of these keys is always
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, indicating which
+ `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,
+ indicating which
signature will be served when a serving request does not specify one.
For each signature, the outputs are provided by the corresponding
- `ExportOutput`s, and the inputs are always the input receivers provided by
- the serving_input_receiver_fn.
+ `tf.estimator.export.ExportOutput`s, and the inputs are always the input
+ receivers provided by
+ the `serving_input_receiver_fn`.
- For training and evaluation, the train_op is stored in an extra collection,
- and loss, metrics, and predictions are included in a SignatureDef for the
+ For training and evaluation, the `train_op` is stored in an extra
+ collection,
+ and loss, metrics, and predictions are included in a `SignatureDef` for the
mode in question.
- Extra assets may be written into the SavedModel via the assets_extra
+ Extra assets may be written into the `SavedModel` via the `assets_extra`
argument. This should be a dict, where each key gives a destination path
(including the filename) relative to the assets.extra directory. The
corresponding value gives the full path of the source file to be copied.
@@ -745,25 +818,28 @@ class Estimator(object):
Args:
export_dir_base: A string containing a directory in which to create
- timestamped subdirectories containing exported SavedModels.
- input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
- mappings, where the input_receiver_fn is a function that takes no
- argument and returns the appropriate subclass of `InputReceiver`.
+ timestamped subdirectories containing exported `SavedModel`s.
+ input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
+ `input_receiver_fn` mappings, where the `input_receiver_fn` is a
+ function that takes no arguments and returns the appropriate subclass of
+ `InputReceiver`.
assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel, or `None` if no extra assets are needed.
- as_text: whether to write the SavedModel proto in text format.
+ within the exported `SavedModel`, or `None` if no extra assets are
+ needed.
+ as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A dict of tf.estimator.ModeKeys value to string path for each exported
+ A dict of `tf.estimator.ModeKeys` value to string path for each exported
directory.
Raises:
- ValueError: if any input_receiver_fn is None, no export_outputs
+ ValueError: if any `input_receiver_fn` is `None`, no `export_outputs`
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
@@ -836,25 +912,29 @@ class Estimator(object):
export_tags=None,
check_variables=True):
# pylint: disable=line-too-long
- """Loads variables and adds them along with a MetaGraphDef for saving.
+ """Loads variables and adds them along with a `tf.MetaGraphDef` for saving.
Args:
- builder: instance of SavedModelBuilder that will be used for saving.
- input_receiver_fn_map: dict of tf.estimator.ModeKeys to input_receiver_fn
- mappings, where the input_receiver_fn is a function that takes no
- argument and returns the appropriate subclass of `InputReceiver`.
+ builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will
+ be used for saving.
+ input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to
+ `input_receiver_fn` mappings, where the `input_receiver_fn` is a
+ function that takes no argument and returns the appropriate subclass of
+ `InputReceiver`.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the NodeDefs. For a detailed guide, see
- [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
- save_variables: bool, whether variables should be saved. If False, just
- the MetaGraphDef will be saved. Note that save_variables should only be
- True for the first call to this function, and the SavedModelBuilder will
- raise an error if that is not the case.
- mode: tf.estimator.ModeKeys value indicating which mode will be exported.
- export_tags: The set of tags with which to save `MetaGraphDef`. If None,
- a default set will be selected to matched the passed mode.
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued
+ Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ save_variables: bool, whether variables should be saved. If `False`, just
+ the `tf.MetaGraphDef` will be saved. Note that `save_variables` should
+ only be `True` for the first call to this function, and the
+ `SavedModelBuilder` will raise an error if that is not the case.
+ mode: `tf.estimator.ModeKeys` value indicating which mode will be
+ exported.
+ export_tags: The set of tags with which to save `tf.MetaGraphDef`. If
+ `None`, a default set will be selected to matched the passed mode.
check_variables: bool, whether to check the checkpoint has all variables.
Raises:
@@ -936,21 +1016,23 @@ class Estimator(object):
builder.add_meta_graph(**meta_graph_kwargs)
def _get_export_outputs_for_spec(self, estimator_spec):
- """Given an EstimatorSpec, determine what our export outputs should be.
+ """Given an `EstimatorSpec`, determine what our export outputs should be.
- EstimatorSpecs contain export_outputs that are used for serving, but for
+ `EstimatorSpecs` contains `export_outputs` that are used for serving, but
+ for
training and eval graphs, we must wrap the tensors of interest in
- appropriate ExportOutput objects.
+ appropriate `tf.estimator.export.ExportOutput` objects.
Args:
- estimator_spec: EstimatorSpec object that will be exported.
+ estimator_spec: `tf.estimator.EstimatorSpec` object that will be exported.
Returns:
- a dict mapping export_output_name to ExportOutput object.
+ a dict mapping `export_output_name` to `tf.estimator.export.ExportOutput`
+ object.
Raises:
- ValueError: if an appropriate ExportOutput cannot be found for the
- passed EstimatorSpec.mode
+ ValueError: if an appropriate `ExportOutput` cannot be found for the
+ passed `EstimatorSpec.mode`
"""
mode = estimator_spec.mode
if mode == model_fn_lib.ModeKeys.PREDICT:
@@ -985,16 +1067,21 @@ class Estimator(object):
'QueueRunner. That means predict yields forever. '
'This is probably a mistake.')
- def _get_features_and_labels_from_input_fn(self, input_fn, mode,
- distribution=None):
- """Extracts the `features` and labels from return values of `input_fn`."""
- if distribution is not None and mode == model_fn_lib.ModeKeys.TRAIN:
+ def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):
+ if distribution is not None:
result = distribution.distribute_dataset(
lambda: self._call_input_fn(input_fn, mode))
else:
result = self._call_input_fn(input_fn, mode)
- return estimator_util.parse_input_fn_result(result)
+ iterator = result.make_initializable_iterator()
+ input_hooks = [estimator_util._DatasetInitializerHook(iterator)] # pylint: disable=protected-access
+ return iterator, input_hooks
+
+ def _get_features_and_labels_from_input_fn(self, input_fn, mode):
+ """Extracts the `features` and labels from return values of `input_fn`."""
+ return estimator_util.parse_input_fn_result(
+ self._call_input_fn(input_fn, mode))
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -1027,13 +1114,13 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys.GLOBAL_STEP}.
+ be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
Args:
graph: The graph in which to create the global step tensor.
Returns:
- The global step `Tensor`.
+ The global step `tf.Tensor`.
"""
return training.create_global_step(graph)
@@ -1044,7 +1131,7 @@ class Estimator(object):
graph: The graph in which to create the global step tensor.
Returns:
- The global step `Tensor`.
+ The global step `tf.Tensor`.
"""
step = self._create_global_step(graph)
assert step == training.get_global_step()
@@ -1056,21 +1143,21 @@ class Estimator(object):
Args:
input_fn: The input function.
- mode: ModeKeys
+ mode: `tf.estimator.ModeKeys`
Returns:
- The return value of the passed input_fn, which should be one of:
+ The return value of the passed `input_fn`, which should be one of:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
- tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where `features` is a `Tensor` or a
+ tuple `(features, labels)` with same constraints as below.
+ * A tuple `(features, labels)`: Where `features` is a `Tensor` or a
dictionary of string feature name to `Tensor` and `labels` is a
`Tensor` or a dictionary of string label name to `Tensor`. Both
`features` and `labels` are consumed by `model_fn`. They should
satisfy the expectation of `model_fn` from inputs.
Raises:
- ValueError: if input_fn takes invalid arguments.
+ ValueError: if `input_fn` takes invalid arguments.
"""
input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
@@ -1089,14 +1176,14 @@ class Estimator(object):
Args:
features: features dict.
labels: labels dict.
- mode: ModeKeys
- config: RunConfig
+ mode: `tf.estimator.ModeKeys`
+ config: `tf.estimator.RunConfig`
Returns:
- An `EstimatorSpec` object.
+ An `tf.estimator.EstimatorSpec` object.
Raises:
- ValueError: if model_fn returns invalid objects.
+ ValueError: if `model_fn` returns invalid objects.
"""
model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
@@ -1129,14 +1216,14 @@ class Estimator(object):
return self._train_model_default(input_fn, hooks, saving_listeners)
def _train_model_default(self, input_fn, hooks, saving_listeners):
- """Initiate training with input_fn, without DistributionStrategies.
+ """Initiate training with `input_fn`, without `DistributionStrategies`.
Args:
input_fn: A function that provides input data for training as minibatches.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- saving_listeners: list of `CheckpointSaverListener` objects. Used for
- callbacks that run immediately before or after checkpoint savings.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
+ for callbacks that run immediately before or after checkpoint savings.
Returns:
Loss from training
@@ -1163,14 +1250,14 @@ class Estimator(object):
saving_listeners)
def _train_model_distributed(self, input_fn, hooks, saving_listeners):
- """Initiate training with input_fn, using DistributionStrategies.
+ """Initiate training with `input_fn`, using `DistributionStrategies`.
Args:
input_fn: A function that provides input data for training as minibatches.
- hooks: List of `SessionRunHook` subclass instances. Used for callbacks
- inside the training loop.
- saving_listeners: list of `CheckpointSaverListener` objects. Used for
- callbacks that run immediately before or after checkpoint savings.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the training loop.
+ saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used
+ for callbacks that run immediately before or after checkpoint savings.
Returns:
Loss from training
@@ -1184,101 +1271,87 @@ class Estimator(object):
worker_hooks = []
with ops.Graph().as_default() as g:
+ # We want to create the iterations variable outside the distribution scope
+ # as that is just stored on the host and mainly used to drive the loop
+ # and doesn't need to be a Mirrored/Device variable.
+ if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN, self._train_distribution)
+ worker_hooks.extend(input_hooks)
+ global_step_tensor = self._create_and_assert_global_step(g)
+ # we want to add to the global collection in the main thread not the
+ # tower threads.
+ ops.add_to_collection(
+ training_util.GLOBAL_STEP_READ_KEY,
+ self._train_distribution.read_var(global_step_tensor))
if is_tpu_strategy:
- # Create the iterator for run_on_dataset function
- # TODO(sourabhbajaj): refactor this out to call a function on the
- # strategy
- dataset = self._train_distribution.distribute_dataset(
- lambda: self._call_input_fn(input_fn, # pylint: disable=g-long-lambda
- model_fn_lib.ModeKeys.TRAIN))
- iterator = dataset.make_initializable_iterator()
- worker_hooks.append(
- estimator_util._DatasetInitializerHook(iterator)) # pylint: disable=protected-access
-
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
-
# Create a step_fn from the train_op of grouped_estimator_spec
- def step_fn(ctx, inputs):
+ def step_fn(ctx, features, labels=None):
"""A single step that is passed to run_on_dataset."""
- features, labels = inputs
estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels,
model_fn_lib.ModeKeys.TRAIN,
self.config)
- ctx.last_step_outputs = estimator_spec.loss
- ctx.non_tensor_outputs = {'estimator_spec': estimator_spec}
- with ops.control_dependencies([estimator_spec.train_op]):
- return array_ops.identity(estimator_spec.loss)
+ ctx.set_last_step_output(
+ name='loss',
+ output=estimator_spec.loss,
+ aggregation=distribute_lib.get_loss_reduction())
+ ctx.set_non_tensor_output(
+ name='estimator_spec', output=estimator_spec)
+ return estimator_spec.train_op
# Create new train_op post graph rewrites
- # TODO(sourabhbajaj): Make sure train_steps and tpu_iterations
- # work correctly. Currently hardcoded at 2
initial_training_loss = constant_op.constant(1e7)
- distributed_train_op, tpu_result, ctx = \
- self._train_distribution._run_steps_on_dataset( # pylint: disable=protected-access
- step_fn, iterator, iterations=2,
- initial_loop_values=initial_training_loss)
+ ctx = self._train_distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=steps_per_run_variable,
+ initial_loop_values={'loss': initial_training_loss})
+ distributed_train_op = ctx.run_op
+ loss = ctx.last_step_outputs['loss']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.TRAIN,
- self._train_distribution))
- worker_hooks.extend(input_hooks)
- global_step_tensor = self._create_and_assert_global_step(g)
- # we want to add to the global collection in the main thread not the
- # tower threads.
- ops.add_to_collection(
- training_util.GLOBAL_STEP_READ_KEY,
- self._train_distribution.read_var(global_step_tensor))
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
labels, # although this will be None it seems
model_fn_lib.ModeKeys.TRAIN,
self.config)
+ loss = self._train_distribution.unwrap(
+ self._train_distribution.reduce(
+ distribute_lib.get_loss_reduction(),
+ grouped_estimator_spec.loss,
+ destinations='/device:CPU:0'))[0]
+ distributed_train_op = grouped_estimator_spec.train_op
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._train_distribution)
+ # TODO(yuefengz): add a test for unwrapping per_device_hooks.
def get_hooks_from_the_first_device(per_device_hooks):
- hooks_list = self._train_distribution.unwrap(per_device_hooks)
- assert hooks_list
- return hooks_list[0]
+ return [
+ self._distribution.unwrap(per_device_hook)[0]
+ for per_device_hook in per_device_hooks
+ ]
training_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_hooks)
training_chief_hooks = get_hooks_from_the_first_device(
grouped_estimator_spec.training_chief_hooks)
-
- # TODO(sourabhbajaj): Merge the two code paths and clean up the code
- if is_tpu_strategy:
- distributed_loss = tpu_result
- worker_hooks.append(
- estimator_util.StrategyInitFinalizeHook(
- self._train_distribution.get_initialization_ops,
- self._train_distribution.get_finalize_ops))
- else:
- distributed_loss = grouped_estimator_spec.loss
- distributed_train_op = grouped_estimator_spec.train_op
+ worker_hooks.append(
+ estimator_util.StrategyInitFinalizeHook(
+ self._train_distribution.initialize,
+ self._train_distribution.finalize))
estimator_spec = model_fn_lib.EstimatorSpec(
mode=grouped_estimator_spec.mode,
- loss=self._train_distribution.unwrap(
- self._train_distribution.reduce(
- distribute_lib.get_loss_reduction(),
- distributed_loss,
- destinations='/device:CPU:0'))[0],
+ loss=loss,
train_op=self._train_distribution.group(distributed_train_op),
training_hooks=training_hooks,
training_chief_hooks=training_chief_hooks,
@@ -1375,31 +1448,18 @@ class Estimator(object):
"""Builds the graph and related hooks to run evaluation."""
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(ops.get_default_graph())
- features, labels, input_hooks = (
- self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution))
if self._eval_distribution:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval_distributed(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval_distributed(input_fn, self.config))
else:
- (loss_metric, scaffold, evaluation_hooks, eval_metric_ops) = (
- self._call_model_fn_eval(features, labels, self.config))
+ (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
+ self._call_model_fn_eval(input_fn, self.config))
global_step_tensor = training_util.get_global_step(ops.get_default_graph())
# Call to warm_start has to be after model_fn is called.
self._maybe_warm_start(checkpoint_path)
- if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
- raise ValueError(
- 'Metric with name "%s" is not allowed, because Estimator ' %
- (model_fn_lib.LOSS_METRIC_KEY) +
- 'already defines a default metric with the same name.')
- eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
-
- update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops,
- self._eval_distribution)
-
if ops.GraphKeys.GLOBAL_STEP in eval_dict:
raise ValueError(
'Metric with name `global_step` is not allowed, because Estimator '
@@ -1424,26 +1484,71 @@ class Estimator(object):
return scaffold, update_op, eval_dict, all_hooks
- def _call_model_fn_eval(self, features, labels, config):
+ def _call_model_fn_eval(self, input_fn, config):
+ """Call model_fn for evaluation and handle return values."""
+ features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL)
+
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, config)
- loss_metric = metrics_lib.mean(estimator_spec.loss)
- return (loss_metric, estimator_spec.scaffold,
- estimator_spec.evaluation_hooks, estimator_spec.eval_metric_ops)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss)
+ update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
+ return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
+ input_hooks, update_op, eval_dict)
- def _call_model_fn_eval_distributed(self, features, labels, config):
+ def _call_model_fn_eval_distributed(self, input_fn, config):
"""Call model_fn in distribution mode and handle return values."""
- grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
- self._call_model_fn, features, labels,
- model_fn_lib.ModeKeys.EVAL, config)
+
+ iterator, input_hooks = self._get_iterator_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL, self._eval_distribution)
+
+ is_tpu_strategy = (
+ self._eval_distribution.__class__.__name__ == 'TPUStrategy')
+
+ if is_tpu_strategy:
+ def step_fn(ctx, features, labels=None):
+ """Runs one step of the eval computation and captures outputs."""
+ estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels, model_fn_lib.ModeKeys.EVAL,
+ config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ estimator_spec.eval_metric_ops, estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+ ctx.set_non_tensor_output(name='estimator_spec', output=estimator_spec)
+ ctx.set_non_tensor_output(name='eval_dict', output=eval_dict)
+ return update_op
+
+ # TODO(priyag): Fix eval step hook to account for steps_per_run.
+ ctx = self._eval_distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ update_op = ctx.run_op
+ eval_dict = ctx.non_tensor_outputs['eval_dict']
+ grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
+ else:
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
+ grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
+ self._call_model_fn, features, labels,
+ model_fn_lib.ModeKeys.EVAL, config)
+ eval_metric_ops = _verify_and_create_loss_metric(
+ grouped_estimator_spec.eval_metric_ops, grouped_estimator_spec.loss,
+ self._eval_distribution)
+ update_op, eval_dict = _extract_metric_update_ops(
+ eval_metric_ops, self._eval_distribution)
+
scaffold = _combine_distributed_scaffold(
grouped_estimator_spec.scaffold, self._eval_distribution)
evaluation_hooks = self._eval_distribution.unwrap(
grouped_estimator_spec.evaluation_hooks)[0]
- loss_metric = self._eval_distribution.call_for_each_tower(
- metrics_lib.mean, grouped_estimator_spec.loss)
- return (loss_metric, scaffold,
- evaluation_hooks, grouped_estimator_spec.eval_metric_ops)
+ evaluation_hooks = evaluation_hooks + (
+ estimator_util.StrategyInitFinalizeHook(
+ self._eval_distribution.initialize,
+ self._eval_distribution.finalize),)
+
+ return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
all_hooks, output_dir):
@@ -1479,6 +1584,23 @@ class Estimator(object):
warm_starting_util.warm_start(*self._warm_start_settings)
+def _verify_and_create_loss_metric(eval_metric_ops, loss, distribution=None):
+ """Creates a metric for loss and throws an error if one already exists."""
+ if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:
+ raise ValueError(
+ 'Metric with name "%s" is not allowed, because Estimator ' %
+ (model_fn_lib.LOSS_METRIC_KEY) +
+ 'already defines a default metric with the same name.')
+
+ if distribution is None:
+ loss_metric = metrics_lib.mean(loss)
+ else:
+ loss_metric = distribution.call_for_each_tower(
+ metrics_lib.mean, loss)
+ eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric
+ return eval_metric_ops
+
+
def maybe_overwrite_model_dir_and_session_config(config, model_dir):
"""Overwrite estimator config by `model_dir` and `session_config` if needed.
@@ -1512,9 +1634,9 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
"`model_dir` are set both in constructor and `RunConfig`, but with "
"different values. In constructor: '{}', in `RunConfig`: "
"'{}' ".format(model_dir, config.model_dir))
- if model_dir:
- config = run_config.RunConfig.replace(config, model_dir=model_dir)
- if getattr(config, 'model_dir', None) is None:
+ if model_dir:
+ config = run_config.RunConfig.replace(config, model_dir=model_dir)
+ elif getattr(config, 'model_dir', None) is None:
model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s', model_dir)
config = run_config.RunConfig.replace(config, model_dir=model_dir)
@@ -1523,7 +1645,7 @@ def maybe_overwrite_model_dir_and_session_config(config, model_dir):
def create_per_tower_ready_op(scaffold):
- """Create a Scaffold.ready_op inside a tower."""
+ """Create a `tf.train.Scaffold.ready_op` inside a tower."""
if scaffold.ready_op:
return scaffold.ready_op
@@ -1538,7 +1660,7 @@ def create_per_tower_ready_op(scaffold):
def create_per_tower_ready_for_local_init_op(scaffold):
- """Create a Scaffold.ready_for_local_init_op inside a tower."""
+ """Create a `tf.train.Scaffold.ready_for_local_init_op` inside a tower."""
if scaffold.ready_for_local_init_op:
return scaffold.ready_for_local_init_op
@@ -1636,7 +1758,7 @@ def _check_checkpoint_available(model_dir):
def _check_hooks_type(hooks):
- """Returns hooks if all are SessionRunHook, raises TypeError otherwise."""
+ """Returns hooks if all are `SessionRunHook`, raises TypeError otherwise."""
hooks = list(hooks or [])
for h in hooks:
if not isinstance(h, training.SessionRunHook):
@@ -1656,17 +1778,18 @@ def _check_listeners_type(saving_listeners):
def _get_replica_device_setter(config):
- """Creates a replica device setter if required as a default device_fn.
+ """Creates a replica device setter if required as a default `device_fn`.
- `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
- distributed related arguments such as number of ps_replicas based on given
- config.
+ `Estimator` uses `tf.train.ReplicaDeviceSetter` as a default device placer. It
+ sets the
+ distributed related arguments such as number of `ps_replicas` based on given
+ `config`.
Args:
- config: A `RunConfig` instance.
+ config: A `tf.estimator.RunConfig` instance.
Returns:
- A replica device setter, or None.
+ A replica device setter, or `None`.
"""
if config.task_type:
worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
@@ -1685,7 +1808,7 @@ def _get_replica_device_setter(config):
def _verify_model_fn_args(model_fn, params):
- """Verifies model fn arguments."""
+ """Verifies `model_fn` arguments."""
args = set(function_utils.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
@@ -1717,19 +1840,21 @@ def _extract_metric_update_ops(eval_dict, distribution=None):
update_ops = []
value_ops = {}
# Sort metrics lexicographically so graph is identical every time.
- for name, metric_ops in sorted(six.iteritems(eval_dict)):
- value_ops[name] = metric_ops[0]
- if distribution:
- update_op = distribution.group(metric_ops[1])
+ for name, value in sorted(six.iteritems(eval_dict)):
+ if isinstance(value, metrics.Metric):
+ metric_result = value.result()
+ # We expect only one update op for every metric when there is no
+ # distribution strategy.
+ metric_update = value.updates if distribution else value.updates[0]
else:
- update_op = metric_ops[1]
- update_ops.append(update_op)
+ metric_result = value[0]
+ metric_update = value[1]
- if update_ops:
- update_op = control_flow_ops.group(*update_ops)
- else:
- update_op = None
+ value_ops[name] = metric_result
+ update_ops.append(
+ distribution.group(metric_update) if distribution else metric_update)
+ update_op = control_flow_ops.group(*update_ops) if update_ops else None
return update_op, value_ops
@@ -1783,10 +1908,24 @@ def _write_dict_to_summary(output_dir,
logging.warn('Skipping summary for %s, cannot parse string to Summary.',
key)
continue
+ elif isinstance(dictionary[key], np.ndarray):
+ value = summary_proto.value.add()
+ value.tag = key
+ value.node_name = key
+ tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
+ value.tensor.CopyFrom(tensor_proto)
+ # pylint: disable=line-too-long
+ logging.info(
+ 'Summary for np.ndarray is not visible in Tensorboard by default. '
+ 'Consider using a Tensorboard plugin for visualization (see '
+ 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
+ ' for more information).')
+ # pylint: enable=line-too-long
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, '
- 'np.int32 or int or a serialized string of Summary.', key)
+ 'np.int32 or int or np.ndarray or a serialized string of Summary.',
+ key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
@@ -1816,7 +1955,7 @@ def _write_checkpoint_path_to_summary(output_dir, checkpoint_path,
def _has_dataset_or_queue_runner(maybe_tensor):
- """Returns True if TF dataset or QueueRunner has been used."""
+ """Returns `True` if `Dataset` or `QueueRunner` has been used."""
# Check TF dataset first. Here, we use a simple algorithm to check the top
# level Tensors only, which should be sufficient for most users.
tensors = [x for x in nest.flatten(maybe_tensor) if isinstance(x, ops.Tensor)]
@@ -1839,9 +1978,9 @@ class WarmStartSettings(
'var_name_to_vocab_info',
'var_name_to_prev_var_name',
])):
- """Settings for warm-starting in Estimators.
+ """Settings for warm-starting in `tf.estimator.Estimators`.
- Example Use with canned `DNNEstimator`:
+ Example Use with canned `tf.estimator.DNNEstimator`:
```
emb_vocab_file = tf.feature_column.embedding_column(
@@ -1958,23 +2097,19 @@ class WarmStartSettings(
ckpt_to_initialize_from: [Required] A string specifying the directory with
checkpoint file(s) or path to checkpoint from which to warm-start the
model parameters.
- vars_to_warm_start: [Optional] One of the following:
-
- - A regular expression (string) that captures which variables to
- warm-start (see tf.get_collection). This expression will only consider
- variables in the TRAINABLE_VARIABLES collection.
- - A list of Variables to warm-start.
- - A list of strings, each representing a full variable name to warm-start.
- - `None`, in which case only variables specified in
- `var_name_to_vocab_info` will be warm-started.
-
- Defaults to `'.*'`, which warm-starts all variables in the
- TRAINABLE_VARIABLES collection. Note that this excludes variables such as
- accumulators and moving statistics from batch norm.
+ vars_to_warm_start: [Optional] One of the following: - A regular expression
+ (string) that captures which variables to warm-start (see
+ `tf.get_collection`). This expression will only consider variables in the
+ `TRAINABLE_VARIABLES` collection. - A list of Variables to warm-start. - A
+ list of strings, each representing a full variable name to warm-start. -
+ `None`, in which case only variables specified in `var_name_to_vocab_info`
+ will be warm-started. Defaults to `'.*'`, which warm-starts all variables
+ in the `TRAINABLE_VARIABLES` collection. Note that this excludes
+ variables such as accumulators and moving statistics from batch norm.
var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
- VocabInfo. The variable names should be "full" variables, not the names
- of the partitions. If not explicitly provided, the variable is assumed to
- have no vocabulary.
+ `tf.estimator.VocabInfo`. The variable names should be "full" variables,
+ not the names of the partitions. If not explicitly provided, the variable
+ is assumed to have no vocabulary.
var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
name of the previously-trained variable in `ckpt_to_initialize_from`. If
not explicitly provided, the name of the variable is assumed to be same
@@ -1999,43 +2134,45 @@ class WarmStartSettings(
def _get_saved_model_ckpt(saved_model_dir):
- """Return path to variables checkpoint in a SavedModel directory."""
+ """Return path to variables checkpoint in a `SavedModel` directory."""
if not gfile.Exists(
- os.path.join(compat.as_bytes(saved_model_dir),
- compat.as_bytes('variables/variables.index'))):
+ os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
+ compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)
- return os.path.join(
- compat.as_bytes(saved_model_dir),
- compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
- constants.VARIABLES_FILENAME)))
+ return saved_model_utils.get_variables_path(saved_model_dir)
def _get_default_warm_start_settings(warm_start_from):
- """Returns default WarmStartSettings.
+ """Returns default `tf.estimator.WarmStartSettings`.
Args:
warm_start_from: Either a string representing the filepath of a checkpoint
- or SavedModel to initialize from, or an instance of WarmStartSettings.
+ or `SavedModel` to initialize from, or an instance of
+ `tf.estimator.WarmStartSettings`.
Returns:
- Either None or an instance of WarmStartSettings.
+ Either None or an instance of `WarmStartSettings`.
Raises:
- ValueError: If warm_start_from is not None but is neither a string nor an
- instance of WarmStartSettings.
+ ValueError: If `warm_start_from` is not `None` but is neither a string nor
+ an
+ instance of `WarmStartSettings`.
"""
if warm_start_from is None:
return None
if isinstance(warm_start_from, (six.string_types, six.binary_type)):
# Infer that this is a SavedModel if export_path +
# 'variables/variables.index' exists, and if so, construct the
- # WarmStartSettings pointing to export_path + 'variables/variables'.
- if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
- compat.as_bytes('variables/variables.index'))):
+ # WarmStartSettings pointing to the variables path
+ # (export_path + 'variables/variables').
+ if gfile.Exists(os.path.join(
+ saved_model_utils.get_variables_dir(warm_start_from),
+ compat.as_text('variables.index'))):
logging.info('Warm-starting from a SavedModel')
return WarmStartSettings(
- ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
+ ckpt_to_initialize_from=saved_model_utils.get_variables_path(
+ warm_start_from))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index e8552092e0..1ed5e30b0e 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.layers import layers
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@@ -58,6 +59,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
+from tensorflow.python.ops.random_ops import random_uniform
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -158,16 +160,7 @@ class EstimatorInheritanceConstraintTest(test.TestCase):
def __init__(self):
super(_Estimator, self).__init__(model_fn=dummy_model_fn)
- def _call_input_fn(self, input_fn, mode):
- return input_fn()
-
- def _create_global_step(self, graph):
- pass
-
- def _convert_train_steps_to_hooks(self, steps, max_steps):
- pass
-
- def _convert_eval_steps_to_hooks(self, steps):
+ def _tf_api_names(self):
pass
_Estimator()
@@ -473,6 +466,29 @@ class EstimatorTrainTest(test.TestCase):
est.train(InputFn(), steps=1)
self.assertEqual(1, input_fn_call_count[0])
+ def test_nested_input_fn(self):
+ expected_params = {'batch_size': 10}
+
+ def _input_fn():
+ dataset_features = dataset_ops.Dataset.from_tensor_slices(
+ (random_uniform([4]),
+ random_uniform([4, 100], maxval=100, dtype=dtypes.int32)))
+ dataset_labels = dataset_ops.Dataset.from_tensor_slices(
+ random_uniform([4, 10]))
+ dataset = dataset_ops.Dataset.zip((dataset_features, dataset_labels))
+ dataset = dataset.repeat(-1)
+ iterator = dataset.make_initializable_iterator()
+ return iterator.get_next()
+
+ def _model_fn(features, labels, mode, params, config):
+ del params, config
+ return model_fn_global_step_incrementer(features, labels, mode)
+
+ expected_config = run_config.RunConfig().replace(tf_random_seed=4321)
+ est = estimator.Estimator(
+ model_fn=_model_fn, params=expected_params, config=expected_config)
+ est.train(_input_fn, steps=4)
+
def test_input_fn_args(self):
expected_mode = model_fn_lib.ModeKeys.TRAIN
expected_params = {'batch_size': 10}
@@ -940,22 +956,44 @@ class EstimatorTrainTest(test.TestCase):
est = estimator.Estimator(model_fn=_model_fn)
est.train(dummy_input_fn, steps=1)
+ def test_config_should_not_be_evaluator_or_ps(self):
+
+ class FakeEvaluatorConfig(run_config.RunConfig):
+
+ @property
+ def task_type(self):
+ return run_config.TaskType.EVALUATOR
+
+ est = estimator.Estimator(
+ model_fn=dummy_model_fn, config=FakeEvaluatorConfig())
+ with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
+ est.train(dummy_input_fn, steps=1)
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
- metric_name = params.get('metric_name') or 'metric'
- metric_value = params.get('metric_value') or 2.
global_step = training.get_global_step()
loss = constant_op.constant(1.)
+ metric_name_1 = params.get('metric_name') or 'metric'
+ metric_value_1 = params.get('metric_value') or 2.
+ metric_name_2 = params.get('metric_name_2') or 'metric2'
+ metric_value_2 = params.get('metric_value_2') or 2.
+
metric_update_op = loss.op
metric_tensor = control_flow_ops.with_dependencies(
- [metric_update_op], constant_op.constant(metric_value))
+ [metric_update_op], constant_op.constant(metric_value_1))
+
+ mean = metrics_module.Mean()
+ mean.update_state(metric_value_2)
return model_fn_lib.EstimatorSpec(
mode,
loss=loss,
predictions={'predictions': constant_op.constant(1.)},
train_op=state_ops.assign_add(global_step, 1),
- eval_metric_ops={metric_name: (metric_tensor, metric_update_op)})
+ eval_metric_ops={
+ metric_name_1: (metric_tensor, metric_update_op),
+ metric_name_2: mean,
+ })
class _StepCounterHook(session_run_hook.SessionRunHook):
@@ -1139,16 +1177,22 @@ class EstimatorEvaluateTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params):
del features, labels, params
+ mean = metrics_module.Mean()
+ mean.update_state(variables.Variable(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(
- variables.Variable(2.) + 1)})
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ })
+
est = estimator.Estimator(model_fn=_model_fn)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ scores = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is newly
# initialized (since there is no checkpoint).
- self.assertEqual(3., metrics['metric'])
+ self.assertEqual(3., scores['mean1'])
+ self.assertEqual(3., scores['mean2'])
def test_no_checkpoint_uses_init_with_warm_starting(self):
def _make_model_fn(x):
@@ -1156,14 +1200,24 @@ class EstimatorEvaluateTest(test.TestCase):
_, _ = features, labels
x_var = variable_scope.get_variable('x', initializer=x)
global_step = training.get_global_step()
+ mean = metrics_module.Mean()
+ mean.update_state(x_var + 1)
return model_fn_lib.EstimatorSpec(
mode,
predictions={'y': constant_op.constant(1.0)},
loss=constant_op.constant(1.),
- eval_metric_ops={'metric': metrics_lib.mean(x_var + 1)},
+ eval_metric_ops={
+ 'mean1': mean,
+ 'mean2': metrics_lib.mean(x_var + 1)
+ },
train_op=state_ops.assign_add(global_step, 1),
- export_outputs={'test': export_output.ClassificationOutput(
- constant_op.constant([4.2]), constant_op.constant(['label']))})
+ export_outputs={
+ 'test':
+ export_output.ClassificationOutput(
+ constant_op.constant([4.2]),
+ constant_op.constant(['label']))
+ })
+
return _variable_creating_and_export_model_fn
first_est = estimator.Estimator(model_fn=_make_model_fn(42.))
@@ -1182,30 +1236,37 @@ class EstimatorEvaluateTest(test.TestCase):
# or an exported SavedModel.
est = estimator.Estimator(model_fn=_make_model_fn(52.),
warm_start_from=exported_path)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from the SavedModel of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
est = estimator.Estimator(model_fn=_make_model_fn(62.),
warm_start_from=first_est.model_dir)
- metrics = est.evaluate(dummy_input_fn, steps=1)
+ eval_metrics = est.evaluate(dummy_input_fn, steps=1)
# Metric value here is set to 1 + the value of the Variable that is
# warm-started from a checkpoint of the first model (42.), as opposed to
# the initialization in the new model_fn (52.).
- self.assertEqual(43., metrics['metric'])
+ self.assertEqual(43., eval_metrics['mean1'])
+ self.assertEqual(43., eval_metrics['mean2'])
def test_scores(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params={
'metric_name': 'metric',
- 'metric_value': 2.})
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('metric', scores)
self.assertAlmostEqual(2., scores['metric'])
+ self.assertIn('metric2', scores)
+ self.assertAlmostEqual(3., scores['metric2'])
def test_tuple_metrics(self):
def _model_fn(features, labels, mode):
@@ -1256,8 +1317,12 @@ class EstimatorEvaluateTest(test.TestCase):
def test_global_step_is_reported(self):
est = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
- params={'metric_name': 'metric',
- 'metric_value': 2.})
+ params={
+ 'metric_name': 'metric',
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ })
est.train(dummy_input_fn, steps=5)
scores = est.evaluate(dummy_input_fn, steps=1)
self.assertIn('global_step', scores)
@@ -1300,7 +1365,10 @@ class EstimatorEvaluateTest(test.TestCase):
def test_evaluate_from_checkpoint(self):
params = {
'metric_name': 'metric',
- 'metric_value': 2.}
+ 'metric_value': 2.,
+ 'metric_name_2': 'metric2',
+ 'metric_value_2': 3.,
+ }
est1 = estimator.Estimator(
model_fn=_model_fn_with_eval_metric_ops,
params=params)
@@ -1458,6 +1526,48 @@ class EstimatorEvaluateTest(test.TestCase):
self.assertProtoEquals(expected_tensor_proto,
next(summaries).value[0].tensor)
+ def test_summary_writing_with_tensor(self):
+
+ def model_fn_with_prediction_mean_tensor_eval_metric_ops(
+ features, labels, mode, params):
+ _, _ = features, labels
+ global_step = training.get_global_step()
+
+ metric_name = params.get('metric_name') or 'metric'
+ predictions = constant_op.constant([1., .5, 0.])
+ eval_metric_ops = {metric_name: metrics_lib.mean_tensor(predictions)}
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=constant_op.constant(1.),
+ predictions={'predictions': predictions},
+ train_op=state_ops.assign_add(global_step, 1),
+ eval_metric_ops=eval_metric_ops)
+
+ metric_key = 'PMT'
+ params = {
+ 'metric_name': metric_key,
+ }
+ est = estimator.Estimator(
+ model_fn=model_fn_with_prediction_mean_tensor_eval_metric_ops,
+ params=params,
+ config=run_config.RunConfig(save_summary_steps=1))
+ est.train(input_fn=dummy_input_fn, steps=10)
+ est.evaluate(
+ input_fn=dummy_input_fn,
+ steps=10,
+ )
+
+ writer_cache.FileWriterCache.clear()
+
+ self.assertTrue(
+ check_eventfile_for_keyword(metric_key, est.eval_dir()),
+ '{} should be part of reported summaries.'.format(metric_key))
+
+ summaries = summaries_with_matching_keyword(metric_key, est.eval_dir())
+ for value in next(summaries).value:
+ if value.tag == metric_key:
+ self.assertTrue(value.HasField('tensor'))
+
class EstimatorPredictTest(test.TestCase):
@@ -1957,8 +2067,15 @@ def _model_fn_with_x_y(features, labels, mode):
multiplied = math_ops.multiply(
features['x'], features['y'], name='{}multiplied'.format(prefix))
- metrics = {'mean': metrics_lib.mean(features['x'] - features['y'],
- name='{}mean'.format(prefix))}
+ mean = metrics_module.Mean(name='{}mean'.format(prefix))
+ mean.update_state(features['x'] - features['y'])
+ eval_metrics = {
+ 'mean1':
+ mean,
+ 'mean2':
+ metrics_lib.mean(
+ features['x'] - features['y'], name='{}mean'.format(prefix))
+ }
variables.Variable(1., name='later_var')
variables.Variable(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
@@ -1966,7 +2083,7 @@ def _model_fn_with_x_y(features, labels, mode):
predictions=multiplied,
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
def _model_fn_with_saveables_for_export_tests(features, labels, mode):
@@ -2325,14 +2442,18 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
del features, labels # Unused
- metrics = {'metrics': (constant_op.constant([0]),
- control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ eval_metrics = {
+ 'metrics1': (constant_op.constant([0]), control_flow_ops.no_op()),
+ 'metrics2': metric_obj,
+ }
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
loss=constant_op.constant(1.),
train_op=state_ops.assign_add(training.get_global_step(), 1),
- eval_metric_ops=metrics)
+ eval_metric_ops=eval_metrics)
tmpdir = tempfile.mkdtemp()
est = estimator.Estimator(model_fn=_model_fn)
@@ -2354,8 +2475,10 @@ class EstimatorExportTest(test.TestCase):
meta_graph = loader.load(sess, [tag_constants.EVAL], export_dir)
sig_outputs = meta_graph.signature_def[
model_fn_lib.ModeKeys.EVAL].outputs
- self.assertEqual(
- sig_outputs['metrics/update_op'].name, 'metric_op_wrapper:0')
+ self.assertTrue(sig_outputs['metrics1/update_op'].name.startswith(
+ 'metric_op_wrapper'))
+ self.assertTrue(sig_outputs['metrics2/update_op'].name.startswith(
+ 'metric_op_wrapper'))
def test_export_savedmodel_with_saveables_proto_roundtrip(self):
tmpdir = tempfile.mkdtemp()
@@ -2641,6 +2764,7 @@ class EstimatorExportTest(test.TestCase):
_, _ = features, labels
my_int = variables.Variable(1, name='my_int',
collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ _ = training.get_or_create_steps_per_run_variable()
scores = constant_op.constant([3.])
with ops.control_dependencies([
variables.local_variables_initializer(),
@@ -3009,9 +3133,13 @@ class EstimatorIntegrationTest(test.TestCase):
loss = losses.mean_squared_error(labels, predictions)
train_op = training.GradientDescentOptimizer(learning_rate=0.5).minimize(
loss, training.get_global_step())
+ mean = metrics_module.Mean()
+ mean.update_state(loss)
eval_metric_ops = {
- 'absolute_error': metrics_lib.mean_absolute_error(
- labels, predictions)
+ 'absolute_error':
+ metrics_lib.mean_absolute_error(labels, predictions),
+ 'mean':
+ mean,
}
return model_fn_lib.EstimatorSpec(
@@ -3031,12 +3159,13 @@ class EstimatorIntegrationTest(test.TestCase):
x={'x': data}, y=data, batch_size=50, num_epochs=None, shuffle=True)
est.train(train_input_fn, steps=200)
- # EVALUTE
+ # EVALUATE
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data}, y=data, batch_size=50, num_epochs=1, shuffle=True)
scores = est.evaluate(eval_input_fn)
self.assertEqual(200, scores['global_step'])
self.assertGreater(0.1, scores['absolute_error'])
+ self.assertAlmostEqual(4.4e-14, scores['mean'], places=2)
# PREDICT
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 529e7a8b87..55aace5fa9 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -217,6 +217,29 @@ class TensorServingInputReceiver(
receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
+class UnsupervisedInputReceiver(ServingInputReceiver):
+ """A return type for a training_input_receiver_fn or eval_input_receiver_fn.
+
+ This differs from SupervisedInputReceiver in that it does not require a set
+ of labels.
+
+ The expected return values are:
+ features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the features to be passed to the model.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
+ """
+
+ def __new__(cls, features, receiver_tensors):
+ return super(UnsupervisedInputReceiver, cls).__new__(
+ cls,
+ features=features,
+ receiver_tensors=receiver_tensors,
+ receiver_tensors_alternatives=None)
+
+
class SupervisedInputReceiver(
collections.namedtuple('SupervisedInputReceiver',
['features', 'labels', 'receiver_tensors'])):
@@ -288,14 +311,33 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
def _placeholder_from_tensor(t, default_batch_size=None):
- shape_list = t.get_shape().as_list()
- shape_list[0] = default_batch_size
- shape = tensor_shape.TensorShape(shape_list)
+ """Creates a placeholder that matches the dtype and shape of passed tensor.
+
+ Args:
+ t: Tensor or EagerTensor
+ default_batch_size: the number of query examples expected per batch.
+ Leave unset for variable batch size (recommended).
+
+ Returns:
+ Placeholder that matches the passed tensor.
+ """
+ batch_shape = tensor_shape.TensorShape([default_batch_size])
+ shape = batch_shape.concatenate(t.get_shape()[1:])
# Reuse the feature tensor's op name (t.op.name) for the placeholder,
# excluding the index from the tensor's name (t.name):
# t.name = "%s:%d" % (t.op.name, t._value_index)
- return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
+ try:
+ name = t.op.name
+ except AttributeError:
+ # In Eager mode, tensors don't have ops or names, and while they do have
+ # IDs, those are not maintained across runs. The name here is used
+ # primarily for debugging, and is not critical to the placeholder.
+ # So, in order to make this Eager-compatible, continue with an empty
+ # name if none is available.
+ name = None
+
+ return array_ops.placeholder(dtype=t.dtype, shape=shape, name=name)
def _placeholders_from_receiver_tensors_dict(input_vals,
diff --git a/tensorflow/python/estimator/export/export_output.py b/tensorflow/python/estimator/export/export_output.py
index 20382a58d8..c17fc08f21 100644
--- a/tensorflow/python/estimator/export/export_output.py
+++ b/tensorflow/python/estimator/export/export_output.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.util.tf_export import estimator_export
@@ -259,7 +260,10 @@ class _SupervisedOutput(ExportOutput):
loss: dict of Tensors or single Tensor representing calculated loss.
predictions: dict of Tensors or single Tensor representing model
predictions.
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
metric_value must be a Tensor, and update_op must be a Tensor or Op.
Raises:
@@ -311,7 +315,11 @@ class _SupervisedOutput(ExportOutput):
Here, we separate out the tuples and create a dict with names to tensors.
Args:
- metrics: dict of (metric_value, update_op) tuples, or a single tuple.
+ metrics: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) (metric_value, update_op) tuples, or a single tuple.
+ metric_value must be a Tensor, and update_op must be a Tensor or Op.
Returns:
dict of output_names to tensors
@@ -324,7 +332,13 @@ class _SupervisedOutput(ExportOutput):
metrics = {self.METRICS_NAME: metrics}
outputs = {}
- for key, (metric_val, metric_op) in metrics.items():
+ for key, value in metrics.items():
+ if isinstance(value, metrics_module.Metric):
+ metric_val = value.result()
+ assert len(value.updates) == 1 # We expect only one update op.
+ metric_op = value.updates[0]
+ else:
+ metric_val, metric_op = value
key = self._check_output_key(key, self.METRICS_NAME)
key = self._prefix_key(key, self.METRICS_NAME)
@@ -397,7 +411,3 @@ class EvalOutput(_SupervisedOutput):
def _get_signature_def_fn(self):
return signature_def_utils.supervised_eval_signature_def
-
-
-
-
diff --git a/tensorflow/python/estimator/export/export_output_test.py b/tensorflow/python/estimator/export/export_output_test.py
index d94c764fd7..96ce0e580d 100644
--- a/tensorflow/python/estimator/export/export_output_test.py
+++ b/tensorflow/python/estimator/export/export_output_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
@@ -240,16 +241,19 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10])),
- "metrics2": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics": metric_obj,
+ "metrics2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(outputter.loss["loss/my_loss"], loss["my_loss"])
self.assertEqual(
outputter.predictions["predictions/output1"], predictions["output1"])
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper:0")
self.assertEqual(
outputter.metrics["metrics2/update_op"], metrics["metrics2"][1])
@@ -259,7 +263,8 @@ class SupervisedOutputTest(test.TestCase):
self.assertEqual(outputter.loss, {"loss": loss["my_loss"]})
self.assertEqual(
outputter.predictions, {"predictions": predictions["output1"]})
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
+ self.assertEqual(outputter.metrics["metrics/update_op"].name,
+ "metric_op_wrapper_1:0")
def test_supervised_outputs_none(self):
outputter = MockSupervisedOutput(
@@ -282,34 +287,56 @@ class SupervisedOutputTest(test.TestCase):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {("my", "loss"): constant_op.constant([0])}
predictions = {(u"output1", "2"): constant_op.constant(["foo"])}
- metrics = {("metrics", "twice"): (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ ("metrics", "1"):
+ metric_obj,
+ ("metrics", "2"): (constant_op.constant([0]),
+ constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss/my/loss"]))
self.assertEqual(set(outputter.predictions.keys()),
set(["predictions/output1/2"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/twice/value", "metrics/twice/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics/1/value", "metrics/1/update_op", "metrics/2/value",
+ "metrics/2/update_op"
+ ]))
def test_supervised_outputs_no_prepend(self):
"""Tests that no errors are raised when provided outputs are valid."""
loss = {"loss": constant_op.constant([0])}
predictions = {u"predictions": constant_op.constant(["foo"])}
- metrics = {u"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
self.assertEqual(set(outputter.loss.keys()), set(["loss"]))
self.assertEqual(set(outputter.predictions.keys()), set(["predictions"]))
- self.assertEqual(set(outputter.metrics.keys()),
- set(["metrics/value", "metrics/update_op"]))
+ self.assertEqual(
+ set(outputter.metrics.keys()),
+ set([
+ "metrics_1/value", "metrics_1/update_op", "metrics_2/update_op",
+ "metrics_2/value"
+ ]))
def test_train_signature_def(self):
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]),
- constant_op.constant([10]))}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), constant_op.constant([10]))
+ }
outputter = export_output_lib.TrainOutput(loss, predictions, metrics)
@@ -318,7 +345,8 @@ class SupervisedOutputTest(test.TestCase):
sig_def = outputter.as_signature_def(receiver)
self.assertTrue("loss/my_loss" in sig_def.outputs)
- self.assertTrue("metrics/value" in sig_def.outputs)
+ self.assertTrue("metrics_1/value" in sig_def.outputs)
+ self.assertTrue("metrics_2/value" in sig_def.outputs)
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
@@ -337,18 +365,33 @@ class SupervisedOutputTest(test.TestCase):
self.assertTrue("predictions/output1" in sig_def.outputs)
self.assertTrue("features" in sig_def.inputs)
- def test_metric_op_is_operation(self):
+ def test_metric_op_is_tensor(self):
"""Tests that ops.Operation is wrapped by a tensor for metric_ops."""
loss = {"my_loss": constant_op.constant([0])}
predictions = {u"output1": constant_op.constant(["foo"])}
- metrics = {"metrics": (constant_op.constant([0]), control_flow_ops.no_op())}
+ metric_obj = metrics_module.Mean()
+ metric_obj.update_state(constant_op.constant([0]))
+ metrics = {
+ "metrics_1": metric_obj,
+ "metrics_2": (constant_op.constant([0]), control_flow_ops.no_op())
+ }
outputter = MockSupervisedOutput(loss, predictions, metrics)
- self.assertEqual(outputter.metrics["metrics/value"], metrics["metrics"][0])
- self.assertEqual(
- outputter.metrics["metrics/update_op"].name, "metric_op_wrapper:0")
+
+ self.assertTrue(outputter.metrics["metrics_1/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_1/update_op"], ops.Tensor))
self.assertTrue(
- isinstance(outputter.metrics["metrics/update_op"], ops.Tensor))
+ isinstance(outputter.metrics["metrics_1/value"], ops.Tensor))
+
+ self.assertEqual(outputter.metrics["metrics_2/value"],
+ metrics["metrics_2"][0])
+ self.assertTrue(outputter.metrics["metrics_2/update_op"].name.startswith(
+ "metric_op_wrapper"))
+ self.assertTrue(
+ isinstance(outputter.metrics["metrics_2/update_op"], ops.Tensor))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index d2ac7f0b3b..3eed1ab163 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
@@ -162,6 +163,29 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase):
_ = export.ServingInputReceiver(feature, receiver_tensor)
+class UnsupervisedInputReceiverTest(test_util.TensorFlowTestCase):
+
+ # Since this is basically a wrapper around ServingInputReceiver, we only
+ # have a simple sanity check to ensure that it works.
+
+ def test_unsupervised_input_receiver_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0":
+ constant_op.constant([0]),
+ u"feature1":
+ constant_op.constant([1]),
+ "feature2":
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ export.UnsupervisedInputReceiver(features, receiver_tensors)
+
+
class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
def test_input_receiver_constructor(self):
@@ -378,6 +402,21 @@ class ExportTest(test_util.TensorFlowTestCase):
v = serving_input_receiver_fn()
self.assertTrue(isinstance(v, export.ServingInputReceiver))
+ def test_build_raw_serving_input_receiver_fn_without_shape(self):
+ """Test case for issue #21178."""
+ f = {"feature_1": array_ops.placeholder(dtypes.float32),
+ "feature_2": array_ops.placeholder(dtypes.int32)}
+ serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f)
+ v = serving_input_receiver_fn()
+ self.assertTrue(isinstance(v, export.ServingInputReceiver))
+ self.assertEqual(
+ tensor_shape.unknown_shape(),
+ v.receiver_tensors["feature_1"].shape)
+ self.assertEqual(
+ tensor_shape.unknown_shape(),
+ v.receiver_tensors["feature_2"].shape)
+
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_serving_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -396,6 +435,7 @@ class ExportTest(test_util.TensorFlowTestCase):
dtypes.int32,
serving_input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -416,6 +456,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(
dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -439,6 +480,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["input", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_batch_size(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -451,6 +493,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
self.assertEqual([10], input_receiver.features["feature_1"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -459,6 +502,7 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn(self):
def dummy_input_fn():
return ({"x": constant_op.constant([[1], [1]]),
@@ -476,6 +520,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["x", "y", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
def dummy_input_fn(feature_key="x"):
return ({feature_key: constant_op.constant([[1], [1]]),
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index c4b006955c..fcccfbde7a 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -323,6 +323,43 @@ class LatestExporterTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir_3))
self.assertTrue(gfile.Exists(export_dir_4))
+ def test_garbage_collect_exports_with_trailing_delimiter(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ gfile.MkDir(export_dir_base)
+ export_dir_1 = _create_test_export_dir(export_dir_base)
+ export_dir_2 = _create_test_export_dir(export_dir_base)
+ export_dir_3 = _create_test_export_dir(export_dir_base)
+ export_dir_4 = _create_test_export_dir(export_dir_base)
+
+ self.assertTrue(gfile.Exists(export_dir_1))
+ self.assertTrue(gfile.Exists(export_dir_2))
+ self.assertTrue(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
+ def _serving_input_receiver_fn():
+ return array_ops.constant([1]), None
+
+ exporter = exporter_lib.LatestExporter(
+ name="latest_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ exports_to_keep=1)
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ # Garbage collect all but the most recent 2 exports,
+ # where recency is determined based on the timestamp directory names.
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ mock_list_directory.return_value = [
+ os.path.basename(export_dir_1) + b"/",
+ os.path.basename(export_dir_2) + b"/",
+ os.path.basename(export_dir_3) + b"/",
+ os.path.basename(export_dir_4) + b"/",
+ ]
+ exporter.export(estimator, export_dir_base, None, None, False)
+
+ self.assertFalse(gfile.Exists(export_dir_1))
+ self.assertFalse(gfile.Exists(export_dir_2))
+ self.assertFalse(gfile.Exists(export_dir_3))
+ self.assertTrue(gfile.Exists(export_dir_4))
+
def _create_test_export_dir(export_dir_base):
export_dir = _get_timestamped_export_dir(export_dir_base)
diff --git a/tensorflow/python/estimator/gc.py b/tensorflow/python/estimator/gc.py
index 9f8a463ec1..03ad33dd6b 100644
--- a/tensorflow/python/estimator/gc.py
+++ b/tensorflow/python/estimator/gc.py
@@ -201,9 +201,11 @@ def _get_paths(base_dir, parser):
raw_paths = gfile.ListDirectory(base_dir)
paths = []
for r in raw_paths:
- p = parser(Path(os.path.join(compat.as_str_any(base_dir),
- compat.as_str_any(r)),
- None))
+ # ListDirectory() return paths with "/" at the last if base_dir was GCS URL
+ r = compat.as_str_any(r)
+ if r[-1] == '/':
+ r = r[0:len(r)-1]
+ p = parser(Path(os.path.join(compat.as_str_any(base_dir), r), None))
if p:
paths.append(p)
return sorted(paths)
diff --git a/tensorflow/python/estimator/gc_test.py b/tensorflow/python/estimator/gc_test.py
index 2cbdd511d1..53c3d4ca2a 100644
--- a/tensorflow/python/estimator/gc_test.py
+++ b/tensorflow/python/estimator/gc_test.py
@@ -140,6 +140,17 @@ class GcTest(test_util.TensorFlowTestCase):
gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42"))
gc._get_paths(base_dir, _create_parser(base_dir))
+ def testGcsDirWithSeparator(self):
+ base_dir = "gs://bucket/foo"
+ with test.mock.patch.object(gfile, "ListDirectory") as mock_list_directory:
+ # gfile.ListDirectory returns directory names with separator '/'
+ mock_list_directory.return_value = ["0/", "1/"]
+ self.assertEqual(
+ gc._get_paths(base_dir, _create_parser(base_dir)),
+ [
+ gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1)
+ ])
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/estimator/inputs/numpy_io_test.py b/tensorflow/python/estimator/inputs/numpy_io_test.py
index 81b201cc5c..4e7b00b307 100644
--- a/tensorflow/python/estimator/inputs/numpy_io_test.py
+++ b/tensorflow/python/estimator/inputs/numpy_io_test.py
@@ -19,9 +19,15 @@ from __future__ import division
from __future__ import print_function
import numpy as np
-
+from tensorflow.python.client import session as session_lib
from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.feature_column import feature_column_lib as fc
+from tensorflow.python.feature_column.feature_column import _LinearModel
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
@@ -456,5 +462,159 @@ class NumpyIoTest(test.TestCase):
self.assertAllEqual(res_arr[1], res_dict[1])
+class FeatureColumnIntegrationTest(test.TestCase):
+
+ def _initialized_session(self, config=None):
+ sess = session_lib.Session(config=config)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ return sess
+
+ def _get_linear_model_bias(self, name='linear_model'):
+ with variable_scope.variable_scope(name, reuse=True):
+ return variable_scope.get_variable('bias_weights')
+
+ def _get_linear_model_column_var(self, column, name='linear_model'):
+ return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ name + '/' + column.name)[0]
+
+ def _get_keras_linear_model_predictions(
+ self,
+ features,
+ feature_columns,
+ units=1,
+ sparse_combiner='sum',
+ weight_collections=None,
+ trainable=True,
+ cols_to_vars=None):
+ keras_linear_model = _LinearModel(
+ feature_columns,
+ units,
+ sparse_combiner,
+ weight_collections,
+ trainable,
+ name='linear_model')
+ retval = keras_linear_model(features) # pylint: disable=not-callable
+ if cols_to_vars is not None:
+ cols_to_vars.update(keras_linear_model.cols_to_vars())
+ return retval
+
+ def test_linear_model_numpy_input_fn(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.linear_model(features, [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with self._initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = self._get_linear_model_bias()
+ price_buckets_var = self._get_linear_model_column_var(price_buckets)
+ body_style_var = self._get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_linear_model_impl_numpy_input_fn(self):
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
+ price, boundaries=[
+ 0.,
+ 10.,
+ 100.,
+ ])
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([-1., 2., 13., 104.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = self._get_keras_linear_model_predictions(
+ features, [price_buckets, body_style])
+ # self.assertEqual(1 + 3 + 5, net.shape[1])
+ with self._initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ bias = self._get_linear_model_bias()
+ price_buckets_var = self._get_linear_model_column_var(price_buckets)
+ body_style_var = self._get_linear_model_column_var(body_style)
+
+ sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
+ sess.run(bias.assign([5.]))
+
+ self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
+ def test_functional_input_layer_with_numpy_input_fn(self):
+ embedding_values = (
+ (1., 2., 3., 4., 5.), # id 0
+ (6., 7., 8., 9., 10.), # id 1
+ (11., 12., 13., 14., 15.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ del shape, dtype, partition_info
+ return embedding_values
+
+ # price has 1 dimension in input_layer
+ price = fc.numeric_column('price')
+ body_style = fc.categorical_column_with_vocabulary_list(
+ 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
+ # one_hot_body_style has 3 dims in input_layer.
+ one_hot_body_style = fc.indicator_column(body_style)
+ # embedded_body_style has 5 dims in input_layer.
+ embedded_body_style = fc.embedding_column(body_style, dimension=5,
+ initializer=_initializer)
+
+ input_fn = numpy_io.numpy_input_fn(
+ x={
+ 'price': np.array([11., 12., 13., 14.]),
+ 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
+ },
+ batch_size=2,
+ shuffle=False)
+ features = input_fn()
+ net = fc.input_layer(features,
+ [price, one_hot_body_style, embedded_body_style])
+ self.assertEqual(1 + 3 + 5, net.shape[1])
+ with self._initialized_session() as sess:
+ coord = coordinator.Coordinator()
+ threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
+
+ # Each row is formed by concatenating `embedded_body_style`,
+ # `one_hot_body_style`, and `price` in order.
+ self.assertAllEqual(
+ [[11., 12., 13., 14., 15., 0., 0., 1., 11.],
+ [1., 2., 3., 4., 5., 1., 0., 0., 12]],
+ sess.run(net))
+
+ coord.request_stop()
+ coord.join(threads)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index c91204a35f..6361c6acc1 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -33,9 +33,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
-from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.keras.engine.network import Network
-from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
@@ -43,12 +40,10 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import data_structures
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -92,184 +87,78 @@ def _any_weight_initialized(keras_model):
return False
-def _create_ordered_io(keras_model, estimator_io, is_input=True):
- """Create a list of tensors from IO dictionary based on Keras IO order.
+def _convert_estimator_io_to_keras(keras_model, features, labels):
+ """Converts estimator features and labels to keras input and target tensors.
Args:
- keras_model: An instance of compiled keras model.
- estimator_io: The features or labels (dict or plain array) from model_fn.
- is_input: True if dictionary is for inputs.
+ keras_model: a compiled `tf.keras.Model` instance, used to determine the
+ order of the returned lists.
+ features: Dict of tensors or `None`.
+ labels: Dict of tensors, a single tensor, or `None`.
Returns:
- A list of tensors based on Keras IO order.
-
- Raises:
- ValueError: if dictionary keys cannot be found in Keras model input_names
- or output_names.
- """
- if isinstance(estimator_io, (list, tuple)):
- # Case currently not supported by most built-in input_fn,
- # but it's good to have for sanity
- return [_convert_tensor(x) for x in estimator_io]
- elif isinstance(estimator_io, dict):
- if is_input:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.input_names
- else:
- keras_io_names = [
- 'input_%d' % i for i in range(1, len(estimator_io) + 1)]
- else:
- if keras_model._is_graph_network:
- keras_io_names = keras_model.output_names
- else:
- keras_io_names = [
- 'output_%d' % i for i in range(1, len(estimator_io) + 1)]
-
- for key in estimator_io:
- if key not in keras_io_names:
- raise ValueError(
- 'Cannot find %s with name "%s" in Keras Model. '
- 'It needs to match one '
- 'of the following: %s' % ('input' if is_input else 'output', key,
- ', '.join(keras_io_names)))
- tensors = [_convert_tensor(estimator_io[io_name])
- for io_name in keras_io_names]
- return tensors
- else:
- # Plain array.
- return _convert_tensor(estimator_io)
-
-
-def _in_place_subclassed_model_reset(model):
- """Substitute for model cloning that works for subclassed models.
-
- Subclassed models cannot be cloned because their topology is not serializable.
- To "instantiate" an identical model in a new TF graph, we reuse the original
- model object, but we clear its state.
-
- After calling this function on a model instance, you can use the model
- instance as if it were a model clone (in particular you can use it in a new
- graph).
-
- This method clears the state of the input model. It is thus destructive.
- However the original state can be restored fully by calling
- `_in_place_subclassed_model_state_restoration`.
-
- Args:
- model: Instance of a Keras model created via subclassing.
-
- Raises:
- ValueError: In case the model uses a subclassed model as inner layer.
+ Tuple of (
+ list of input tensors or `None`,
+ list of target tensors or `None`)
+ The order of tensors is determined by the order set in the keras model.
"""
- assert not model._is_graph_network # Only makes sense for subclassed networks
- # Retrieve all layers tracked by the model as well as their attribute names
- attributes_cache = {}
- for name in dir(model):
- try:
- value = getattr(model, name)
- except (AttributeError, ValueError, TypeError):
- continue
- if isinstance(value, Layer):
- attributes_cache[name] = value
- assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
- # Handle case: list/tuple of layers (also tracked by the Network API).
- if value and all(isinstance(val, Layer) for val in value):
- raise ValueError('We do not support the use of list-of-layers '
- 'attributes in subclassed models used with '
- '`model_to_estimator` at this time. Found list '
- 'model: %s' % name)
-
- # Replace layers on the model with fresh layers
- layers_to_names = {value: key for key, value in attributes_cache.items()}
- original_layers = model._layers[:]
- model._layers = data_structures.NoDependency([])
- for layer in original_layers: # We preserve layer order.
- config = layer.get_config()
- # This will not work for nested subclassed models used as layers.
- # This would be theoretically possible to support, but would add complexity.
- # Only do it if users complain.
- if isinstance(layer, Network) and not layer._is_graph_network:
- raise ValueError('We do not support the use of nested subclassed models '
- 'in `model_to_estimator` at this time. Found nested '
- 'model: %s' % layer)
- fresh_layer = layer.__class__.from_config(config)
- name = layers_to_names[layer]
- setattr(model, name, fresh_layer)
-
- # Cache original model build attributes (in addition to layers)
- if (not hasattr(model, '_original_attributes_cache') or
- model._original_attributes_cache is None):
- if model.built:
- attributes_to_cache = [
- 'inputs',
- 'outputs',
- '_feed_outputs',
- '_feed_output_names',
- '_feed_output_shapes',
- '_feed_loss_fns',
- 'loss_weights_list',
- 'targets',
- '_feed_targets',
- 'sample_weight_modes',
- 'weighted_metrics',
- 'metrics_names',
- 'metrics_tensors',
- 'metrics_updates',
- 'stateful_metric_names',
- 'total_loss',
- 'sample_weights',
- '_feed_sample_weights',
- 'train_function',
- 'test_function',
- 'predict_function',
- '_collected_trainable_weights',
- '_feed_inputs',
- '_feed_input_names',
- '_feed_input_shapes',
- 'optimizer',
- ]
- for name in attributes_to_cache:
- attributes_cache[name] = getattr(model, name)
- model._original_attributes_cache = data_structures.NoDependency(
- attributes_cache)
- # Reset built state
- model.built = False
- model.inputs = None
- model.outputs = None
-
-
-def _in_place_subclassed_model_state_restoration(model):
- """Restores the original state of a model after it was "reset".
-
- This undoes this action of `_in_place_subclassed_model_reset`.
- Args:
- model: Instance of a Keras model created via subclassing, on which
- `_in_place_subclassed_model_reset` was previously called.
- """
- assert not model._is_graph_network
- # Restore layers and build attributes
- if (hasattr(model, '_original_attributes_cache') and
- model._original_attributes_cache is not None):
- # Models have sticky attribute assignment, so we want to be careful to add
- # back the previous attributes and track Layers by their original names
- # without adding dependencies on "utility" attributes which Models exempt
- # when they're constructed.
- model._layers = data_structures.NoDependency([])
- for name, value in model._original_attributes_cache.items():
- if not isinstance(value, checkpointable.CheckpointableBase):
- # If this value is not already checkpointable, it's probably that way
- # for a reason; we don't want to start tracking data structures that the
- # original Model didn't.
- value = data_structures.NoDependency(value)
- setattr(model, name, value)
- model._original_attributes_cache = None
- else:
- # Restore to the state of a never-called model.
- model.built = False
- model.inputs = None
- model.outputs = None
+ def _to_ordered_tensor_list(obj, key_order, obj_name, order_name):
+ """Convert obj to an ordered list of tensors.
+
+ Args:
+ obj: List, dict, or single tensor. May be `None`.
+ key_order: List of strings with the order to return (used if obj is a
+ dict).
+ obj_name: String name of object (e.g. "features" or "labels")
+ order_name: String name of the key order (e.g. "inputs" or "outputs")
+
+ Returns:
+ List of tensors, or `None`
+
+ Raises:
+ KeyError: If obj has invalid keys.
+ """
+ if obj is None:
+ return None
+ elif isinstance(obj, (list, tuple)):
+ return [_convert_tensor(x) for x in obj]
+ elif isinstance(obj, dict):
+ # Ensure that the obj keys and keys in key_order are exactly the same.
+ different_keys = set(obj.keys()) ^ set(key_order)
+
+ if different_keys:
+ raise KeyError(
+ 'The dictionary passed into {obj_name} does not have the expected '
+ '{order_name} keys defined in the keras model.'
+ '\n\tExpected keys: {order_keys}'
+ '\n\t{obj_name} keys: {obj_keys}'
+ '\n\tDifference: {different_keys}'.format(
+ order_name=order_name, order_keys=set(key_order),
+ obj_name=obj_name, obj_keys=set(obj.keys()),
+ different_keys=different_keys))
+
+ return [_convert_tensor(obj[key]) for key in key_order]
+ else: # Assume obj is a tensor.
+ return [_convert_tensor(obj)]
+
+ input_names = None
+ output_names = None
+ if isinstance(features, dict):
+ input_names = (
+ keras_model.input_names if keras_model._is_graph_network else
+ ['input_%d' % i for i in range(1, len(features) + 1)])
+ if isinstance(labels, dict):
+ output_names = (
+ keras_model.output_names if keras_model._is_graph_network else
+ ['output_%d' % i for i in range(1, len(labels) + 1)])
+
+ input_tensors = _to_ordered_tensor_list(
+ features, input_names, 'features', 'inputs')
+ target_tensors = _to_ordered_tensor_list(
+ labels, output_names, 'labels', 'outputs')
+
+ return input_tensors, target_tensors
def _clone_and_build_model(mode,
@@ -289,61 +178,14 @@ def _clone_and_build_model(mode,
Returns:
The newly built model.
"""
- # Set to True during training, False for inference.
+ # Set to True during training, False for inference or testing.
K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
-
- # Get list of inputs.
- if features is None:
- input_tensors = None
- else:
- input_tensors = _create_ordered_io(keras_model,
- estimator_io=features,
- is_input=True)
- # Get list of outputs.
- if labels is None:
- target_tensors = None
- elif isinstance(labels, dict):
- target_tensors = _create_ordered_io(keras_model,
- estimator_io=labels,
- is_input=False)
- else:
- target_tensors = [
- _convert_tensor(labels)
- ]
-
- if keras_model._is_graph_network:
- if custom_objects:
- with CustomObjectScope(custom_objects):
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = models.clone_model(keras_model, input_tensors=input_tensors)
- else:
- model = keras_model
- _in_place_subclassed_model_reset(model)
- if input_tensors is not None:
- model._set_inputs(input_tensors)
-
- # Compile/Build model
- if mode is model_fn_lib.ModeKeys.PREDICT:
- if isinstance(model, models.Sequential):
- model.build()
- else:
- if isinstance(keras_model.optimizer, optimizers.TFOptimizer):
- optimizer = keras_model.optimizer
- else:
- optimizer_config = keras_model.optimizer.get_config()
- optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
- optimizer.iterations = training_util.get_or_create_global_step()
-
- model.compile(
- optimizer,
- keras_model.loss,
- metrics=keras_model.metrics,
- loss_weights=keras_model.loss_weights,
- sample_weight_mode=keras_model.sample_weight_mode,
- weighted_metrics=keras_model.weighted_metrics,
- target_tensors=target_tensors)
- return model
+ input_tensors, target_tensors = _convert_estimator_io_to_keras(
+ keras_model, features, labels)
+ return models.clone_and_build_model(
+ keras_model, input_tensors, target_tensors, custom_objects,
+ compile_clone=(mode != model_fn_lib.ModeKeys.PREDICT),
+ in_place_reset=(not keras_model._is_graph_network))
def _create_keras_model_fn(keras_model, custom_objects=None):
@@ -361,7 +203,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
"""model_fn for keras Estimator."""
# Raise an error when users use DistributionStrategy with native Keras
# optimizers. Currently we only support native TensorFlow optimizers.
- if distribute_lib.has_distribution_strategy() and \
+ if distribution_strategy_context.has_distribution_strategy() and \
not isinstance(keras_model.optimizer,
(tf_optimizer_module.Optimizer, optimizers.TFOptimizer)):
raise ValueError('Only TensorFlow native optimizers are supported with '
@@ -373,7 +215,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
# We need to make sure that the output names of the last layer in the model
# is the same for each of the cloned models. This is required for mirrored
# strategy when we call regroup.
- if distribute_lib.has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
for name in model.output_names:
name = re.compile(r'_\d$').sub('', name)
model_output_names.append(name)
@@ -396,7 +238,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
loss = model.total_loss
if model.metrics:
- # TODO(fchollet): support stateful metrics
+ # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
# When each metric maps to an output
if isinstance(model.metrics, dict):
@@ -423,7 +265,7 @@ def _create_keras_model_fn(keras_model, custom_objects=None):
if not model._is_graph_network:
# Reset model state to original state,
# to avoid `model_fn` being destructive for the initial model argument.
- _in_place_subclassed_model_state_restoration(keras_model)
+ models.in_place_subclassed_model_state_restoration(keras_model)
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=predictions,
@@ -487,8 +329,9 @@ def model_to_estimator(keras_model=None,
config=None):
"""Constructs an `Estimator` instance from given keras model.
- For usage example, please see
- @{$guide/estimators$creating_estimators_from_keras_models}.
+ For usage example, please see:
+ [Creating estimators from Keras
+ Models](https://tensorflow.org/guide/estimators#model_to_estimator).
Args:
keras_model: A compiled Keras model object. This argument is mutually
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 332e385726..290c4604ce 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -184,12 +184,14 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
gfile.MakeDirs(self._base_dir)
self._config = run_config_lib.RunConfig(
tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)
+ super(TestKerasEstimator, self).setUp()
def tearDown(self):
# Make sure nothing is stuck in limbo.
writer_cache.FileWriterCache.clear()
if os.path.isdir(self._base_dir):
gfile.DeleteRecursively(self._base_dir)
+ super(TestKerasEstimator, self).tearDown()
def test_train(self):
for model_type in ['sequential', 'functional']:
@@ -511,19 +513,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
input_dict = {'input_1': x_train}
output_dict = {'invalid_output_name': y_train}
return input_dict, output_dict
-
model = simple_functional_model()
model.compile(
loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
with self.test_session():
est_keras = keras_lib.model_to_estimator(
keras_model=model, config=self._config)
-
with self.test_session():
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_input_name'):
est_keras.train(input_fn=invald_input_name_input_fn, steps=100)
- with self.assertRaises(ValueError):
+ with self.assertRaisesRegexp(KeyError,
+ 'Difference: .*invalid_output_name'):
est_keras.train(input_fn=invald_output_name_input_fn, steps=100)
def test_custom_objects(self):
diff --git a/tensorflow/python/estimator/model_fn.py b/tensorflow/python/estimator/model_fn.py
index 9db9ccd01d..fd2787aeaf 100644
--- a/tensorflow/python/estimator/model_fn.py
+++ b/tensorflow/python/estimator/model_fn.py
@@ -26,6 +26,7 @@ import six
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras.metrics import Metric
from tensorflow.python.ops import array_ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
@@ -141,13 +142,15 @@ class EstimatorSpec(
prediction.
predictions: Predictions `Tensor` or dict of `Tensor`.
loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.
- train_op: Op for the training step.
- eval_metric_ops: Dict of metric results keyed by name. The values of the
- dict are the results of calling a metric function, namely a
- `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated
- without any impact on state (typically is a pure computation results
- based on variables.). For example, it should not trigger the `update_op`
- or requires any input fetching.
+ train_op: Op to run one training step.
+ eval_metric_ops: Dict of metric results keyed by name.
+ The values of the dict can be one of the following:
+ (1) instance of `Metric` class.
+ (2) Results of calling a metric function, namely a
+ `(metric_tensor, update_op)` tuple. `metric_tensor` should be
+ evaluated without any impact on state (typically is a pure computation
+ results based on variables.). For example, it should not trigger the
+ `update_op` or requires any input fetching.
export_outputs: Describes the output signatures to be exported to
`SavedModel` and used during serving.
A dict `{name: output}` where:
@@ -218,21 +221,27 @@ class EstimatorSpec(
if not isinstance(eval_metric_ops, dict):
raise TypeError(
'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))
- for key, metric_value_and_update in six.iteritems(eval_metric_ops):
- if (not isinstance(metric_value_and_update, tuple) or
- len(metric_value_and_update) != 2):
- raise TypeError(
- 'Values of eval_metric_ops must be (metric_value, update_op) '
- 'tuples, given: {} for key: {}'.format(
- metric_value_and_update, key))
- metric_value, metric_update = metric_value_and_update
- for metric_value_member in nest.flatten(metric_value):
- # Allow (possibly nested) tuples for metric values, but require that
- # each of them be Tensors or Operations.
- _check_is_tensor_or_operation(metric_value_member,
+ for key, value in six.iteritems(eval_metric_ops):
+ # TODO(psv): When we deprecate the old metrics, throw an error here if
+ # the value is not an instance of `Metric` class.
+ if isinstance(value, Metric):
+ if not value.updates: # Check if metrics updates are available.
+ raise ValueError(
+ 'Please call update_state(...) on the "{metric_name}" metric'
+ .format(metric_name=value.name))
+ else:
+ if not isinstance(value, tuple) or len(value) != 2:
+ raise TypeError(
+ 'Values of eval_metric_ops must be (metric_value, update_op) '
+ 'tuples, given: {} for key: {}'.format(value, key))
+ metric_value, metric_update = value
+ for metric_value_member in nest.flatten(metric_value):
+ # Allow (possibly nested) tuples for metric values, but require that
+ # each of them be Tensors or Operations.
+ _check_is_tensor_or_operation(metric_value_member,
+ 'eval_metric_ops[{}]'.format(key))
+ _check_is_tensor_or_operation(metric_update,
'eval_metric_ops[{}]'.format(key))
- _check_is_tensor_or_operation(metric_update,
- 'eval_metric_ops[{}]'.format(key))
# Validate the passed export outputs, or generate defaults.
if mode == ModeKeys.PREDICT:
@@ -267,8 +276,12 @@ class EstimatorSpec(
if train_op is not None and train_op.graph is not default_graph:
raise ValueError(error_message_template.format('train_op', train_op.name))
for key, value in list(six.iteritems(eval_metric_ops)):
- values = nest.flatten(value)
- for val in values:
+ if isinstance(value, Metric):
+ values_to_check = value.updates[:]
+ values_to_check.append(value.result())
+ else:
+ values_to_check = nest.flatten(value)
+ for val in values_to_check:
if val.graph is not default_graph:
raise ValueError(error_message_template.format(
'eval_metric_ops',
@@ -287,6 +300,19 @@ class EstimatorSpec(
'All hooks must be SessionRunHook instances, given: {}'.format(
hook))
+ # Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
+ # are by default not added to any collections. We are doing this here, so
+ # that metric variables get initialized.
+ local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
+ vars_to_add = set()
+ for key, value in six.iteritems(eval_metric_ops):
+ if isinstance(value, Metric):
+ vars_to_add.update(value.variables)
+ # Remove variables that are in the local variables collection already.
+ vars_to_add = vars_to_add.difference(local_vars)
+ for v in vars_to_add:
+ ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
+
scaffold = scaffold or monitored_session.Scaffold()
# Validate scaffold.
if not isinstance(scaffold, monitored_session.Scaffold):
diff --git a/tensorflow/python/estimator/model_fn_test.py b/tensorflow/python/estimator/model_fn_test.py
index 08e41fd414..8a3a9f3f51 100644
--- a/tensorflow/python/estimator/model_fn_test.py
+++ b/tensorflow/python/estimator/model_fn_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.keras import metrics
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -48,7 +49,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -56,16 +57,21 @@ class EstimatorSpecTrainTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -77,7 +83,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -86,20 +92,20 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant([1.]),
train_op=control_flow_ops.no_op())
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, train_op=control_flow_ops.no_op())
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
@@ -107,7 +113,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -121,7 +127,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -130,13 +136,13 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=control_flow_ops.no_op())
def testTrainOpMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing train_op'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN, loss=constant_op.constant(1.))
def testTrainOpNotOperationAndTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError,
'train_op must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -147,7 +153,7 @@ class EstimatorSpecTrainTest(test.TestCase):
def testTrainOpFromDifferentGraph(self):
with ops.Graph().as_default():
train_op = control_flow_ops.no_op()
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -156,7 +162,7 @@ class EstimatorSpecTrainTest(test.TestCase):
train_op=train_op)
def testTrainingChiefHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -166,7 +172,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_chief_hooks=[_InvalidHook()])
def testTrainingHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -176,7 +182,7 @@ class EstimatorSpecTrainTest(test.TestCase):
training_hooks=[_InvalidHook()])
def testScaffoldInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'scaffold must be tf\.train\.Scaffold'):
model_fn.EstimatorSpec(
@@ -186,7 +192,7 @@ class EstimatorSpecTrainTest(test.TestCase):
scaffold=_InvalidScaffold())
def testReturnDefaultScaffold(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
estimator_spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.TRAIN,
loss=constant_op.constant(1.),
@@ -199,7 +205,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -208,16 +214,21 @@ class EstimatorSpecEvalTest(test.TestCase):
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -227,7 +238,7 @@ class EstimatorSpecEvalTest(test.TestCase):
evaluation_hooks=[_FakeHook()])
def testEvaluationHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -237,7 +248,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testTupleMetric(self):
"""Tests that no errors are raised when a metric is tuple-valued."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -248,7 +259,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLoss1DTensor(self):
"""Tests that no errors are raised when loss is 1D tensor."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1.])
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -257,7 +268,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossNumber(self):
"""Tests that error is raised when loss is a number (not Tensor)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -265,14 +276,14 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=1.)
def testLossMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing loss'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
predictions={'loss': constant_op.constant(1.)})
def testLossNotScalar(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant([1., 2.])
with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):
model_fn.EstimatorSpec(
@@ -281,7 +292,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testLossSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = sparse_tensor.SparseTensor(
indices=[[0]],
values=[0.],
@@ -296,7 +307,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testLossFromDifferentGraph(self):
with ops.Graph().as_default():
loss = constant_op.constant(1.)
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -305,7 +316,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testReplaceRaisesConstructorChecks(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -313,7 +324,7 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(loss=constant_op.constant([1., 2.]))
def testReplaceDoesReplace(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -321,7 +332,7 @@ class EstimatorSpecEvalTest(test.TestCase):
self.assertEqual(['m'], list(new_spec.predictions.keys()))
def testReplaceNotAllowModeChange(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
spec = model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)
@@ -331,13 +342,13 @@ class EstimatorSpecEvalTest(test.TestCase):
spec._replace(mode=model_fn.ModeKeys.TRAIN)
def testPredictionsMissingIsOkay(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL, loss=constant_op.constant(1.))
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.EVAL,
@@ -345,7 +356,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss)
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
@@ -354,7 +365,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -370,7 +381,7 @@ class EstimatorSpecEvalTest(test.TestCase):
def testPredictionsFromDifferentGraph(self):
with ops.Graph().as_default():
predictions = {'loss': constant_op.constant(1.)}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
model_fn.EstimatorSpec(
@@ -379,7 +390,7 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=constant_op.constant(1.))
def testEvalMetricOpsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError, 'eval_metric_ops must be a dict'):
@@ -390,7 +401,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops=loss)
def testEvalMetricOpsNoTuple(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
TypeError,
@@ -403,7 +414,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': loss})
def testEvalMetricOpsNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -413,7 +424,7 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ('NonTensor', loss)})
def testEvalMetricNestedNoTensorOrOperation(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(TypeError, 'must be Operation or Tensor'):
model_fn.EstimatorSpec(
@@ -423,11 +434,26 @@ class EstimatorSpecEvalTest(test.TestCase):
eval_metric_ops={'loss': ((('NonTensor',),),
control_flow_ops.no_op())})
- def testEvalMetricOpsFromDifferentGraph(self):
+ def testEvalMetricOpsFromDifferentGraphWithMetricTuple(self):
with ops.Graph().as_default():
eval_metric_ops = {
'loss': (control_flow_ops.no_op(), constant_op.constant(1.))}
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(
+ ValueError, 'must be from the default graph'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
+ def testEvalMetricOpsFromDifferentGraphWithMetricObject(self):
+ with ops.Graph().as_default():
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(constant_op.constant(1.))
+ eval_metric_ops = {'metric': metric_obj}
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
with self.assertRaisesRegexp(
ValueError, 'must be from the default graph'):
@@ -437,29 +463,46 @@ class EstimatorSpecEvalTest(test.TestCase):
loss=loss,
eval_metric_ops=eval_metric_ops)
+ def testEvalMetricOpsWithoutUpdates(self):
+ with ops.Graph().as_default():
+ eval_metric_ops = {'mean': metrics.Mean()}
+ with ops.Graph().as_default(), self.cached_session():
+ loss = constant_op.constant(1.)
+ with self.assertRaisesRegexp(ValueError, 'Please call update_state(...)'):
+ model_fn.EstimatorSpec(
+ mode=model_fn.ModeKeys.EVAL,
+ predictions={'loss': loss},
+ loss=loss,
+ eval_metric_ops=eval_metric_ops)
+
class EstimatorSpecInferTest(test.TestCase):
"""Tests EstimatorSpec in infer mode."""
def testRequiredArgumentsSet(self):
"""Tests that no errors are raised when all required arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions={'loss': constant_op.constant(1.)})
def testAllArgumentsSet(self):
"""Tests that no errors are raised when all arguments are set."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
loss = constant_op.constant(1.)
predictions = {'loss': loss}
classes = constant_op.constant('hello')
+ metric_obj = metrics.Mean()
+ metric_obj.update_state(loss)
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
loss=loss,
train_op=control_flow_ops.no_op(),
- eval_metric_ops={'loss': (control_flow_ops.no_op(), loss)},
+ eval_metric_ops={
+ 'loss': (control_flow_ops.no_op(), loss),
+ 'mean': metric_obj,
+ },
export_outputs={
'head_name': export_output.ClassificationOutput(classes=classes)
},
@@ -470,7 +513,7 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_FakeHook()])
def testPredictionHookInvalid(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, 'All hooks must be SessionRunHook instances'):
model_fn.EstimatorSpec(
@@ -479,25 +522,25 @@ class EstimatorSpecInferTest(test.TestCase):
prediction_hooks=[_InvalidHook()])
def testPredictionsMissing(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(ValueError, 'Missing predictions'):
model_fn.EstimatorSpec(mode=model_fn.ModeKeys.PREDICT)
def testPredictionsTensor(self):
"""Tests that no error is raised when predictions is Tensor (not dict)."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions=constant_op.constant(1.))
def testPredictionsNumber(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
with self.assertRaisesRegexp(
TypeError, r'predictions\[number\] must be Tensor'):
model_fn.EstimatorSpec(
mode=model_fn.ModeKeys.PREDICT, predictions={'number': 1.})
def testPredictionsSparseTensor(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {
'sparse': sparse_tensor.SparseTensor(
indices=[[0]],
@@ -509,7 +552,7 @@ class EstimatorSpecInferTest(test.TestCase):
mode=model_fn.ModeKeys.PREDICT, predictions=predictions)
def testExportOutputsNoDict(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
classes = constant_op.constant('hello')
with self.assertRaisesRegexp(
@@ -520,7 +563,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs=export_output.ClassificationOutput(classes=classes))
def testExportOutputsValueNotExportOutput(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
with self.assertRaisesRegexp(
TypeError,
@@ -533,7 +576,7 @@ class EstimatorSpecInferTest(test.TestCase):
export_outputs={'head_name': predictions})
def testExportOutputsSingleheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
regression_output = export_output.RegressionOutput(value=output_1)
@@ -552,7 +595,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(expected_export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadWithDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -571,7 +614,7 @@ class EstimatorSpecInferTest(test.TestCase):
self.assertEqual(export_outputs, estimator_spec.export_outputs)
def testExportOutputsMultiheadMissingDefault(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.)}
output_1 = constant_op.constant([1.])
output_2 = constant_op.constant(['2'])
@@ -594,13 +637,13 @@ class EstimatorSpecInferTest(test.TestCase):
def testDefaultExportOutputCreated(self):
"""Ensure that a default PredictOutput is created for export."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = constant_op.constant(1.)
self._assertDefaultExportOutputForPredictions(predictions)
def testDefaultExportOutputCreatedDict(self):
"""Ensure that a default PredictOutput is created for export for dicts."""
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
predictions = {'loss': constant_op.constant(1.),
'score': constant_op.constant(10.)}
self._assertDefaultExportOutputForPredictions(predictions)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 220c3e58ca..b1ca207b62 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -26,6 +26,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
@@ -51,6 +52,7 @@ _DEFAULT_REPLACEABLE_LIST = [
'device_fn',
'protocol',
'eval_distribute',
+ 'experimental_distribute',
]
_SAVE_CKPT_ERR = (
@@ -331,7 +333,8 @@ class RunConfig(object):
train_distribute=None,
device_fn=None,
protocol=None,
- eval_distribute=None):
+ eval_distribute=None,
+ experimental_distribute=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -458,7 +461,8 @@ class RunConfig(object):
train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.train_distribute` is preferred.
device_fn: A callable invoked for every `Operation` that takes the
`Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter`
@@ -468,7 +472,13 @@ class RunConfig(object):
eval_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.eval_distribute` is preferred.
+ experimental_distribute: an optional
+ `tf.contrib.distribute.DistributeConfig` object specifying
+ DistributionStrategy-related configuration. The `train_distribute` and
+ `eval_distribute` can be passed as parameters to `RunConfig` or set in
+ `experimental_distribute` but not both.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -508,11 +518,15 @@ class RunConfig(object):
train_distribute=train_distribute,
device_fn=device_fn,
protocol=protocol,
- eval_distribute=eval_distribute)
+ eval_distribute=eval_distribute,
+ experimental_distribute=experimental_distribute)
- self._init_distributed_setting_from_environment_var(tf_config)
-
- self._maybe_overwrite_session_config_for_distributed_training()
+ if train_distribute or eval_distribute or experimental_distribute:
+ logging.info('Initializing RunConfig with distribution strategies.')
+ distribute_coordinator_training.init_run_config(self, tf_config)
+ else:
+ self._init_distributed_setting_from_environment_var(tf_config)
+ self._maybe_overwrite_session_config_for_distributed_training()
def _maybe_overwrite_session_config_for_distributed_training(self):
"""Overwrites the session_config for distributed training.
@@ -810,6 +824,7 @@ class RunConfig(object):
- `device_fn`,
- `protocol`.
- `eval_distribute`,
+ - `experimental_distribute`,
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index a01b2300dd..240be5dabe 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -26,6 +26,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -129,8 +130,8 @@ class TrainSpec(
Args:
input_fn: A function that provides input data for training as minibatches.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -193,8 +194,8 @@ class EvalSpec(
Args:
input_fn: A function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Premade Estimators](https://tensorflow.org/api_guides/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
@@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
evaluation `input_fn`, steps, etc.
This utility function provides consistent behavior for both local
- (non-distributed) and distributed configurations. Currently, the only
- supported distributed training configuration is between-graph replication.
+ (non-distributed) and distributed configurations. The default distribution
+ configuration is parameter server-based between-graph replication. For other
+ types of distribution configurations such as all-reduce training, please use
+ [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
Overfitting: In order to avoid overfitting, it is recommended to set up the
training `input_fn` to shuffle the training data properly.
@@ -323,6 +326,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
```
+ Note that in current implementation `estimator.evaluate` will be called
+ multiple times. This means that evaluation graph (including eval_input_fn)
+ will be re-created for each `evaluate` call. `estimator.train` will be called
+ only once.
Example of distributed training:
@@ -422,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
}'
```
+ When `distribute` or `experimental_distribute.train_distribute` and
+ `experimental_distribute.remote_cluster` is set, this method will start a
+ client running on the current host which connects to the `remote_cluster` for
+ training and evaluation.
+
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
@@ -440,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
-
config = estimator.config
+
+ # If `distribute_coordinator_mode` is set and running in distributed
+ # environment, we run `train_and_evaluate` via distribute coordinator.
+ if distribute_coordinator_training.should_run_distribute_coordinator(config):
+ logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
+ distribute_coordinator_training.train_and_evaluate(
+ estimator, train_spec, eval_spec, _TrainingExecutor)
+ return
+
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
config.task_id > 0):
raise ValueError(
@@ -833,6 +853,13 @@ class _TrainingExecutor(object):
if difference > 0:
logging.info('Waiting %f secs before starting next eval run.', difference)
time.sleep(difference)
+ elif (throttle_secs == 0 and
+ eval_result.status != _EvalStatus.EVALUATED):
+ # Prints a user-actionable warning to avoid unnecessary load on evaluator.
+ logging.warning(
+ 'EvalSpec.throttle_secs is set as 0. This might overload the job '
+ 'before finding (next) new checkpoint. Please consider to increase '
+ 'it.')
return (eval_result, should_early_stop)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index dc106c7d3b..7d46917a6f 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -83,6 +83,9 @@ _INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'
_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'
_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'
_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'
+_INPROPER_THROTTL_SECS = (
+ 'EvalSpec.throttle_secs is set as 0.*Please consider to increase')
+
# The message should NOT have 'local' word as part of it. As (?!word) is looking
# ahead, so, the $ (ending) check is required; otherwise, it will match
# partially and return successuful.
@@ -1281,7 +1284,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
]
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(logging, 'warning') as mock_log:
@@ -1295,6 +1298,34 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
# successuful evaluation)
self.assertEqual(2, mock_log.call_count)
+ def test_warning_if_throttle_secs_is_zero(self):
+ training_max_step = 200
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: training_max_step}
+ ]
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec.max_steps = training_max_step
+
+ self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
+
+ # We need to make the first one invalid, so it will check the
+ # throttle_secs=0.
+ mock_est.latest_checkpoint.side_effect = [None, 'path']
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ with test.mock.patch.object(logging, 'warning') as mock_log:
+ executor.run_evaluator()
+
+ # First ckpt is invalid.
+ self.assertEqual(2, mock_est.latest_checkpoint.call_count)
+ self.assertEqual(1, mock_est.evaluate.call_count)
+
+ self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS)
+
def test_continuous_eval_listener_eval_result(self):
training_max_step = 200
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index d4a75478d5..31e4778e72 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -109,13 +109,17 @@ def parse_input_fn_result(result):
else:
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
+ return parse_iterator_result(result) + (input_hooks,)
+
+def parse_iterator_result(result):
+ """Gets features, labels from result."""
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
'input_fn should return (features, labels) as a len 2 tuple.')
- return result[0], result[1], input_hooks
- return result, None, input_hooks
+ return result[0], result[1]
+ return result, None
class _DatasetInitializerHook(training.SessionRunHook):
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py
index d7e0610779..d440c454dc 100644
--- a/tensorflow/python/estimator/util_test.py
+++ b/tensorflow/python/estimator/util_test.py
@@ -39,7 +39,7 @@ class UtilTest(test.TestCase):
features, labels, hooks = util.parse_input_fn_result(_input_fn())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vals = sess.run([features, labels])
self.assertAllEqual(vals[0], np.arange(100))
@@ -67,7 +67,7 @@ class UtilTest(test.TestCase):
features, labels, hooks = util.parse_input_fn_result(_input_fn())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
vals = sess.run([features])
self.assertAllEqual(vals[0], np.arange(100))