aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 06:28:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 07:35:19 -0700
commit783e8388db67f3fefc6c714d479615821a9dc5e7 (patch)
treea25e65ce1552ca593f7cf72d61b0f35bfd18613f
parent7a0f6252d7f93b8cc78e8cc61c88f45d98642b30 (diff)
Add python/tools/print_selective_registration_header utility for printing a
header for use with 'selective registration' (see core/framework/selective_registration.h), given an input model. Change: 134659033
-rw-r--r--tensorflow/core/framework/selective_registration.h3
-rw-r--r--tensorflow/python/BUILD15
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/tools/BUILD18
-rw-r--r--tensorflow/python/tools/print_selective_registration_header.py129
-rw-r--r--tensorflow/python/tools/print_selective_registration_header_test.py114
-rw-r--r--tensorflow/python/util/kernel_registry.cc57
-rw-r--r--tensorflow/python/util/kernel_registry.h34
-rw-r--r--tensorflow/python/util/kernel_registry.i28
9 files changed, 397 insertions, 2 deletions
diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h
index 1227500782..751e2cde84 100644
--- a/tensorflow/core/framework/selective_registration.h
+++ b/tensorflow/core/framework/selective_registration.h
@@ -31,7 +31,8 @@ limitations under the License.
// functions should be defined as valid constexpr functions, so that they are
// evaluated at compile time: this is needed to make symbols referenced by
// un-registered objects unused, and therefore allow the linker to strip them
-// out.
+// out. See tools/print_required_ops/print_selective_registration_header.py
+// for a tool that can be used to generate ops_to_register.h.
#include "ops_to_register.h"
// Op kernel classes for which ShouldRegisterOpKernel returns false will not be
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fe2d56eec3..48c8dc18c8 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -115,6 +115,17 @@ cc_library(
)
cc_library(
+ name = "kernel_registry",
+ srcs = ["util/kernel_registry.cc"],
+ hdrs = ["util/kernel_registry.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_library(
name = "py_func_lib",
srcs = ["lib/core/py_func.cc"],
hdrs = [
@@ -1794,15 +1805,17 @@ tf_py_wrap_cc(
"lib/io/py_record_writer.i",
"platform/base.i",
"training/server_lib.i",
+ "util/kernel_registry.i",
"util/port.i",
"util/py_checkpoint_reader.i",
],
deps = [
+ ":cpp_shape_inference",
+ ":kernel_registry",
":numpy_lib",
":py_func_lib",
":py_record_reader_lib",
":py_record_writer_lib",
- ":cpp_shape_inference",
":python_op_gen",
":tf_session_helper",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 7a8fbf7201..9115ec891a 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -37,3 +37,4 @@ limitations under the License.
%include "tensorflow/python/framework/python_op_gen.i"
%include "tensorflow/python/framework/cpp_shape_inference.i"
+%include "tensorflow/python/util/kernel_registry.i"
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 9eec25d393..cf94442308 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -158,6 +158,24 @@ py_test(
],
)
+py_binary(
+ name = "print_selective_registration_header",
+ srcs = ["print_selective_registration_header.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = ["//tensorflow:tensorflow_py"],
+)
+
+py_test(
+ name = "print_selective_registration_header_test",
+ srcs = ["print_selective_registration_header_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":print_selective_registration_header",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/python/tools/print_selective_registration_header.py b/tensorflow/python/tools/print_selective_registration_header.py
new file mode 100644
index 0000000000..629c58a17f
--- /dev/null
+++ b/tensorflow/python/tools/print_selective_registration_header.py
@@ -0,0 +1,129 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Prints a header file to be used with SELECTIVE_REGISTRATION.
+
+Example usage:
+ print_selective_registration_header \
+ --graphs=path/to/graph.pb > ops_to_register.h
+
+ Then when compiling tensorflow, include ops_to_register.h in the include
+ search path and pass -DSELECTIVE_REGISTRATION - see
+ core/framework/selective_registration.h for more details.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import tensorflow as tf
+from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python import pywrap_tensorflow
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string('proto_fileformat', 'rawproto',
+ 'Format of proto file, either textproto or rawproto')
+
+tf.app.flags.DEFINE_string(
+ 'graphs', '',
+ 'Comma-separated list of paths to model files to be analyzed.')
+
+tf.app.flags.DEFINE_string('default_ops', 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp',
+ 'Default operator:kernel pairs to always include '
+ 'implementation for')
+
+
+def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
+ """Gets the ops and kernels needed from the model files."""
+ ops = set()
+
+ for proto_file in proto_files:
+ tf.logging.info('Loading proto file %s', proto_file)
+ # Load GraphDef.
+ file_data = tf.gfile.GFile(proto_file).read()
+ if proto_fileformat == 'rawproto':
+ graph_def = graph_pb2.GraphDef.FromString(file_data)
+ else:
+ assert proto_fileformat == 'textproto'
+ graph_def = text_format.Parse(file_data, graph_pb2.GraphDef())
+
+ # Find all ops and kernels used by the graph.
+ for node_def in graph_def.node:
+ if not node_def.device:
+ node_def.device = '/cpu:0'
+ kernel_class = pywrap_tensorflow.TryFindKernelClass(
+ node_def.SerializeToString())
+ if kernel_class:
+ op_and_kernel = (str(node_def.op), kernel_class.decode('utf-8'))
+ if op_and_kernel not in ops:
+ ops.add(op_and_kernel)
+ else:
+ print(
+ 'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr)
+
+ # Add default ops.
+ for s in default_ops_str.split(','):
+ op, kernel = s.split(':')
+ op_and_kernel = (op, kernel)
+ if op_and_kernel not in ops:
+ ops.add(op_and_kernel)
+
+ return list(sorted(ops))
+
+
+def print_header(ops_and_kernels, ops):
+ """Prints a header for use with tensorflow SELECTIVE_REGISTRATION."""
+ print('#ifndef OPS_TO_REGISTER')
+ print('#define OPS_TO_REGISTER')
+
+ print('constexpr inline bool ShouldRegisterOp(const char op[]) {')
+ print(' return false')
+ for op in sorted(ops):
+ print(' || (strcmp(op, "%s") == 0)' % op)
+ print(' ;')
+ print('}')
+
+ line = 'const char kNecessaryOpKernelClasses[] = ","\n'
+ for _, kernel_class in ops_and_kernels:
+ line += '"%s,"\n' % kernel_class
+ line += ';'
+ print(line)
+
+ print('const bool kRequiresSymbolicGradients = %s;' %
+ ('true' if 'SymbolicGradient' in ops else 'false'))
+
+ print('#endif')
+
+
+def main(unused_argv):
+ if not FLAGS.graphs:
+ print('--graphs is required')
+ return 1
+ graphs = FLAGS.graphs.split(',')
+ ops_and_kernels = get_ops_and_kernels(FLAGS.proto_fileformat, graphs,
+ FLAGS.default_ops)
+ ops = set([op for op, _ in ops_and_kernels])
+ if not ops:
+ print('Error reading graph!')
+ return 1
+
+ print_header(ops_and_kernels, ops)
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py
new file mode 100644
index 0000000000..e3aac9cf18
--- /dev/null
+++ b/tensorflow/python/tools/print_selective_registration_header_test.py
@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for print_selective_registration_header."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.tools import print_selective_registration_header
+
+# Note that this graph def is not valid to be loaded - its inputs are not
+# assigned correctly in all cases.
+GRAPH_DEF_TXT = """
+ node: {
+ name: "node_1"
+ op: "Reshape"
+ input: [ "none", "none" ]
+ device: "/cpu:0"
+ attr: { key: "T" value: { type: DT_FLOAT } }
+ }
+ node: {
+ name: "node_2"
+ op: "MatMul"
+ input: [ "none", "none" ]
+ device: "/cpu:0"
+ attr: { key: "T" value: { type: DT_FLOAT } }
+ attr: { key: "transpose_a" value: { b: false } }
+ attr: { key: "transpose_b" value: { b: false } }
+ }
+ node: {
+ name: "node_3"
+ op: "MatMul"
+ input: [ "none", "none" ]
+ device: "/cpu:0"
+ attr: { key: "T" value: { type: DT_DOUBLE } }
+ attr: { key: "transpose_a" value: { b: false } }
+ attr: { key: "transpose_b" value: { b: false } }
+ }
+"""
+
+GRAPH_DEF_TXT_2 = """
+ node: {
+ name: "node_4"
+ op: "BiasAdd"
+ input: [ "none", "none" ]
+ device: "/cpu:0"
+ attr: { key: "T" value: { type: DT_FLOAT } }
+ }
+
+"""
+
+
+class PrintOpFilegroupTest(tf.test.TestCase):
+
+ def WriteGraphFiles(self, graphs):
+ fnames = []
+ for i, graph in enumerate(graphs):
+ fname = os.path.join(self.get_temp_dir(), 'graph%s.pb' % i)
+ with tf.gfile.GFile(fname, 'wb') as f:
+ f.write(graph.SerializeToString())
+ fnames.append(fname)
+ return fnames
+
+ def testGetOps(self):
+ default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp'
+ graphs = [text_format.Parse(d, graph_pb2.GraphDef())
+ for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]]
+
+ ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
+ 'rawproto', self.WriteGraphFiles(graphs), default_ops)
+ self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), #
+ ('MatMul', 'MatMulOp<CPUDevice, double, false >'), #
+ ('MatMul', 'MatMulOp<CPUDevice, float, false >'), #
+ ('NoOp', 'NoOp'), #
+ ('Reshape', 'ReshapeOp'), #
+ ('_Recv', 'RecvOp'), #
+ ('_Send', 'SendOp'), #
+ ],
+ ops_and_kernels)
+
+ graphs[0].node[0].ClearField('device')
+ graphs[0].node[2].ClearField('device')
+ ops_and_kernels = print_selective_registration_header.get_ops_and_kernels(
+ 'rawproto', self.WriteGraphFiles(graphs), default_ops)
+ self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), #
+ ('MatMul', 'MatMulOp<CPUDevice, double, false >'), #
+ ('MatMul', 'MatMulOp<CPUDevice, float, false >'), #
+ ('NoOp', 'NoOp'), #
+ ('Reshape', 'ReshapeOp'), #
+ ('_Recv', 'RecvOp'), #
+ ('_Send', 'SendOp'), #
+ ],
+ ops_and_kernels)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/util/kernel_registry.cc b/tensorflow/python/util/kernel_registry.cc
new file mode 100644
index 0000000000..b05c2ef04b
--- /dev/null
+++ b/tensorflow/python/util/kernel_registry.cc
@@ -0,0 +1,57 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/util/kernel_registry.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace swig {
+
+string TryFindKernelClass(const string& serialized_node_def) {
+ tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(serialized_node_def)) {
+ LOG(WARNING) << "Error parsing node_def";
+ return "";
+ }
+
+ const tensorflow::OpRegistrationData* op_reg_data;
+ auto status =
+ tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
+ if (!status.ok()) {
+ LOG(WARNING) << "Op " << node_def.op() << " not found: " << status;
+ return "";
+ }
+ AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
+
+ tensorflow::DeviceNameUtils::ParsedName parsed_name;
+ if (!tensorflow::DeviceNameUtils::ParseFullName(node_def.device(),
+ &parsed_name)) {
+ LOG(WARNING) << "Failed to parse device from node_def: "
+ << node_def.ShortDebugString();
+ return "";
+ }
+ string class_name = "";
+ tensorflow::FindKernelDef(tensorflow::DeviceType(parsed_name.type.c_str()),
+ node_def, nullptr /* kernel_def */, &class_name);
+ return class_name;
+}
+
+} // namespace swig
+} // namespace tensorflow
diff --git a/tensorflow/python/util/kernel_registry.h b/tensorflow/python/util/kernel_registry.h
new file mode 100644
index 0000000000..c00b60d91b
--- /dev/null
+++ b/tensorflow/python/util/kernel_registry.h
@@ -0,0 +1,34 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Functions for getting information about kernels registered in the binary.
+#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
+#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace swig {
+
+// Returns the kernel class name required to execute <node_def> on the device
+// type of <node_def.device>, or an empty string if the kernel class is not
+// found or the device name is invalid.
+string TryFindKernelClass(const string& serialized_node_def);
+
+} // namespace swig
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_
diff --git a/tensorflow/python/util/kernel_registry.i b/tensorflow/python/util/kernel_registry.i
new file mode 100644
index 0000000000..0c2e0df37d
--- /dev/null
+++ b/tensorflow/python/util/kernel_registry.i
@@ -0,0 +1,28 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+%include "tensorflow/python/platform/base.i"
+
+%{
+#include "tensorflow/python/util/kernel_registry.h"
+%}
+
+%ignoreall
+
+%unignore tensorflow;
+%unignore tensorflow::swig;
+%unignore tensorflow::swig::TryFindKernelClass;
+%include "tensorflow/python/util/kernel_registry.h"
+
+%unignoreall