aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/selective_registration_header_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/selective_registration_header_lib.py')
-rw-r--r--tensorflow/python/tools/selective_registration_header_lib.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py
index dc0612bb3f..b99c632c3e 100644
--- a/tensorflow/python/tools/selective_registration_header_lib.py
+++ b/tensorflow/python/tools/selective_registration_header_lib.py
@@ -32,6 +32,16 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging
+# Usually, we use each graph node to induce registration of an op and
+# corresponding kernel; nodes without a corresponding kernel (perhaps due to
+# attr types) generate a warning but are otherwise ignored. Ops in this set are
+# registered even if there's no corresponding kernel.
+OPS_WITHOUT_KERNEL_WHITELIST = frozenset([
+ # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see
+ # core/common_runtime/accumulate_n_optimizer.cc.
+ 'AccumulateNV2'
+])
+
def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
"""Gets the ops and kernels needed from the model files."""
@@ -53,8 +63,10 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str):
node_def.device = '/cpu:0'
kernel_class = pywrap_tensorflow.TryFindKernelClass(
node_def.SerializeToString())
- if kernel_class:
- op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8')))
+ op = str(node_def.op)
+ if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST:
+ op_and_kernel = (op, str(kernel_class.decode('utf-8'))
+ if kernel_class else None)
if op_and_kernel not in ops:
ops.add(op_and_kernel)
else:
@@ -129,6 +141,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
'''
line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n'
for _, kernel_class in ops_and_kernels:
+ if kernel_class is None: continue
line += '"%s",\n' % kernel_class
line += '};'
append(line)