aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/experimental/writer/BUILD66
-rw-r--r--tensorflow/contrib/lite/experimental/writer/enum_mapping.h116
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc370
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer.cc41
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc281
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h126
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc62
-rw-r--r--tensorflow/contrib/lite/schema/BUILD14
-rw-r--r--third_party/flatbuffers/BUILD.bazel1
-rw-r--r--third_party/flatbuffers/build_defs.bzl19
10 files changed, 1088 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD
new file mode 100644
index 0000000000..82d39c00ab
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "option_writer_generator",
+ srcs = ["option_writer_generator.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "writer_lib",
+ srcs = [
+ "enum_mapping.h",
+ "writer_lib.cc",
+ ],
+ hdrs = [
+ "writer_lib.h",
+ ],
+ data = [
+ ":option_writer_gen",
+ ],
+ textual_hdrs = ["option_writer_generated.h"],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ ],
+)
+
+cc_binary(
+ name = "writer",
+ srcs = ["writer.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "writer_lib_test",
+ size = "small",
+ srcs = ["writer_lib_test.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+genrule(
+ name = "option_writer_gen",
+ outs = ["option_writer_generated.h"],
+ cmd = "$(location :option_writer_generator) $(@)",
+ tools = [":option_writer_generator"],
+)
diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
new file mode 100644
index 0000000000..8bc464fd71
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+// TODO(aselle): Ideally extract this from the schema.
+
+namespace tflite {
+
+inline ActivationFunctionType TfLiteActivationToSchemaActivation(
+ TfLiteFusedActivation act) {
+ switch (act) {
+ case kTfLiteActNone:
+ return ActivationFunctionType_NONE;
+ case kTfLiteActRelu:
+ return ActivationFunctionType_RELU;
+ case kTfLiteActRelu1:
+ return ActivationFunctionType_RELU_N1_TO_1;
+ case kTfLiteActRelu6:
+ return ActivationFunctionType_RELU6;
+ case kTfLiteActTanh:
+ return ActivationFunctionType_TANH;
+ case kTfLiteActSignBit:
+ return ActivationFunctionType_SIGN_BIT;
+ case kTfLiteActSigmoid:
+ return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
+ }
+ return ActivationFunctionType_NONE;
+}
+
+inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingUnknown:
+ return Padding_SAME; // TODO(aselle): Consider an error.
+ case kTfLitePaddingSame:
+ return Padding_SAME;
+ case kTfLitePaddingValid:
+ return Padding_VALID;
+ }
+ return Padding_SAME; // TODO(aselle): Consider an error.
+}
+
+inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
+ switch (type) {
+ // case kTfLiteNoType: return TensorType_NONE;
+ case kTfLiteNoType:
+ return TensorType_FLOAT32; // TODO(aselle): Consider an error.
+ case kTfLiteFloat32:
+ return TensorType_FLOAT32;
+ case kTfLiteInt32:
+ return TensorType_INT32;
+ case kTfLiteUInt8:
+ return TensorType_UINT8;
+ case kTfLiteInt64:
+ return TensorType_INT64;
+ case kTfLiteString:
+ return TensorType_STRING;
+ case kTfLiteBool:
+ return TensorType_BOOL;
+ case kTfLiteInt16:
+ return TensorType_INT16;
+ case kTfLiteComplex64:
+ return TensorType_COMPLEX64;
+ }
+ // TODO(aselle): consider an error
+}
+
+inline FullyConnectedOptionsWeightsFormat
+FullyConnectedOptionsWeightsFormatToSchema(
+ TfLiteFullyConnectedWeightsFormat format) {
+ switch (format) {
+ case kTfLiteFullyConnectedWeightsFormatDefault:
+ return FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
+ return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ }
+}
+
+inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
+ switch (type) {
+ case kTfLiteLSTMFullKernel:
+ return LSTMKernelType_FULL;
+ case kTfLiteLSTMBasicKernel:
+ return LSTMKernelType_BASIC;
+ }
+}
+
+inline LSHProjectionType LSHProjectionTypeToSchema(
+ TfLiteLSHProjectionType type) {
+ switch (type) {
+ case kTfLiteLshProjectionUnknown:
+ return LSHProjectionType_UNKNOWN;
+ case kTfLiteLshProjectionSparse:
+ return LSHProjectionType_SPARSE;
+ case kTfLiteLshProjectionDense:
+ return LSHProjectionType_DENSE;
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000000..e6d5a776b3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h" // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+ "TfLitePoolParams",
+ "TfLiteDepthwiseConvParams",
+ "TfLiteSVDFParams",
+ "TfLiteRNNParams",
+ "TfLiteSequenceRNNParams",
+ "TfLiteFullyConnectedParams",
+ "TfLiteLSHProjectionParams",
+ "TfLiteSoftmaxParams",
+ "TfLiteConcatenationParams",
+ "TfLiteAddParams",
+ "TfLiteSpaceToBatchNDParams",
+ "TfLiteBatchToSpaceNDParams",
+ "TfLiteMulParams",
+ "TfLiteSubParams",
+ "TfLiteDivParams",
+ "TfLiteL2NormParams",
+ "TfLiteLocalResponseNormParams",
+ "TfLiteLSTMParams",
+ "TfLiteResizeBilinearParams",
+ "TfLitePadParams",
+ "TfLitePadV2Params",
+ "TfLiteReshapeParams",
+ "TfLiteSkipGramParams",
+ "TfLiteSpaceToDepthParams",
+ "TfLiteCastParams",
+ "TfLiteEmbeddingLookupSparseParams",
+ "TfLiteGatherParams",
+ "TfLiteTransposeParams",
+ "TfLiteReducerParams",
+ "TfLiteSplitParams",
+ "TfLiteSqueezeParams",
+ "TfLiteStridedSliceParams",
+ "TfLiteArgMaxParams",
+ "TfLiteArgMinParams",
+ "TfLiteTransposeConvParams",
+ "TfLiteSparseToDenseParams",
+ "TfLiteShapeParams",
+ "TfLiteFakeQuantParams",
+ "TfLitePackParams",
+ "TfLiteOneHotParams",
+ nullptr};
+} // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+ const char* s = in.c_str();
+ bool first = true;
+ std::string out;
+ while (*s != '\0') {
+ if (*s == '_') {
+ first = true;
+ } else if (first) {
+ out.push_back(tolower(*s));
+ first = false;
+ } else {
+ out.push_back(tolower(*s));
+ }
+ s++;
+ }
+ return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+ OpOptionData() {
+ BuildOpList();
+ BuildOptionToTypeFunctionMap();
+ BuildOpToOptionMap();
+ }
+
+ // A list of builtin operations
+ const std::vector<std::string>& ops() const { return ops_; }
+ // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+ const std::unordered_map<std::string, std::string>& op_to_option() {
+ return op_to_option_;
+ }
+ // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+ const std::unordered_map<std::string, std::string>& option_to_struct() {
+ return option_to_struct_;
+ }
+ // Maps from option to a flatbuffer type function that describes that option.
+ const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+ option_to_type_function() {
+ return option_to_type_function_;
+ }
+
+ private:
+ void BuildOpList() {
+ for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+ ++curr) {
+ if (strlen(*curr) != 0) ops_.push_back(*curr);
+ }
+ }
+
+ void BuildOptionToTypeFunctionMap() {
+ auto d = tflite::BuiltinOptionsTypeTable();
+ for (int i = 0; i < d->num_elems; i++) {
+ flatbuffers::TypeCode code = d->type_codes[i];
+ if (code.sequence_ref != -1) {
+ option_to_type_function_.insert(
+ std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+ }
+ }
+ }
+
+ void BuildOpToOptionMap() {
+ // Manually specified mappings between ops and options
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+ op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+ op_to_option_["UNPACK"] = "";
+ op_to_option_["SUM"] = "ReducerOptions";
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+ op_to_option_["MEAN"] = "ReducerOptions";
+ op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ // Manually specified mappings between ops and options (none)
+ op_to_option_["EMBEDDING_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["FLOOR"] = "";
+ op_to_option_["HASHTABLE_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["LOGISTIC"] = "";
+ op_to_option_["RELU"] = "";
+ op_to_option_["RELU_N1_TO_1"] = "";
+ op_to_option_["RELU6"] = "";
+ op_to_option_["TANH"] = "";
+ op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["PRELU"] = "";
+ op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["SIN"] = "";
+ op_to_option_["LOG"] = "";
+ op_to_option_["SQRT"] = "";
+ op_to_option_["RSQRT"] = "";
+
+ // TODO(aselle): These are undesirable hacks. Consider changing C structs
+ option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+ option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+ option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+ option_to_struct_["LocalResponseNormalizationOptions"] =
+ "TfLiteLocalResponseNormParams";
+ // Now for every op, try to find an option.
+ bool fatal = false;
+ for (auto op_name : ops_) {
+ bool found_option = false;
+ auto d = tflite::BuiltinOptionsTypeTable();
+ std::string collapsed_option_name_guess =
+ ToCollapsed(op_name) + "options";
+ // O(n^2) but not that big of n.
+ for (int i = 0; i < d->num_elems; i++) {
+ std::string option_name = d->names[i];
+ std::string collapsed_option_name = ToCollapsed(option_name);
+ if (collapsed_option_name_guess == collapsed_option_name) {
+ op_to_option_.insert(std::make_pair(op_name, option_name));
+ found_option = true;
+ break;
+ }
+ }
+ auto it = op_to_option_.find(op_name);
+ if (it == op_to_option_.end()) {
+ std::cerr << "Didn't find option for " << op_name << std::endl;
+ fatal = true;
+ } else if (!it->second.empty()) {
+ std::string option_name = it->second;
+
+ if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+ bool param_struct_found = false;
+ std::string params_guess = std::string("TfLite") + option_name;
+ size_t start = params_guess.find("Options");
+ size_t len = strlen("Options");
+ params_guess.replace(start, len, "Params");
+ for (auto* param = param_structs; *param != nullptr; param++) {
+ if (*param == params_guess) {
+ param_struct_found = true;
+ break;
+ }
+ }
+ if (!param_struct_found) {
+ std::cerr << "Failed to get param struct for option " << option_name
+ << std::endl;
+ fatal = true;
+ } else {
+ option_to_struct_.insert(std::make_pair(option_name, params_guess));
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<std::string> ops_;
+ std::unordered_map<std::string, std::string> op_to_option_;
+ std::unordered_map<std::string, std::string> option_to_struct_;
+ std::unordered_map<std::string, flatbuffers::TypeFunction>
+ option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+ const std::string& option_name,
+ const std::string& option_type,
+ const flatbuffers::TypeTable* options,
+ const std::string& struct_name) {
+ // Skip tricky ones for now
+ if (struct_name == "TfLiteResizeBilinearParams") return;
+ if (struct_name == "TfLiteSqueezeParams") return;
+ if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+ if (struct_name == "TfLiteReshapeParams") return;
+
+ fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
+ fprintf(fp,
+ " const auto* params = reinterpret_cast<const "
+ "%s*>(builtin_op_data);\n",
+ struct_name.c_str());
+
+ for (size_t i = 0; i < options->num_elems; i++) {
+ std::string elem_name = options->names[i];
+ // TODO(aselle): Irregular naming in builtins
+ if (elem_name == "fused_activation_function")
+ elem_name = "activation";
+ else if (elem_name == "stride_w")
+ elem_name = "stride_width";
+ else if (elem_name == "stride_h")
+ elem_name = "stride_height";
+ else if (elem_name == "dilation_h_factor")
+ elem_name = "dilation_height_factor";
+ else if (elem_name == "dilation_w_factor")
+ elem_name = "dilation_width_factor";
+ else if (elem_name == "new_shape")
+ elem_name = "shape";
+
+ flatbuffers::TypeCode code = options->type_codes[i];
+ auto contained_type = code.sequence_ref != -1
+ ? options->type_refs[code.sequence_ref]
+ : nullptr;
+ std::string mapper = "";
+ if (contained_type == TensorTypeTypeTable) {
+ mapper = "TfLiteTypeToSchemaType";
+ } else if (contained_type == ActivationFunctionTypeTypeTable) {
+ mapper = "TfLiteActivationToSchemaActivation";
+ } else if (contained_type == PaddingTypeTable) {
+ mapper = "TfLitePaddingToSchemaPadding";
+ } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+ mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+ } else if (contained_type == LSTMKernelTypeTypeTable) {
+ mapper = "LSTMKernelTypeToSchema";
+ } else if (contained_type == LSHProjectionTypeTypeTable) {
+ mapper = "LSHProjectionTypeToSchema";
+ }
+
+ fprintf(fp,
+ " auto val%zu = "
+ "%s(params->%s);\n",
+ i, mapper.c_str(), elem_name.c_str());
+ }
+ fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
+ for (size_t i = 0; i < options->num_elems; i++) {
+ fprintf(fp, ", val%zu", i);
+ }
+ fprintf(fp, ").Union();\n");
+ fprintf(fp, " return std::make_pair(%s, union_type);\n",
+ option_type.c_str());
+ fprintf(fp, " }\n break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+ std::unordered_set<std::string> ignores;
+ ignores.insert("CONCAT_EMBEDDINGS");
+ ignores.insert("CALL");
+
+ // Allow any op that doesn't have an options struct to be blocked
+ // together
+ for (const auto& op_name : option->ops()) {
+ auto option_it = option->op_to_option().find(op_name);
+ if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+ continue;
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ }
+ fprintf(fp,
+ " return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+
+ // Iterate over each ops
+ for (const auto& op_name : option->ops()) {
+ if (ignores.find(op_name) != ignores.end()) continue;
+ // Get to the option and struct names, continuing if not found.
+ auto option_it = option->op_to_option().find(op_name);
+ if (option_it->second.empty()) continue;
+ std::string option_name = option_it->second;
+ std::string option_type = "BuiltinOptions_" + option_name;
+ auto option_func_it = option->option_to_type_function().find(option_name);
+ if (option_func_it == option->option_to_type_function().end()) continue;
+ auto struct_name_it = option->option_to_struct().find(option_name);
+ if (struct_name_it == option->option_to_struct().end()) {
+ // If no C struct, then it better have no arguments.
+ auto type_info = option_func_it->second();
+ if (type_info->num_elems != 0) {
+ // We have non-zero arguments in the schema, this means there
+ // should be a struct.
+ fprintf(stderr,
+ "Op %s uses option struct %s which has no builtin struct\n",
+ op_name.c_str(), option_name.c_str());
+ exit(1);
+ }
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
+ option_type.c_str(), option_name.c_str());
+ } else {
+ // If C struct, then we need to assign all properties
+ auto struct_name = struct_name_it->second;
+ GenerateImportForOp(fp, op_name, option_name, option_type,
+ option_func_it->second(), struct_name);
+ }
+ }
+ // TODO(aselle): Handle unhandled cases more gracefully.
+ fprintf(fp,
+ "default: return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+}
+
+} // namespace tflite
+
+int main(int argc, char* argv[]) {
+ tflite::OpOptionData option;
+ if (argc != 2) {
+ fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+ return 1;
+ }
+ FILE* fp = fopen(argv[1], "w");
+ tflite::GenerateImport(&option, fp);
+ fclose(fp);
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc
new file mode 100644
index 0000000000..20ede214fb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Just does a read/write loop of tflite file format using the interpreter as
+// an intermediate.
+//
+// Usage:
+// writer <input tflite> <output tflite>
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+int main(int argc, char* argv[]) {
+ if (argc != 3) {
+ fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
+ return 1;
+ }
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(argv[1]);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
+ tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
+ tflite::InterpreterWriter writer(interpreter.get());
+ writer.Write(argv[2]);
+
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
new file mode 100644
index 0000000000..52b17faf82
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <cstdlib>
+#include <cstring>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+template <class T>
+using Offset = flatbuffers::Offset<T>;
+template <class T>
+using Vector = flatbuffers::Vector<T>;
+using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
+
+std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
+ FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
+ switch (op) {
+#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
+ }
+ return std::make_pair(BuiltinOptions_NONE, Offset<void>());
+}
+
+template <class T_OUTPUT, class T_INPUT>
+Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
+ const T_INPUT& v) {
+ std::vector<T_OUTPUT> inputs(v.begin(), v.end());
+ return fbb->template CreateVector<T_OUTPUT>(inputs);
+}
+
+Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Operator>> operators;
+
+ std::vector<int> operator_to_opcode;
+ // TODO(aselle): Augment this once we put execution plan in schema.
+ operator_to_opcode.resize(interpreter_->nodes_size(), -1);
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteRegistration* registration = &node_and_registration->second;
+ if (!registration->custom_name) {
+ operator_to_opcode[op_index] =
+ GetOpCodeForBuiltin(registration->builtin_code);
+ } else {
+ operator_to_opcode[op_index] =
+ GetOpCodeForCustom(registration->custom_name);
+ }
+ }
+ // second pass serialize operators
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ Offset<void> builtin_options;
+ BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
+ // Custom data
+ // TODO(aselle): Custom options format is not known by default. Just assume
+ // for now.
+ auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
+ Offset<Vector<uint8_t>> custom_options = 0;
+
+ if (!registration.custom_name) {
+ // builtin
+ auto builtin_options_and_type = CreateBuiltinUnion(
+ fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
+ node.builtin_data);
+ builtin_options = builtin_options_and_type.second;
+ builtin_options_type = builtin_options_and_type.first;
+ } else {
+ auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
+ if (custom_writer != custom_op_to_writer_.end() &&
+ custom_writer->second) {
+ // delegate to custom writer if it exists
+ custom_writer->second(fbb, interpreter_, op_index, &custom_options,
+ &custom_options_format);
+ } else {
+ // use the custom data as fact
+ custom_options = fbb->CreateVector(
+ reinterpret_cast<const uint8_t*>(node.custom_initial_data),
+ node.custom_initial_data_size);
+ }
+ }
+
+ int opcode_index = operator_to_opcode[op_index];
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
+ auto inputs = ExportVector<int32_t>(fbb, written_inputs);
+ auto outputs = ExportVector<int32_t>(fbb, written_outputs);
+ operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
+ builtin_options_type, builtin_options,
+ custom_options, custom_options_format));
+ }
+
+ return fbb->template CreateVector<Offset<Operator>>(operators);
+}
+
+Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
+ FlatBufferBuilder* fbb) {
+ tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
+
+ std::vector<Offset<Tensor>> tensors;
+
+ // Make a map from tensor index to whether the tensor is a temporary.
+ std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
+ for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ for (auto tensor_index :
+ TfLiteIntArrayView(node_and_registration->first.temporaries))
+ tensor_is_temporary[tensor_index] = true;
+ }
+
+ // Now we need to remap all used tensor indices
+ int curr_output_index = 0;
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ tensor_index++) {
+ if (!tensor_is_temporary[tensor_index]) {
+ tensor_to_written_tensor_[tensor_index] = curr_output_index++;
+ }
+ }
+
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ ++tensor_index) {
+ // Skip temporaries.
+ if (tensor_is_temporary[tensor_index]) continue;
+
+ if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
+ // We only need to convert non temporaries
+ if (tensor->allocation_type != kTfLiteArenaRw &&
+ tensor->allocation_type != kTfLiteMmapRo &&
+ tensor->allocation_type != kTfLiteArenaRwPersistent)
+ continue;
+ // Allocate a buffer index
+ int buffer_index = 0; // This is null
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ buffer_index = buffers_.size();
+ buffers_.push_back(std::make_pair(
+ reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
+ }
+ // Primitive type.
+ TensorType type = TfLiteTypeToSchemaType(tensor->type);
+ // Handle quantization
+ const Offset<Vector<float>> null_array;
+ Offset<Vector<float>> scale_array;
+ Offset<Vector<int64_t>> zero_point_array;
+ if (tensor->params.scale != 0.f) {
+ // We have quantization, make a single arugment array (multi channel
+ // quant needs updating here).
+ scale_array = fbb->CreateVector<float>({tensor->params.scale});
+ zero_point_array =
+ fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ }
+ Offset<QuantizationParameters> quantization_params =
+ CreateQuantizationParameters(*fbb, null_array, null_array,
+ scale_array, zero_point_array);
+ // Shape
+ TfLiteIntArrayView shape_view(tensor->dims);
+ std::vector<int> shape =
+ std::vector<int>(shape_view.begin(), shape_view.end());
+
+ tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
+ type, buffer_index,
+ fbb->CreateString(tensor->name),
+ quantization_params, tensor->is_variable));
+ }
+ }
+ return fbb->template CreateVector<Offset<Tensor>>(tensors);
+}
+
+Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (auto buffer : buffers_) {
+ auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+ buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+ }
+ return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
+}
+
+Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<OperatorCode>> codes;
+ for (auto it : opcodes_) {
+ const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+ codes.push_back(CreateOperatorCodeDirect(
+ *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+ }
+ return fbb->template CreateVector<Offset<OperatorCode>>(codes);
+}
+
+template <class T>
+std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
+ const T& input) {
+ std::vector<int> output;
+ output.reserve(input.size());
+ for (int x : input) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
+ return output;
+}
+
+TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+ size_t* size) {
+ if (!out || !size) return kTfLiteError;
+ FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ std::vector<Offset<SubGraph>> subgraphs_as_vector;
+ { // subgraph specific stuff
+ auto tensors = ExportTensors(&builder);
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(interpreter_->inputs());
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(interpreter_->outputs());
+ auto inputs = ExportVector<int32_t>(&builder, written_inputs);
+ auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+
+ auto ops = ExportOperators(&builder);
+ subgraphs_as_vector.push_back(
+ CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
+ }
+ Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
+
+ auto description = builder.CreateString("Exported from Interpreter.");
+
+ auto op_codes = CreateOpCodeTable(&builder);
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs_as_vector),
+ description, buffers);
+ ::tflite::FinishModelBuffer(builder, model);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ *size = builder.GetSize();
+ (*out).reset(new uint8_t[*size]);
+ memcpy(out->get(), buffer, *size);
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
+ std::unique_ptr<uint8_t[]> buffer;
+ size_t size;
+ TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+
+ FILE* fp = fopen(filename.c_str(), "wb");
+ if (!fp) return kTfLiteError;
+
+ if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
+ if (fclose(fp)) return kTfLiteError;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::RegisterCustomWriter(
+ const std::string& custom_name, CustomWriter custom_writer) {
+ if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
+ return kTfLiteError;
+ }
+ custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
new file mode 100644
index 0000000000..a98108b496
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -0,0 +1,126 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
+//
+// Usage:
+// From command line:
+// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
+// -- foo.tflite foo.out.tflite
+//
+// From C++
+// std::unique_ptr<Interpreter> interpreter;
+// // Build Interpreter however
+// // ... <omitted>
+// InterpreterWriter(interpreter.get()).Write("output.tflite");
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#include <iostream>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
+// file format.
+class InterpreterWriter {
+ public:
+ typedef flatbuffers::Offset<Operator> (*CustomWriter)(
+ flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
+ int node_index,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
+ CustomOptionsFormat* custom_options_format);
+
+ // Construct an interpreter writer for the specified `interpreter`. Then,
+ // a uses .Write() or .GetBuffer(...) to extract the data.
+ explicit InterpreterWriter(Interpreter* interpreter)
+ : interpreter_(interpreter) {
+ buffers_.push_back(std::make_pair(nullptr, 0));
+ }
+
+ // Get a buffer and size of a serialized flatbuffer.
+ TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+ // Write the serialized flatbuffer to the prescribed `filename`.
+ TfLiteStatus Write(const std::string& filename);
+ // Registers a custom writer for a custom op. The customization allows the
+ // caller to change the custom data.
+ TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
+ CustomWriter custom_writer);
+
+ private:
+ template <class T>
+ using Offset = flatbuffers::Offset<T>;
+ template <class T_OUTPUT, class T_INPUT>
+ Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
+ flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
+ Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+ flatbuffers::FlatBufferBuilder* fbb);
+
+ template <class T>
+ std::vector<int> RemapTensorIndicesToWritten(const T& input);
+
+ int GetOpCodeForBuiltin(int builtin_op_index) {
+ // auto it = builtin_op_to_opcode_.find(builtin_op_index);
+ std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
+ builtin_op_to_opcode_.insert(
+ std::make_pair(builtin_op_index, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({builtin_op_index, ""});
+ }
+ return result.first->second;
+ }
+
+ int GetOpCodeForCustom(const std::string& custom_name) {
+ std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
+ custom_op_to_opcode_.insert(
+ std::make_pair(custom_name, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+ }
+ return result.first->second;
+ }
+
+ // The interpreter we are writing
+ Interpreter* interpreter_;
+ // Keep track of byte buffers
+ std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+ // List of op codes and mappings from builtin or custom op to opcode
+ struct OpCode {
+ int builtin;
+ std::string custom;
+ };
+ // For every tensor index in the interpreter, the index in the written.
+ // This is different due to temporary tensors not being written.
+ std::vector<int> tensor_to_written_tensor_;
+ // List of used opcodes
+ std::vector<OpCode> opcodes_;
+ std::unordered_map<int, int> builtin_op_to_opcode_;
+ std::unordered_map<std::string, int> custom_op_to_opcode_;
+ std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
new file mode 100644
index 0000000000..49194a76c8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+// Make an interpreter that has no tensors and no nodes
+// TODO(b/113731921): add more tests.
+TEST(Writer, BasicTest) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+ float foo[] = {1, 2, 3};
+ interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetTensorParametersReadOnly(
+ 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
+ reinterpret_cast<char*>(foo), sizeof(foo));
+ interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({2});
+ const char* initial_data = "";
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ TfLiteAddParams* builtin_data =
+ reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+ builtin_data->activation = kTfLiteActNone;
+ const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+ interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+ reinterpret_cast<void*>(builtin_data), reg);
+
+ InterpreterWriter writer(&interpreter);
+ writer.Write("/tmp/test.tflite");
+ std::unique_ptr<FlatBufferModel> model =
+ FlatBufferModel::BuildFromFile("/tmp/test.tflite");
+ InterpreterBuilder builder(*model, resolver);
+ std::unique_ptr<Interpreter> new_interpreter;
+ builder(&new_interpreter);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index 28a7e50003..55bf2c48b9 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -56,6 +56,20 @@ flatbuffer_cc_library(
srcs = ["schema.fbs"],
)
+# Generic schema for inference on device (but with reflections makes bigger).
+flatbuffer_cc_library(
+ name = "schema_fbs_with_reflection",
+ srcs = ["schema.fbs"],
+ flatc_args = [
+ "--reflect-types",
+ "--reflect-names",
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+ ],
+ gen_reflections = True,
+ out_prefix = "reflection/",
+)
+
# Schema test to make sure we don't introduce backward incompatible changes
# to schemas.
cc_test(
diff --git a/third_party/flatbuffers/BUILD.bazel b/third_party/flatbuffers/BUILD.bazel
index 9d233a30d6..934c0d9650 100644
--- a/third_party/flatbuffers/BUILD.bazel
+++ b/third_party/flatbuffers/BUILD.bazel
@@ -142,6 +142,7 @@ filegroup(
srcs = [
"include/flatbuffers/base.h",
"include/flatbuffers/flatbuffers.h",
+ "include/flatbuffers/minireflect.h",
"include/flatbuffers/stl_emulation.h",
"include/flatbuffers/util.h",
],
diff --git a/third_party/flatbuffers/build_defs.bzl b/third_party/flatbuffers/build_defs.bzl
index 2f25156668..235b44f7cf 100644
--- a/third_party/flatbuffers/build_defs.bzl
+++ b/third_party/flatbuffers/build_defs.bzl
@@ -92,14 +92,17 @@ def flatbuffer_library_public(
cmd = reflection_genrule_cmd,
message = "Generating flatbuffer reflection binary for %s:" % (name),
)
- native.Fileset(
- name = reflection_name,
- out = "%s_out" % reflection_name,
- entries = [
- native.FilesetEntry(files = reflection_outs),
- ],
- visibility = reflection_visiblity,
- )
+ # TODO(b/114456773): Make bazel rules proper and supported by flatbuffer
+ # Have to comment this since FilesetEntry is not supported in bazel
+ # skylark.
+ # native.Fileset(
+ # name = reflection_name,
+ # out = "%s_out" % reflection_name,
+ # entries = [
+ # native.FilesetEntry(files = reflection_outs),
+ # ],
+ # visibility = reflection_visiblity,
+ # )
def flatbuffer_cc_library(
name,