aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar Martin Wicke <577277+martinwicke@users.noreply.github.com>2018-09-22 09:45:11 -0700
committerGravatar GitHub <noreply@github.com>2018-09-22 09:45:11 -0700
commit413ac36f33deb0c354dd687963d2410eab048970 (patch)
treefd4dc4e9fc5a76efd62c78c213b0e34983359256 /tensorflow/contrib/learn
parentc22d996c3d6a16db292bd3464b2ef7b91adae676 (diff)
parente692dda4c8b199555e2fa32132a7784e0893c870 (diff)
Merge branch 'master' into fix_expand_dims
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/BUILD33
-rw-r--r--tensorflow/contrib/learn/__init__.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py10
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/kmeans.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py144
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/run_config.py7
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/stability_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py37
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py49
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py28
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py8
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py58
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/losses_ops.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py16
28 files changed, 332 insertions, 169 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index d665fc9335..61185f65a9 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -79,16 +79,7 @@ py_library(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python:weights_broadcast_ops",
- "//tensorflow/python/estimator",
"//tensorflow/python/estimator:estimator_py",
- "//tensorflow/python/estimator:export_export",
- "//tensorflow/python/estimator:export_output",
- "//tensorflow/python/estimator:inputs",
- "//tensorflow/python/estimator:inputs_queues",
- "//tensorflow/python/estimator:model_fn",
- "//tensorflow/python/estimator:numpy_io",
- "//tensorflow/python/estimator:pandas_io",
- "//tensorflow/python/estimator:run_config",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/ops/losses",
@@ -117,7 +108,6 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/data_feeder_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/python:client_testlib",
@@ -171,9 +161,8 @@ tf_py_test(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
- "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:estimator_py",
],
- tags = ["no_windows"], # TODO: needs investigation on Windows
)
py_test(
@@ -220,7 +209,7 @@ py_test(
"//tensorflow/contrib/training:training_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
- "//tensorflow/python/estimator:run_config",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -245,7 +234,7 @@ py_test(
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python:variables",
- "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -259,7 +248,7 @@ py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
- "//tensorflow/python/estimator:run_config",
+ "//tensorflow/python/estimator:estimator_py",
],
)
@@ -281,7 +270,11 @@ py_test(
size = "medium",
srcs = ["python/learn/estimators/estimator_test.py"],
srcs_version = "PY2AND3",
- tags = ["manual"],
+ tags = [
+ "manual",
+ "noasan", # times out
+ "optonly", # test is flaky without optimization.
+ ],
deps = [
":learn",
"//tensorflow/contrib/framework:framework_py",
@@ -410,6 +403,7 @@ py_test(
srcs = ["python/learn/estimators/dnn_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
@@ -431,6 +425,7 @@ py_test(
name = "kmeans_test",
size = "medium",
srcs = ["python/learn/estimators/kmeans_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"noasan", # b/73741358
@@ -482,6 +477,7 @@ py_test(
name = "state_saving_rnn_estimator_test",
size = "medium",
srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = ["noasan"],
deps = [
@@ -594,7 +590,6 @@ py_test(
size = "small",
srcs = ["python/learn/learn_io/io_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
deps = [
":learn",
"//tensorflow/contrib/learn/python/learn/datasets",
@@ -615,7 +610,7 @@ py_test(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:session",
"//tensorflow/python:training",
- "//tensorflow/python/estimator:export_output",
+ "//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/saved_model:signature_constants",
"@six_archive//:six",
],
@@ -741,7 +736,7 @@ py_test(
tf_py_test(
name = "graph_io_test",
- size = "small",
+ size = "medium",
srcs = ["python/learn/learn_io/graph_io_test.py"],
additional_deps = [
":learn",
diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py
index 79bd73faaf..28a6f5aed9 100644
--- a/tensorflow/contrib/learn/__init__.py
+++ b/tensorflow/contrib/learn/__init__.py
@@ -19,7 +19,8 @@ This module and all its submodules are deprecated. See
[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
for migration instructions.
-See the @{$python/contrib.learn} guide.
+See the [Contrib Learn](https://tensorflow.org/api_guides/python/contrib.learn)
+guide.
@@BaseEstimator
@@Estimator
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index c9a11f27f1..1d8a59281a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -155,7 +155,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
sequence_input = dynamic_rnn_estimator.build_sequence_input(
self.GetColumnsToTensors(), self.sequence_feature_columns,
self.context_feature_columns)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input)
@@ -330,7 +330,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
actual_state = dynamic_rnn_estimator.dict_to_state_tuple(state_dict, cell)
flattened_state = dynamic_rnn_estimator.state_tuple_to_dict(actual_state)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
(state_dict_val, actual_state_val, flattened_state_val) = sess.run(
[state_dict, actual_state, flattened_state])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 7a026a15e4..c1de42782e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -72,6 +72,7 @@ from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary as core_summary
from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
@@ -891,7 +892,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
# Check that model has been trained (if nothing has been set explicitly).
if not checkpoint_path:
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -956,7 +957,7 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
as_iterable=True,
iterate_batches=False):
# Check that model has been trained.
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
@@ -1364,7 +1365,7 @@ class Estimator(BaseEstimator):
if not checkpoint_path:
# Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError(
"Couldn't find trained model at %s." % self._model_dir)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index d81a534b79..9e5aaf3118 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
- self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'],
+ # TODO(b/78461127): Please modify tests to not directly rely on names of
+ # checkpoints.
+ self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 06f4173170..c6f79e00d5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -777,7 +777,7 @@ class _RegressionHead(_SingleHead):
key = prediction_key.PredictionKey.SCORES
with ops.name_scope(None, "predictions", (logits,)):
if self.logits_dimension == 1:
- logits = array_ops.squeeze(logits, squeeze_dims=(1,), name=key)
+ logits = array_ops.squeeze(logits, axis=(1,), name=key)
return {key: self._link_fn(logits)}
def _metrics(self, eval_loss, predictions, labels, weights):
@@ -974,7 +974,7 @@ def _softmax_cross_entropy_loss(labels, logits, weights=None):
is_squeezed_labels = False
# TODO(ptucker): This will break for dynamic shapes.
if len(labels.get_shape()) == 2:
- labels = array_ops.squeeze(labels, squeeze_dims=(1,))
+ labels = array_ops.squeeze(labels, axis=(1,))
is_squeezed_labels = True
loss = nn.sparse_softmax_cross_entropy_with_logits(
@@ -1862,12 +1862,12 @@ def _get_arguments(func):
if hasattr(func, "__code__"):
# Regular function.
return tf_inspect.getargspec(func)
- elif hasattr(func, "__call__"):
- # Callable object.
- return _get_arguments(func.__call__)
elif hasattr(func, "func"):
# Partial function.
return _get_arguments(func.func)
+ elif hasattr(func, "__call__"):
+ # Callable object.
+ return _get_arguments(func.__call__)
def _verify_loss_fn_args(loss_fn):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
index 66ebcfd1d8..21f7dcc5e4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py
@@ -15,9 +15,9 @@
"""Implementation of k-means clustering on top of `Estimator` API (deprecated).
This module is deprecated. Please use
-@{tf.contrib.factorization.KMeansClustering} instead of
-@{tf.contrib.learn.KMeansClustering}. It has a similar interface, but uses the
-@{tf.estimator.Estimator} API instead of @{tf.contrib.learn.Estimator}.
+`tf.contrib.factorization.KMeansClustering` instead of
+`tf.contrib.learn.KMeansClustering`. It has a similar interface, but uses the
+`tf.estimator.Estimator` API instead of `tf.contrib.learn.Estimator`.
"""
from __future__ import absolute_import
diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
index b28835a809..584556992a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py
@@ -36,7 +36,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 70b70af98c..e100bc7a1e 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -31,7 +31,6 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
-from tensorflow.python.training import training_util
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training as train
+from tensorflow.python.training import training_util
# The default learning rate of 0.2 is a historical artifact of the initial
@@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params):
parent_scope = "linear"
with variable_scope.variable_scope(
- values=features.values(), name_or_scope=parent_scope) as scope:
+ values=features.values(),
+ name_or_scope=parent_scope,
+ partitioner=optimizer.partitioner) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index d3bb0fda57..597ca4e86d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.training import ftrl
from tensorflow.python.training import input as input_lib
@@ -863,6 +864,38 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self):
+ """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[2., 3., 1.],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ # 'GB' is out of the vocabulary.
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5])
+ }, constant_op.constant([[1], [0], [1]])
+
+ country = feature_column_lib.sparse_column_with_keys(
+ 'country', keys=['US', 'CA', 'MK', 'IT', 'CN'])
+ country_weighted_by_price = feature_column_lib.weighted_sparse_column(
+ country, 'price')
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id')
+ classifier = linear.LinearClassifier(
+ feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testSdcaOptimizerCrossedFeatures(self):
"""Tests LinearClassifier with SDCAOptimizer and crossed features."""
@@ -934,6 +967,63 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearClassifier with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([[0.6], [0.8], [0.3]]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [1.0], [1.0]])
+ }, constant_op.constant([[1], [0], [1]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id',
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ classifier = linear.LinearClassifier(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ print('all scores = {}'.format(scores))
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testEval(self):
"""Tests that eval produces correct metrics.
"""
@@ -1508,6 +1598,60 @@ class LinearRegressorTest(test.TestCase):
loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
self.assertLess(loss, 0.05)
+ def testSdcaOptimizerPartitionedVariables(self):
+ """Tests LinearRegressor with SDCAOptimizer with partitioned variables."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ constant_op.constant([0.6, 0.8, 0.3]),
+ 'sq_footage':
+ constant_op.constant([[900.0], [700.0], [600.0]]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 3], [2, 1]],
+ dense_shape=[3, 5]),
+ 'weights':
+ constant_op.constant([[3.0], [5.0], [7.0]])
+ }, constant_op.constant([[1.55], [-1.25], [-3.0]])
+
+ price = feature_column_lib.real_valued_column('price')
+ sq_footage_bucket = feature_column_lib.bucketized_column(
+ feature_column_lib.real_valued_column('sq_footage'),
+ boundaries=[650.0, 800.0])
+ country = feature_column_lib.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ sq_footage_country = feature_column_lib.crossed_column(
+ [sq_footage_bucket, country], hash_bucket_size=10)
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id', symmetric_l2_regularization=1.0,
+ partitioner=partitioned_variables.fixed_size_partitioner(
+ num_shards=2, axis=0))
+ tf_config = {
+ 'cluster': {
+ run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
+ }
+ }
+ with test.mock.patch.dict('os.environ',
+ {'TF_CONFIG': json.dumps(tf_config)}):
+ config = run_config.RunConfig()
+ # Because we did not start a distributed cluster, we need to pass an
+ # empty ClusterSpec, otherwise the device_setter will look for
+ # distributed jobs, such as "/job:ps" which are not present.
+ config._cluster_spec = server_lib.ClusterSpec({})
+
+ regressor = linear.LinearRegressor(
+ feature_columns=[price, sq_footage_bucket, country, sq_footage_country],
+ weight_column_name='weights',
+ optimizer=sdca_optimizer,
+ config=config)
+ regressor.fit(input_fn=input_fn, steps=20)
+ loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']
+ self.assertLess(loss, 0.05)
+
def testSdcaOptimizerSparseFeaturesWithL1Reg(self):
"""Tests LinearClassifier with SDCAOptimizer and sparse features."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
index 82563141cc..ebf5f5617d 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common_test.py
@@ -44,7 +44,7 @@ class RnnCommonTest(test.TestCase):
constant_op.constant(labels, dtype=dtypes.int32),
constant_op.constant(sequence_length, dtype=dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
activations_masked, labels_masked = sess.run(
[activations_masked_t, labels_masked_t])
diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
index 14ee2ba609..08f23aa223 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py
@@ -221,7 +221,7 @@ class ClusterConfig(object):
class RunConfig(ClusterConfig, core_run_config.RunConfig):
"""This class specifies the configurations for an `Estimator` run.
- This class is a deprecated implementation of @{tf.estimator.RunConfig}
+ This class is a deprecated implementation of `tf.estimator.RunConfig`
interface.
"""
_USE_DEFAULT = 0
@@ -240,6 +240,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100,
+ protocol=None,
evaluation_master='',
model_dir=None,
session_config=None):
@@ -289,6 +290,8 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
session_config: a ConfigProto used to set session parameters, or None.
Note - using this argument, it is easy to provide settings which break
otherwise perfectly good models. Use with care.
+ protocol: An optional argument which specifies the protocol used when
+ starting server. None means default to grpc.
"""
# Neither parent class calls super().__init__(), so here we have to
# manually call their __init__() methods.
@@ -299,6 +302,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
# so instead of breaking compatibility with that assumption, we
# just manually initialize this field:
self._train_distribute = None
+ self._eval_distribute = None
self._device_fn = None
gpu_options = config_pb2.GPUOptions(
@@ -313,6 +317,7 @@ class RunConfig(ClusterConfig, core_run_config.RunConfig):
self._save_summary_steps = save_summary_steps
self._save_checkpoints_secs = save_checkpoints_secs
self._log_step_count_steps = log_step_count_steps
+ self._protocol = protocol
self._session_config = session_config
if save_checkpoints_secs == RunConfig._USE_DEFAULT:
if save_checkpoints_steps is None:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
index 6d04543819..81376c0e2a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
@@ -68,12 +68,12 @@ class StabilityTest(test.TestCase):
minval = -0.3333
maxval = 0.3333
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val1 = session.run(x)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val2 = session.run(x)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
index 442247409d..06c61554fa 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator_test.py
@@ -53,7 +53,7 @@ class PrepareInputsForRnnTest(test.TestCase):
sequence_feature_columns,
num_unroll)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time)
@@ -314,7 +314,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
else:
self.assertAllEqual(v, got[k])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run(
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index 3744abd860..4e64efdd95 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
__all__ = ["Experiment"]
def _get_standardized_predicate_fn(predicate_fn):
- pred_fn_args = estimator_util.fn_args(predicate_fn)
+ pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
@@ -95,7 +95,7 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener):
# Load and cache the path of the most recent checkpoint to avoid duplicate
# searches on GCS.
logging.info("Checking for checkpoint in %s", self._model_dir)
- latest_path = saver.latest_checkpoint(self._model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
if not latest_path:
logging.warning("Skipping evaluation and export since model has not been "
@@ -162,16 +162,16 @@ class Experiment(object):
Args:
estimator: Object implementing Estimator interface, which could be a
- combination of @{tf.contrib.learn.Trainable} and
- @{tf.contrib.learn.Evaluable} (deprecated), or
- @{tf.estimator.Estimator}.
+ combination of `tf.contrib.learn.Trainable` and
+ `tf.contrib.learn.Evaluable` (deprecated), or
+ `tf.estimator.Estimator`.
train_input_fn: function, returns features and labels for training.
eval_input_fn: function, returns features and labels for evaluation. If
`eval_steps` is `None`, this should be configured only to produce for a
finite number of batches (generally, 1 epoch over the evaluation data).
eval_metrics: `dict` of string, metric function. If `None`, default set
is used. This should be `None` if the `estimator` is
- @{tf.estimator.Estimator}. If metrics are provided they will be
+ `tf.estimator.Estimator`. If metrics are provided they will be
*appended* to the default set.
train_steps: Perform this many steps of training. `None`, the default,
means train forever.
@@ -468,10 +468,15 @@ class Experiment(object):
on which that evaluation was based.
At the beginning of evaluation, the passed `eval_results` will be None
so it's expected that the predicate function handles that gracefully.
- When `predicate_fn` is not specified, continuous eval will run in an
- infinite loop (if `train_steps` is None). or exit once global step
- reaches `train_steps`.
-
+ Continuous eval behavior under different conditions:
+ * When `predicate_fn` is specified:
+ + if `train_steps` is None, run until `predicate_fn` returns False.
+ + if `train_steps` is specified, run until either global step
+ reaches `train_steps` or `predicate_fn` returns False.
+ * When `predicate_fn` is not specified:
+ + if `train_steps` is None, run in an infinite loop.
+ + if `train_steps` is specified, run until global step reaches
+ `train_steps`.
export: Whether to export from this step. Default is 'True'.
Raises:
@@ -500,7 +505,7 @@ class Experiment(object):
eval_result = None
last_warning_time = 0
while (not predicate_fn or predicate_fn(
- eval_result, checkpoint_path=previous_path if eval_result else None)):
+ eval_result, checkpoint_path=previous_path)):
# Exit if we have already reached number of steps to train.
if self._has_training_stopped(eval_result):
logging.info("Exiting continuous eval, global_step=%s >= "
@@ -511,7 +516,8 @@ class Experiment(object):
start = time.time()
error_msg = None
- latest_path = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if not latest_path:
error_msg = ("Estimator is not fitted yet. "
"Will start an evaluation when a checkpoint is ready.")
@@ -773,7 +779,8 @@ class Experiment(object):
saving_listeners=self._saving_listeners)
logging.info("Evaluating model now.")
- latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir)
+ latest_checkpoint = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
eval_result = self._call_evaluate(
input_fn=self._eval_input_fn,
steps=self._eval_steps,
diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py
index d10927a0cd..fb16c94c29 100644
--- a/tensorflow/contrib/learn/python/learn/experiment_test.py
+++ b/tensorflow/contrib/learn/python/learn/experiment_test.py
@@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase):
noop_hook = _NoopHook()
def _predicate_fn(eval_result, checkpoint_path):
- self.assertEqual(not eval_result,
+ self.assertEqual(eval_result is None,
checkpoint_path is None)
return est.eval_count < 3 # pylint: disable=cell-var-from-loop
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index 0d039d593b..33180b778a 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@@ -124,7 +125,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@@ -174,7 +175,7 @@ class GraphActionsTest(test.TestCase):
return in0, in1, out
def test_infer(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, in1, out = self._build_inference_graph()
self.assertEqual({
@@ -192,7 +193,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_coordinator_request_stop_called(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})
self.assertTrue(request_stop.called)
@@ -203,7 +204,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_run_feeds_iter_cleanup_with_exceptions(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
try:
for _ in learn.graph_actions.run_feeds_iter({
@@ -233,7 +234,7 @@ class GraphActionsTest(test.TestCase):
self.assertTrue(test_ops.resource_initialized_op(handle).eval())
def test_infer_different_default_graph(self):
- with self.test_session():
+ with self.cached_session():
self._assert_ckpt(self._output_dir, False)
with ops.Graph().as_default():
in0, in1, out = self._build_inference_graph()
@@ -248,7 +249,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_invalid_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, _ = self._build_inference_graph()
with self.assertRaisesRegexp(TypeError, 'Can not convert a NoneType'):
@@ -256,7 +257,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, out = self._build_inference_graph()
self.assertEqual(
@@ -270,7 +271,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test eval for 1 epoch.
def test_evaluate_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions.evaluate(
@@ -287,7 +288,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -309,7 +310,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_ready_for_local_init(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
variables.Variable(
@@ -326,7 +327,7 @@ class GraphActionsTest(test.TestCase):
max_steps=1)
def test_evaluate_feed_fn(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -351,7 +352,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_feed_fn_with_exhaustion(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -374,7 +375,7 @@ class GraphActionsTest(test.TestCase):
expected_session_logs=[])
def test_evaluate_with_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
@@ -434,7 +435,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Test number and contents of checkpoint files.
def _assert_ckpt(self, output_dir, expected=True):
- ckpt_state = saver_lib.get_checkpoint_state(output_dir)
+ ckpt_state = checkpoint_management.get_checkpoint_state(output_dir)
if expected:
pattern = '%s/model.ckpt-.*' % output_dir
primary_ckpt_path = ckpt_state.model_checkpoint_path
@@ -468,7 +469,7 @@ class GraphActionsTrainTest(test.TestCase):
return in0, in1, out
def test_train_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
train_op = constant_op.constant(1.0)
loss_op = constant_op.constant(2.0)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
@@ -502,7 +503,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Mock supervisor, and assert all interactions.
def test_train(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_summaries(self._output_dir)
@@ -521,7 +522,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_steps_is_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -534,7 +535,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -548,7 +549,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(25, step)
def test_train_max_steps_is_not_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -561,7 +562,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -575,7 +576,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(15, step)
def test_train_loss(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
@@ -597,7 +598,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_summaries(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -623,7 +624,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_chief_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -662,7 +663,7 @@ class GraphActionsTrainTest(test.TestCase):
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
- with self.test_session(g):
+ with self.session(g):
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 1f439965da..284a4f45f6 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -58,7 +58,7 @@ class DataFeederTest(test.TestCase):
self.assertEqual(expected_np_dtype, v)
else:
self.assertEqual(expected_np_dtype, feeder.input_dtype)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inp, _ = feeder.input_builder()
if isinstance(inp, dict):
for v in list(inp.values()):
@@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase):
def test_unsupervised(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
inp, _ = feeder.input_builder()
feed_dict_fn = feeder.get_feed_dict_fn()
feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase):
def test_epoch(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
feeder.input_builder()
epoch = feeder.make_epoch_variable()
feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b7d9..5e90d1fa20 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase):
for index in range(2):
yield {'a': np.ones(1) * index}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase):
'label2': np.ones(1) * index - 64,
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key=['label', 'label2'],
@@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones((3, 3)) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase):
def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
x = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
failing_input_fn = generator_io.generator_input_fn(
x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
return np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
yield np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase):
}
y = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', np.arange(10)]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', 'target']
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
features = input_fn()
@@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index e11e8b698a..8e68a17e47 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -207,7 +207,7 @@ class GraphIOTest(test.TestCase):
parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32)
}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
features = graph_io.read_batch_record_features(
_VALID_FILE_PATTERN,
batch_size,
@@ -242,7 +242,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
_VALID_FILE_PATTERN,
batch_size,
@@ -276,7 +276,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
[_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2],
batch_size,
@@ -325,7 +325,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filename,
batch_size,
@@ -374,7 +374,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, result = graph_io.read_keyed_batch_features(
filename,
batch_size,
@@ -429,7 +429,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
result = graph_io.read_batch_features(
filename,
batch_size,
@@ -475,7 +475,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filenames,
batch_size,
@@ -519,7 +519,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
@@ -640,7 +640,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 10
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
[filename],
batch_size,
@@ -672,7 +672,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples(
filename,
batch_size,
@@ -714,7 +714,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda
parsing_ops.decode_json_example(example), dtypes)
@@ -773,7 +773,7 @@ class GraphIOTest(test.TestCase):
examples = parsing_ops.parse_example(serialized, features)
return math_ops.less(examples["age"], 2)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io._read_keyed_batch_examples_helper(
filename,
batch_size,
@@ -812,7 +812,7 @@ class GraphIOTest(test.TestCase):
coord.join(threads)
def test_queue_parsed_features_single_tensor(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features = {"test": constant_op.constant([1, 2, 3])}
_, queued_features = graph_io.queue_parsed_features(features)
coord = coordinator.Coordinator()
@@ -833,7 +833,7 @@ class GraphIOTest(test.TestCase):
_, queued_feature = graph_io.read_keyed_batch_features_shared_queue(
_VALID_FILE_PATTERN, batch_size, feature, reader)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features_result = graph_io.read_batch_features(
_VALID_FILE_PATTERN, batch_size, feature, reader)
session.run(variables.local_variables_initializer())
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e8f3..396539a76a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 77f7c73d54..3d691d4340 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -51,7 +51,7 @@ from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary as core_summary
-from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
@@ -735,7 +735,8 @@ class ValidationMonitor(EveryN):
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
@@ -1059,7 +1060,8 @@ class ExportMonitor(EveryN):
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
- latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
+ latest_path = checkpoint_management.latest_checkpoint(
+ self._estimator.model_dir)
if latest_path is None:
logging.info("Skipping export at the end since model has not been saved "
"yet.")
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index 5c34d0ddb0..83e48a36e7 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -39,9 +39,9 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import monitored_session
-from tensorflow.python.training import saver
from tensorflow.python.training import training_util
@@ -127,12 +127,12 @@ class MonitorsTest(test.TestCase):
monitor.end()
def test_base_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(learn.monitors.BaseMonitor())
def test_every_0(self):
monitor = _MyEveryN(every_n_steps=0, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(30))
self.assertAllEqual(expected_steps, monitor.steps_begun)
@@ -141,7 +141,7 @@ class MonitorsTest(test.TestCase):
def test_every_1(self):
monitor = _MyEveryN(every_n_steps=1, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(1, 30))
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -150,7 +150,7 @@ class MonitorsTest(test.TestCase):
def test_every_2(self):
monitor = _MyEveryN(every_n_steps=2, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(2, 29, 2)) + [29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -159,7 +159,7 @@ class MonitorsTest(test.TestCase):
def test_every_8(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = [0, 1, 2, 10, 18, 26, 29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -168,7 +168,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_no_max_steps(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(
monitor, num_epochs=3, num_steps_per_epoch=10, pass_max_steps=False)
begin_end_steps = [0, 1, 2, 10, 18, 26]
@@ -179,7 +179,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_begin(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_begin(step)
@@ -192,7 +192,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_end(step, output=None)
@@ -207,7 +207,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_at_the_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -224,7 +224,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_should_not_be_called_twice(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -240,13 +240,13 @@ class MonitorsTest(test.TestCase):
self.assertEqual([8, 16], monitor.post_steps)
def test_print(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
t = constant_op.constant(42.0, name='foo')
self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
self.assertRegexpMatches(str(self.logged_message), t.name)
def test_logging_trainable(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(constant_op.constant(42.0), name='foo')
var.initializer.run()
cof = constant_op.constant(1.0)
@@ -258,7 +258,7 @@ class MonitorsTest(test.TestCase):
self.assertRegexpMatches(str(self.logged_message), var.name)
def test_summary_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
var = variables.Variable(0.0)
@@ -312,12 +312,12 @@ class MonitorsTest(test.TestCase):
monitor = learn.monitors.ValidationMonitor(
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'set_estimator'):
self._run_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_ckpt(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -330,13 +330,13 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
mock_latest_checkpoint.assert_called_with(model_dir)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_no_early_stopping_rounds(self,
mock_latest_checkpoint,
mock_estimator_class):
@@ -351,12 +351,12 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_invalid_metric(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -370,12 +370,12 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=1)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'missing from outputs'):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@test.mock.patch.object(estimators, 'Estimator', autospec=True)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor(self, mock_latest_checkpoint,
mock_estimator_class):
estimator = mock_estimator_class()
@@ -392,7 +392,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -464,7 +464,7 @@ class MonitorsTest(test.TestCase):
monitor.epoch_end(epoch=0)
monitor.end()
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_with_core_estimator(self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
model_dir = 'model/dir'
@@ -477,7 +477,7 @@ class MonitorsTest(test.TestCase):
every_n_steps=0, early_stopping_rounds=2)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -495,7 +495,7 @@ class MonitorsTest(test.TestCase):
expected_best_metrics={'loss': 42.0, 'auc': 0.5})
monitor.post_step(step=step, session=None)
- @test.mock.patch.object(saver, 'latest_checkpoint')
+ @test.mock.patch.object(checkpoint_management, 'latest_checkpoint')
def test_validation_monitor_fail_with_core_estimator_and_metrics(
self, mock_latest_checkpoint):
estimator = test.mock.Mock(spec=core_estimator.Estimator)
@@ -509,7 +509,7 @@ class MonitorsTest(test.TestCase):
metrics=constant_op.constant(2.0),
every_n_steps=0, early_stopping_rounds=2)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
@@ -525,7 +525,7 @@ class MonitorsTest(test.TestCase):
def test_graph_dump(self):
monitor0 = learn.monitors.GraphDump()
monitor1 = learn.monitors.GraphDump()
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
const_var = variables.Variable(42.0, name='my_const')
counter_var = variables.Variable(0.0, name='my_counter')
assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
@@ -568,7 +568,7 @@ class MonitorsTest(test.TestCase):
def test_capture_variable(self):
monitor = learn.monitors.CaptureVariable(
var_name='my_assign_add:0', every_n=8, first_n=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(0.0, name='my_var')
var.initializer.run()
state_ops.assign_add(var, 1.0, name='my_assign_add')
diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
index 92976d1539..9f2cadb017 100644
--- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py
@@ -40,7 +40,7 @@ def mean_squared_error_regressor(tensor_in, labels, weights, biases, name=None):
[tensor_in, labels]):
predictions = nn.xw_plus_b(tensor_in, weights, biases)
if len(labels.get_shape()) == 1 and len(predictions.get_shape()) == 2:
- predictions = array_ops_.squeeze(predictions, squeeze_dims=[1])
+ predictions = array_ops_.squeeze(predictions, axis=[1])
return predictions, losses.mean_squared_error(labels, predictions)
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923db3..ff190110c1 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@ class OpsTest(test.TestCase):
"""Ops tests."""
def test_softmax_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
features = array_ops.placeholder(dtypes.float32, [None, 3])
labels = array_ops.placeholder(dtypes.float32, [None, 2])
weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@ class OpsTest(test.TestCase):
ids_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
ids = np.random.randint(0, n_embed, ids_shape)
- with self.test_session():
+ with self.cached_session():
embed_np = embeds[ids]
embed_tf = ops.embedding_lookup(embeds, ids).eval()
self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@ class OpsTest(test.TestCase):
def test_categorical_variable(self):
random_seed.set_random_seed(42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
embeddings = ops.categorical_variable(
cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61955..5a7e4ebfea 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase):
"""Sequence-to-sequence tests."""
def test_sequence_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
decoding = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
@@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase):
def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
- with self.test_session() as session:
+ with self.cached_session() as session:
x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase):
[[0, 0, 0], [0, 0, 0]]])
def test_rnn_decoder(self):
- with self.test_session():
+ with self.cached_session():
decoder_inputs = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
index 3eacac7a3d..0144b93814 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
@@ -298,7 +299,8 @@ def _export_estimator(estimator,
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
- tf_saver.latest_checkpoint(estimator._model_dir))
+ checkpoint_management.latest_checkpoint(
+ estimator._model_dir))
with ops.Graph().as_default() as g:
training_util.create_global_step(g)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index c7cdb41312..4f22054af3 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -55,7 +55,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.summary import summary_iterator
-from tensorflow.python.training import saver
+from tensorflow.python.training import checkpoint_management
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated
@@ -343,7 +343,8 @@ def get_temp_export_dir(timestamped_export_dir):
"""
(dirname, basename) = os.path.split(timestamped_export_dir)
temp_export_dir = os.path.join(
- compat.as_bytes(dirname), compat.as_bytes('temp-{}'.format(basename)))
+ compat.as_bytes(dirname),
+ compat.as_bytes('temp-{}'.format(compat.as_text(basename))))
return temp_export_dir
@@ -414,7 +415,7 @@ def make_export_strategy(serving_input_fn,
`InputFnOps`.
default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head.
- Must be `None` if the estimator inherits from @{tf.estimator.Estimator}
+ Must be `None` if the estimator inherits from `tf.estimator.Estimator`
or for single-headed models.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel. Each key should give the destination
@@ -452,7 +453,7 @@ def make_export_strategy(serving_input_fn,
The string path to the exported directory.
Raises:
- ValueError: If `estimator` is a @{tf.estimator.Estimator} instance
+ ValueError: If `estimator` is a `tf.estimator.Estimator` instance
and `default_output_alternative_key` was specified.
"""
if isinstance(estimator, core_estimator.Estimator):
@@ -503,7 +504,7 @@ def make_parsing_export_strategy(feature_columns,
that must be provided at serving time (excluding labels!).
default_output_alternative_key: the name of the head to serve when an
incoming serving request does not explicitly request a specific head.
- Must be `None` if the estimator inherits from @{tf.estimator.Estimator}
+ Must be `None` if the estimator inherits from `tf.estimator.Estimator`
or for single-headed models.
assets_extra: A dict specifying how to populate the assets.extra directory
within the exported SavedModel. Each key should give the destination
@@ -713,7 +714,8 @@ def make_best_model_export_strategy(
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
- checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
@@ -765,7 +767,7 @@ def extend_export_strategy(base_export_strategy,
The string path to the SavedModel indicated by post_export_fn.
Raises:
- ValueError: If `estimator` is a @{tf.estimator.Estimator} instance
+ ValueError: If `estimator` is a `tf.estimator.Estimator` instance
and `default_output_alternative_key` was specified or if post_export_fn
does not return a valid directory.
RuntimeError: If unable to create temporary or final export directory.