diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-07-15 21:34:58 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-15 22:47:49 -0700 |
commit | 2c4cba87fa0c8b3f003fb84544d2c68140583a0e (patch) | |
tree | c830afda8d1061c456e3b92207dca0a3b7aa4f40 | |
parent | 88659f2ce750c913f8911de1474a19e8fa9d9d82 (diff) |
Automated rollback of change 127590512
Change: 127608996
-rw-r--r-- | tensorflow/contrib/session_bundle/BUILD | 43 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/constants.py | 35 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/exporter.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/exporter_test.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/session_bundle.py | 118 | ||||
-rw-r--r-- | tensorflow/contrib/session_bundle/session_bundle_test.py | 75 | ||||
-rw-r--r-- | tensorflow/python/platform/googletest.py | 14 | ||||
-rw-r--r-- | tensorflow/python/platform/test.py | 13 |
8 files changed, 37 insertions, 323 deletions
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 47e4314c4f..3862963f23 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -23,18 +23,13 @@ filegroup( ), ) -py_library( - name = "constants", - srcs = ["constants.py"], - srcs_version = "PY2AND3", -) +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") py_library( name = "exporter", srcs = ["exporter.py"], srcs_version = "PY2AND3", deps = [ - ":constants", ":gc", ":manifest_proto_py", "//tensorflow/python:framework", @@ -50,7 +45,6 @@ py_test( srcs_version = "PY2AND3", visibility = ["//visibility:private"], deps = [ - ":constants", ":exporter", ":gc", ":manifest_proto_py", @@ -151,39 +145,6 @@ cc_test( ], ) -py_library( - name = "session_bundle_py", - srcs = ["session_bundle.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":constants", - ":exporter", - ":manifest_proto_py", - "//tensorflow:tensorflow_py", - "//tensorflow/core:protos_all_py", - ], -) - -py_test( - name = "session_bundle_py_test", - size = "small", - srcs = [ - "session_bundle_test.py", - ], - data = [ - "//tensorflow/contrib/session_bundle/example:half_plus_two", - ], - main = "session_bundle_test.py", - srcs_version = "PY2AND3", - deps = [ - ":manifest_proto_py", - ":session_bundle_py", - "//tensorflow:tensorflow_py", - "//tensorflow/core:protos_all_py", - ], -) - cc_library( name = "signature", srcs = ["signature.cc"], @@ -246,8 +207,6 @@ cc_library( ], ) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") - tf_proto_library( name = "manifest_proto", srcs = ["manifest.proto"], diff --git a/tensorflow/contrib/session_bundle/constants.py b/tensorflow/contrib/session_bundle/constants.py deleted file mode 100644 index d05032d849..0000000000 --- a/tensorflow/contrib/session_bundle/constants.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Constants for export/import. - -See: go/tf-exporter for these constants and directory structure. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -VERSION_FORMAT_SPECIFIER = "%08d" -ASSETS_DIRECTORY = "assets" -EXPORT_BASE_NAME = "export" -EXPORT_SUFFIX_NAME = "meta" -META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME -VARIABLES_FILENAME = EXPORT_BASE_NAME -VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????" -INIT_OP_KEY = "serving_init_op" -SIGNATURES_KEY = "serving_signatures" -ASSETS_KEY = "serving_assets" -GRAPH_KEY = "serving_graph" diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py index 4709b73114..ac3bd8890b 100644 --- a/tensorflow/contrib/session_bundle/exporter.py +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -27,7 +27,6 @@ import six from google.protobuf.any_pb2 import Any -from tensorflow.contrib.session_bundle import constants from tensorflow.contrib.session_bundle import gc from tensorflow.contrib.session_bundle import manifest_pb2 from tensorflow.core.framework import graph_pb2 @@ -38,6 +37,19 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.util import compat +# See: go/tf-exporter for these constants and directory structure. +VERSION_FORMAT_SPECIFIER = "%08d" +ASSETS_DIRECTORY = "assets" +EXPORT_BASE_NAME = "export" +EXPORT_SUFFIX_NAME = "meta" +META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME +VARIABLES_FILENAME = EXPORT_BASE_NAME +VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????" +INIT_OP_KEY = "serving_init_op" +SIGNATURES_KEY = "serving_signatures" +ASSETS_KEY = "serving_assets" +GRAPH_KEY = "serving_graph" + def gfile_copy_callback(files_to_copy, export_dir_path): """Callback to copy files using `gfile.Copy` to an export directory. @@ -188,12 +200,12 @@ class Exporter(object): node.device = "" graph_any_buf = Any() graph_any_buf.Pack(copy) - ops.add_to_collection(constants.GRAPH_KEY, graph_any_buf) + ops.add_to_collection(GRAPH_KEY, graph_any_buf) if init_op: if not isinstance(init_op, ops.Operation): raise TypeError("init_op needs to be an Operation: %s" % init_op) - ops.add_to_collection(constants.INIT_OP_KEY, init_op) + ops.add_to_collection(INIT_OP_KEY, init_op) signatures_proto = manifest_pb2.Signatures() if default_graph_signature: @@ -202,7 +214,7 @@ class Exporter(object): signatures_proto.named_signatures[signature_name].CopyFrom(signature) signatures_any_buf = Any() signatures_any_buf.Pack(signatures_proto) - ops.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf) + ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf) for filename, tensor in assets: asset = manifest_pb2.AssetFile() @@ -210,7 +222,7 @@ class Exporter(object): asset.tensor_binding.tensor_name = tensor.name asset_any_buf = Any() asset_any_buf.Pack(asset) - ops.add_to_collection(constants.ASSETS_KEY, asset_any_buf) + ops.add_to_collection(ASSETS_KEY, asset_any_buf) self._assets_callback = assets_callback @@ -247,7 +259,7 @@ class Exporter(object): global_step = training_util.global_step(sess, global_step_tensor) export_dir = os.path.join( compat.as_bytes(export_dir_base), - compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step)) + compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step)) # Prevent overwriting on existing exports which could lead to bad/corrupt # storage and loading of models. This is an important check that must be @@ -264,14 +276,13 @@ class Exporter(object): self._saver.save(sess, os.path.join( compat.as_text(tmp_export_dir), - compat.as_text(constants.EXPORT_BASE_NAME)), - meta_graph_suffix=constants.EXPORT_SUFFIX_NAME) + compat.as_text(EXPORT_BASE_NAME)), + meta_graph_suffix=EXPORT_SUFFIX_NAME) # Run the asset callback. if self._assets_callback and self._assets_to_copy: assets_dir = os.path.join( - compat.as_bytes(tmp_export_dir), - compat.as_bytes(constants.ASSETS_DIRECTORY)) + compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY)) gfile.MakeDirs(assets_dir) self._assets_callback(self._assets_to_copy, assets_dir) diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py index bed1c8ff45..8b54933fc9 100644 --- a/tensorflow/contrib/session_bundle/exporter_test.py +++ b/tensorflow/contrib/session_bundle/exporter_test.py @@ -23,7 +23,6 @@ import os.path import tensorflow as tf -from tensorflow.contrib.session_bundle import constants from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.contrib.session_bundle import manifest_pb2 @@ -114,14 +113,14 @@ class SaveRestoreShardedTest(tf.test.TestCase): target="", config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: save = tf.train.import_meta_graph( - os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER % - global_step, constants.META_GRAPH_DEF_FILENAME)) + os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.META_GRAPH_DEF_FILENAME)) self.assertIsNotNone(save) meta_graph_def = save.export_meta_graph() collection_def = meta_graph_def.collection_def # Validate custom graph_def. - graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value + graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value self.assertEquals(len(graph_def_any), 1) graph_def = tf.GraphDef() graph_def_any[0].Unpack(graph_def) @@ -131,12 +130,12 @@ class SaveRestoreShardedTest(tf.test.TestCase): self.assertProtoEquals(compare_def, graph_def) # Validate init_op. - init_ops = collection_def[constants.INIT_OP_KEY].node_list.value + init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value self.assertEquals(len(init_ops), 1) self.assertEquals(init_ops[0], "init_op") # Validate signatures. - signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value + signatures_any = collection_def[exporter.SIGNATURES_KEY].any_list.value self.assertEquals(len(signatures_any), 1) signatures = manifest_pb2.Signatures() signatures_any[0].Unpack(signatures) @@ -152,21 +151,21 @@ class SaveRestoreShardedTest(tf.test.TestCase): self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") # Validate the assets. - assets_any = collection_def[constants.ASSETS_KEY].any_list.value + assets_any = collection_def[exporter.ASSETS_KEY].any_list.value self.assertEquals(len(assets_any), 1) asset = manifest_pb2.AssetFile() assets_any[0].Unpack(asset) assets_path = os.path.join(export_path, - constants.VERSION_FORMAT_SPECIFIER % - global_step, constants.ASSETS_DIRECTORY, + exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.ASSETS_DIRECTORY, "hello42.txt") asset_contents = gfile.GFile(assets_path).read() self.assertEqual(asset_contents, "your data here") self.assertEquals("hello42.txt", asset.filename) self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) ignored_asset_path = os.path.join(export_path, - constants.VERSION_FORMAT_SPECIFIER % - global_step, constants.ASSETS_DIRECTORY, + exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.ASSETS_DIRECTORY, "ignored.txt") self.assertFalse(gfile.Exists(ignored_asset_path)) @@ -174,16 +173,16 @@ class SaveRestoreShardedTest(tf.test.TestCase): if sharded: save.restore(sess, os.path.join( - export_path, constants.VERSION_FORMAT_SPECIFIER % - global_step, constants.VARIABLES_FILENAME_PATTERN)) + export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.VARIABLES_FILENAME_PATTERN)) else: save.restore(sess, os.path.join( - export_path, constants.VERSION_FORMAT_SPECIFIER % - global_step, constants.VARIABLES_FILENAME)) + export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.VARIABLES_FILENAME)) self.assertEqual(10, tf.get_collection("v")[0].eval()) self.assertEqual(20, tf.get_collection("v")[1].eval()) - tf.get_collection(constants.INIT_OP_KEY)[0].run() + tf.get_collection(exporter.INIT_OP_KEY)[0].run() self.assertEqual(30, tf.get_collection("v")[2].eval()) def testDuplicateExportRaisesError(self): diff --git a/tensorflow/contrib/session_bundle/session_bundle.py b/tensorflow/contrib/session_bundle/session_bundle.py deleted file mode 100644 index 1479db5712..0000000000 --- a/tensorflow/contrib/session_bundle/session_bundle.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Importer for an exported TensorFlow model. - -This module provides a function to create a SessionBundle containing both the -Session and MetaGraph. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import tensorflow as tf - -from tensorflow.contrib.session_bundle import constants -from tensorflow.contrib.session_bundle import manifest_pb2 -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.lib.io import file_io -from tensorflow.python.platform import gfile - - -def LoadSessionBundleFromPath(export_dir, target="", config=None): - """Load session bundle from the given path. - - The function reads input from the export_dir, constructs the graph data to the - default graph and restores the parameters for the session created. - - Args: - export_dir: the directory that contains files exported by exporter. - target: The execution engine to connect to. See target in tf.Session() - config: A ConfigProto proto with configuration options. See config in - tf.Session() - - Returns: - session: a tensorflow session created from the variable files. - meta_graph: a meta graph proto saved in the exporter directory. - - Raises: - RuntimeError: if the required files are missing or contain unrecognizable - fields, i.e. the exported model is invalid. - """ - meta_graph_filename = os.path.join(export_dir, - constants.META_GRAPH_DEF_FILENAME) - if not gfile.Exists(meta_graph_filename): - raise RuntimeError("Expected meta graph file missing %s" % - meta_graph_filename) - variables_filename = os.path.join(export_dir, - constants.VARIABLES_FILENAME) - if not gfile.Exists(variables_filename): - variables_filename = os.path.join( - export_dir, constants.VARIABLES_FILENAME_PATTERN) - if not gfile.Glob(variables_filename): - raise RuntimeError("Expected variables file missing %s" % - variables_filename) - assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY) - - # Reads meta graph file. - meta_graph_def = meta_graph_pb2.MetaGraphDef() - meta_graph_def.ParseFromString(file_io.read_file_to_string( - meta_graph_filename)) - - collection_def = meta_graph_def.collection_def - graph_def = tf.GraphDef() - if constants.GRAPH_KEY in collection_def: - # Use serving graph_def in MetaGraphDef collection_def if exists - graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value - if len(graph_def_any) != 1: - raise RuntimeError( - "Expected exactly one serving GraphDef in : %s" % meta_graph_def) - else: - graph_def_any[0].Unpack(graph_def) - # Replace the graph def in meta graph proto. - meta_graph_def.graph_def.CopyFrom(graph_def) - - tf.reset_default_graph() - sess = tf.Session(target, graph=None, config=config) - # Import the graph. - saver = tf.train.import_meta_graph(meta_graph_def) - # Restore the session. - saver.restore(sess, variables_filename) - - init_op_tensor = None - if constants.INIT_OP_KEY in collection_def: - init_ops = collection_def[constants.INIT_OP_KEY].node_list.value - if len(init_ops) != 1: - raise RuntimeError( - "Expected exactly one serving init op in : %s" % meta_graph_def) - init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0] - - # Create asset input tensor list. - asset_tensor_dict = {} - if constants.ASSETS_KEY in collection_def: - assets_any = collection_def[constants.ASSETS_KEY].any_list.value - for asset in assets_any: - asset_pb = manifest_pb2.AssetFile() - asset.Unpack(asset_pb) - asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join( - assets_dir, asset_pb.filename) - - if init_op_tensor: - # Run the init op. - sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict) - - return sess, meta_graph_def diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.py b/tensorflow/contrib/session_bundle/session_bundle_test.py deleted file mode 100644 index a9e157eb19..0000000000 --- a/tensorflow/contrib/session_bundle/session_bundle_test.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2016 Google Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for session_bundle.py.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import numpy as np -import tensorflow as tf -from tensorflow.contrib.session_bundle import constants -from tensorflow.contrib.session_bundle import manifest_pb2 -from tensorflow.contrib.session_bundle import session_bundle -from tensorflow.python.util import compat - - -class SessionBundleLoadTest(tf.test.TestCase): - - def testBasic(self): - base_path = tf.test.test_src_dir_path( - "contrib/session_bundle/example/half_plus_two/00000123") - tf.reset_default_graph() - sess, meta_graph_def = session_bundle.LoadSessionBundleFromPath( - base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2})) - - self.assertTrue(sess) - asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY) - with sess.as_default(): - path1, path2 = sess.run(["filename1:0", "filename2:0"]) - self.assertEqual( - compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1) - self.assertEqual( - compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2) - - collection_def = meta_graph_def.collection_def - - signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value - self.assertEquals(len(signatures_any), 1) - - signatures = manifest_pb2.Signatures() - signatures_any[0].Unpack(signatures) - default_signature = signatures.default_signature - input_name = default_signature.regression_signature.input.tensor_name - output_name = default_signature.regression_signature.output.tensor_name - y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])}) - # The operation is y = 0.5 * x + 2 - self.assertEqual(y[0][0], 2) - self.assertEqual(y[0][1], 2.5) - self.assertEqual(y[0][2], 3) - self.assertEqual(y[0][3], 3.5) - - def testBadPath(self): - base_path = tf.test.test_src_dir_path("/no/such/a/dir") - tf.reset_default_graph() - with self.assertRaises(RuntimeError) as cm: - _, _ = session_bundle.LoadSessionBundleFromPath( - base_path, target="local", - config=tf.ConfigProto(device_count={"CPU": 2})) - self.assertTrue("Expected meta graph file missing" in str(cm.exception)) - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index a41ba73a32..04909a40a9 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -94,20 +94,6 @@ def GetTempDir(): return temp_dir -def test_src_dir_path(relative_path): - """Creates an absolute test srcdir path given a relative path. - - Args: - relative_path: a path relative to tensorflow root. - e.g. "contrib/session_bundle/example". - - Returns: - An absolute path to the linked in runfiles. - """ - return os.path.join(os.environ['TEST_SRCDIR'], - "org_tensorflow/tensorflow", relative_path) - - def StatefulSessionAvailable(): return False diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index ac700f43c4..25b0fea80f 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -98,19 +98,6 @@ def get_temp_dir(): return googletest.GetTempDir() -def test_src_dir_path(relative_path): - """Creates an absolute test srcdir path given a relative path. - - Args: - relative_path: a path relative to tensorflow root. - e.g. "core/platform". - - Returns: - An absolute path to the linked in runfiles. - """ - return googletest.test_src_dir_path(relative_path) - - def is_built_with_cuda(): """Returns whether TensorFlow was built with CUDA (GPU) support.""" return test_util.IsGoogleCudaEnabled() |