aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/kernel_registry.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 06:28:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 07:35:19 -0700
commit783e8388db67f3fefc6c714d479615821a9dc5e7 (patch)
treea25e65ce1552ca593f7cf72d61b0f35bfd18613f /tensorflow/python/util/kernel_registry.cc
parent7a0f6252d7f93b8cc78e8cc61c88f45d98642b30 (diff)
Add python/tools/print_selective_registration_header utility for printing a
header for use with 'selective registration' (see core/framework/selective_registration.h), given an input model. Change: 134659033
Diffstat (limited to 'tensorflow/python/util/kernel_registry.cc')
-rw-r--r--tensorflow/python/util/kernel_registry.cc57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/python/util/kernel_registry.cc b/tensorflow/python/util/kernel_registry.cc
new file mode 100644
index 0000000000..b05c2ef04b
--- /dev/null
+++ b/tensorflow/python/util/kernel_registry.cc
@@ -0,0 +1,57 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/python/util/kernel_registry.h"
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace swig {
+
+string TryFindKernelClass(const string& serialized_node_def) {
+ tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(serialized_node_def)) {
+ LOG(WARNING) << "Error parsing node_def";
+ return "";
+ }
+
+ const tensorflow::OpRegistrationData* op_reg_data;
+ auto status =
+ tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data);
+ if (!status.ok()) {
+ LOG(WARNING) << "Op " << node_def.op() << " not found: " << status;
+ return "";
+ }
+ AddDefaultsToNodeDef(op_reg_data->op_def, &node_def);
+
+ tensorflow::DeviceNameUtils::ParsedName parsed_name;
+ if (!tensorflow::DeviceNameUtils::ParseFullName(node_def.device(),
+ &parsed_name)) {
+ LOG(WARNING) << "Failed to parse device from node_def: "
+ << node_def.ShortDebugString();
+ return "";
+ }
+ string class_name = "";
+ tensorflow::FindKernelDef(tensorflow::DeviceType(parsed_name.type.c_str()),
+ node_def, nullptr /* kernel_def */, &class_name);
+ return class_name;
+}
+
+} // namespace swig
+} // namespace tensorflow