diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-05 14:40:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-05 14:47:01 -0700 |
commit | 75390d4c3568358ea81a072b0ccc94071022c38d (patch) | |
tree | ee0b5c314e726471299f0018d59e504fc91457cf /tensorflow/python/tools | |
parent | a3c1ccd1da64040eeb139a0c6c1fc34ae46d7290 (diff) |
Special-case the AccumulateNV2 op in print_selective_registration_header
AccumulateNV2 doesn't have or need a kernel. It gets rewritten to other ops by
accumulate_n_optimizer.cc. This change allows it to be mentioned in the output
of print_selective_registration_header, rather than being ignored with a
warning. Behavior for other ops is preserved.
PiperOrigin-RevId: 211701878
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r-- | tensorflow/python/tools/print_selective_registration_header_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/tools/selective_registration_header_lib.py | 17 |
2 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index 4b3d98242c..cce8060fb9 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -59,6 +59,9 @@ GRAPH_DEF_TXT = """ } """ +# AccumulateNV2 is included because it should be included in the header despite +# lacking a kernel (it's rewritten by AccumulateNV2RemovePass; see +# core/common_runtime/accumulate_n_optimizer.cc. GRAPH_DEF_TXT_2 = """ node: { name: "node_4" @@ -67,6 +70,12 @@ GRAPH_DEF_TXT_2 = """ device: "/cpu:0" attr: { key: "T" value: { type: DT_FLOAT } } } + node: { + name: "node_5" + op: "AccumulateNV2" + attr: { key: "T" value: { type: DT_INT32 } } + attr: { key : "N" value: { i: 3 } } + } """ @@ -100,6 +109,7 @@ class PrintOpFilegroupTest(test.TestCase): self.assertListEqual( [ + ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # @@ -117,6 +127,7 @@ class PrintOpFilegroupTest(test.TestCase): 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ + ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp<CPUDevice, float>'), # ('MatMul', matmul_prefix + 'MatMulOp<CPUDevice, double, false >'), # @@ -196,6 +207,7 @@ class PrintOpFilegroupTest(test.TestCase): constexpr inline bool ShouldRegisterOp(const char op[]) { return false + || isequal(op, "AccumulateNV2") || isequal(op, "BiasAdd") ; } diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index dc0612bb3f..b99c632c3e 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -32,6 +32,16 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging +# Usually, we use each graph node to induce registration of an op and +# corresponding kernel; nodes without a corresponding kernel (perhaps due to +# attr types) generate a warning but are otherwise ignored. Ops in this set are +# registered even if there's no corresponding kernel. +OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ + # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see + # core/common_runtime/accumulate_n_optimizer.cc. + 'AccumulateNV2' +]) + def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): """Gets the ops and kernels needed from the model files.""" @@ -53,8 +63,10 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): node_def.device = '/cpu:0' kernel_class = pywrap_tensorflow.TryFindKernelClass( node_def.SerializeToString()) - if kernel_class: - op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8'))) + op = str(node_def.op) + if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: + op_and_kernel = (op, str(kernel_class.decode('utf-8')) + if kernel_class else None) if op_and_kernel not in ops: ops.add(op_and_kernel) else: @@ -129,6 +141,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels, ''' line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' for _, kernel_class in ops_and_kernels: + if kernel_class is None: continue line += '"%s",\n' % kernel_class line += '};' append(line) |