aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/mutable_op_resolver.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/mutable_op_resolver.h')
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h79
1 files changed, 79 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000000..c319041e9b
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ return CombineHashes({a, b});
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_