diff options
Diffstat (limited to 'tensorflow/python/tools/selective_registration_header_lib.py')
-rw-r--r-- | tensorflow/python/tools/selective_registration_header_lib.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index dc0612bb3f..b99c632c3e 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -32,6 +32,16 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging +# Usually, we use each graph node to induce registration of an op and +# corresponding kernel; nodes without a corresponding kernel (perhaps due to +# attr types) generate a warning but are otherwise ignored. Ops in this set are +# registered even if there's no corresponding kernel. +OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ + # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see + # core/common_runtime/accumulate_n_optimizer.cc. + 'AccumulateNV2' +]) + def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): """Gets the ops and kernels needed from the model files.""" @@ -53,8 +63,10 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): node_def.device = '/cpu:0' kernel_class = pywrap_tensorflow.TryFindKernelClass( node_def.SerializeToString()) - if kernel_class: - op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8'))) + op = str(node_def.op) + if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: + op_and_kernel = (op, str(kernel_class.decode('utf-8')) + if kernel_class else None) if op_and_kernel not in ops: ops.add(op_and_kernel) else: @@ -129,6 +141,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels, ''' line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' for _, kernel_class in ops_and_kernels: + if kernel_class is None: continue line += '"%s",\n' % kernel_class line += '};' append(line) |