aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/tools/print_selective_registration_header_test.py14
-rw-r--r--tensorflow/python/tools/selective_registration_header_lib.py18
2 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py
index fe20df5924..36978b0860 100644
--- a/tensorflow/python/tools/print_selective_registration_header_test.py
+++ b/tensorflow/python/tools/print_selective_registration_header_test.py
@@ -156,13 +156,6 @@ class PrintOpFilegroupTest(test.TestCase):
expected = '''// This file was autogenerated by %s
#ifndef OPS_TO_REGISTER
#define OPS_TO_REGISTER
-constexpr inline bool ShouldRegisterOp(const char op[]) {
- return false
- || (strcmp(op, "BiasAdd") == 0)
- ;
-}
-#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
-
namespace {
constexpr const char* skip(const char* x) {
@@ -194,6 +187,13 @@ constexpr inline bool ShouldRegisterOp(const char op[]) {
};
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
+constexpr inline bool ShouldRegisterOp(const char op[]) {
+ return false
+ || isequal(op, "BiasAdd")
+ ;
+}
+#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
+
#define SHOULD_REGISTER_OP_GRADIENT false
#endif''' % self.script_name
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index 7be61ca379..7f7470994d 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -100,15 +100,6 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
append('#define SHOULD_REGISTER_OP_GRADIENT true')
else:
- append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
- append(' return false')
- for op in sorted(ops):
- append(' || (strcmp(op, "%s") == 0)' % op)
- append(' ;')
- append('}')
- append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
- append('')
-
line = '''
namespace {
constexpr const char* skip(const char* x) {
@@ -147,6 +138,15 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
'kNecessaryOpKernelClasses))')
append('')
+ append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
+ append(' return false')
+ for op in sorted(ops):
+ append(' || isequal(op, "%s")' % op)
+ append(' ;')
+ append('}')
+ append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
+ append('')
+
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
'true' if 'SymbolicGradient' in ops else 'false'))