aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 18:22:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 18:25:59 -0700
commit708b30f4cb82271bb28cb70a1e0c89a1933f5b64 (patch)
tree22470a9314f7f4225b6d08170a3d7ea91b0216a1 /tensorflow/contrib/learn
parentd0cac47a767dd972516f75ce57f0d6185e3b6514 (diff)
Move from deprecated self.test_session() to self.session() when a graph is set.
self.test_session() has been deprecated in cl/208545396 as its behavior confuses readers of the test. Moving to self.session() instead. PiperOrigin-RevId: 209696110
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/stability_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/graph_actions_test.py42
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py28
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors_test.py44
5 files changed, 60 insertions, 60 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
index 6d04543819..81376c0e2a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/stability_test.py
@@ -68,12 +68,12 @@ class StabilityTest(test.TestCase):
minval = -0.3333
maxval = 0.3333
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val1 = session.run(x)
with ops.Graph().as_default() as g:
- with self.test_session(graph=g) as session:
+ with self.session(graph=g) as session:
g.seed = my_seed
x = random_ops.random_uniform([10, 10], minval=minval, maxval=maxval)
val2 = session.run(x)
diff --git a/tensorflow/contrib/learn/python/learn/graph_actions_test.py b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
index df156da3f4..d5c02124ac 100644
--- a/tensorflow/contrib/learn/python/learn/graph_actions_test.py
+++ b/tensorflow/contrib/learn/python/learn/graph_actions_test.py
@@ -175,7 +175,7 @@ class GraphActionsTest(test.TestCase):
return in0, in1, out
def test_infer(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, in1, out = self._build_inference_graph()
self.assertEqual({
@@ -193,7 +193,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_coordinator_request_stop_called(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
learn.graph_actions.infer(None, {'a': in0, 'b': in1, 'c': out})
self.assertTrue(request_stop.called)
@@ -204,7 +204,7 @@ class GraphActionsTest(test.TestCase):
side_effect=learn.graph_actions.coordinator.Coordinator.request_stop,
autospec=True)
def test_run_feeds_iter_cleanup_with_exceptions(self, request_stop):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, in1, out = self._build_inference_graph()
try:
for _ in learn.graph_actions.run_feeds_iter({
@@ -249,7 +249,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_invalid_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, _ = self._build_inference_graph()
with self.assertRaisesRegexp(TypeError, 'Can not convert a NoneType'):
@@ -257,7 +257,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_infer_feed(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
in0, _, out = self._build_inference_graph()
self.assertEqual(
@@ -271,7 +271,7 @@ class GraphActionsTest(test.TestCase):
# TODO(ptucker): Test eval for 1 epoch.
def test_evaluate_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._assert_ckpt(self._output_dir, False)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions.evaluate(
@@ -288,7 +288,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -310,7 +310,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_ready_for_local_init(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
variables.Variable(
@@ -327,7 +327,7 @@ class GraphActionsTest(test.TestCase):
max_steps=1)
def test_evaluate_feed_fn(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -352,7 +352,7 @@ class GraphActionsTest(test.TestCase):
self._assert_ckpt(self._output_dir, False)
def test_evaluate_feed_fn_with_exhaustion(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
in0, _, out = self._build_inference_graph()
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer, expected_session_logs=[])
@@ -375,7 +375,7 @@ class GraphActionsTest(test.TestCase):
expected_session_logs=[])
def test_evaluate_with_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
_, _, out = self._build_inference_graph()
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
writer = learn.graph_actions.get_summary_writer(self._output_dir)
@@ -469,7 +469,7 @@ class GraphActionsTrainTest(test.TestCase):
return in0, in1, out
def test_train_invalid_args(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
train_op = constant_op.constant(1.0)
loss_op = constant_op.constant(2.0)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
@@ -503,7 +503,7 @@ class GraphActionsTrainTest(test.TestCase):
# TODO(ptucker): Mock supervisor, and assert all interactions.
def test_train(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_summaries(self._output_dir)
@@ -522,7 +522,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_steps_is_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -535,7 +535,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -549,7 +549,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(25, step)
def test_train_max_steps_is_not_incremental(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -562,7 +562,7 @@ class GraphActionsTrainTest(test.TestCase):
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions.train(
@@ -576,7 +576,7 @@ class GraphActionsTrainTest(test.TestCase):
self.assertEqual(15, step)
def test_train_loss(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
@@ -598,7 +598,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_summaries(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -624,7 +624,7 @@ class GraphActionsTrainTest(test.TestCase):
self._assert_ckpt(self._output_dir, True)
def test_train_chief_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
@@ -663,7 +663,7 @@ class GraphActionsTrainTest(test.TestCase):
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
- with self.test_session(g):
+ with self.session(g):
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 1f439965da..5e07b9313f 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -58,7 +58,7 @@ class DataFeederTest(test.TestCase):
self.assertEqual(expected_np_dtype, v)
else:
self.assertEqual(expected_np_dtype, feeder.input_dtype)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
inp, _ = feeder.input_builder()
if isinstance(inp, dict):
for v in list(inp.values()):
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
index e11e8b698a..8e68a17e47 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py
@@ -207,7 +207,7 @@ class GraphIOTest(test.TestCase):
parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32)
}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
features = graph_io.read_batch_record_features(
_VALID_FILE_PATTERN,
batch_size,
@@ -242,7 +242,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
_VALID_FILE_PATTERN,
batch_size,
@@ -276,7 +276,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 1234
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as sess:
+ with ops.Graph().as_default() as g, self.session(graph=g) as sess:
inputs = graph_io.read_batch_examples(
[_VALID_FILE_PATTERN, _VALID_FILE_PATTERN_2],
batch_size,
@@ -325,7 +325,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filename,
batch_size,
@@ -374,7 +374,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, result = graph_io.read_keyed_batch_features(
filename,
batch_size,
@@ -429,7 +429,7 @@ class GraphIOTest(test.TestCase):
features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)}
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
result = graph_io.read_batch_features(
filename,
batch_size,
@@ -475,7 +475,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
filenames,
batch_size,
@@ -519,7 +519,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples_shared_queue(
filenames,
batch_size,
@@ -640,7 +640,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 10
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
inputs = graph_io.read_batch_examples(
[filename],
batch_size,
@@ -672,7 +672,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io.read_keyed_batch_examples(
filename,
batch_size,
@@ -714,7 +714,7 @@ class GraphIOTest(test.TestCase):
queue_capacity = 5
name = "my_batch"
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
dtypes = {"age": parsing_ops.FixedLenFeature([1], dtypes_lib.int64)}
parse_fn = lambda example: parsing_ops.parse_single_example( # pylint: disable=g-long-lambda
parsing_ops.decode_json_example(example), dtypes)
@@ -773,7 +773,7 @@ class GraphIOTest(test.TestCase):
examples = parsing_ops.parse_example(serialized, features)
return math_ops.less(examples["age"], 2)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
keys, inputs = graph_io._read_keyed_batch_examples_helper(
filename,
batch_size,
@@ -812,7 +812,7 @@ class GraphIOTest(test.TestCase):
coord.join(threads)
def test_queue_parsed_features_single_tensor(self):
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features = {"test": constant_op.constant([1, 2, 3])}
_, queued_features = graph_io.queue_parsed_features(features)
coord = coordinator.Coordinator()
@@ -833,7 +833,7 @@ class GraphIOTest(test.TestCase):
_, queued_feature = graph_io.read_keyed_batch_features_shared_queue(
_VALID_FILE_PATTERN, batch_size, feature, reader)
- with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
+ with ops.Graph().as_default() as g, self.session(graph=g) as session:
features_result = graph_io.read_batch_features(
_VALID_FILE_PATTERN, batch_size, feature, reader)
session.run(variables.local_variables_initializer())
diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py
index ff1da32c21..83e48a36e7 100644
--- a/tensorflow/contrib/learn/python/learn/monitors_test.py
+++ b/tensorflow/contrib/learn/python/learn/monitors_test.py
@@ -127,12 +127,12 @@ class MonitorsTest(test.TestCase):
monitor.end()
def test_base_monitor(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(learn.monitors.BaseMonitor())
def test_every_0(self):
monitor = _MyEveryN(every_n_steps=0, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(30))
self.assertAllEqual(expected_steps, monitor.steps_begun)
@@ -141,7 +141,7 @@ class MonitorsTest(test.TestCase):
def test_every_1(self):
monitor = _MyEveryN(every_n_steps=1, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(1, 30))
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -150,7 +150,7 @@ class MonitorsTest(test.TestCase):
def test_every_2(self):
monitor = _MyEveryN(every_n_steps=2, first_n_steps=-1)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = list(range(2, 29, 2)) + [29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -159,7 +159,7 @@ class MonitorsTest(test.TestCase):
def test_every_8(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10)
expected_steps = [0, 1, 2, 10, 18, 26, 29]
self.assertEqual(expected_steps, monitor.steps_begun)
@@ -168,7 +168,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_no_max_steps(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(
monitor, num_epochs=3, num_steps_per_epoch=10, pass_max_steps=False)
begin_end_steps = [0, 1, 2, 10, 18, 26]
@@ -179,7 +179,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_begin(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_begin(step)
@@ -192,7 +192,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_recovered_after_step_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
for step in [8, 16]:
monitor.step_begin(step)
monitor.step_end(step, output=None)
@@ -207,7 +207,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_at_the_end(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -224,7 +224,7 @@ class MonitorsTest(test.TestCase):
def test_every_8_call_post_step_should_not_be_called_twice(self):
monitor = _MyEveryN(every_n_steps=8)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin()
for step in [8, 16]:
monitor.step_begin(step)
@@ -240,13 +240,13 @@ class MonitorsTest(test.TestCase):
self.assertEqual([8, 16], monitor.post_steps)
def test_print(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
t = constant_op.constant(42.0, name='foo')
self._run_monitor(learn.monitors.PrintTensor(tensor_names=[t.name]))
self.assertRegexpMatches(str(self.logged_message), t.name)
def test_logging_trainable(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(constant_op.constant(42.0), name='foo')
var.initializer.run()
cof = constant_op.constant(1.0)
@@ -258,7 +258,7 @@ class MonitorsTest(test.TestCase):
self.assertRegexpMatches(str(self.logged_message), var.name)
def test_summary_saver(self):
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
var = variables.Variable(0.0)
@@ -312,7 +312,7 @@ class MonitorsTest(test.TestCase):
monitor = learn.monitors.ValidationMonitor(
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'set_estimator'):
self._run_monitor(monitor)
@@ -330,7 +330,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
mock_latest_checkpoint.assert_called_with(model_dir)
@@ -351,7 +351,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
self._run_monitor(monitor)
self._assert_validation_monitor(monitor)
@@ -370,7 +370,7 @@ class MonitorsTest(test.TestCase):
x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=1)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
with self.assertRaisesRegexp(ValueError, 'missing from outputs'):
self._run_monitor(monitor, num_epochs=1, num_steps_per_epoch=1)
@@ -392,7 +392,7 @@ class MonitorsTest(test.TestCase):
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -477,7 +477,7 @@ class MonitorsTest(test.TestCase):
every_n_steps=0, early_stopping_rounds=2)
self._assert_validation_monitor(monitor)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
self.assertEqual(0, estimator.evaluate.call_count)
@@ -509,7 +509,7 @@ class MonitorsTest(test.TestCase):
metrics=constant_op.constant(2.0),
every_n_steps=0, early_stopping_rounds=2)
monitor.set_estimator(estimator)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
monitor.begin(max_steps=100)
monitor.epoch_begin(epoch=0)
@@ -525,7 +525,7 @@ class MonitorsTest(test.TestCase):
def test_graph_dump(self):
monitor0 = learn.monitors.GraphDump()
monitor1 = learn.monitors.GraphDump()
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
const_var = variables.Variable(42.0, name='my_const')
counter_var = variables.Variable(0.0, name='my_counter')
assign_add = state_ops.assign_add(counter_var, 1.0, name='my_assign_add')
@@ -568,7 +568,7 @@ class MonitorsTest(test.TestCase):
def test_capture_variable(self):
monitor = learn.monitors.CaptureVariable(
var_name='my_assign_add:0', every_n=8, first_n=2)
- with ops.Graph().as_default() as g, self.test_session(g):
+ with ops.Graph().as_default() as g, self.session(g):
var = variables.Variable(0.0, name='my_var')
var.initializer.run()
state_ops.assign_add(var, 1.0, name='my_assign_add')