aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-07-15 21:34:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-15 22:47:49 -0700
commit2c4cba87fa0c8b3f003fb84544d2c68140583a0e (patch)
treec830afda8d1061c456e3b92207dca0a3b7aa4f40
parent88659f2ce750c913f8911de1474a19e8fa9d9d82 (diff)
Automated rollback of change 127590512
Change: 127608996
-rw-r--r--tensorflow/contrib/session_bundle/BUILD43
-rw-r--r--tensorflow/contrib/session_bundle/constants.py35
-rw-r--r--tensorflow/contrib/session_bundle/exporter.py31
-rw-r--r--tensorflow/contrib/session_bundle/exporter_test.py31
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.py118
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.py75
-rw-r--r--tensorflow/python/platform/googletest.py14
-rw-r--r--tensorflow/python/platform/test.py13
8 files changed, 37 insertions, 323 deletions
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
index 47e4314c4f..3862963f23 100644
--- a/tensorflow/contrib/session_bundle/BUILD
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -23,18 +23,13 @@ filegroup(
),
)
-py_library(
- name = "constants",
- srcs = ["constants.py"],
- srcs_version = "PY2AND3",
-)
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
py_library(
name = "exporter",
srcs = ["exporter.py"],
srcs_version = "PY2AND3",
deps = [
- ":constants",
":gc",
":manifest_proto_py",
"//tensorflow/python:framework",
@@ -50,7 +45,6 @@ py_test(
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
- ":constants",
":exporter",
":gc",
":manifest_proto_py",
@@ -151,39 +145,6 @@ cc_test(
],
)
-py_library(
- name = "session_bundle_py",
- srcs = ["session_bundle.py"],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
- deps = [
- ":constants",
- ":exporter",
- ":manifest_proto_py",
- "//tensorflow:tensorflow_py",
- "//tensorflow/core:protos_all_py",
- ],
-)
-
-py_test(
- name = "session_bundle_py_test",
- size = "small",
- srcs = [
- "session_bundle_test.py",
- ],
- data = [
- "//tensorflow/contrib/session_bundle/example:half_plus_two",
- ],
- main = "session_bundle_test.py",
- srcs_version = "PY2AND3",
- deps = [
- ":manifest_proto_py",
- ":session_bundle_py",
- "//tensorflow:tensorflow_py",
- "//tensorflow/core:protos_all_py",
- ],
-)
-
cc_library(
name = "signature",
srcs = ["signature.cc"],
@@ -246,8 +207,6 @@ cc_library(
],
)
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
-
tf_proto_library(
name = "manifest_proto",
srcs = ["manifest.proto"],
diff --git a/tensorflow/contrib/session_bundle/constants.py b/tensorflow/contrib/session_bundle/constants.py
deleted file mode 100644
index d05032d849..0000000000
--- a/tensorflow/contrib/session_bundle/constants.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Constants for export/import.
-
-See: go/tf-exporter for these constants and directory structure.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-VERSION_FORMAT_SPECIFIER = "%08d"
-ASSETS_DIRECTORY = "assets"
-EXPORT_BASE_NAME = "export"
-EXPORT_SUFFIX_NAME = "meta"
-META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME
-VARIABLES_FILENAME = EXPORT_BASE_NAME
-VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????"
-INIT_OP_KEY = "serving_init_op"
-SIGNATURES_KEY = "serving_signatures"
-ASSETS_KEY = "serving_assets"
-GRAPH_KEY = "serving_graph"
diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py
index 4709b73114..ac3bd8890b 100644
--- a/tensorflow/contrib/session_bundle/exporter.py
+++ b/tensorflow/contrib/session_bundle/exporter.py
@@ -27,7 +27,6 @@ import six
from google.protobuf.any_pb2 import Any
-from tensorflow.contrib.session_bundle import constants
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.framework import graph_pb2
@@ -38,6 +37,19 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
+# See: go/tf-exporter for these constants and directory structure.
+VERSION_FORMAT_SPECIFIER = "%08d"
+ASSETS_DIRECTORY = "assets"
+EXPORT_BASE_NAME = "export"
+EXPORT_SUFFIX_NAME = "meta"
+META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME
+VARIABLES_FILENAME = EXPORT_BASE_NAME
+VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????"
+INIT_OP_KEY = "serving_init_op"
+SIGNATURES_KEY = "serving_signatures"
+ASSETS_KEY = "serving_assets"
+GRAPH_KEY = "serving_graph"
+
def gfile_copy_callback(files_to_copy, export_dir_path):
"""Callback to copy files using `gfile.Copy` to an export directory.
@@ -188,12 +200,12 @@ class Exporter(object):
node.device = ""
graph_any_buf = Any()
graph_any_buf.Pack(copy)
- ops.add_to_collection(constants.GRAPH_KEY, graph_any_buf)
+ ops.add_to_collection(GRAPH_KEY, graph_any_buf)
if init_op:
if not isinstance(init_op, ops.Operation):
raise TypeError("init_op needs to be an Operation: %s" % init_op)
- ops.add_to_collection(constants.INIT_OP_KEY, init_op)
+ ops.add_to_collection(INIT_OP_KEY, init_op)
signatures_proto = manifest_pb2.Signatures()
if default_graph_signature:
@@ -202,7 +214,7 @@ class Exporter(object):
signatures_proto.named_signatures[signature_name].CopyFrom(signature)
signatures_any_buf = Any()
signatures_any_buf.Pack(signatures_proto)
- ops.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf)
+ ops.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
for filename, tensor in assets:
asset = manifest_pb2.AssetFile()
@@ -210,7 +222,7 @@ class Exporter(object):
asset.tensor_binding.tensor_name = tensor.name
asset_any_buf = Any()
asset_any_buf.Pack(asset)
- ops.add_to_collection(constants.ASSETS_KEY, asset_any_buf)
+ ops.add_to_collection(ASSETS_KEY, asset_any_buf)
self._assets_callback = assets_callback
@@ -247,7 +259,7 @@ class Exporter(object):
global_step = training_util.global_step(sess, global_step_tensor)
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
- compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))
+ compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step))
# Prevent overwriting on existing exports which could lead to bad/corrupt
# storage and loading of models. This is an important check that must be
@@ -264,14 +276,13 @@ class Exporter(object):
self._saver.save(sess,
os.path.join(
compat.as_text(tmp_export_dir),
- compat.as_text(constants.EXPORT_BASE_NAME)),
- meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)
+ compat.as_text(EXPORT_BASE_NAME)),
+ meta_graph_suffix=EXPORT_SUFFIX_NAME)
# Run the asset callback.
if self._assets_callback and self._assets_to_copy:
assets_dir = os.path.join(
- compat.as_bytes(tmp_export_dir),
- compat.as_bytes(constants.ASSETS_DIRECTORY))
+ compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY))
gfile.MakeDirs(assets_dir)
self._assets_callback(self._assets_to_copy, assets_dir)
diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py
index bed1c8ff45..8b54933fc9 100644
--- a/tensorflow/contrib/session_bundle/exporter_test.py
+++ b/tensorflow/contrib/session_bundle/exporter_test.py
@@ -23,7 +23,6 @@ import os.path
import tensorflow as tf
-from tensorflow.contrib.session_bundle import constants
from tensorflow.contrib.session_bundle import exporter
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
@@ -114,14 +113,14 @@ class SaveRestoreShardedTest(tf.test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
save = tf.train.import_meta_graph(
- os.path.join(export_path, constants.VERSION_FORMAT_SPECIFIER %
- global_step, constants.META_GRAPH_DEF_FILENAME))
+ os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.META_GRAPH_DEF_FILENAME))
self.assertIsNotNone(save)
meta_graph_def = save.export_meta_graph()
collection_def = meta_graph_def.collection_def
# Validate custom graph_def.
- graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
+ graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value
self.assertEquals(len(graph_def_any), 1)
graph_def = tf.GraphDef()
graph_def_any[0].Unpack(graph_def)
@@ -131,12 +130,12 @@ class SaveRestoreShardedTest(tf.test.TestCase):
self.assertProtoEquals(compare_def, graph_def)
# Validate init_op.
- init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
+ init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value
self.assertEquals(len(init_ops), 1)
self.assertEquals(init_ops[0], "init_op")
# Validate signatures.
- signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
+ signatures_any = collection_def[exporter.SIGNATURES_KEY].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
@@ -152,21 +151,21 @@ class SaveRestoreShardedTest(tf.test.TestCase):
self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")
# Validate the assets.
- assets_any = collection_def[constants.ASSETS_KEY].any_list.value
+ assets_any = collection_def[exporter.ASSETS_KEY].any_list.value
self.assertEquals(len(assets_any), 1)
asset = manifest_pb2.AssetFile()
assets_any[0].Unpack(asset)
assets_path = os.path.join(export_path,
- constants.VERSION_FORMAT_SPECIFIER %
- global_step, constants.ASSETS_DIRECTORY,
+ exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.ASSETS_DIRECTORY,
"hello42.txt")
asset_contents = gfile.GFile(assets_path).read()
self.assertEqual(asset_contents, "your data here")
self.assertEquals("hello42.txt", asset.filename)
self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
ignored_asset_path = os.path.join(export_path,
- constants.VERSION_FORMAT_SPECIFIER %
- global_step, constants.ASSETS_DIRECTORY,
+ exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.ASSETS_DIRECTORY,
"ignored.txt")
self.assertFalse(gfile.Exists(ignored_asset_path))
@@ -174,16 +173,16 @@ class SaveRestoreShardedTest(tf.test.TestCase):
if sharded:
save.restore(sess,
os.path.join(
- export_path, constants.VERSION_FORMAT_SPECIFIER %
- global_step, constants.VARIABLES_FILENAME_PATTERN))
+ export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.VARIABLES_FILENAME_PATTERN))
else:
save.restore(sess,
os.path.join(
- export_path, constants.VERSION_FORMAT_SPECIFIER %
- global_step, constants.VARIABLES_FILENAME))
+ export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.VARIABLES_FILENAME))
self.assertEqual(10, tf.get_collection("v")[0].eval())
self.assertEqual(20, tf.get_collection("v")[1].eval())
- tf.get_collection(constants.INIT_OP_KEY)[0].run()
+ tf.get_collection(exporter.INIT_OP_KEY)[0].run()
self.assertEqual(30, tf.get_collection("v")[2].eval())
def testDuplicateExportRaisesError(self):
diff --git a/tensorflow/contrib/session_bundle/session_bundle.py b/tensorflow/contrib/session_bundle/session_bundle.py
deleted file mode 100644
index 1479db5712..0000000000
--- a/tensorflow/contrib/session_bundle/session_bundle.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Importer for an exported TensorFlow model.
-
-This module provides a function to create a SessionBundle containing both the
-Session and MetaGraph.
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import tensorflow as tf
-
-from tensorflow.contrib.session_bundle import constants
-from tensorflow.contrib.session_bundle import manifest_pb2
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.platform import gfile
-
-
-def LoadSessionBundleFromPath(export_dir, target="", config=None):
- """Load session bundle from the given path.
-
- The function reads input from the export_dir, constructs the graph data to the
- default graph and restores the parameters for the session created.
-
- Args:
- export_dir: the directory that contains files exported by exporter.
- target: The execution engine to connect to. See target in tf.Session()
- config: A ConfigProto proto with configuration options. See config in
- tf.Session()
-
- Returns:
- session: a tensorflow session created from the variable files.
- meta_graph: a meta graph proto saved in the exporter directory.
-
- Raises:
- RuntimeError: if the required files are missing or contain unrecognizable
- fields, i.e. the exported model is invalid.
- """
- meta_graph_filename = os.path.join(export_dir,
- constants.META_GRAPH_DEF_FILENAME)
- if not gfile.Exists(meta_graph_filename):
- raise RuntimeError("Expected meta graph file missing %s" %
- meta_graph_filename)
- variables_filename = os.path.join(export_dir,
- constants.VARIABLES_FILENAME)
- if not gfile.Exists(variables_filename):
- variables_filename = os.path.join(
- export_dir, constants.VARIABLES_FILENAME_PATTERN)
- if not gfile.Glob(variables_filename):
- raise RuntimeError("Expected variables file missing %s" %
- variables_filename)
- assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)
-
- # Reads meta graph file.
- meta_graph_def = meta_graph_pb2.MetaGraphDef()
- meta_graph_def.ParseFromString(file_io.read_file_to_string(
- meta_graph_filename))
-
- collection_def = meta_graph_def.collection_def
- graph_def = tf.GraphDef()
- if constants.GRAPH_KEY in collection_def:
- # Use serving graph_def in MetaGraphDef collection_def if exists
- graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
- if len(graph_def_any) != 1:
- raise RuntimeError(
- "Expected exactly one serving GraphDef in : %s" % meta_graph_def)
- else:
- graph_def_any[0].Unpack(graph_def)
- # Replace the graph def in meta graph proto.
- meta_graph_def.graph_def.CopyFrom(graph_def)
-
- tf.reset_default_graph()
- sess = tf.Session(target, graph=None, config=config)
- # Import the graph.
- saver = tf.train.import_meta_graph(meta_graph_def)
- # Restore the session.
- saver.restore(sess, variables_filename)
-
- init_op_tensor = None
- if constants.INIT_OP_KEY in collection_def:
- init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
- if len(init_ops) != 1:
- raise RuntimeError(
- "Expected exactly one serving init op in : %s" % meta_graph_def)
- init_op_tensor = tf.get_collection(constants.INIT_OP_KEY)[0]
-
- # Create asset input tensor list.
- asset_tensor_dict = {}
- if constants.ASSETS_KEY in collection_def:
- assets_any = collection_def[constants.ASSETS_KEY].any_list.value
- for asset in assets_any:
- asset_pb = manifest_pb2.AssetFile()
- asset.Unpack(asset_pb)
- asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
- assets_dir, asset_pb.filename)
-
- if init_op_tensor:
- # Run the init op.
- sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)
-
- return sess, meta_graph_def
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.py b/tensorflow/contrib/session_bundle/session_bundle_test.py
deleted file mode 100644
index a9e157eb19..0000000000
--- a/tensorflow/contrib/session_bundle/session_bundle_test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-"""Tests for session_bundle.py."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os.path
-import numpy as np
-import tensorflow as tf
-from tensorflow.contrib.session_bundle import constants
-from tensorflow.contrib.session_bundle import manifest_pb2
-from tensorflow.contrib.session_bundle import session_bundle
-from tensorflow.python.util import compat
-
-
-class SessionBundleLoadTest(tf.test.TestCase):
-
- def testBasic(self):
- base_path = tf.test.test_src_dir_path(
- "contrib/session_bundle/example/half_plus_two/00000123")
- tf.reset_default_graph()
- sess, meta_graph_def = session_bundle.LoadSessionBundleFromPath(
- base_path, target="", config=tf.ConfigProto(device_count={"CPU": 2}))
-
- self.assertTrue(sess)
- asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
- with sess.as_default():
- path1, path2 = sess.run(["filename1:0", "filename2:0"])
- self.assertEqual(
- compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
- self.assertEqual(
- compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)
-
- collection_def = meta_graph_def.collection_def
-
- signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
- self.assertEquals(len(signatures_any), 1)
-
- signatures = manifest_pb2.Signatures()
- signatures_any[0].Unpack(signatures)
- default_signature = signatures.default_signature
- input_name = default_signature.regression_signature.input.tensor_name
- output_name = default_signature.regression_signature.output.tensor_name
- y = sess.run([output_name], {input_name: np.array([[0], [1], [2], [3]])})
- # The operation is y = 0.5 * x + 2
- self.assertEqual(y[0][0], 2)
- self.assertEqual(y[0][1], 2.5)
- self.assertEqual(y[0][2], 3)
- self.assertEqual(y[0][3], 3.5)
-
- def testBadPath(self):
- base_path = tf.test.test_src_dir_path("/no/such/a/dir")
- tf.reset_default_graph()
- with self.assertRaises(RuntimeError) as cm:
- _, _ = session_bundle.LoadSessionBundleFromPath(
- base_path, target="local",
- config=tf.ConfigProto(device_count={"CPU": 2}))
- self.assertTrue("Expected meta graph file missing" in str(cm.exception))
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index a41ba73a32..04909a40a9 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -94,20 +94,6 @@ def GetTempDir():
return temp_dir
-def test_src_dir_path(relative_path):
- """Creates an absolute test srcdir path given a relative path.
-
- Args:
- relative_path: a path relative to tensorflow root.
- e.g. "contrib/session_bundle/example".
-
- Returns:
- An absolute path to the linked in runfiles.
- """
- return os.path.join(os.environ['TEST_SRCDIR'],
- "org_tensorflow/tensorflow", relative_path)
-
-
def StatefulSessionAvailable():
return False
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index ac700f43c4..25b0fea80f 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -98,19 +98,6 @@ def get_temp_dir():
return googletest.GetTempDir()
-def test_src_dir_path(relative_path):
- """Creates an absolute test srcdir path given a relative path.
-
- Args:
- relative_path: a path relative to tensorflow root.
- e.g. "core/platform".
-
- Returns:
- An absolute path to the linked in runfiles.
- """
- return googletest.test_src_dir_path(relative_path)
-
-
def is_built_with_cuda():
"""Returns whether TensorFlow was built with CUDA (GPU) support."""
return test_util.IsGoogleCudaEnabled()