aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-31 17:13:04 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-31 17:13:04 +0800
commitb3114e5b1e930c4dd1a1fdfaac721a219677d611 (patch)
tree7bfb4c4e8a9a158e86752a73117559df3d0386c1 /tensorflow/contrib/estimator
parentf8ee9799e6a72d4fe24f9fad76d6e6b1b3a01af1 (diff)
parent9357b2558adc13c479c8edb66c5002c5c6ec3664 (diff)
Merge remote-tracking branch 'upstream/master' into ENH/feature_importances_for_boosted_tree
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD31
-rw-r--r--tensorflow/contrib/estimator/__init__.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/export.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/exporter.py280
-rw-r--r--tensorflow/contrib/estimator/python/estimator/exporter_test.py206
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py39
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders_test.py129
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py58
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks.py75
-rw-r--r--tensorflow/contrib/estimator/python/estimator/hooks_test.py83
-rw-r--r--tensorflow/contrib/estimator/python/estimator/linear.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py22
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py92
15 files changed, 894 insertions, 136 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 349f48f7f7..77f62df99d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -20,6 +20,7 @@ py_library(
":dnn_linear_combined",
":early_stopping",
":export",
+ ":exporter",
":extenders",
":head",
":hooks",
@@ -220,6 +221,33 @@ py_test(
)
py_library(
+ name = "exporter",
+ srcs = [
+ "python/estimator/exporter.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ "//tensorflow/python/estimator:exporter",
+ ],
+)
+
+py_test(
+ name = "exporter_test",
+ size = "medium",
+ srcs = ["python/estimator/exporter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":exporter",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:exporter",
+ ],
+)
+
+py_library(
name = "head",
srcs = [
"python/estimator/head.py",
@@ -487,6 +515,9 @@ py_test(
size = "medium",
srcs = ["python/estimator/saved_model_estimator_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "notsan",
+ ],
deps = [
":export",
":saved_model_estimator",
diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py
index e1453ae1d0..258860f263 100644
--- a/tensorflow/contrib/estimator/__init__.py
+++ b/tensorflow/contrib/estimator/__init__.py
@@ -45,6 +45,7 @@ _allowed_symbols = [
'clip_gradients_by_norm',
'forward_features',
'InMemoryEvaluatorHook',
+ 'make_stop_at_checkpoint_step_hook',
'logistic_regression_head',
'multi_class_head',
'multi_head',
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index 505c94e971..513feb03b6 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -37,13 +37,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
@@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
index 2eef60c39f..724bc2c82f 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
@@ -147,7 +147,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
if a categorical column is multivalent. One of "mean", "sqrtn", and
"sum" -- these are effectively different ways to do example-level
normalization, which can be useful for bag-of-words features. For more
- details, see @{tf.feature_column.linear_model$linear_model}.
+ details, see `tf.feature_column.linear_model`.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
diff --git a/tensorflow/contrib/estimator/python/estimator/export.py b/tensorflow/contrib/estimator/python/estimator/export.py
index 03cf6f107c..b0deb9b494 100644
--- a/tensorflow/contrib/estimator/python/estimator/export.py
+++ b/tensorflow/contrib/estimator/python/estimator/export.py
@@ -31,8 +31,8 @@ def export_saved_model_for_mode(
# pylint: disable=line-too-long
"""Exports a single train/eval/predict graph as a SavedModel.
- For a detailed guide, see
- @{$saved_model#using_savedmodel_with_estimators$Using SavedModel with Estimators}.
+ For a detailed guide, see [Using SavedModel with Estimators](
+ https://tensorflow.org/guide/saved_model#using_savedmodel_with_estimators).
Sample usage:
```python
diff --git a/tensorflow/contrib/estimator/python/estimator/exporter.py b/tensorflow/contrib/estimator/python/estimator/exporter.py
new file mode 100644
index 0000000000..09d7440605
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/exporter.py
@@ -0,0 +1,280 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements StepsExporter to export the model in user specified steps."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.estimator import exporter
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.summary import summary_iterator
+
+DEFAULT_GLOBAL_STEP_KEY = ops.GraphKeys.GLOBAL_STEP
+
+
+class StepsExporter(exporter.Exporter):
+ """This class exports the model in user specified steps.
+
+ This class exports the model at the steps given by the `steps_to_keep`
+ argument. Each number in the list is treated as a lower bound for model
+ exports, to handle the case when evaluation is performed at different steps.
+
+ Consider this example:
+
+ ```
+ steps_to_keep = [1, 2, 3, 6, 7, 10, 12, 25]
+ ```
+
+ The model is evaluated at step increments of 5: `[5, 10, 15, 20, 25, 30]`.
+ The `StepsExporter` will export the model when it has reached steps
+ `[5, 10, 15, 25]`.
+
+ This example illustrates the two cases when the model is exported:
+
+ 1. Model is evaluated on a step defined in the list `steps_to_keep`.
+
+ In the example, the model is exported on step `10` and `25`.
+
+ 2. Model is evaluated on a step not defined in the list `steps_to_keep`, but
+ is still exported because a step in `steps_to_keep` was missed.
+
+ In the example, when the model reaches step `5`, the model is exported even
+ though `steps_to_keep` does not contain `5`. Step `5` is exported to make
+ up for step `3`, which was missed. Steps `1` and `2` in `steps_to_keep` are
+ skipped completely (e.g. say the model is evaluated at step `6`. It will
+ **not** be exported to make up for step `2`).
+
+ Using the `steps_to_keep` list as a lower bound allows users to define
+ approximate step boundaries for exporting their models, and avoid frustrating
+ off-by-one calculation errors.
+
+ Sample Use Cases:
+ There are specific points during the training when having a saved version of
+ the model would be useful. One example is at the end of each training phase
+ when the set of freezed weights is changed.
+ Another good use case is saving the model at the end of each epoch for
+ visualization or retraining.
+ """
+
+ def __init__(self,
+ steps_to_keep,
+ name='steps_exporter',
+ serving_input_receiver_fn=None,
+ event_file_pattern='eval/*.tfevents.*',
+ assets_extra=None,
+ as_text=False):
+ """Create an `StepsExporter` to use with `tf.estimator.EvalSpec`.
+
+ Example of creating a StepsExporter for training and evaluation:
+
+ ```python
+ categorical_feature_a = categorical_column_with_hash_bucket(...)
+ categorical_feature_b = categorical_column_with_hash_bucket(...)
+
+ categorical_feature_a_emb = embedding_column(
+ categorical_column=categorical_feature_a, ...)
+ categorical_feature_b_emb = embedding_column(
+ categorical_column=categorical_feature_b, ...)
+
+ estimator = tf.estimator.DNNClassifier(
+ feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
+ hidden_units=[1024, 512, 256])
+
+ # Input pipeline for train and evaluate.
+ def train_input_fn: # returns x, y
+ # please shuffle the data.
+ pass
+ def eval_input_fn_eval: # returns x, y
+ pass
+
+ exporter = tf.contrib.estimator.exporter.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=serving_input_receiver_fn,
+ event_file_pattern='eval/*.tfevents.*'
+ steps_to_keep=[...])
+
+ train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)
+
+ eval_spec = [tf.estimator.EvalSpec(
+ input_fn=eval_input_fn,
+ steps=1,
+ exporters=exporter,
+ start_delay_secs=0,
+ throttle_secs=5)]
+
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+
+ # Models will be exported to estimator.model_dir in timestamped directories,
+ # which can be used for serving, analysis with TFMA, or directly loaded in.
+ # For example:
+ export_dir = os.path.join(estimator.model_dir,
+ <timestamped directory name>)
+
+ with ops.Graph().as_default() as graph:
+ with session.Session(graph=graph) as sess:
+ tf.saved_model.loader.load(
+ sess, [tf.saved_model.tag_constants.SERVING], export_dir)
+
+ ```
+
+ Args:
+ steps_to_keep: Non-empty list of positive integers containing
+ the step numbers at which the model should be exported. All the exports
+ will be kept, so there is no garbage collection.
+ name: Unique name of this `Exporter` that is going to be used in the
+ export path.
+ serving_input_receiver_fn: A function that takes no arguments and returns
+ a `ServingInputReceiver`.
+ event_file_pattern: Event file name pattern relative to model_dir. If
+ None, however, the exporter would not be preemption-safe. To be
+ preemption-safe, event_file_pattern should be specified.
+ assets_extra: An optional dict specifying how to populate the assets.extra
+ directory within the exported SavedModel. Each key should give the
+ destination path (including the filename) relative to the assets.extra
+ directory. The corresponding value gives the full path of the source
+ file to be copied. For example, the simple case of copying a single
+ file without renaming it is specified as `{'my_asset_file.txt':
+ '/path/to/my_asset_file.txt'}`.
+ as_text: Whether to write the SavedModel proto in text format. Defaults to
+ `False`.
+
+ Raises:
+ ValueError: If any arguments is invalid.
+ """
+ # pylint: disable=protected-access
+ self._saved_model_exporter = exporter._SavedModelExporter(
+ name, serving_input_receiver_fn, assets_extra, as_text)
+ # pylint: enable=protected-access
+
+ self._event_file_pattern = event_file_pattern
+ self._model_dir = None
+
+ self._input_steps_to_keep = steps_to_keep
+ steps_to_keep = [step for step in steps_to_keep if isinstance(step, int)]
+ steps_to_keep = [step for step in steps_to_keep if step > 0]
+ if not steps_to_keep:
+ raise ValueError(
+ '`steps_to_keep` list must have at least one positive integer')
+ elif self._input_steps_to_keep != steps_to_keep:
+ tf_logging.warn('Changed `steps_to_keep`, by omitting non-integer or'
+ ' less than 1 elements, to [%s]',
+ ', '.join(str(step) for step in steps_to_keep))
+ self._steps_to_keep = sorted(steps_to_keep)
+ self._steps_kept = []
+
+ @property
+ def name(self):
+ return self._saved_model_exporter.name
+
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ """Exports the given Estimator to a specific format.
+
+ Args:
+ estimator: A `tf.estimator.Estimator` instance to export.
+ export_path: A string containing a directory where to write the export.
+ checkpoint_path: The checkpoint path to export.
+ eval_result: The output of Estimator.evaluate on this checkpoint.
+ is_the_final_export: This boolean is True when this is an export in the
+ end of training. It is False for the intermediate exports during the
+ training. When passing Exporter to tf.estimator.train_and_evaluate
+ is_the_final_export is always False if TrainSpec.max_steps is None.
+
+ Returns:
+ The string path to the exported directory or None if export is skipped.
+
+ Raises:
+ ValueError: If `eval_result` is None or doesn't have
+ `ops.GraphKeys.GLOBAL_STEP` as a key.
+ """
+ export_result = None
+
+ if not eval_result or DEFAULT_GLOBAL_STEP_KEY not in eval_result:
+ raise ValueError(
+ '`eval_result` is empty, or does not have global step. This'
+ ' should never happen as Estimator always sets the global step in '
+ '`eval_result`. Please file a bug report. Got eval_result: %s'
+ % str(eval_result))
+
+ if self._model_dir != estimator.model_dir and self._event_file_pattern:
+ tf_logging.info('Loads the steps that the model was already evaluated at,'
+ 'from event files')
+ self._model_dir = estimator.model_dir
+ full_event_file_pattern = os.path.join(self._model_dir,
+ self._event_file_pattern)
+ self._steps_kept = self._get_kept_steps(full_event_file_pattern)
+
+ if self._steps_kept:
+ self._steps_kept = sorted(self._steps_kept)
+ self._steps_to_keep = [step for step in self._steps_to_keep if
+ step > self._steps_kept[-1]]
+ # It is assumed that the model is exported at any evaluated step 'n' if
+ # there is any `steps_missed` lower than 'n'. As a result, all the steps in
+ # `_steps_to_keep` lower than the last evaluated step will be removed.
+ steps_missed = [step for step in self._steps_to_keep
+ if step <= eval_result[DEFAULT_GLOBAL_STEP_KEY]]
+
+ if steps_missed:
+ # update the `_steps_to_keep` list by omitting all steps smaller than the
+ # current global step which are missed to be exported
+ export_result = self._saved_model_exporter.export(estimator, export_path,
+ checkpoint_path,
+ eval_result,
+ is_the_final_export)
+ self._steps_to_keep = [step for step in self._steps_to_keep if step
+ not in steps_missed]
+ # contains all the steps in which export has happened.
+ self._steps_kept.append(eval_result[DEFAULT_GLOBAL_STEP_KEY])
+ # Show warning for all the missed steps except the last one
+ if steps_missed[:-1]:
+ tf_logging.warn('Missed steps [%s] for exporting, as no evaluation'
+ ' took place at them.', ', '.join(str(step) for step in
+ steps_missed[:-1]))
+ # Log model export if the last missed step is the same as the current step
+ if steps_missed[-1] == eval_result[DEFAULT_GLOBAL_STEP_KEY]:
+ tf_logging.info('Performing model export at step %d.',
+ eval_result[DEFAULT_GLOBAL_STEP_KEY])
+ # Show warning for exporting model at another step instead of the user
+ # specified one
+ else:
+ tf_logging.warn('Performing model export at step %d instead of %d, as'
+ ' no evaluation took place at step %d.',
+ eval_result[DEFAULT_GLOBAL_STEP_KEY], steps_missed[-1],
+ steps_missed[-1])
+ return export_result
+
+ def _get_kept_steps(self, event_files):
+ """Get the steps that the model was evaluated at, from event files.
+
+ Args:
+ event_files: Absolute pattern of event files.
+
+ Returns:
+ steps_kept: A list of steps in which the model was evaluated.
+ """
+ if not event_files:
+ return None
+
+ steps_kept = []
+ for event_file in gfile.Glob(os.path.join(event_files)):
+ for event in summary_iterator.summary_iterator(event_file):
+ if event.step not in steps_kept:
+ steps_kept.append(event.step)
+ return steps_kept
diff --git a/tensorflow/contrib/estimator/python/estimator/exporter_test.py b/tensorflow/contrib/estimator/python/estimator/exporter_test.py
new file mode 100644
index 0000000000..0d009b945e
--- /dev/null
+++ b/tensorflow/contrib/estimator/python/estimator/exporter_test.py
@@ -0,0 +1,206 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `StepsExporter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+import tempfile
+
+from tensorflow.contrib.estimator.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import estimator as estimator_lib
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+class StepsExporterTest(test.TestCase):
+
+ def test_error_out_if_steps_to_keep_has_no_positive_integers(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ with self.assertRaisesRegexp(ValueError, "positive integer"):
+ exporter = exporter_lib.StepsExporter(
+ name="specified_steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ steps_to_keep=[-1, 0, 1.1])
+ self.assertEqual("specified_steps_exporter", exporter.name)
+
+ def test_steps_exporter(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 1},
+ False)
+
+ self.assertEqual("export_result_path", export_result)
+ estimator.export_savedmodel.assert_called_with(
+ export_dir_base,
+ _serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ checkpoint_path="checkpoint_path",
+ strip_default_attrs=True)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_steps_exporter_with_preemption(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ eval_dir_base = os.path.join(export_dir_base, "eval_continuous")
+ estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1)
+ estimator_lib._write_dict_to_summary(eval_dir_base, {}, 2)
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ event_file_pattern="eval_continuous/*.tfevents.*",
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1, 2, 6, 8])
+
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.model_dir = export_dir_base
+ estimator.export_savedmodel.return_value = "export_result_path"
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 3},
+ False)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 6},
+ False)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 7},
+ False)
+ self.assertEqual(None, export_result)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_specified_step_is_saved(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1, 5, 8, 10, 11])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 1},
+ False)
+
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 2},
+ False)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 5},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 10},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 15},
+ False)
+ self.assertTrue(estimator.export_savedmodel.called)
+ self.assertEqual("export_result_path", export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {"global_step": 20},
+ False)
+ self.assertEqual(None, export_result)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+ def test_steps_exporter_with_no_global_step_key(self):
+
+ def _serving_input_receiver_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp()
+ gfile.MkDir(export_dir_base)
+ gfile.MkDir(export_dir_base + "/export")
+ gfile.MkDir(export_dir_base + "/eval")
+
+ exporter = exporter_lib.StepsExporter(
+ name="steps_exporter",
+ serving_input_receiver_fn=_serving_input_receiver_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ steps_to_keep=[1])
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+ estimator.model_dir = export_dir_base
+
+ with self.assertRaisesRegexp(ValueError, "does not have global step"):
+ exporter.export(estimator, export_dir_base, "checkpoint_path", {}, False)
+
+ shutil.rmtree(export_dir_base, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index bf08be09e7..e3c44bea66 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import function_utils
@@ -34,7 +35,7 @@ _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
def add_metrics(estimator, metric_fn):
- """Creates a new @{tf.estimator.Estimator} which has given metrics.
+ """Creates a new `tf.estimator.Estimator` which has given metrics.
Example:
@@ -61,7 +62,7 @@ def add_metrics(estimator, metric_fn):
```
Args:
- estimator: A @{tf.estimator.Estimator} object.
+ estimator: A `tf.estimator.Estimator` object.
metric_fn: A function which should obey the following signature:
- Args: can only have following four arguments in any order:
* predictions: Predictions `Tensor` or dict of `Tensor` created by given
@@ -79,7 +80,7 @@ def add_metrics(estimator, metric_fn):
function, namely a `(metric_tensor, update_op)` tuple.
Returns:
- A new @{tf.estimator.Estimator} which has a union of original metrics with
+ A new `tf.estimator.Estimator` which has a union of original metrics with
given ones.
"""
_verify_metric_fn_args(metric_fn)
@@ -140,7 +141,7 @@ def clip_gradients_by_norm(optimizer, clip_norm):
name='ClipByNorm' + optimizer.get_name())
-def forward_features(estimator, keys=None):
+def forward_features(estimator, keys=None, sparse_default_values=None):
"""Forward features to predictions dictionary.
In some cases, user wants to see some of the features in estimators prediction
@@ -148,39 +149,36 @@ def forward_features(estimator, keys=None):
runs inference on the users graph and returns the results. Keys are essential
because there is no order guarantee on the outputs so they need to be rejoined
to the inputs via keys or transclusion of the inputs in the outputs.
-
Example:
-
```python
def input_fn():
features, labels = ...
features['unique_example_id'] = ...
features, labels
-
estimator = tf.estimator.LinearClassifier(...)
estimator = tf.contrib.estimator.forward_features(
estimator, 'unique_example_id')
estimator.train(...)
assert 'unique_example_id' in estimator.predict(...)
```
-
Args:
- estimator: A @{tf.estimator.Estimator} object.
- keys: a `string` or a `list` of `string`. If it is `None`, all of the
+ estimator: A `tf.estimator.Estimator` object.
+ keys: A `string` or a `list` of `string`. If it is `None`, all of the
`features` in `dict` is forwarded to the `predictions`. If it is a
`string`, only given key is forwarded. If it is a `list` of strings, all
the given `keys` are forwarded.
+ sparse_default_values: A dict of `str` keys mapping the name of the sparse
+ features to be converted to dense, to the default value to use. Only
+ sparse features indicated in the dictionary are converted to dense and the
+ provided default value is used.
Returns:
- A new @{tf.estimator.Estimator} which forwards features to predictions.
-
+ A new `tf.estimator.Estimator` which forwards features to predictions.
Raises:
ValueError:
* if `keys` is already part of `predictions`. We don't allow
override.
* if 'keys' does not exist in `features`.
- * if feature key refers to a `SparseTensor`, since we don't support
- `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
"""
@@ -231,11 +229,18 @@ def forward_features(estimator, keys=None):
for key in get_keys(features):
feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
features[key])
+ if sparse_default_values and (key in sparse_default_values):
+ if not isinstance(feature, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'Feature ({}) is expected to be a `SparseTensor`.'.format(key))
+ feature = sparse_ops.sparse_tensor_to_dense(
+ feature, default_value=sparse_default_values[key])
if not isinstance(feature, ops.Tensor):
raise ValueError(
- 'Forwarded feature ({}) should be a Tensor. Please use keys '
- 'argument of forward_features to filter unwanted features. Type of '
- 'features[{}] is {}.'.format(key, key, type(feature)))
+ 'Feature ({}) should be a Tensor. Please use `keys` '
+ 'argument of forward_features to filter unwanted features, or'
+ 'add key to argument `sparse_default_values`.'
+ 'Type of features[{}] is {}.'.format(key, key, type(feature)))
predictions[key] = feature
spec = spec._replace(predictions=predictions)
if spec.export_outputs:
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
index 407af2deaf..c8fdaa8791 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""extenders tests."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
@@ -170,19 +172,53 @@ class ClipGradientsByNormTest(test.TestCase):
class ForwardFeaturesTest(test.TestCase):
"""Tests forward_features."""
- def test_forward_single_key(self):
-
- def input_fn():
- return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
+ def _export_estimator(self, estimator, serving_input_fn):
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+ return export_dir, tmpdir
+ def make_dummy_input_fn(self):
+ def _input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': [[3.], [5.]],
+ 'id': [[101], [102]],
+ 'sparse_id': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[1.], [2.]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+ return _input_fn
+
+ def test_forward_keys(self):
+
+ input_fn = self.make_dummy_input_fn()
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
- self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
- estimator = extenders.forward_features(estimator, 'id')
- predictions = next(estimator.predict(input_fn=input_fn))
- self.assertIn('id', predictions)
- self.assertEqual(101, predictions['id'])
+ forwarded_keys = ['id', 'sparse_id']
+
+ for key in forwarded_keys:
+ self.assertNotIn(key, next(estimator.predict(input_fn=input_fn)))
+
+ estimator = extenders.forward_features(
+ estimator, forwarded_keys, sparse_default_values={'sparse_id': 1})
+
+ expected_results = [101, 2, 102, 5]
+ predictions = estimator.predict(input_fn=input_fn)
+ for _ in range(2):
+ prediction = next(predictions)
+ for key in forwarded_keys:
+ self.assertIn(key, prediction)
+ self.assertEqual(expected_results.pop(0), sum(prediction[key]))
def test_forward_in_exported(self):
@@ -205,11 +241,7 @@ class ForwardFeaturesTest(test.TestCase):
estimator = extenders.forward_features(estimator, 'id')
# export saved model
- tmpdir = tempfile.mkdtemp()
- export_dir_base = os.path.join(
- compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
- self.assertTrue(gfile.Exists(export_dir))
+ export_dir, tmpdir = self._export_estimator(estimator, serving_input_fn)
# restore model
predict_fn = from_saved_model(export_dir, signature_def_key='predict')
@@ -222,6 +254,47 @@ class ForwardFeaturesTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_forward_in_exported_sparse(self):
+ features_columns = [fc.indicator_column(
+ fc.categorical_column_with_vocabulary_list('x', range(10)))]
+
+ classifier = linear.LinearClassifier(feature_columns=features_columns)
+
+ def train_input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[0], [1]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+
+ classifier.train(train_input_fn, max_steps=1)
+
+ classifier = extenders.forward_features(
+ classifier, keys=['x'], sparse_default_values={'x': 0})
+
+ def serving_input_fn():
+ features_ph = array_ops.placeholder(dtype=dtypes.int32, name='x',
+ shape=[None])
+ features = {'x': layers.dense_to_sparse(features_ph)}
+ return estimator_lib.export.ServingInputReceiver(features,
+ {'x': features_ph})
+ export_dir, tmpdir = self._export_estimator(classifier, serving_input_fn)
+ prediction_fn = from_saved_model(export_dir, signature_def_key='predict')
+
+ features = (0, 2)
+ prediction = prediction_fn({'x': features})
+
+ self.assertIn('x', prediction)
+ self.assertEqual(features, tuple(prediction['x']))
+ gfile.DeleteRecursively(tmpdir)
+
def test_forward_list(self):
def input_fn():
@@ -266,7 +339,6 @@ class ForwardFeaturesTest(test.TestCase):
extenders.forward_features(estimator, ['x', estimator])
def test_key_should_be_in_features(self):
-
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
@@ -279,27 +351,36 @@ class ForwardFeaturesTest(test.TestCase):
next(estimator.predict(input_fn=input_fn))
def test_forwarded_feature_should_not_be_a_sparse_tensor(self):
-
def input_fn():
return {
'x': [[3.], [5.]],
- 'id':
- sparse_tensor.SparseTensor(
- values=['1', '2'],
- indices=[[0, 0], [1, 0]],
- dense_shape=[2, 1])
- }, [[1.], [2.]]
+ 'id': sparse_tensor.SparseTensor(
+ values=['1', '2'],
+ indices=[[0, 0], [1, 0]],
+ dense_shape=[2, 1])
+ }, [[1.], [2.]]
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
estimator = extenders.forward_features(estimator)
with self.assertRaisesRegexp(ValueError,
- 'Forwarded feature.* should be a Tensor.'):
+ 'Feature .* should be a Tensor.*'):
next(estimator.predict(input_fn=input_fn))
- def test_predictions_should_be_dict(self):
+ def test_forwarded_feature_should_be_a_sparse_tensor(self):
+ input_fn = self.make_dummy_input_fn()
+
+ estimator = linear.LinearRegressor([fc.numeric_column('x')])
+ estimator.train(input_fn=input_fn, steps=1)
+ estimator = extenders.forward_features(
+ estimator, sparse_default_values={'id': 0, 'sparse_id': 0})
+ with self.assertRaisesRegexp(
+ ValueError, 'Feature .* is expected to be a `SparseTensor`.'):
+ next(estimator.predict(input_fn=input_fn))
+
+ def test_predictions_should_be_dict(self):
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 2d367adb47..c6e75f8d46 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -215,7 +215,7 @@ class MultiLabelHead(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -246,7 +246,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.PREDICT,
logits=logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertAllEqual(
expected_export_classes,
@@ -271,7 +271,7 @@ class MultiLabelHead(test.TestCase):
logits=logits)
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -297,7 +297,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(expected_training_loss,
actual_training_loss.eval())
@@ -321,7 +321,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, actual_training_loss.eval(), atol=1e-4)
@@ -338,7 +338,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels_placeholder)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -375,7 +375,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits_input,
labels=labels_input)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval())
@@ -394,7 +394,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.EVAL,
logits=logits,
labels=labels)[0]
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -433,7 +433,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -753,7 +753,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -791,7 +791,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=1e-4)
@@ -825,7 +825,7 @@ class MultiLabelHead(test.TestCase):
mode=model_fn.ModeKeys.TRAIN,
logits=logits,
labels=labels)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=1e-4)
@@ -864,7 +864,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -890,7 +890,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -919,7 +919,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1011,7 +1011,7 @@ class MultiLabelHead(test.TestCase):
optimizer=_Optimizer())
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -1040,7 +1040,7 @@ class MultiLabelHead(test.TestCase):
labels=np.array([[1, 0], [1, 1]], dtype=np.int64),
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
sess.run(spec.train_op)
w_value, t_value = sess.run([w, t])
@@ -1079,7 +1079,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1127,7 +1127,7 @@ class MultiLabelHead(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -1162,7 +1162,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels)
atol = 1.e-3
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
self.assertAllClose(
expected_training_loss, training_loss.eval(), atol=atol)
@@ -1197,7 +1197,7 @@ class MultiLabelHead(test.TestCase):
train_op_fn=_train_op_fn)
atol = 1.e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, monitored_session.Scaffold())
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, atol=atol)
@@ -1224,7 +1224,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1252,7 +1252,7 @@ class MultiLabelHead(test.TestCase):
logits=logits,
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1327,7 +1327,7 @@ class PoissonRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run([spec.loss, spec.train_op])
self.assertAlmostEqual(expected_loss, loss, delta=atol)
@@ -1352,7 +1352,7 @@ class PoissonRegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
@@ -1395,7 +1395,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run([spec.loss, spec.train_op])
self.assertAlmostEqual(expected_loss, loss, delta=atol)
@@ -1419,7 +1419,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1444,7 +1444,7 @@ class LogisticRegressionHead(test.TestCase):
labels=labels,
train_op_fn=_train_op_fn)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -1471,7 +1471,7 @@ class LogisticRegressionHead(test.TestCase):
self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype)
# Assert predictions.
- with self.test_session():
+ with self.cached_session():
_initialize_variables(self, spec.scaffold)
self.assertAllClose(
expected_predictions, spec.predictions[keys.PREDICTIONS].eval())
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py
index caadafdfa6..66c46e66b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import time
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.framework import ops
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training
+from tensorflow.python.training import training_util
# pylint: disable=protected-access
@@ -72,8 +74,9 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
estimator: A `tf.estimator.Estimator` instance to call evaluate.
input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A
function that constructs the input data for evaluation.
- See @{$premade_estimators#create_input_functions} for more
- information. The function should construct and return one of
+ See [Createing input functions](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
the following:
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
@@ -210,4 +213,72 @@ class InMemoryEvaluatorHook(training.SessionRunHook):
self._evaluate(session)
+class _StopAtCheckpointStepHook(training.SessionRunHook):
+ """Hook that requests stop at a specified step based on checkpoint.
+
+ Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper
+ hook.
+ """
+
+ def __init__(self, model_dir, last_step,
+ wait_after_file_check_secs=30):
+ """Initializes a `StopAtCheckpointStepHook`.
+
+ This hook requests stop after a last step has been reached. It checks latest
+ checkpoint to verify last step is written on disk or not.
+
+ Args:
+ model_dir: Directory to read global step from latest checkpoint.
+ last_step: Step after which to stop.
+ wait_after_file_check_secs: Reading same file by many workers may create
+ I/O issues. To throttle that we will wait given secs after each read of
+ the file.
+
+ Raises:
+ ValueError: If one of the arguments is invalid.
+ """
+ if last_step is None:
+ raise ValueError('last_step must be specified.')
+ if model_dir is None:
+ raise ValueError('model_dir must be specified.')
+
+ self._model_dir = model_dir
+ self._last_step = last_step
+ self._wait_after_file_check_secs = wait_after_file_check_secs
+
+ def begin(self):
+ self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
+ if self._global_step_tensor is None:
+ raise RuntimeError(
+ 'Global step should be created to use StopAtCheckpointStepHook.')
+
+ def before_run(self, run_context): # pylint: disable=unused-argument
+ return training.SessionRunArgs(self._global_step_tensor)
+
+ def after_run(self, run_context, run_values):
+ global_step = run_values.results + 1
+ if global_step >= self._last_step:
+ # Check latest global step in the checkpoint to ensure that the targeted
+ # last step is written on disk.
+
+ step = estimator_lib._load_global_step_from_checkpoint_dir(
+ self._model_dir)
+ if step >= self._last_step:
+ run_context.request_stop()
+ else:
+ time.sleep(self._wait_after_file_check_secs)
+
+
+def make_stop_at_checkpoint_step_hook(estimator,
+ last_step,
+ wait_after_file_check_secs=30):
+ """Creates a proper StopAtCheckpointStepHook based on chief status."""
+
+ if estimator.config.is_chief:
+ return training.StopAtStepHook(last_step=last_step)
+ return _StopAtCheckpointStepHook(
+ model_dir=estimator.model_dir,
+ last_step=last_step,
+ wait_after_file_check_secs=wait_after_file_check_secs)
+
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/hooks_test.py b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
index ee88d5ecf5..c6c6cad95a 100644
--- a/tensorflow/contrib/estimator/python/estimator/hooks_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/hooks_test.py
@@ -21,8 +21,11 @@ from __future__ import print_function
import glob
import json
import os
+import tempfile
+import time
from tensorflow.contrib.estimator.python.estimator import hooks as hooks_lib
+from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -316,5 +319,85 @@ class InMemoryEvaluatorHookTest(test.TestCase):
estimator.train(input_fn, hooks=[evaluator])
+class StopAtCheckpointStepHookTest(test.TestCase):
+
+ def test_do_not_stop_if_checkpoint_is_not_there(self):
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib._StopAtCheckpointStepHook(
+ model_dir=tempfile.mkdtemp(), last_step=10)
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_do_not_stop_if_checkpoint_step_is_smaller(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_nine = step.assign(9)
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib._StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_nine)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertTrue(mock_sleep.called)
+ self.assertFalse(mon_sess.should_stop())
+
+ def test_stop_if_checkpoint_step_is_laststep(self):
+ model_dir = tempfile.mkdtemp()
+ with ops.Graph().as_default():
+ step = training.create_global_step()
+ assign_ten = step.assign(10)
+ no_op = control_flow_ops.no_op()
+ hook = hooks_lib._StopAtCheckpointStepHook(
+ model_dir=model_dir, last_step=10)
+ with tf_session.Session() as sess:
+ sess.run(assign_ten)
+ training.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+ with training.SingularMonitoredSession(hooks=[hook]) as mon_sess:
+ mon_sess.raw_session().run(assign_ten)
+ with test.mock.patch.object(time, 'sleep') as mock_sleep:
+ mon_sess.run(no_op)
+ self.assertFalse(mock_sleep.called)
+ self.assertTrue(mon_sess.should_stop())
+
+ def test_creates_regular_stop_at_step_hook_for_chief(self):
+ # by default an estimator is in chief mode
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1])
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, training.StopAtStepHook)
+ self.assertEqual(300, hook._last_step)
+
+ def test_creates_checkpoint_hook_for_workers(self):
+
+ class FakeWorkerConfig(estimator_lib.RunConfig):
+
+ @property
+ def is_chief(self):
+ return False
+
+ dnn = estimator_lib.DNNClassifier(
+ feature_columns=[feature_column_lib.numeric_column('x')],
+ hidden_units=[3, 1],
+ config=FakeWorkerConfig())
+ hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)
+ self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)
+ self.assertEqual(300, hook._last_step)
+ self.assertEqual(dnn.model_dir, hook._model_dir)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py
index 62a37abefb..2b68f24eb2 100644
--- a/tensorflow/contrib/estimator/python/estimator/linear.py
+++ b/tensorflow/contrib/estimator/python/estimator/linear.py
@@ -121,7 +121,7 @@ class LinearEstimator(estimator.Estimator):
is multivalent. One of "mean", "sqrtn", and "sum" -- these are
effectively different ways to do example-level normalization, which can
be useful for bag-of-words features. for more details, see
- @{tf.feature_column.linear_model$linear_model}.
+ `tf.feature_column.linear_model`.
"""
def _model_fn(features, labels, mode, config):
return linear_lib._linear_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 3d6fccb118..2b4d5f5261 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -132,7 +132,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -202,7 +202,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -259,7 +259,7 @@ class MultiHeadTest(test.TestCase):
spec.export_outputs.keys())
# Assert predictions and export_outputs.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
predictions = sess.run(spec.predictions)
@@ -336,7 +336,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, and metrics.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNone(spec.scaffold.summary_op)
value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
@@ -362,7 +362,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)[0]
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
# (averaged over classes, averaged over examples).
self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol)
@@ -397,7 +397,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
# training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
@@ -445,7 +445,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
# training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
@@ -498,7 +498,7 @@ class MultiHeadTest(test.TestCase):
logits=logits,
labels=labels)[0]
tol = 1e-3
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(
expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
@@ -535,7 +535,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
@@ -579,7 +579,7 @@ class MultiHeadTest(test.TestCase):
optimizer=_Optimizer())
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
loss, train_result = sess.run((spec.loss, spec.train_op))
self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
@@ -634,7 +634,7 @@ class MultiHeadTest(test.TestCase):
# Assert predictions, loss, train_op, and summaries.
tol = 1e-3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_initialize_variables(self, spec.scaffold)
self.assertIsNotNone(spec.scaffold.summary_op)
loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index dd8a3a95f1..65229d67bb 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -209,7 +209,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -233,7 +233,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
# Add another trainable variable that doesn't produce a gradient to
# verify that None gradients are supported.
_ = variable_scope.get_variable(
@@ -275,7 +275,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
# for the second.
expected_c = 10.0 - 3.0, 7.0 - 4.0
- with self.test_session() as session, variable_scope.variable_scope(
+ with self.cached_session() as session, variable_scope.variable_scope(
'', reuse=variable_scope.AUTO_REUSE):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
@@ -299,7 +299,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -330,7 +330,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
@@ -359,7 +359,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
@@ -374,7 +374,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -396,7 +396,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -424,7 +424,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -456,7 +456,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session():
+ with self.cached_session():
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/GPU:0'])
_ = replicated_model_fn(
@@ -470,7 +470,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
features = np.array([[0.01], [0.002]])
labels = np.array([[0.01], [0.02]])
- with self.test_session():
+ with self.cached_session():
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
_ = replicated_model_fn(
@@ -521,7 +521,7 @@ class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, devices=['/gpu:0'])
estimator_spec = replicated_model_fn(
@@ -649,7 +649,7 @@ class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -746,7 +746,7 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn,
loss_reduction=losses.Reduction.SUM,
@@ -777,7 +777,7 @@ class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session(), ops_lib.Graph().as_default():
+ with self.cached_session(), ops_lib.Graph().as_default():
with self.assertRaisesRegexp(
ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
@@ -819,7 +819,7 @@ class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase):
features = np.array([[1.0], [2.0]])
labels = np.array([[1.0], [2.0]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'Please.+wrap.+with.+TowerOptimizer'):
replicated_model_fn = replicate_model_fn.replicate_model_fn(
@@ -845,7 +845,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss))
def test_gradients_are_computed(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=None,
@@ -879,7 +879,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
self.assertEqual(0.25, session.run(c))
def test_gradients_are_computed_with_mean_reduction(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=model_fn_lib.ModeKeys.EVAL,
@@ -932,7 +932,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
return model_fn_lib.EstimatorSpec(
mode=mode, loss=math_ops.reduce_sum(loss))
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
model_fn,
mode=None,
@@ -975,7 +975,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual(a.dense_shape, b.dense_shape)
def test_simple_half_split(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -988,7 +988,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
def test_to_each_their_own(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1001,7 +1001,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
def test_one_batch(self):
- with self.test_session():
+ with self.cached_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1014,7 +1014,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
def test_half_split_in_dictionary(self):
- with self.test_session():
+ with self.cached_session():
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1029,7 +1029,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
def test_sparse_tensor_can_be_split_unevenly(self):
- with self.test_session():
+ with self.cached_session():
features = {
'x':
sparse_tensor.SparseTensor(
@@ -1054,7 +1054,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[2.0]], label_shards[1].eval())
def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
- with self.test_session():
+ with self.cached_session():
features = {
'x':
sparse_tensor.SparseTensor(
@@ -1081,7 +1081,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[2.0]], label_shards[1].eval())
def test_one_batch_in_dictionary(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.cached_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1095,7 +1095,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval())
def test_feature_and_label_dictionaries(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.cached_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]}
@@ -1127,7 +1127,7 @@ class TrainSpecTest(test_util.TensorFlowTestCase):
return constant_op.constant(loss_value, dtype=dtypes.float64)
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = list(map(self.create_constant_loss, [2, 4, 6]))
tower_specs = list(map(self.create_estimator_spec, tower_losses))
@@ -1161,7 +1161,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
return metrics
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = map(self.create_constant_loss, [2, 4, 6])
tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
tower_specs = [
@@ -1187,7 +1187,7 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
def test_handles_single_tower(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_losses = map(self.create_constant_loss, [5])
tower_metrics = map(self.create_eval_metrics, [0.2])
tower_specs = [
@@ -1231,7 +1231,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase):
})
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
tower_specs = replicate_model_fn._get_loss_towers(
self.model_fn,
mode=None,
@@ -1273,7 +1273,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total')
def test_example(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
@@ -1303,7 +1303,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
def test_reduce_is_idempotent(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
@@ -1329,7 +1329,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
def test_handles_single_tower(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
self.create_tower_metrics(0)
session.run(
variables.variables_initializer(
@@ -1346,7 +1346,7 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase):
self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01)
def test_doesnt_accept_uneven_number_of_variables(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
self.create_metric_variable(-1.0, 'oddball')
@@ -1418,7 +1418,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
return estimator_spec
def test_merge_predict_output(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
{
@@ -1428,7 +1428,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs))
def test_merge_classification_output_scores_classes(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1440,7 +1440,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
estimator_spec.export_outputs['classification_output'].classes))
def test_merge_classification_output_scores(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1450,7 +1450,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
None, estimator_spec.export_outputs['classification_scores'].classes)
def test_merge_classification_output_classes(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllEqual(
[b'split_inputs/split:0', b'split_inputs/split:1'],
@@ -1460,7 +1460,7 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase):
None, estimator_spec.export_outputs['classification_classes'].scores)
def test_merge_regression_output(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
estimator_spec = self.replicate_estimator_spec(session)
self.assertAllClose(
[0.1, 0.02],
@@ -1548,7 +1548,7 @@ class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
def test_vectors(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
total = replicate_model_fn._compute_sum_on_device(
[1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
@@ -1557,7 +1557,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
self.assertEqual(10.0, session.run(total))
def test_tensors(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
total = replicate_model_fn._compute_sum_on_device(
[[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
@@ -1566,7 +1566,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
self.assertAllEqual([4.0, 6.0], session.run(total))
def test_indexedslices(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([1.0, 2.0]), [0, 1],
dense_shape=constant_op.constant([2]))
@@ -1580,7 +1580,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
session.run(ops_lib.convert_to_tensor(total)))
def test_indexedslices_higher_dimensions(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
dense_shape=constant_op.constant([2, 4]))
@@ -1595,7 +1595,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
session.run(ops_lib.convert_to_tensor(total)))
def test_indexedslices_some_dont_overlap(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
a = ops_lib.IndexedSlices(
constant_op.constant([1.0, 2.0]), [0, 3],
dense_shape=constant_op.constant([4]))
@@ -1637,7 +1637,7 @@ class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
},
]
- with self.test_session() as session:
+ with self.cached_session() as session:
self.assertAllClose({
'a': np.array([1.0, 2.0, 3.0]),
'b': np.array([11.0, 12.0, 13.0, 14.0]),