diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-29 06:28:56 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-29 07:35:19 -0700 |
commit | 783e8388db67f3fefc6c714d479615821a9dc5e7 (patch) | |
tree | a25e65ce1552ca593f7cf72d61b0f35bfd18613f /tensorflow/python/util/kernel_registry.cc | |
parent | 7a0f6252d7f93b8cc78e8cc61c88f45d98642b30 (diff) |
Add python/tools/print_selective_registration_header utility for printing a
header for use with 'selective registration' (see
core/framework/selective_registration.h), given an input model.
Change: 134659033
Diffstat (limited to 'tensorflow/python/util/kernel_registry.cc')
-rw-r--r-- | tensorflow/python/util/kernel_registry.cc | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/python/util/kernel_registry.cc b/tensorflow/python/util/kernel_registry.cc new file mode 100644 index 0000000000..b05c2ef04b --- /dev/null +++ b/tensorflow/python/util/kernel_registry.cc @@ -0,0 +1,57 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/python/util/kernel_registry.h" + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace swig { + +string TryFindKernelClass(const string& serialized_node_def) { + tensorflow::NodeDef node_def; + if (!node_def.ParseFromString(serialized_node_def)) { + LOG(WARNING) << "Error parsing node_def"; + return ""; + } + + const tensorflow::OpRegistrationData* op_reg_data; + auto status = + tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); + if (!status.ok()) { + LOG(WARNING) << "Op " << node_def.op() << " not found: " << status; + return ""; + } + AddDefaultsToNodeDef(op_reg_data->op_def, &node_def); + + tensorflow::DeviceNameUtils::ParsedName parsed_name; + if (!tensorflow::DeviceNameUtils::ParseFullName(node_def.device(), + &parsed_name)) { + LOG(WARNING) << "Failed to parse device from node_def: " + << node_def.ShortDebugString(); + return ""; + } + string class_name = ""; + tensorflow::FindKernelDef(tensorflow::DeviceType(parsed_name.type.c_str()), + node_def, nullptr /* kernel_def */, &class_name); + return class_name; +} + +} // namespace swig +} // namespace tensorflow |