aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn/python/learn/utils/export_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/utils/export_test.py')
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export_test.py40
1 files changed, 22 insertions, 18 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py
index ce1d73256a..95070ada3b 100644
--- a/tensorflow/contrib/learn/python/learn/utils/export_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import gfile
@@ -49,9 +50,8 @@ def _training_input_fn():
class ExportTest(test.TestCase):
-
def _get_default_signature(self, export_meta_filename):
- """Gets the default signature from the export.meta file."""
+ """ Gets the default signature from the export.meta file. """
with session.Session():
save = saver.import_meta_graph(export_meta_filename)
meta_graph_def = save.export_meta_graph()
@@ -68,18 +68,19 @@ class ExportTest(test.TestCase):
self.assertTrue(gfile.Exists(export_dir))
# Only the written checkpoints are exported.
self.assertTrue(
- saver.checkpoint_exists(export_dir + '00000001/export'),
+ saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
'Exported checkpoint expected but not found: %s' %
- (export_dir + '00000001/export'))
+ os.path.join(export_dir, '00000001', 'export'))
self.assertTrue(
- saver.checkpoint_exists(export_dir + '00000010/export'),
+ saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
'Exported checkpoint expected but not found: %s' %
- (export_dir + '00000010/export'))
+ os.path.join(export_dir, '00000010', 'export'))
self.assertEquals(
six.b(os.path.join(export_dir, '00000010')),
export_monitor.last_export_dir)
# Validate the signature
- signature = self._get_default_signature(export_dir + '00000010/export.meta')
+ signature = self._get_default_signature(
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField(expected_signature))
def testExportMonitor_EstimatorProvidesSignature(self):
@@ -88,7 +89,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features)
- export_dir = tempfile.mkdtemp() + 'export/'
+ export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
regressor.fit(x, y, steps=10, monitors=[export_monitor])
@@ -99,7 +100,7 @@ class ExportTest(test.TestCase):
x = np.random.rand(1000)
y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)]
- export_dir = tempfile.mkdtemp() + 'export/'
+ export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
export_dir=export_dir,
@@ -122,7 +123,7 @@ class ExportTest(test.TestCase):
input_feature_key = 'my_example_key'
monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
- export_dir=tempfile.mkdtemp() + 'export/',
+ export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn,
input_feature_key=input_feature_key,
exports_to_keep=2,
@@ -140,7 +141,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
- export_dir=tempfile.mkdtemp() + 'export/',
+ export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn,
input_feature_key=input_feature_key,
exports_to_keep=2,
@@ -165,7 +166,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
- export_dir=tempfile.mkdtemp() + 'export/',
+ export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn,
input_feature_key=input_feature_key,
exports_to_keep=2,
@@ -187,7 +188,7 @@ class ExportTest(test.TestCase):
monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
- export_dir=tempfile.mkdtemp() + 'export/',
+ export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
input_fn=_serving_input_fn,
input_feature_key=input_feature_key,
exports_to_keep=2,
@@ -210,7 +211,7 @@ class ExportTest(test.TestCase):
shape=(1,), minval=0.0, maxval=1000.0)
}, None
- export_dir = tempfile.mkdtemp() + 'export/'
+ export_dir = os.path.join(tempfile.mkdtemp(), 'export')
monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
export_dir=export_dir,
@@ -235,7 +236,7 @@ class ExportTest(test.TestCase):
y = 2 * x + 3
cont_features = [feature_column.real_valued_column('', dimension=1)]
regressor = learn.LinearRegressor(feature_columns=cont_features)
- export_dir = tempfile.mkdtemp() + 'export/'
+ export_dir = os.path.join(tempfile.mkdtemp(), 'export')
export_monitor = learn.monitors.ExportMonitor(
every_n_steps=1,
export_dir=export_dir,
@@ -244,10 +245,13 @@ class ExportTest(test.TestCase):
regressor.fit(x, y, steps=10, monitors=[export_monitor])
self.assertTrue(gfile.Exists(export_dir))
- self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export'))
- self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export'))
+ with self.assertRaises(errors.NotFoundError):
+ saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
+ self.assertTrue(
+ saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
# Validate the signature
- signature = self._get_default_signature(export_dir + '00000010/export.meta')
+ signature = self._get_default_signature(
+ os.path.join(export_dir, '00000010', 'export.meta'))
self.assertTrue(signature.HasField('regression_signature'))