aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2017-02-28 17:28:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-28 17:46:55 -0800
commit0e7edf20ee67c4e2ece258649147f8e9f2027f6b (patch)
tree21a5294a0b7e4991cb1f9db9a5e30b5fd564af54
parent3e6c638727c3274908a7c9c6bbf4474c014511fe (diff)
Make selective registration handle spaces in kernel name.
Change: 148837641
-rw-r--r--tensorflow/core/framework/op_kernel.h4
-rw-r--r--tensorflow/python/tools/print_selective_registration_header_test.py63
-rw-r--r--tensorflow/python/tools/selective_registration_header_lib.py37
3 files changed, 80 insertions, 24 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index b6e302c492..eb1ca88938 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1140,9 +1140,11 @@ class Name : public KernelDefBuilder {
REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
+ constexpr bool should_register_##ctr##__flag = \
+ SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \
static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \
- SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
+ should_register_##ctr##__flag \
? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \
#__VA_ARGS__, \
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py
index 6947428831..08b146f970 100644
--- a/tensorflow/python/tools/print_selective_registration_header_test.py
+++ b/tensorflow/python/tools/print_selective_registration_header_test.py
@@ -148,28 +148,53 @@ class PrintOpFilegroupTest(test.TestCase):
default_ops = ''
graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
+ expected = '''#ifndef OPS_TO_REGISTER
+#define OPS_TO_REGISTER
+constexpr inline bool ShouldRegisterOp(const char op[]) {
+ return false
+ || (strcmp(op, "BiasAdd") == 0)
+ ;
+}
+#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
+
+
+ namespace {
+ constexpr const char* skip(const char* x) {
+ return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
+ }
+
+ constexpr bool isequal(const char* x, const char* y) {
+ return (*skip(x) && *skip(y))
+ ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
+ : (!*skip(x) && !*skip(y));
+ }
+
+ template<int N>
+ struct find_in {
+ static constexpr bool f(const char* x, const char* const y[N]) {
+ return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
+ }
+ };
+
+ template<>
+ struct find_in<0> {
+ static constexpr bool f(const char* x, const char* const y[]) {
+ return false;
+ }
+ };
+ } // end namespace
+ constexpr const char* kNecessaryOpKernelClasses[] = {
+"BiasOp<CPUDevice, float>",
+};
+#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
+
+#define SHOULD_REGISTER_OP_GRADIENT false
+#endif'''
+
header = selective_registration_header_lib.get_header(
self.WriteGraphFiles(graphs), 'rawproto', default_ops)
print(header)
- self.assertListEqual([
- '#ifndef OPS_TO_REGISTER',
- '#define OPS_TO_REGISTER',
- 'constexpr inline bool ShouldRegisterOp(const char op[]) {',
- ' return false',
- ' || (strcmp(op, "BiasAdd") == 0)',
- ' ;',
- '}',
- '#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)',
- '',
- 'const char kNecessaryOpKernelClasses[] = ","',
- '"BiasOp<CPUDevice, float>,"',
- ';',
- '#define SHOULD_REGISTER_OP_KERNEL(clz)'
- ' (strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)',
- '',
- '#define SHOULD_REGISTER_OP_GRADIENT false',
- '#endif',
- ], header.split('\n'))
+ self.assertListEqual(expected.split('\n'), header.split('\n'))
if __name__ == '__main__':
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index 1229ea2532..b297721aff 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -106,13 +106,42 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
append('')
- line = 'const char kNecessaryOpKernelClasses[] = ","\n'
+ line = '''
+ namespace {
+ constexpr const char* skip(const char* x) {
+ return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x;
+ }
+
+ constexpr bool isequal(const char* x, const char* y) {
+ return (*skip(x) && *skip(y))
+ ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1))
+ : (!*skip(x) && !*skip(y));
+ }
+
+ template<int N>
+ struct find_in {
+ static constexpr bool f(const char* x, const char* const y[N]) {
+ return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1);
+ }
+ };
+
+ template<>
+ struct find_in<0> {
+ static constexpr bool f(const char* x, const char* const y[]) {
+ return false;
+ }
+ };
+ } // end namespace
+ '''
+ line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
for _, kernel_class in ops_and_kernels:
- line += '"%s,"\n' % kernel_class
- line += ';'
+ line += '"%s",\n' % kernel_class
+ line += '};'
append(line)
append('#define SHOULD_REGISTER_OP_KERNEL(clz) '
- '(strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)')
+ '(find_in<sizeof(kNecessaryOpKernelClasses) '
+ '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, '
+ 'kNecessaryOpKernelClasses))')
append('')
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (