diff options
Diffstat (limited to 'tensorflow/contrib/learn/python/learn/utils/export_test.py')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/export_test.py | 40 |
1 files changed, 22 insertions, 18 deletions
diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py index ce1d73256a..95070ada3b 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py @@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import manifest_pb2 from tensorflow.python.client import session from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import gfile @@ -49,9 +50,8 @@ def _training_input_fn(): class ExportTest(test.TestCase): - def _get_default_signature(self, export_meta_filename): - """Gets the default signature from the export.meta file.""" + """ Gets the default signature from the export.meta file. """ with session.Session(): save = saver.import_meta_graph(export_meta_filename) meta_graph_def = save.export_meta_graph() @@ -68,18 +68,19 @@ class ExportTest(test.TestCase): self.assertTrue(gfile.Exists(export_dir)) # Only the written checkpoints are exported. self.assertTrue( - saver.checkpoint_exists(export_dir + '00000001/export'), + saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')), 'Exported checkpoint expected but not found: %s' % - (export_dir + '00000001/export')) + os.path.join(export_dir, '00000001', 'export')) self.assertTrue( - saver.checkpoint_exists(export_dir + '00000010/export'), + saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')), 'Exported checkpoint expected but not found: %s' % - (export_dir + '00000010/export')) + os.path.join(export_dir, '00000010', 'export')) self.assertEquals( six.b(os.path.join(export_dir, '00000010')), export_monitor.last_export_dir) # Validate the signature - signature = self._get_default_signature(export_dir + '00000010/export.meta') + signature = self._get_default_signature( + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField(expected_signature)) def testExportMonitor_EstimatorProvidesSignature(self): @@ -88,7 +89,7 @@ class ExportTest(test.TestCase): y = 2 * x + 3 cont_features = [feature_column.real_valued_column('', dimension=1)] regressor = learn.LinearRegressor(feature_columns=cont_features) - export_dir = tempfile.mkdtemp() + 'export/' + export_dir = os.path.join(tempfile.mkdtemp(), 'export') export_monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, exports_to_keep=2) regressor.fit(x, y, steps=10, monitors=[export_monitor]) @@ -99,7 +100,7 @@ class ExportTest(test.TestCase): x = np.random.rand(1000) y = 2 * x + 3 cont_features = [feature_column.real_valued_column('', dimension=1)] - export_dir = tempfile.mkdtemp() + 'export/' + export_dir = os.path.join(tempfile.mkdtemp(), 'export') export_monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, @@ -122,7 +123,7 @@ class ExportTest(test.TestCase): input_feature_key = 'my_example_key' monitor = learn.monitors.ExportMonitor( every_n_steps=1, - export_dir=tempfile.mkdtemp() + 'export/', + export_dir=os.path.join(tempfile.mkdtemp(), 'export'), input_fn=_serving_input_fn, input_feature_key=input_feature_key, exports_to_keep=2, @@ -140,7 +141,7 @@ class ExportTest(test.TestCase): monitor = learn.monitors.ExportMonitor( every_n_steps=1, - export_dir=tempfile.mkdtemp() + 'export/', + export_dir=os.path.join(tempfile.mkdtemp(), 'export'), input_fn=_serving_input_fn, input_feature_key=input_feature_key, exports_to_keep=2, @@ -165,7 +166,7 @@ class ExportTest(test.TestCase): monitor = learn.monitors.ExportMonitor( every_n_steps=1, - export_dir=tempfile.mkdtemp() + 'export/', + export_dir=os.path.join(tempfile.mkdtemp(), 'export'), input_fn=_serving_input_fn, input_feature_key=input_feature_key, exports_to_keep=2, @@ -187,7 +188,7 @@ class ExportTest(test.TestCase): monitor = learn.monitors.ExportMonitor( every_n_steps=1, - export_dir=tempfile.mkdtemp() + 'export/', + export_dir=os.path.join(tempfile.mkdtemp(), 'export'), input_fn=_serving_input_fn, input_feature_key=input_feature_key, exports_to_keep=2, @@ -210,7 +211,7 @@ class ExportTest(test.TestCase): shape=(1,), minval=0.0, maxval=1000.0) }, None - export_dir = tempfile.mkdtemp() + 'export/' + export_dir = os.path.join(tempfile.mkdtemp(), 'export') monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, @@ -235,7 +236,7 @@ class ExportTest(test.TestCase): y = 2 * x + 3 cont_features = [feature_column.real_valued_column('', dimension=1)] regressor = learn.LinearRegressor(feature_columns=cont_features) - export_dir = tempfile.mkdtemp() + 'export/' + export_dir = os.path.join(tempfile.mkdtemp(), 'export') export_monitor = learn.monitors.ExportMonitor( every_n_steps=1, export_dir=export_dir, @@ -244,10 +245,13 @@ class ExportTest(test.TestCase): regressor.fit(x, y, steps=10, monitors=[export_monitor]) self.assertTrue(gfile.Exists(export_dir)) - self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export')) - self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export')) + with self.assertRaises(errors.NotFoundError): + saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export')) + self.assertTrue( + saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) # Validate the signature - signature = self._get_default_signature(export_dir + '00000010/export.meta') + signature = self._get_default_signature( + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField('regression_signature')) |