diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-29 06:28:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-29 07:35:19 -0700 |
commit | 783e8388db67f3fefc6c714d479615821a9dc5e7 (patch) | |
tree | a25e65ce1552ca593f7cf72d61b0f35bfd18613f | |
parent | 7a0f6252d7f93b8cc78e8cc61c88f45d98642b30 (diff) |
Add python/tools/print_selective_registration_header utility for printing a
header for use with 'selective registration' (see
core/framework/selective_registration.h), given an input model.
Change: 134659033
-rw-r--r-- | tensorflow/core/framework/selective_registration.h | 3 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/python/tensorflow.i | 1 | ||||
-rw-r--r-- | tensorflow/python/tools/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/python/tools/print_selective_registration_header.py | 129 | ||||
-rw-r--r-- | tensorflow/python/tools/print_selective_registration_header_test.py | 114 | ||||
-rw-r--r-- | tensorflow/python/util/kernel_registry.cc | 57 | ||||
-rw-r--r-- | tensorflow/python/util/kernel_registry.h | 34 | ||||
-rw-r--r-- | tensorflow/python/util/kernel_registry.i | 28 |
9 files changed, 397 insertions, 2 deletions
diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h index 1227500782..751e2cde84 100644 --- a/tensorflow/core/framework/selective_registration.h +++ b/tensorflow/core/framework/selective_registration.h @@ -31,7 +31,8 @@ limitations under the License. // functions should be defined as valid constexpr functions, so that they are // evaluated at compile time: this is needed to make symbols referenced by // un-registered objects unused, and therefore allow the linker to strip them -// out. +// out. See tools/print_required_ops/print_selective_registration_header.py +// for a tool that can be used to generate ops_to_register.h. #include "ops_to_register.h" // Op kernel classes for which ShouldRegisterOpKernel returns false will not be diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index fe2d56eec3..48c8dc18c8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -115,6 +115,17 @@ cc_library( ) cc_library( + name = "kernel_registry", + srcs = ["util/kernel_registry.cc"], + hdrs = ["util/kernel_registry.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow", + ], +) + +cc_library( name = "py_func_lib", srcs = ["lib/core/py_func.cc"], hdrs = [ @@ -1794,15 +1805,17 @@ tf_py_wrap_cc( "lib/io/py_record_writer.i", "platform/base.i", "training/server_lib.i", + "util/kernel_registry.i", "util/port.i", "util/py_checkpoint_reader.i", ], deps = [ + ":cpp_shape_inference", + ":kernel_registry", ":numpy_lib", ":py_func_lib", ":py_record_reader_lib", ":py_record_writer_lib", - ":cpp_shape_inference", ":python_op_gen", ":tf_session_helper", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i index 7a8fbf7201..9115ec891a 100644 --- a/tensorflow/python/tensorflow.i +++ b/tensorflow/python/tensorflow.i @@ -37,3 +37,4 @@ limitations under the License. %include "tensorflow/python/framework/python_op_gen.i" %include "tensorflow/python/framework/cpp_shape_inference.i" +%include "tensorflow/python/util/kernel_registry.i" diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 9eec25d393..cf94442308 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -158,6 +158,24 @@ py_test( ], ) +py_binary( + name = "print_selective_registration_header", + srcs = ["print_selective_registration_header.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//tensorflow:tensorflow_py"], +) + +py_test( + name = "print_selective_registration_header_test", + srcs = ["print_selective_registration_header_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":print_selective_registration_header", + "//tensorflow:tensorflow_py", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/python/tools/print_selective_registration_header.py b/tensorflow/python/tools/print_selective_registration_header.py new file mode 100644 index 0000000000..629c58a17f --- /dev/null +++ b/tensorflow/python/tools/print_selective_registration_header.py @@ -0,0 +1,129 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Prints a header file to be used with SELECTIVE_REGISTRATION. + +Example usage: + print_selective_registration_header \ + --graphs=path/to/graph.pb > ops_to_register.h + + Then when compiling tensorflow, include ops_to_register.h in the include + search path and pass -DSELECTIVE_REGISTRATION - see + core/framework/selective_registration.h for more details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import tensorflow as tf +from google.protobuf import text_format +from tensorflow.core.framework import graph_pb2 +from tensorflow.python import pywrap_tensorflow + +FLAGS = tf.app.flags.FLAGS + +tf.app.flags.DEFINE_string('proto_fileformat', 'rawproto', + 'Format of proto file, either textproto or rawproto') + +tf.app.flags.DEFINE_string( + 'graphs', '', + 'Comma-separated list of paths to model files to be analyzed.') + +tf.app.flags.DEFINE_string('default_ops', 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp', + 'Default operator:kernel pairs to always include ' + 'implementation for') + + +def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): + """Gets the ops and kernels needed from the model files.""" + ops = set() + + for proto_file in proto_files: + tf.logging.info('Loading proto file %s', proto_file) + # Load GraphDef. + file_data = tf.gfile.GFile(proto_file).read() + if proto_fileformat == 'rawproto': + graph_def = graph_pb2.GraphDef.FromString(file_data) + else: + assert proto_fileformat == 'textproto' + graph_def = text_format.Parse(file_data, graph_pb2.GraphDef()) + + # Find all ops and kernels used by the graph. + for node_def in graph_def.node: + if not node_def.device: + node_def.device = '/cpu:0' + kernel_class = pywrap_tensorflow.TryFindKernelClass( + node_def.SerializeToString()) + if kernel_class: + op_and_kernel = (str(node_def.op), kernel_class.decode('utf-8')) + if op_and_kernel not in ops: + ops.add(op_and_kernel) + else: + print( + 'Warning: no kernel found for op %s' % node_def.op, file=sys.stderr) + + # Add default ops. + for s in default_ops_str.split(','): + op, kernel = s.split(':') + op_and_kernel = (op, kernel) + if op_and_kernel not in ops: + ops.add(op_and_kernel) + + return list(sorted(ops)) + + +def print_header(ops_and_kernels, ops): + """Prints a header for use with tensorflow SELECTIVE_REGISTRATION.""" + print('#ifndef OPS_TO_REGISTER') + print('#define OPS_TO_REGISTER') + + print('constexpr inline bool ShouldRegisterOp(const char op[]) {') + print(' return false') + for op in sorted(ops): + print(' || (strcmp(op, "%s") == 0)' % op) + print(' ;') + print('}') + + line = 'const char kNecessaryOpKernelClasses[] = ","\n' + for _, kernel_class in ops_and_kernels: + line += '"%s,"\n' % kernel_class + line += ';' + print(line) + + print('const bool kRequiresSymbolicGradients = %s;' % + ('true' if 'SymbolicGradient' in ops else 'false')) + + print('#endif') + + +def main(unused_argv): + if not FLAGS.graphs: + print('--graphs is required') + return 1 + graphs = FLAGS.graphs.split(',') + ops_and_kernels = get_ops_and_kernels(FLAGS.proto_fileformat, graphs, + FLAGS.default_ops) + ops = set([op for op, _ in ops_and_kernels]) + if not ops: + print('Error reading graph!') + return 1 + + print_header(ops_and_kernels, ops) + + +if __name__ == '__main__': + tf.app.run() diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py new file mode 100644 index 0000000000..e3aac9cf18 --- /dev/null +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -0,0 +1,114 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for print_selective_registration_header.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +from google.protobuf import text_format +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.tools import print_selective_registration_header + +# Note that this graph def is not valid to be loaded - its inputs are not +# assigned correctly in all cases. +GRAPH_DEF_TXT = """ + node: { + name: "node_1" + op: "Reshape" + input: [ "none", "none" ] + device: "/cpu:0" + attr: { key: "T" value: { type: DT_FLOAT } } + } + node: { + name: "node_2" + op: "MatMul" + input: [ "none", "none" ] + device: "/cpu:0" + attr: { key: "T" value: { type: DT_FLOAT } } + attr: { key: "transpose_a" value: { b: false } } + attr: { key: "transpose_b" value: { b: false } } + } + node: { + name: "node_3" + op: "MatMul" + input: [ "none", "none" ] + device: "/cpu:0" + attr: { key: "T" value: { type: DT_DOUBLE } } + attr: { key: "transpose_a" value: { b: false } } + attr: { key: "transpose_b" value: { b: false } } + } +""" + +GRAPH_DEF_TXT_2 = """ + node: { + name: "node_4" + op: "BiasAdd" + input: [ "none", "none" ] + device: "/cpu:0" + attr: { key: "T" value: { type: DT_FLOAT } } + } + +""" + + +class PrintOpFilegroupTest(tf.test.TestCase): + + def WriteGraphFiles(self, graphs): + fnames = [] + for i, graph in enumerate(graphs): + fname = os.path.join(self.get_temp_dir(), 'graph%s.pb' % i) + with tf.gfile.GFile(fname, 'wb') as f: + f.write(graph.SerializeToString()) + fnames.append(fname) + return fnames + + def testGetOps(self): + default_ops = 'NoOp:NoOp,_Recv:RecvOp,_Send:SendOp' + graphs = [text_format.Parse(d, graph_pb2.GraphDef()) + for d in [GRAPH_DEF_TXT, GRAPH_DEF_TXT_2]] + + ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( + 'rawproto', self.WriteGraphFiles(graphs), default_ops) + self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), # + ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # + ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # + ('NoOp', 'NoOp'), # + ('Reshape', 'ReshapeOp'), # + ('_Recv', 'RecvOp'), # + ('_Send', 'SendOp'), # + ], + ops_and_kernels) + + graphs[0].node[0].ClearField('device') + graphs[0].node[2].ClearField('device') + ops_and_kernels = print_selective_registration_header.get_ops_and_kernels( + 'rawproto', self.WriteGraphFiles(graphs), default_ops) + self.assertListEqual([('BiasAdd', 'BiasOp<CPUDevice, float>'), # + ('MatMul', 'MatMulOp<CPUDevice, double, false >'), # + ('MatMul', 'MatMulOp<CPUDevice, float, false >'), # + ('NoOp', 'NoOp'), # + ('Reshape', 'ReshapeOp'), # + ('_Recv', 'RecvOp'), # + ('_Send', 'SendOp'), # + ], + ops_and_kernels) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/util/kernel_registry.cc b/tensorflow/python/util/kernel_registry.cc new file mode 100644 index 0000000000..b05c2ef04b --- /dev/null +++ b/tensorflow/python/util/kernel_registry.cc @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/util/kernel_registry.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace swig { + +string TryFindKernelClass(const string& serialized_node_def) { + tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(serialized_node_def)) { + LOG(WARNING) << "Error parsing node_def"; + return ""; + } + + const tensorflow::OpRegistrationData* op_reg_data; + auto status = + tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); + if (!status.ok()) { + LOG(WARNING) << "Op " << node_def.op() << " not found: " << status; + return ""; + } + AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); + + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(node_def.device(), + &parsed_name)) { + LOG(WARNING) << "Failed to parse device from node_def: " + << node_def.ShortDebugString(); + return ""; + } + string class_name = ""; + tensorflow::FindKernelDef(tensorflow::DeviceType(parsed_name.type.c_str()), + node_def, nullptr /* kernel_def */, &class_name); + return class_name; +} + +} // namespace swig +} // namespace tensorflow diff --git a/tensorflow/python/util/kernel_registry.h b/tensorflow/python/util/kernel_registry.h new file mode 100644 index 0000000000..c00b60d91b --- /dev/null +++ b/tensorflow/python/util/kernel_registry.h @@ -0,0 +1,34 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Functions for getting information about kernels registered in the binary. +#ifndef THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ +#define THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace swig { + +// Returns the kernel class name required to execute <node_def> on the device +// type of <node_def.device>, or an empty string if the kernel class is not +// found or the device name is invalid. +string TryFindKernelClass(const string& serialized_node_def); + +} // namespace swig +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_PYTHON_UTIL_KERNEL_REGISTRY_H_ diff --git a/tensorflow/python/util/kernel_registry.i b/tensorflow/python/util/kernel_registry.i new file mode 100644 index 0000000000..0c2e0df37d --- /dev/null +++ b/tensorflow/python/util/kernel_registry.i @@ -0,0 +1,28 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +%include "tensorflow/python/platform/base.i" + +%{ +#include "tensorflow/python/util/kernel_registry.h" +%} + +%ignoreall + +%unignore tensorflow; +%unignore tensorflow::swig; +%unignore tensorflow::swig::TryFindKernelClass; +%include "tensorflow/python/util/kernel_registry.h" + +%unignoreall |