diff options
author | 2016-06-27 19:17:25 -0800 | |
---|---|---|
committer | 2016-06-27 20:32:26 -0700 | |
commit | 29f7de839dea116674bd91d155fec8fdab2c6c52 (patch) | |
tree | edaec162e68d694114cedbecc054e44f7a2c786e /tensorflow/contrib | |
parent | fdb5b1c18a128a877b82b5a290d73279c8f7bea7 (diff) |
Adding utilities to export Estimators and ExportMonitor for continues export.
Clean up session_bundle/exporter to be importable.
Change: 126029457
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/learn/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/monitors.py | 33 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/export.py | 134 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/utils/export_test.py | 50 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/exporter.py | 32 |
7 files changed, 253 insertions, 15 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 6515a5aa1e..31b0a5ee29 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -18,6 +18,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/learn/python/learn/datasets", + "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/python:framework", ], ) @@ -575,6 +576,19 @@ py_test( ) py_test( + name = "export_test", + size = "small", + srcs = ["python/learn/utils/export_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "stability_test", size = "small", srcs = ["python/learn/tests/stability_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index fe1f7110df..f1a02bfddd 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -22,6 +22,7 @@ from __future__ import print_function import numpy as np import six +from tensorflow.contrib.learn.python.learn.utils import export from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver @@ -520,3 +521,35 @@ class GraphDump(BaseMonitor): else: matched.append(key) return matched, non_matched + + +class ExportMonitor(EveryN): + """Monitor that exports Estimator every N steps.""" + + def __init__(self, every_n_steps, export_dir, exports_to_keep=5): + """Initializes ExportMonitor. + + Args: + every_n_steps: Run monitor every N steps. + export_dir: str, fodler to export. + exports_to_keep: int, number of exports to keep. + """ + super(ExportMonitor, self).__init__(every_n_steps=every_n_steps) + self.export_dir = export_dir + self.exports_to_keep = exports_to_keep + + def every_n_step_end(self, step, outputs): + super(ExportMonitor, self).every_n_step_end(step, outputs) + try: + export.export_estimator(self._estimator, self.export_dir, + exports_to_keep=self.exports_to_keep) + except RuntimeError: + # Currently we are not syncronized with saving checkpoints, which leads to + # runtime errors when we are calling export on the same global step. + logging.info("Skipping exporting for the same step. " + "Consider exporting less frequently.") + + def end(self): + super(ExportMonitor, self).end() + export.export_estimator(self._estimator, self.export_dir, + exports_to_keep=self.exports_to_keep) diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index d8567eebbc..149a4b9772 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -20,3 +20,4 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.learn.python.learn.utils import checkpoints +from tensorflow.contrib.learn.python.learn.utils.export import export_estimator diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py new file mode 100644 index 0000000000..27e3c79146 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -0,0 +1,134 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Export utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.contrib.session_bundle import exporter +from tensorflow.contrib.session_bundle import gc +from tensorflow.python.client import session as tf_session +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import saver as tf_saver + + +def _get_first_op_from_collection(collection_name): + """Get first element from the collection.""" + elements = ops.get_collection(collection_name) + if elements is not None: + if elements: + return elements[0] + return None + + +def _get_saver(): + """Lazy init and return saver.""" + saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS) + if saver is not None: + if saver: + saver = saver[0] + else: + saver = None + if saver is None and variables.all_variables(): + saver = tf_saver.Saver() + ops.add_to_collection(ops.GraphKeys.SAVERS, saver) + return saver + + +def _export_graph(graph, saver, checkpoint_path, export_dir, + default_graph_signature, named_graph_signatures, + exports_to_keep): + """Exports graph via session_bundle, by creating a Session.""" + with graph.as_default(): + with tf_session.Session('') as session: + session.run(variables.initialize_local_variables()) + saver.restore(session, checkpoint_path) + export = exporter.Exporter(saver) + export.init(session.graph.as_graph_def(), + default_graph_signature=default_graph_signature, + named_graph_signatures=named_graph_signatures) + export.export(export_dir, contrib_variables.get_global_step(), session, + exports_to_keep=exports_to_keep) + + +def _generic_signature_fn(examples, unused_features, predictions): + """Creates generic signature from given examples and predictions. + + This is neeed for backward compatibility with default behaviour of + export_estimator. + + Args: + examples: `Tensor`. + unused_features: `dict` of `Tensor`s. + predictions: `dict` of `Tensor`s. + + Returns: + Tuple of default signature and named signature. + """ + tensors = {'inputs': examples} + if not isinstance(predictions, dict): + predictions = {'outputs': predictions} + tensors.update(predictions) + default_signature = exporter.generic_signature(tensors) + return default_signature, {} + + +# pylint: disable=protected-access +def _default_input_fn(estimator, examples): + """Creates default input parsing using Estimator's feature signatures.""" + return estimator._get_feature_ops_from_example(examples) + + +def export_estimator(estimator, export_dir, input_fn=_default_input_fn, + signature_fn=_generic_signature_fn, default_batch_size=1, + exports_to_keep=None): + """Exports inference graph into given dir. + + Args: + estimator: Estimator to export + export_dir: A string containing a directory to write the exported graph + and checkpoints. + input_fn: Function that given `Tensor` of `Example` strings, parses it into + features that are then passed to the model. + signature_fn: Function that given `Tensor` of `Example` strings, + `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions + and returns default and named exporting signautres. + default_batch_size: Default batch size of the `Example` placeholder. + exports_to_keep: Number of exports to keep. + """ + checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir) + with ops.Graph().as_default() as g: + contrib_variables.create_global_step(g) + examples = array_ops.placeholder(dtype=dtypes.string, + shape=[default_batch_size], + name='input_example_tensor') + features = input_fn(estimator, examples) + predictions = estimator._get_predict_ops(features) + default_signature, named_graph_signatures = signature_fn( + examples, features, predictions) + if exports_to_keep is not None: + exports_to_keep = gc.largest_export_versions(exports_to_keep) + _export_graph(g, _get_saver(), checkpoint_path, export_dir, + default_graph_signature=default_signature, + named_graph_signatures=named_graph_signatures, + exports_to_keep=exports_to_keep) +# pylint: enable=protected-access + diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py new file mode 100644 index 0000000000..3db8cbc1a7 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py @@ -0,0 +1,50 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for export tools.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import tempfile + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib import learn + + +class ExportTest(tf.test.TestCase): + + def testExportMonitor(self): + random.seed(42) + x = np.random.rand(1000) + y = 2 * x + 3 + regressor = learn.LinearRegressor() + export_dir = tempfile.mkdtemp() + 'export/' + export_monitor = learn.monitors.ExportMonitor(every_n_steps=1, + export_dir=export_dir, + exports_to_keep=1) + regressor.fit(x, y, steps=10, + monitors=[export_monitor]) + self.assertTrue(tf.gfile.Exists(export_dir)) + self.assertFalse(tf.gfile.Exists(export_dir + '00000000/export')) + self.assertTrue(tf.gfile.Exists(export_dir + '00000010/export')) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 18c4225824..b6f7e33aa1 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -32,7 +32,7 @@ py_library( deps = [ ":gc", ":manifest_proto_py", - "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", ], ) @@ -57,7 +57,7 @@ py_library( srcs = ["gc.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", ], ) diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index fc7458db62..ac3bd8890b 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -25,13 +25,15 @@ import os import re import six -import tensorflow as tf from google.protobuf.any_pb2 import Any from tensorflow.contrib.session_bundle import gc from tensorflow.contrib.session_bundle import manifest_pb2 +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.platform import gfile +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.util import compat @@ -62,20 +64,20 @@ def gfile_copy_callback(files_to_copy, export_dir_path): basename in the export directory. export_dir_path: Directory to copy the files to. """ - tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path) + logging.info("Write assest into: %s using gfile_copy.", export_dir_path) gfile.MakeDirs(export_dir_path) for source_filepath, basename in files_to_copy.items(): new_path = os.path.join( compat.as_bytes(export_dir_path), compat.as_bytes(basename)) - tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path) + logging.info("Copying asset %s to path %s.", source_filepath, new_path) if gfile.Exists(new_path): # Guard against being restarted while copying assets, and the file # existing and being in an unknown state. # TODO(b/28676216): Do some file checks before deleting. - tf.logging.info("Removing file %s.", new_path) + logging.info("Removing file %s.", new_path) gfile.Remove(new_path) - tf.gfile.Copy(source_filepath, new_path) + gfile.Copy(source_filepath, new_path) def regression_signature(input_tensor, output_tensor): @@ -188,22 +190,22 @@ class Exporter(object): self._has_init = True if graph_def or clear_devices: - copy = tf.GraphDef() + copy = graph_pb2.GraphDef() if graph_def: copy.CopyFrom(graph_def) else: - copy.CopyFrom(tf.get_default_graph().as_graph_def()) + copy.CopyFrom(ops.get_default_graph().as_graph_def()) if clear_devices: for node in copy.node: node.device = "" graph_any_buf = Any() graph_any_buf.Pack(copy) - tf.add_to_collection(GRAPH_KEY, graph_any_buf) + ops.add_to_collection(GRAPH_KEY, graph_any_buf) if init_op: if not isinstance(init_op, ops.Operation): raise TypeError("init_op needs to be an Operation: %s" % init_op) - tf.add_to_collection(INIT_OP_KEY, init_op) + ops.add_to_collection(INIT_OP_KEY, init_op) signatures_proto = manifest_pb2.Signatures() if default_graph_signature: @@ -212,7 +214,7 @@ class Exporter(object): signatures_proto.named_signatures[signature_name].CopyFrom(signature) signatures_any_buf = Any() signatures_any_buf.Pack(signatures_proto) - tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf) + ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf) for filename, tensor in assets: asset = manifest_pb2.AssetFile() @@ -220,7 +222,7 @@ class Exporter(object): asset.tensor_binding.tensor_name = tensor.name asset_any_buf = Any() asset_any_buf.Pack(asset) - tf.add_to_collection(ASSETS_KEY, asset_any_buf) + ops.add_to_collection(ASSETS_KEY, asset_any_buf) self._assets_callback = assets_callback @@ -250,6 +252,10 @@ class Exporter(object): if not self._has_init: raise RuntimeError("init must be called first") + # Export dir must not end with / or it will break exports to keep. Strip /. + if export_dir_base.endswith("/"): + export_dir_base = export_dir_base[:-1] + global_step = training_util.global_step(sess, global_step_tensor) export_dir = os.path.join( compat.as_bytes(export_dir_base), @@ -299,11 +305,11 @@ class Exporter(object): def _file_path_value(self, path_tensor): """Returns the filepath value stored in constant `path_tensor`.""" - if not isinstance(path_tensor, tf.Tensor): + if not isinstance(path_tensor, ops.Tensor): raise TypeError("tensor is not a Tensor") if path_tensor.op.type != "Const": raise TypeError("Only constants tensor are supported") - if path_tensor.dtype != tf.string: + if path_tensor.dtype != dtypes.string: raise TypeError("File paths should be string") str_value = path_tensor.op.get_attr("value").string_val if len(str_value) != 1: |