aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Illia Polosukhin <ipolosukhin@google.com>2016-06-27 19:17:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-27 20:32:26 -0700
commit29f7de839dea116674bd91d155fec8fdab2c6c52 (patch)
treeedaec162e68d694114cedbecc054e44f7a2c786e /tensorflow/contrib
parentfdb5b1c18a128a877b82b5a290d73279c8f7bea7 (diff)
Adding utilities to export Estimators and ExportMonitor for continues export.
Clean up session_bundle/exporter to be importable. Change: 126029457
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/learn/BUILD14
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py33
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/__init__.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export.py134
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/export_test.py50
-rw-r--r--tensorflow/contrib/session_bundle/BUILD4
-rw-r--r--tensorflow/contrib/session_bundle/exporter.py32
7 files changed, 253 insertions, 15 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 6515a5aa1e..31b0a5ee29 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -18,6 +18,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/learn/python/learn/datasets",
+ "//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/python:framework",
],
)
@@ -575,6 +576,19 @@ py_test(
)
py_test(
+ name = "export_test",
+ size = "small",
+ srcs = ["python/learn/utils/export_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "stability_test",
size = "small",
srcs = ["python/learn/tests/stability_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index fe1f7110df..f1a02bfddd 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import numpy as np
import six
+from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
@@ -520,3 +521,35 @@ class GraphDump(BaseMonitor):
else:
matched.append(key)
return matched, non_matched
+
+
+class ExportMonitor(EveryN):
+ """Monitor that exports Estimator every N steps."""
+
+ def __init__(self, every_n_steps, export_dir, exports_to_keep=5):
+ """Initializes ExportMonitor.
+
+ Args:
+ every_n_steps: Run monitor every N steps.
+ export_dir: str, fodler to export.
+ exports_to_keep: int, number of exports to keep.
+ """
+ super(ExportMonitor, self).__init__(every_n_steps=every_n_steps)
+ self.export_dir = export_dir
+ self.exports_to_keep = exports_to_keep
+
+ def every_n_step_end(self, step, outputs):
+ super(ExportMonitor, self).every_n_step_end(step, outputs)
+ try:
+ export.export_estimator(self._estimator, self.export_dir,
+ exports_to_keep=self.exports_to_keep)
+ except RuntimeError:
+ # Currently we are not syncronized with saving checkpoints, which leads to
+ # runtime errors when we are calling export on the same global step.
+ logging.info("Skipping exporting for the same step. "
+ "Consider exporting less frequently.")
+
+ def end(self):
+ super(ExportMonitor, self).end()
+ export.export_estimator(self._estimator, self.export_dir,
+ exports_to_keep=self.exports_to_keep)
diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py
index d8567eebbc..149a4b9772 100644
--- a/tensorflow/contrib/learn/python/learn/utils/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py
@@ -20,3 +20,4 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.learn.python.learn.utils import checkpoints
+from tensorflow.contrib.learn.python.learn.utils.export import export_estimator
diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py
new file mode 100644
index 0000000000..27e3c79146
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/export.py
@@ -0,0 +1,134 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Export utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
+from tensorflow.contrib.session_bundle import exporter
+from tensorflow.contrib.session_bundle import gc
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import saver as tf_saver
+
+
+def _get_first_op_from_collection(collection_name):
+ """Get first element from the collection."""
+ elements = ops.get_collection(collection_name)
+ if elements is not None:
+ if elements:
+ return elements[0]
+ return None
+
+
+def _get_saver():
+ """Lazy init and return saver."""
+ saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
+ if saver is not None:
+ if saver:
+ saver = saver[0]
+ else:
+ saver = None
+ if saver is None and variables.all_variables():
+ saver = tf_saver.Saver()
+ ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
+ return saver
+
+
+def _export_graph(graph, saver, checkpoint_path, export_dir,
+ default_graph_signature, named_graph_signatures,
+ exports_to_keep):
+ """Exports graph via session_bundle, by creating a Session."""
+ with graph.as_default():
+ with tf_session.Session('') as session:
+ session.run(variables.initialize_local_variables())
+ saver.restore(session, checkpoint_path)
+ export = exporter.Exporter(saver)
+ export.init(session.graph.as_graph_def(),
+ default_graph_signature=default_graph_signature,
+ named_graph_signatures=named_graph_signatures)
+ export.export(export_dir, contrib_variables.get_global_step(), session,
+ exports_to_keep=exports_to_keep)
+
+
+def _generic_signature_fn(examples, unused_features, predictions):
+ """Creates generic signature from given examples and predictions.
+
+ This is neeed for backward compatibility with default behaviour of
+ export_estimator.
+
+ Args:
+ examples: `Tensor`.
+ unused_features: `dict` of `Tensor`s.
+ predictions: `dict` of `Tensor`s.
+
+ Returns:
+ Tuple of default signature and named signature.
+ """
+ tensors = {'inputs': examples}
+ if not isinstance(predictions, dict):
+ predictions = {'outputs': predictions}
+ tensors.update(predictions)
+ default_signature = exporter.generic_signature(tensors)
+ return default_signature, {}
+
+
+# pylint: disable=protected-access
+def _default_input_fn(estimator, examples):
+ """Creates default input parsing using Estimator's feature signatures."""
+ return estimator._get_feature_ops_from_example(examples)
+
+
+def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
+ signature_fn=_generic_signature_fn, default_batch_size=1,
+ exports_to_keep=None):
+ """Exports inference graph into given dir.
+
+ Args:
+ estimator: Estimator to export
+ export_dir: A string containing a directory to write the exported graph
+ and checkpoints.
+ input_fn: Function that given `Tensor` of `Example` strings, parses it into
+ features that are then passed to the model.
+ signature_fn: Function that given `Tensor` of `Example` strings,
+ `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
+ and returns default and named exporting signautres.
+ default_batch_size: Default batch size of the `Example` placeholder.
+ exports_to_keep: Number of exports to keep.
+ """
+ checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
+ with ops.Graph().as_default() as g:
+ contrib_variables.create_global_step(g)
+ examples = array_ops.placeholder(dtype=dtypes.string,
+ shape=[default_batch_size],
+ name='input_example_tensor')
+ features = input_fn(estimator, examples)
+ predictions = estimator._get_predict_ops(features)
+ default_signature, named_graph_signatures = signature_fn(
+ examples, features, predictions)
+ if exports_to_keep is not None:
+ exports_to_keep = gc.largest_export_versions(exports_to_keep)
+ _export_graph(g, _get_saver(), checkpoint_path, export_dir,
+ default_graph_signature=default_signature,
+ named_graph_signatures=named_graph_signatures,
+ exports_to_keep=exports_to_keep)
+# pylint: enable=protected-access
+
diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py
new file mode 100644
index 0000000000..3db8cbc1a7
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py
@@ -0,0 +1,50 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for export tools."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import tempfile
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib import learn
+
+
+class ExportTest(tf.test.TestCase):
+
+ def testExportMonitor(self):
+ random.seed(42)
+ x = np.random.rand(1000)
+ y = 2 * x + 3
+ regressor = learn.LinearRegressor()
+ export_dir = tempfile.mkdtemp() + 'export/'
+ export_monitor = learn.monitors.ExportMonitor(every_n_steps=1,
+ export_dir=export_dir,
+ exports_to_keep=1)
+ regressor.fit(x, y, steps=10,
+ monitors=[export_monitor])
+ self.assertTrue(tf.gfile.Exists(export_dir))
+ self.assertFalse(tf.gfile.Exists(export_dir + '00000000/export'))
+ self.assertTrue(tf.gfile.Exists(export_dir + '00000010/export'))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 18c4225824..b6f7e33aa1 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -32,7 +32,7 @@ py_library(
deps = [
":gc",
":manifest_proto_py",
- "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
],
)
@@ -57,7 +57,7 @@ py_library(
srcs = ["gc.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
],
)
diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py
index fc7458db62..ac3bd8890b 100644
--- a/tensorflow/contrib/session_bundle/exporter.py
+++ b/tensorflow/contrib/session_bundle/exporter.py
@@ -25,13 +25,15 @@ import os
import re
import six
-import tensorflow as tf
from google.protobuf.any_pb2 import Any
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
@@ -62,20 +64,20 @@ def gfile_copy_callback(files_to_copy, export_dir_path):
basename in the export directory.
export_dir_path: Directory to copy the files to.
"""
- tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
+ logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
gfile.MakeDirs(export_dir_path)
for source_filepath, basename in files_to_copy.items():
new_path = os.path.join(
compat.as_bytes(export_dir_path), compat.as_bytes(basename))
- tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
+ logging.info("Copying asset %s to path %s.", source_filepath, new_path)
if gfile.Exists(new_path):
# Guard against being restarted while copying assets, and the file
# existing and being in an unknown state.
# TODO(b/28676216): Do some file checks before deleting.
- tf.logging.info("Removing file %s.", new_path)
+ logging.info("Removing file %s.", new_path)
gfile.Remove(new_path)
- tf.gfile.Copy(source_filepath, new_path)
+ gfile.Copy(source_filepath, new_path)
def regression_signature(input_tensor, output_tensor):
@@ -188,22 +190,22 @@ class Exporter(object):
self._has_init = True
if graph_def or clear_devices:
- copy = tf.GraphDef()
+ copy = graph_pb2.GraphDef()
if graph_def:
copy.CopyFrom(graph_def)
else:
- copy.CopyFrom(tf.get_default_graph().as_graph_def())
+ copy.CopyFrom(ops.get_default_graph().as_graph_def())
if clear_devices:
for node in copy.node:
node.device = ""
graph_any_buf = Any()
graph_any_buf.Pack(copy)
- tf.add_to_collection(GRAPH_KEY, graph_any_buf)
+ ops.add_to_collection(GRAPH_KEY, graph_any_buf)
if init_op:
if not isinstance(init_op, ops.Operation):
raise TypeError("init_op needs to be an Operation: %s" % init_op)
- tf.add_to_collection(INIT_OP_KEY, init_op)
+ ops.add_to_collection(INIT_OP_KEY, init_op)
signatures_proto = manifest_pb2.Signatures()
if default_graph_signature:
@@ -212,7 +214,7 @@ class Exporter(object):
signatures_proto.named_signatures[signature_name].CopyFrom(signature)
signatures_any_buf = Any()
signatures_any_buf.Pack(signatures_proto)
- tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
+ ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
for filename, tensor in assets:
asset = manifest_pb2.AssetFile()
@@ -220,7 +222,7 @@ class Exporter(object):
asset.tensor_binding.tensor_name = tensor.name
asset_any_buf = Any()
asset_any_buf.Pack(asset)
- tf.add_to_collection(ASSETS_KEY, asset_any_buf)
+ ops.add_to_collection(ASSETS_KEY, asset_any_buf)
self._assets_callback = assets_callback
@@ -250,6 +252,10 @@ class Exporter(object):
if not self._has_init:
raise RuntimeError("init must be called first")
+ # Export dir must not end with / or it will break exports to keep. Strip /.
+ if export_dir_base.endswith("/"):
+ export_dir_base = export_dir_base[:-1]
+
global_step = training_util.global_step(sess, global_step_tensor)
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
@@ -299,11 +305,11 @@ class Exporter(object):
def _file_path_value(self, path_tensor):
"""Returns the filepath value stored in constant `path_tensor`."""
- if not isinstance(path_tensor, tf.Tensor):
+ if not isinstance(path_tensor, ops.Tensor):
raise TypeError("tensor is not a Tensor")
if path_tensor.op.type != "Const":
raise TypeError("Only constants tensor are supported")
- if path_tensor.dtype != tf.string:
+ if path_tensor.dtype != dtypes.string:
raise TypeError("File paths should be string")
str_value = path_tensor.op.get_attr("value").string_val
if len(str_value) != 1: