diff options
author | 2017-02-28 17:28:23 -0800 | |
---|---|---|
committer | 2017-02-28 17:46:55 -0800 | |
commit | 0e7edf20ee67c4e2ece258649147f8e9f2027f6b (patch) | |
tree | 21a5294a0b7e4991cb1f9db9a5e30b5fd564af54 | |
parent | 3e6c638727c3274908a7c9c6bbf4474c014511fe (diff) |
Make selective registration handle spaces in kernel name.
Change: 148837641
3 files changed, 80 insertions, 24 deletions
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index b6e302c492..eb1ca88938 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -1140,9 +1140,11 @@ class Name : public KernelDefBuilder { REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__) #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ + constexpr bool should_register_##ctr##__flag = \ + SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__); \ static ::tensorflow::kernel_factory::OpKernelRegistrar \ registrar__body__##ctr##__object( \ - SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \ + should_register_##ctr##__flag \ ? ::tensorflow::register_kernel::kernel_builder.Build() \ : nullptr, \ #__VA_ARGS__, \ diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index 6947428831..08b146f970 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -148,28 +148,53 @@ class PrintOpFilegroupTest(test.TestCase): default_ops = '' graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())] + expected = '''#ifndef OPS_TO_REGISTER +#define OPS_TO_REGISTER +constexpr inline bool ShouldRegisterOp(const char op[]) { + return false + || (strcmp(op, "BiasAdd") == 0) + ; +} +#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) + + + namespace { + constexpr const char* skip(const char* x) { + return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; + } + + constexpr bool isequal(const char* x, const char* y) { + return (*skip(x) && *skip(y)) + ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) + : (!*skip(x) && !*skip(y)); + } + + template<int N> + struct find_in { + static constexpr bool f(const char* x, const char* const y[N]) { + return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); + } + }; + + template<> + struct find_in<0> { + static constexpr bool f(const char* x, const char* const y[]) { + return false; + } + }; + } // end namespace + constexpr const char* kNecessaryOpKernelClasses[] = { +"BiasOp<CPUDevice, float>", +}; +#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses)) + +#define SHOULD_REGISTER_OP_GRADIENT false +#endif''' + header = selective_registration_header_lib.get_header( self.WriteGraphFiles(graphs), 'rawproto', default_ops) print(header) - self.assertListEqual([ - '#ifndef OPS_TO_REGISTER', - '#define OPS_TO_REGISTER', - 'constexpr inline bool ShouldRegisterOp(const char op[]) {', - ' return false', - ' || (strcmp(op, "BiasAdd") == 0)', - ' ;', - '}', - '#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)', - '', - 'const char kNecessaryOpKernelClasses[] = ","', - '"BiasOp<CPUDevice, float>,"', - ';', - '#define SHOULD_REGISTER_OP_KERNEL(clz)' - ' (strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)', - '', - '#define SHOULD_REGISTER_OP_GRADIENT false', - '#endif', - ], header.split('\n')) + self.assertListEqual(expected.split('\n'), header.split('\n')) if __name__ == '__main__': diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index 1229ea2532..b297721aff 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -106,13 +106,42 @@ def get_header_from_ops_and_kernels(ops_and_kernels, append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') append('') - line = 'const char kNecessaryOpKernelClasses[] = ","\n' + line = ''' + namespace { + constexpr const char* skip(const char* x) { + return (*x) ? (*x == ' ' ? skip(x + 1) : x) : x; + } + + constexpr bool isequal(const char* x, const char* y) { + return (*skip(x) && *skip(y)) + ? (*skip(x) == *skip(y) && isequal(skip(x) + 1, skip(y) + 1)) + : (!*skip(x) && !*skip(y)); + } + + template<int N> + struct find_in { + static constexpr bool f(const char* x, const char* const y[N]) { + return isequal(x, y[0]) || find_in<N - 1>::f(x, y + 1); + } + }; + + template<> + struct find_in<0> { + static constexpr bool f(const char* x, const char* const y[]) { + return false; + } + }; + } // end namespace + ''' + line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' for _, kernel_class in ops_and_kernels: - line += '"%s,"\n' % kernel_class - line += ';' + line += '"%s",\n' % kernel_class + line += '};' append(line) append('#define SHOULD_REGISTER_OP_KERNEL(clz) ' - '(strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)') + '(find_in<sizeof(kNecessaryOpKernelClasses) ' + '/ sizeof(*kNecessaryOpKernelClasses)>::f(clz, ' + 'kNecessaryOpKernelClasses))') append('') append('#define SHOULD_REGISTER_OP_GRADIENT ' + ( |