aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/proto_text
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-14 19:15:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-14 20:22:35 -0700
commitdf15baa9b10a0b2d194181dff7ee14bff70d9b8f (patch)
treea66e1fcb309a7404a2749140d8bac991a6524ac0 /tensorflow/tools/proto_text
parent104fe2822b419c4154d11c401ffd4a3a6e8f24c6 (diff)
Add tools/proto_text for generating ProtoDebugString,
ProtoShortDebugString, and ProtoParseFromString methods from protos. This will allow changing code used on mobile to use the proto LITE_RUNTIME, to reduce code size. This change is only for the tool itself. A future change will add a better genrule and use it the generated code in tensorflow. Change: 119919087
Diffstat (limited to 'tensorflow/tools/proto_text')
-rw-r--r--tensorflow/tools/proto_text/BUILD90
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions.cc117
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc779
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib.h42
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc500
-rw-r--r--tensorflow/tools/proto_text/test.proto107
6 files changed, 1635 insertions, 0 deletions
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
new file mode 100644
index 0000000000..c677721672
--- /dev/null
+++ b/tensorflow/tools/proto_text/BUILD
@@ -0,0 +1,90 @@
+# Description:
+# This package provides build-time generation of proto3 text format functions
+# (ProtoDebugString, ProtoShortDebugString, and ProtoParseFromString) which
+# provide equivalent functionality as proto.DebugString, proto.ShortDebugString,
+# and TextFormat parsing, but can be used with protos generated with
+# LITE_RUNTIME.
+#
+# Note that proto3 well-known types (e.g. Any) are not handled in a special way
+# by the generated code.
+
+package(default_visibility = ["//visibility:private"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library_cc",
+)
+
+cc_binary(
+ name = "gen_proto_text_functions",
+ srcs = ["gen_proto_text_functions.cc"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":gen_proto_text_functions_lib",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "gen_proto_text_functions_lib",
+ srcs = ["gen_proto_text_functions_lib.cc"],
+ hdrs = ["gen_proto_text_functions_lib.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "test_proto",
+ srcs = ["test.proto"],
+)
+
+genrule(
+ name = "test_proto_text_srcs",
+ srcs = ["test.proto"],
+ outs = [
+ "test.pb_text-impl.h",
+ "test.pb_text.h",
+ "test.pb_text.cc",
+ ],
+ cmd = "$(location :gen_proto_text_functions) " +
+ "$(@D) $(location test.proto)",
+ tools = [":gen_proto_text_functions"],
+)
+
+cc_test(
+ name = "gen_proto_text_functions_lib_test",
+ size = "small",
+ srcs = [
+ "gen_proto_text_functions_lib_test.cc",
+ ":test_proto_text_srcs",
+ ],
+ deps = [
+ ":gen_proto_text_functions_lib",
+ ":test_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets. These must be at the end for syncrepo.
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
new file mode 100644
index 0000000000..604ae432cf
--- /dev/null
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
@@ -0,0 +1,117 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+#include <set>
+
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
+
+namespace tensorflow {
+
+namespace {
+class CrashOnErrorCollector
+ : public tensorflow::protobuf::compiler::MultiFileErrorCollector {
+ public:
+ ~CrashOnErrorCollector() override {}
+
+ void AddError(const string& filename, int line, int column,
+ const string& message) override {
+ LOG(FATAL) << "Unexpected error at " << filename << "@" << line << ":"
+ << column << " - " << message;
+ }
+};
+} // namespace
+
+static const char kTensorflowHeaderPrefix[] = "";
+
+// Main program to take input protos and write output pb_text source files that
+// contain generated proto text input and output functions.
+//
+// Main expects the first argument to give the output path. This is followed by
+// pairs of arguments: <proto_name_relative_to_root, proto_file_path>.
+//
+// Note that this code doesn't use tensorflow's command line parsing, because of
+// circular dependencies between libraries if that were done.
+//
+// This is meant to be invoked by a genrule. See BUILD for more information.
+int MainImpl(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ if (argc < 3) {
+ LOG(ERROR) << "Pass output path and at least proto file";
+ return -1;
+ }
+
+ const string output_root = argv[1];
+
+ tensorflow::protobuf::compiler::DiskSourceTree source_tree;
+
+ // This requires all protos to be relative to the directory from which the
+ // genrule is invoked. If protos are generated in some other directory,
+ // then they may not be found.
+ source_tree.MapPath("", ".");
+ CrashOnErrorCollector crash_on_error;
+ tensorflow::protobuf::compiler::Importer importer(&source_tree,
+ &crash_on_error);
+
+ for (int i = 2; i < argc; i++) {
+ const string proto_path = argv[i];
+ const tensorflow::protobuf::FileDescriptor* fd =
+ importer.Import(proto_path);
+
+ string proto_name = proto_path;
+ int index = proto_name.find_last_of("/");
+ if (index != string::npos) proto_name = proto_name.substr(index + 1);
+ index = proto_name.find_last_of(".");
+ if (index != string::npos) proto_name = proto_name.substr(0, index);
+
+ const auto code =
+ tensorflow::GetProtoTextFunctionCode(*fd, kTensorflowHeaderPrefix);
+
+ // Three passes, one for each output file.
+ for (int pass = 0; pass < 3; ++pass) {
+ string suffix;
+ string data;
+ if (pass == 0) {
+ suffix = ".pb_text.h";
+ data = code.header;
+ } else if (pass == 1) {
+ suffix = ".pb_text-impl.h";
+ data = code.header_impl;
+ } else {
+ suffix = ".pb_text.cc";
+ data = code.cc;
+ }
+
+ const string path = output_root + "/" + proto_name + suffix;
+ FILE* f = fopen(path.c_str(), "w");
+ if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) {
+ return -1;
+ }
+ if (fclose(f) != 0) {
+ return -1;
+ }
+ }
+ }
+ return 0;
+}
+
+} // namespace tensorflow
+
+int main(int argc, char** argv) { return tensorflow::MainImpl(argc, argv); }
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
new file mode 100644
index 0000000000..a8e3437296
--- /dev/null
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
@@ -0,0 +1,779 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
+
+#include <algorithm>
+#include <set>
+
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+using ::tensorflow::protobuf::Descriptor;
+using ::tensorflow::protobuf::EnumDescriptor;
+using ::tensorflow::protobuf::FieldDescriptor;
+using ::tensorflow::protobuf::FieldOptions;
+using ::tensorflow::protobuf::FileDescriptor;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+namespace tensorflow {
+
+namespace {
+
+// Class used to generate the code for proto text functions. One of these should
+// be created for each FileDescriptor whose code should be generated.
+//
+// This class has a notion of the current output Section. The Print, Nested,
+// and Unnest functions apply their operations to the current output section,
+// which can be toggled with SetOutput.
+//
+// Note that on the generated code, various pieces are not optimized - for
+// example: map input and output, Cord input and output, comparisons against
+// the field names (it's a loop over all names), and tracking of has_seen.
+//
+// The generated API has, for enums and messages defined in the proto file:
+// 1. For each message:
+// * ProtoDebugString(m): same as msg.DebugString()
+// * ProtoShortDebugString(m): same as msg.ShorDebugString()
+// * ProtoParseFromString(s, m): same as TextFormat.ParseFromString(s, &m);
+// 2. For each enum:
+// * EnumName_<EnumTypeName>(enum_value): same as <EnumTypeName>(enum_value)
+// in proto.
+class Generator {
+ public:
+ Generator(const string& tf_header_prefix)
+ : tf_header_prefix_(tf_header_prefix),
+ header_(&code_.header),
+ header_impl_(&code_.header_impl),
+ cc_(&code_.cc) {}
+
+ void Generate(const FileDescriptor& fd);
+
+ // The generated code; valid after Generate has been called.
+ ProtoTextFunctionCode code() const { return code_; }
+
+ private:
+ struct Section {
+ Section(string* str) : str(str) {}
+ string* str;
+ string indent;
+ };
+
+ // Switches the currently active section to <section>.
+ Generator& SetOutput(Section* section) {
+ cur_ = section;
+ return *this;
+ }
+
+ // Increases indent level. Returns <*this>, to allow chaining.
+ Generator& Nest() {
+ StrAppend(&cur_->indent, " ");
+ return *this;
+ }
+
+ // Decreases indent level. Returns <*this>, to allow chaining.
+ Generator& Unnest() {
+ cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2);
+ return *this;
+ }
+
+ // Appends the concatenated args, with a trailing newline. Returns <*this>, to
+ // allow chaining.
+ template <typename... Args>
+ Generator& Print(Args... args) {
+ StrAppend(cur_->str, cur_->indent, args..., "\n");
+ return *this;
+ }
+
+ // Appends the print code for a single field's value.
+ // If <omit_default> is true, then the emitted code will not print zero-valued
+ // values.
+ // <field_expr> is code that when emitted yields the field's value.
+ void AppendFieldValueAppend(const FieldDescriptor& field,
+ const bool omit_default,
+ const string& field_expr);
+
+ // Appends the print code for as single field.
+ void AppendFieldAppend(const FieldDescriptor& field);
+
+ // Appends the print code for a message. May change which section is currently
+ // active.
+ void AppendDebugStringFunctions(const Descriptor& md);
+
+ // Appends the print and parse functions for an enum. May change which
+ // section is currently active.
+ void AppendEnumFunctions(const EnumDescriptor& enum_d);
+
+ // Appends the parse functions for a message. May change which section is
+ // currently active.
+ void AppendParseMessageFunction(const Descriptor& md);
+
+ // Appends all functions for a message and its nested message and enum types.
+ // May change which section is currently active.
+ void AppendMessageFunctions(const Descriptor& md);
+
+ // Appends lines to open or close namespace declarations.
+ void AddNamespaceToCurrentSection(const string& package, bool open);
+
+ // Appends the given headers as sorted #include lines.
+ void AddHeadersToCurrentSection(const std::vector<string>& headers);
+
+ // When adding #includes for tensorflow headers, prefix them with this.
+ const string tf_header_prefix_;
+ ProtoTextFunctionCode code_;
+ Section* cur_ = nullptr;
+ Section header_;
+ Section header_impl_;
+ Section cc_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Generator);
+};
+
+// Returns the prefix needed to reference objects defined in <fd>. E.g.
+// "::tensorflow::test".
+string GetPackageReferencePrefix(const FileDescriptor* fd) {
+ return StrCat("::", str_util::Join(str_util::Split(fd->package(), '.'), "::"),
+ "::");
+}
+
+// Returns the name of the class generated by proto to represent <d>.
+string GetClassName(const Descriptor& d) {
+ if (d.containing_type() == nullptr) return d.name();
+ return StrCat(GetClassName(*d.containing_type()), "_", d.name());
+}
+
+// Returns the name of the class generated by proto to represent <ed>.
+string GetClassName(const EnumDescriptor& ed) {
+ if (ed.containing_type() == nullptr) return ed.name();
+ return StrCat(GetClassName(*ed.containing_type()), "_", ed.name());
+}
+
+// Returns the qualified name that refers to the class generated by proto to
+// represent <d>.
+string GetQualifiedName(const Descriptor& d) {
+ return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
+}
+
+// Returns the qualified name that refers to the class generated by proto to
+// represent <ed>.
+string GetQualifiedName(const EnumDescriptor& d) {
+ return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
+}
+
+// Returns the qualified name that refers to the generated
+// AppendProtoDebugString function for <d>.
+string GetQualifiedAppendFn(const Descriptor& d) {
+ return StrCat(GetPackageReferencePrefix(d.file()),
+ "internal::AppendProtoDebugString");
+}
+
+// Returns the name of the generated function that returns an enum value's
+// string value.
+string GetEnumNameFn(const EnumDescriptor& enum_d) {
+ return StrCat("EnumName_", GetClassName(enum_d));
+}
+
+// Returns the qualified name of the function returned by GetEnumNameFn().
+string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) {
+ return StrCat(GetPackageReferencePrefix(enum_d.file()),
+ GetEnumNameFn(enum_d));
+}
+
+// Returns the name of a generated header file, either the public api (if impl
+// is false) or the internal implementation header (if impl is true).
+string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) {
+ const int dot_index = fd.name().find_last_of('.');
+ return fd.name().substr(0, dot_index) +
+ (impl ? ".pb_text-impl.h" : ".pb_text.h");
+}
+
+// Returns the name of the header generated by the proto library for <fd>.
+string GetProtoHeaderName(const FileDescriptor& fd) {
+ const int dot_index = fd.name().find_last_of('.');
+ return fd.name().substr(0, dot_index) + ".pb.h";
+}
+
+// Returns the C++ class name for the given proto field.
+string GetCppClass(const FieldDescriptor& d) {
+ return d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
+ ? GetQualifiedName(*d.message_type())
+ : d.cpp_type_name();
+}
+
+// Returns the string that can be used for a header guard for the generated
+// headers for <fd>, either for the public api (if impl is false) or the
+// internal implementation header (if impl is true).
+string GetHeaderGuard(const FileDescriptor& fd, bool impl) {
+ string s = fd.name();
+ std::replace(s.begin(), s.end(), '/', '_');
+ std::replace(s.begin(), s.end(), '.', '_');
+ return s + (impl ? "_IMPL_H_" : "_H_");
+}
+
+void Generator::AppendFieldValueAppend(const FieldDescriptor& field,
+ const bool omit_default,
+ const string& field_expr) {
+ SetOutput(&cc_);
+ switch (field.cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ case FieldDescriptor::CPPTYPE_INT64:
+ case FieldDescriptor::CPPTYPE_UINT32:
+ case FieldDescriptor::CPPTYPE_UINT64:
+ case FieldDescriptor::CPPTYPE_DOUBLE:
+ case FieldDescriptor::CPPTYPE_FLOAT:
+ Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric",
+ "(\"", field.name(), "\", ", field_expr, ");");
+ break;
+ case FieldDescriptor::CPPTYPE_BOOL:
+ Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"",
+ field.name(), "\", ", field_expr, ");");
+ break;
+ case FieldDescriptor::CPPTYPE_STRING: {
+ const auto ctype = field.options().ctype();
+ CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING)
+ << "Unsupported ctype " << ctype;
+
+ string str_expr = field_expr;
+ if (ctype == FieldOptions::CORD) {
+ str_expr = StrCat("(", field_expr, ").ToString()");
+ }
+ Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString",
+ "(\"", field.name(), "\", ", str_expr, ");");
+ break;
+ }
+ case FieldDescriptor::CPPTYPE_ENUM:
+ if (omit_default) {
+ Print("if (", field_expr, " != 0) {").Nest();
+ }
+ Print("o->AppendEnumName(\"", field.name(), "\", ",
+ GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));");
+ if (omit_default) {
+ Unnest().Print("}");
+ }
+ break;
+ case FieldDescriptor::CPPTYPE_MESSAGE:
+ CHECK(!field.message_type()->options().map_entry());
+ if (omit_default) {
+ Print("if (msg.has_", field.name(), "()) {").Nest();
+ }
+ Print("o->OpenNestedMessage(\"", field.name(), "\");");
+ Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr,
+ ");");
+ Print("o->CloseNestedMessage();");
+ if (omit_default) {
+ Unnest().Print("}");
+ }
+ break;
+ }
+}
+
+void Generator::AppendFieldAppend(const FieldDescriptor& field) {
+ const string name = field.name();
+
+ if (field.is_map()) {
+ Print("{").Nest();
+ const auto& key_type = *field.message_type()->FindFieldByName("key");
+ const auto& value_type = *field.message_type()->FindFieldByName("value");
+
+ Print("std::vector<", key_type.cpp_type_name(), "> keys;");
+ Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);");
+ Print("std::stable_sort(keys.begin(), keys.end());");
+ Print("for (const auto& key : keys) {").Nest();
+ Print("o->OpenNestedMessage(\"", name, "\");");
+ AppendFieldValueAppend(key_type, false /* omit_default */, "key");
+ AppendFieldValueAppend(value_type, false /* omit_default */,
+ StrCat("msg.", name, "().at(key)"));
+ Print("o->CloseNestedMessage();");
+ Unnest().Print("}");
+
+ Unnest().Print("}");
+ } else if (field.is_repeated()) {
+ Print("for (int i = 0; i < msg.", name, "_size(); ++i) {");
+ Nest();
+ AppendFieldValueAppend(field, false /* omit_default */,
+ "msg." + name + "(i)");
+ Unnest().Print("}");
+ } else {
+ AppendFieldValueAppend(field, true /* omit_default */,
+ "msg." + name + "()");
+ }
+}
+
+void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) {
+ const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n ",
+ GetQualifiedName(enum_d), " value)");
+ SetOutput(&header_);
+ Print().Print("// Enum text output for ", string(enum_d.full_name()));
+ Print(sig, ";");
+
+ SetOutput(&cc_);
+ Print().Print(sig, " {");
+ Nest().Print("switch (value) {").Nest();
+ for (int i = 0; i < enum_d.value_count(); ++i) {
+ const auto& value = *enum_d.value(i);
+ Print("case ", value.number(), ": return \"", value.name(), "\";");
+ }
+ Print("default: return \"\";");
+ Unnest().Print("}");
+ Unnest().Print("}");
+}
+
+void Generator::AppendParseMessageFunction(const Descriptor& md) {
+ const bool map_append = (md.options().map_entry());
+ string sig;
+ if (!map_append) {
+ sig = StrCat("bool ProtoParseFromString(\n const string& s,\n ",
+ GetQualifiedName(md), "* msg)");
+ SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;");
+
+ SetOutput(&cc_);
+ Print().Print(sig, " {").Nest();
+ Print("msg->Clear();");
+ Print("Scanner scanner(s);");
+ Print("if (!internal::ProtoParseFromScanner(",
+ "&scanner, false, false, msg)) return false;");
+ Print("scanner.Eos();");
+ Print("return scanner.GetResult();");
+ Unnest().Print("}");
+ }
+
+ // Parse from scanner - the real work here.
+ sig = StrCat(map_append ? "inline " : "", "bool ProtoParseFromScanner(",
+ "\n ::tensorflow::strings::Scanner* scanner, bool nested, "
+ "bool close_curly,\n ");
+ const FieldDescriptor* key_type = nullptr;
+ const FieldDescriptor* value_type = nullptr;
+ if (map_append) {
+ key_type = md.FindFieldByName("key");
+ value_type = md.FindFieldByName("value");
+ StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type),
+ ", ", GetCppClass(*value_type), ">* map)");
+ } else {
+ StrAppend(&sig, GetQualifiedName(md), "* msg)");
+ }
+ SetOutput(&header_impl_).Print(sig, ";");
+
+ SetOutput(&cc_);
+ Print().Print("namespace internal {");
+ Print().Print(sig, " {").Nest();
+ if (map_append) {
+ Print(GetCppClass(*key_type), " map_key;");
+ Print("bool set_map_key = false;");
+ Print(GetCppClass(*value_type), " map_value;");
+ Print("bool set_map_value = false;");
+ }
+ Print("std::vector<bool> has_seen(", md.field_count(), ", false);");
+ Print("while(true) {").Nest();
+ Print("ProtoSpaceAndComments(scanner);");
+
+ // Emit success case
+ Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {")
+ .Nest();
+ Print("scanner->One(Scanner::ALL);");
+ Print("ProtoSpaceAndComments(scanner);");
+ if (map_append) {
+ Print("if (!set_map_key || !set_map_value) return false;");
+ Print("(*map)[map_key] = map_value;");
+ }
+ Print("return true;");
+ Unnest().Print("}");
+
+ Print("if (!nested && scanner->empty()) { return true; }");
+ Print("scanner->RestartCapture()");
+ Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)");
+ Print(" .StopCapture();");
+ Print("StringPiece identifier;");
+ Print("if (!scanner->GetResult(nullptr, &identifier)) return false;");
+ Print("bool parsed_colon = false;");
+ Print("ProtoSpaceAndComments(scanner);");
+ Print("if (scanner->Peek() == ':') {");
+ Nest().Print("parsed_colon = true;");
+ Print("scanner->One(Scanner::ALL);");
+ Print("ProtoSpaceAndComments(scanner);");
+ Unnest().Print("}");
+ for (int i = 0; i < md.field_count(); ++i) {
+ const FieldDescriptor* field = md.field(i);
+ const string field_name = field->name();
+ string mutable_value_expr;
+ string set_value_prefix;
+ if (map_append) {
+ mutable_value_expr = StrCat("&map_", field_name);
+ set_value_prefix = StrCat("map_", field_name, " = ");
+ } else if (field->is_repeated()) {
+ if (field->is_map()) {
+ mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
+ set_value_prefix =
+ "UNREACHABLE"; // generator will never use this value.
+ } else {
+ mutable_value_expr = StrCat("msg->add_", field_name, "()");
+ set_value_prefix = StrCat("msg->add_", field_name);
+ }
+ } else {
+ mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
+ set_value_prefix = StrCat("msg->set_", field_name);
+ }
+
+ Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {");
+ Nest();
+
+ if (field->is_repeated()) {
+ CHECK(!map_append);
+
+ // Check to see if this is an array assignment, like a: [1, 2, 3]
+ Print("const bool is_list = (scanner->Peek() == '[');");
+ Print("do {");
+ // [ or , // skip
+ Nest().Print("if (is_list) {");
+ Nest().Print("scanner->One(Scanner::ALL);");
+ Print("ProtoSpaceAndComments(scanner);");
+ Unnest().Print("}");
+ } else if (field->containing_oneof() != nullptr) {
+ CHECK(!map_append);
+
+ // Detect duplicate oneof value.
+ const string oneof_name = field->containing_oneof()->name();
+ Print("if (msg->", oneof_name, "_case() != 0) return false;");
+ }
+
+ if (!field->is_repeated() && !map_append) {
+ // Detect duplicate nested repeated message.
+ Print("if (has_seen[", i, "]) return false;");
+ Print("has_seen[", i, "] = true;");
+ }
+ if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ Print("const char open_char = scanner->Peek();");
+ Print("if (open_char != '{' && open_char != '<') return false;");
+ Print("scanner->One(Scanner::ALL);");
+ Print("ProtoSpaceAndComments(scanner);");
+ Print("if (!", GetPackageReferencePrefix(field->message_type()->file()),
+ "internal::ProtoParseFromScanner(");
+ Print(" scanner, true, open_char == '{', ", mutable_value_expr,
+ ")) return false;");
+ } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
+ if (field->options().ctype() == FieldOptions::CORD) {
+ Print("string str_value;");
+ Print(
+ "if (!parsed_colon || "
+ "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(");
+ Print(" scanner, &str_value)) return false;");
+ Print(mutable_value_expr, "->CopyFrom(str_value);");
+ } else {
+ Print(
+ "if (!parsed_colon || "
+ "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(");
+ Print(" scanner, ", mutable_value_expr, ")) return false;");
+ }
+ } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
+ Print("StringPiece value;");
+ Print(
+ "if (!parsed_colon || "
+ "!scanner->RestartCapture().Many("
+ "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)."
+ "GetResult(nullptr, &value)) return false;");
+ const auto* enum_d = field->enum_type();
+ string value_prefix;
+ if (enum_d->containing_type() == nullptr) {
+ value_prefix = GetPackageReferencePrefix(enum_d->file());
+ } else {
+ value_prefix = StrCat(GetQualifiedName(*enum_d), "_");
+ }
+
+ for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) {
+ const auto* value_d = enum_d->value(enum_i);
+ const string value_name = value_d->name();
+ string condition = StrCat("value == \"", value_name,
+ "\" || value == \"", value_d->number(), "\"");
+ if (value_d->number() == 0) {
+ StrAppend(&condition, " || value == \"-0\"");
+ }
+
+ Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {");
+ Nest();
+ Print(set_value_prefix, "(", value_prefix, value_name, ");");
+ Unnest();
+ }
+ Print("} else {").Nest().Print("return false;").Unnest().Print("}");
+ } else {
+ Print(field->cpp_type_name(), " value;");
+ switch (field->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ case FieldDescriptor::CPPTYPE_INT64:
+ case FieldDescriptor::CPPTYPE_UINT32:
+ case FieldDescriptor::CPPTYPE_UINT64:
+ case FieldDescriptor::CPPTYPE_DOUBLE:
+ case FieldDescriptor::CPPTYPE_FLOAT:
+ Print(
+ "if (!parsed_colon || "
+ "!::tensorflow::strings::ProtoParseNumericFromScanner(",
+ "scanner, &value)) return false;");
+ break;
+ case FieldDescriptor::CPPTYPE_BOOL:
+ Print(
+ "if (!parsed_colon || "
+ "!::tensorflow::strings::ProtoParseBoolFromScanner(",
+ "scanner, &value)) return false;");
+ break;
+ default:
+ LOG(FATAL) << "handled earlier";
+ }
+ Print(set_value_prefix, "(value);");
+ }
+
+ if (field->is_repeated()) {
+ Unnest().Print("} while (is_list && scanner->Peek() == ',');");
+ Print(
+ "if (is_list && "
+ "!scanner->OneLiteral(\"]\").GetResult()) return false;");
+ }
+ if (map_append) {
+ Print("set_map_", field_name, " = true;");
+ }
+ Unnest().Print("}");
+ }
+ Unnest().Print("}");
+ Unnest().Print("}");
+ Unnest().Print().Print("} // namespace internal");
+}
+
+void Generator::AppendDebugStringFunctions(const Descriptor& md) {
+ SetOutput(&header_impl_).Print();
+ SetOutput(&header_).Print().Print("// Message-text conversion for ",
+ string(md.full_name()));
+
+ // Append the two debug string functions for <md>.
+ for (int short_pass = 0; short_pass < 2; ++short_pass) {
+ const bool short_debug = (short_pass == 1);
+
+ // Make the Get functions.
+ const string sig = StrCat(
+ "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString",
+ "(\n const ", GetQualifiedName(md), "& msg)");
+ SetOutput(&header_).Print(sig, ";");
+
+ SetOutput(&cc_);
+ Print().Print(sig, " {").Nest();
+ Print("string s;");
+ Print("::tensorflow::strings::ProtoTextOutput o(&s, ",
+ short_debug ? "true" : "false", ");");
+ Print("internal::AppendProtoDebugString(&o, msg);");
+ Print("o.CloseTopMessage();");
+ Print("return s;");
+ Unnest().Print("}");
+ }
+
+ // Make the Append function.
+ const string sig =
+ StrCat("void AppendProtoDebugString(\n",
+ " ::tensorflow::strings::ProtoTextOutput* o,\n const ",
+ GetQualifiedName(md), "& msg)");
+ SetOutput(&header_impl_).Print(sig, ";");
+ SetOutput(&cc_);
+ Print().Print("namespace internal {").Print();
+ Print(sig, " {").Nest();
+ for (int i = 0; i < md.field_count(); ++i) {
+ SetOutput(&cc_);
+ AppendFieldAppend(*md.field(i));
+ }
+ Unnest().Print("}").Print().Print("} // namespace internal");
+}
+
+void Generator::AppendMessageFunctions(const Descriptor& md) {
+ if (md.options().map_entry()) {
+ // The 'map entry' Message is not a user-visible message type. Only its
+ // parse function is created (and that actually parsed the whole Map, not
+ // just the map entry). Printing of a map is done in the code generated for
+ // the containing message.
+ AppendParseMessageFunction(md);
+ return;
+ }
+
+ AppendDebugStringFunctions(md);
+ AppendParseMessageFunction(md);
+ for (int i = 0; i < md.enum_type_count(); ++i) {
+ AppendEnumFunctions(*md.enum_type(i));
+ }
+ for (int i = 0; i < md.nested_type_count(); ++i) {
+ AppendMessageFunctions(*md.nested_type(i));
+ }
+}
+
+void Generator::AddNamespaceToCurrentSection(const string& package, bool open) {
+ Print();
+ const std::vector<string> parts = str_util::Split(package, '.');
+ if (open) {
+ for (const auto& p : parts) {
+ Print("namespace ", p, " {");
+ }
+ } else {
+ for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
+ Print("} // namespace ", *it);
+ }
+ }
+}
+
+void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) {
+ std::vector<string> sorted = headers;
+ std::sort(sorted.begin(), sorted.end());
+ for (const auto& h : sorted) {
+ Print("#include \"", h, "\"");
+ }
+}
+
+// Adds to <all_fd> and <all_d> with all descriptors recursively
+// reachable from the given descriptor.
+void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
+ std::set<const FileDescriptor*>* all_fd,
+ std::set<const Descriptor*>* all_d);
+
+// Adds to <all_fd> and <all_d> with all descriptors recursively
+// reachable from the given descriptor.
+void GetAllFileDescriptorsFromMessage(const Descriptor* d,
+ std::set<const FileDescriptor*>* all_fd,
+ std::set<const Descriptor*>* all_d) {
+ if (!all_d->insert(d).second) return;
+ GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d);
+ for (int i = 0; i < d->field_count(); ++i) {
+ auto* f = d->field(i);
+ switch (f->cpp_type()) {
+ case FieldDescriptor::CPPTYPE_INT32:
+ case FieldDescriptor::CPPTYPE_INT64:
+ case FieldDescriptor::CPPTYPE_UINT32:
+ case FieldDescriptor::CPPTYPE_UINT64:
+ case FieldDescriptor::CPPTYPE_DOUBLE:
+ case FieldDescriptor::CPPTYPE_FLOAT:
+ case FieldDescriptor::CPPTYPE_BOOL:
+ case FieldDescriptor::CPPTYPE_STRING:
+ break;
+ case FieldDescriptor::CPPTYPE_MESSAGE:
+ GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d);
+ break;
+ case FieldDescriptor::CPPTYPE_ENUM:
+ GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d);
+ break;
+ }
+ }
+ for (int i = 0; i < d->nested_type_count(); ++i) {
+ GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d);
+ }
+}
+
+void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
+ std::set<const FileDescriptor*>* all_fd,
+ std::set<const Descriptor*>* all_d) {
+ if (!all_fd->insert(fd).second) return;
+ for (int i = 0; i < fd->message_type_count(); ++i) {
+ GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d);
+ }
+}
+
+void Generator::Generate(const FileDescriptor& fd) {
+ // This does not emit code with proper proto2 semantics (e.g. it doesn't check
+ // 'has' fields on non-messages), so check that only proto3 is passed.
+ CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name();
+
+ const string package = fd.package();
+ std::set<const FileDescriptor*> all_fd;
+ std::set<const Descriptor*> all_d;
+ GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d);
+
+ std::vector<string> headers;
+
+ // Add header to header file.
+ SetOutput(&header_);
+ Print("// GENERATED FILE - DO NOT MODIFY");
+ Print("#ifndef ", GetHeaderGuard(fd, false /* impl */));
+ Print("#define ", GetHeaderGuard(fd, false /* impl */));
+ Print();
+ headers = {
+ GetProtoHeaderName(fd),
+ StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"),
+ StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"),
+ StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"),
+ };
+ for (const auto& h : headers) {
+ Print("#include \"", h, "\"");
+ }
+ AddNamespaceToCurrentSection(package, true /* is_open */);
+
+ // Add header to impl file.
+ SetOutput(&header_impl_);
+ Print("// GENERATED FILE - DO NOT MODIFY");
+ Print("#ifndef ", GetHeaderGuard(fd, true /* impl */));
+ Print("#define ", GetHeaderGuard(fd, true /* impl */));
+ Print();
+ headers = {
+ GetProtoTextHeaderName(fd, false /* impl */),
+ StrCat(tf_header_prefix_,
+ "tensorflow/core/lib/strings/proto_text_util.h"),
+ StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"),
+ };
+ for (const FileDescriptor* d : all_fd) {
+ if (d != &fd) {
+ headers.push_back(GetProtoTextHeaderName(*d, true /* impl */));
+ }
+ headers.push_back(GetProtoHeaderName(*d));
+ }
+ for (const auto& h : headers) {
+ Print("#include \"", h, "\"");
+ }
+ AddNamespaceToCurrentSection(package, true /* is_open */);
+ SetOutput(&header_impl_).Print().Print("namespace internal {");
+
+ // Add header to cc file.
+ SetOutput(&cc_);
+ Print("// GENERATED FILE - DO NOT MODIFY");
+ headers = {GetProtoTextHeaderName(fd, true /* impl */)};
+ AddHeadersToCurrentSection(headers);
+ Print();
+ Print("using ::tensorflow::strings::Scanner;");
+ Print("using ::tensorflow::strings::StrCat;");
+ AddNamespaceToCurrentSection(package, true /* is_open */);
+
+ // Add declarations and definitions.
+ for (int i = 0; i < fd.enum_type_count(); ++i) {
+ AppendEnumFunctions(*fd.enum_type(i));
+ }
+ for (int i = 0; i < fd.message_type_count(); ++i) {
+ AppendMessageFunctions(*fd.message_type(i));
+ }
+
+ // Add footer to header file.
+ SetOutput(&header_);
+ AddNamespaceToCurrentSection(package, false /* is_open */);
+ Print().Print("#endif // ", GetHeaderGuard(fd, false /* impl */));
+
+ // Add footer to header impl file.
+ SetOutput(&header_impl_).Print().Print("} // namespace internal");
+ AddNamespaceToCurrentSection(package, false /* is_open */);
+ Print().Print("#endif // ", GetHeaderGuard(fd, true /* impl */));
+
+ // Add footer to cc file.
+ SetOutput(&cc_);
+ AddNamespaceToCurrentSection(package, false /* is_open */);
+}
+
+} // namespace
+
+ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd,
+ const string& tf_header_prefix) {
+ Generator gen(tf_header_prefix);
+ gen.Generate(fd);
+ return gen.code();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
new file mode 100644
index 0000000000..f538275ec8
--- /dev/null
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h
@@ -0,0 +1,42 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
+
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+struct ProtoTextFunctionCode {
+ string header; // for a file named proto_name + ".pb_text.h"
+ string header_impl; // for a file named proto_name + ".pb_text-impl.h"
+ string cc; // for a file named proto_name + ".pb_text.cc"
+};
+
+// Returns the generated source code for a proto file descriptor.
+//
+// <tf_header_prefix> is used as the prefix for #include paths, when including
+// tensorflow library headers.
+//
+// Only works for proto3 messages.
+ProtoTextFunctionCode GetProtoTextFunctionCode(
+ const tensorflow::protobuf::FileDescriptor& fd,
+ const string& tf_header_prefix);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
new file mode 100644
index 0000000000..b491295967
--- /dev/null
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc
@@ -0,0 +1,500 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/proto_text/test.pb.h"
+#include "tensorflow/tools/proto_text/test.pb_text.h"
+
+namespace tensorflow {
+namespace test {
+namespace {
+
+// Convert <input> to text depending on <short_debug>, then parse that into a
+// new message using the generated parse function. Return the new message.
+template <typename T>
+T RoundtripParseProtoOrDie(const T& input, bool short_debug) {
+ const string s = short_debug ? input.ShortDebugString() : input.DebugString();
+ T t;
+ EXPECT_TRUE(ProtoParseFromString(s, &t)) << "Failed to parse " << s;
+ return t;
+}
+
+// Macro that takes <proto> and verifies the proto text string output
+// matches DebugString calls on the proto, and verifies parsing the
+// DebugString output works. It does this for regular and short
+// debug strings.
+#define EXPECT_TEXT_TRANSFORMS_MATCH() \
+ EXPECT_EQ(proto.DebugString(), ProtoDebugString(proto)); \
+ EXPECT_EQ(proto.ShortDebugString(), ProtoShortDebugString(proto)); \
+ EXPECT_EQ(proto.DebugString(), \
+ RoundtripParseProtoOrDie(proto, true).DebugString()); \
+ EXPECT_EQ(proto.DebugString(), \
+ RoundtripParseProtoOrDie(proto, false).DebugString());
+
+// Macro for failure cases. Verifies both protobuf and proto_text to
+// make sure they match.
+#define EXPECT_PARSE_FAILURE(str) \
+ EXPECT_FALSE(ProtoParseFromString(str, &proto)); \
+ EXPECT_FALSE(protobuf::TextFormat::ParseFromString(str, &proto))
+
+// Macro for success cases parsing from a string. Verifies protobuf and
+// proto_text cases match.
+#define EXPECT_PARSE_SUCCESS(expected, str) \
+ do { \
+ EXPECT_TRUE(ProtoParseFromString(str, &proto)); \
+ string proto_text_str = ProtoShortDebugString(proto); \
+ EXPECT_TRUE(protobuf::TextFormat::ParseFromString(str, &proto)); \
+ string protobuf_str = ProtoShortDebugString(proto); \
+ EXPECT_EQ(proto_text_str, protobuf_str); \
+ EXPECT_EQ(expected, proto_text_str); \
+ } while (false)
+
+// Test different cases of numeric values, including repeated values.
+TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) {
+ TestAllTypes proto;
+
+ proto.Clear();
+ proto.set_optional_int32(-1);
+ proto.set_optional_int64(-2);
+ proto.set_optional_uint32(3);
+ proto.set_optional_uint64(4);
+ proto.set_optional_sint32(-5);
+ proto.set_optional_sint64(-6);
+ proto.set_optional_fixed32(-7);
+ proto.set_optional_fixed64(-8);
+ proto.set_optional_sfixed32(-9);
+ proto.set_optional_sfixed64(-10);
+ proto.set_optional_float(-12.34);
+ proto.set_optional_double(-5.678);
+ proto.set_optional_bool(true);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Max numeric values.
+ proto.Clear();
+ proto.set_optional_int32(std::numeric_limits<int32>::max());
+ proto.set_optional_int64(std::numeric_limits<int64>::max());
+ proto.set_optional_uint32(std::numeric_limits<uint32>::max());
+ proto.set_optional_uint64(std::numeric_limits<uint64>::max());
+ proto.set_optional_float(std::numeric_limits<float>::max());
+ proto.set_optional_double(std::numeric_limits<double>::max());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Least positive numeric values.
+ proto.Clear();
+ proto.set_optional_float(std::numeric_limits<float>::min());
+ proto.set_optional_double(std::numeric_limits<double>::min());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Lowest numeric values.
+ proto.Clear();
+ proto.set_optional_int32(std::numeric_limits<int32>::lowest());
+ proto.set_optional_int64(std::numeric_limits<int64>::lowest());
+ proto.set_optional_float(std::numeric_limits<float>::lowest());
+ proto.set_optional_double(std::numeric_limits<double>::lowest());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // inf and -inf for float and double.
+ proto.Clear();
+ proto.set_optional_double(std::numeric_limits<double>::infinity());
+ proto.set_optional_float(std::numeric_limits<float>::infinity());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+ proto.set_optional_double(-1 * std::numeric_limits<double>::infinity());
+ proto.set_optional_float(-1 * std::numeric_limits<float>::infinity());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // String and bytes values.
+ proto.Clear();
+ for (int i = 0; i < 256; ++i) {
+ proto.mutable_optional_string()->push_back(static_cast<char>(i));
+ proto.mutable_optional_bytes()->push_back(static_cast<char>(i));
+ }
+ strings::StrAppend(proto.mutable_optional_string(), "¢€𐍈");
+ proto.set_optional_cord(proto.optional_string());
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Repeated values. Include zero values to show they are retained in
+ // repeateds.
+ proto.Clear();
+ proto.add_repeated_int32(-1);
+ proto.add_repeated_int32(0);
+ proto.add_repeated_int64(0);
+ proto.add_repeated_int64(1);
+ proto.add_repeated_uint32(-10);
+ proto.add_repeated_uint32(0);
+ proto.add_repeated_uint32(10);
+ proto.add_repeated_uint64(-20);
+ proto.add_repeated_uint64(0);
+ proto.add_repeated_uint64(20);
+ proto.add_repeated_sint32(-30);
+ proto.add_repeated_sint32(0);
+ proto.add_repeated_sint32(30);
+ proto.add_repeated_sint64(-40);
+ proto.add_repeated_sint64(0);
+ proto.add_repeated_sint64(40);
+ proto.add_repeated_fixed32(-50);
+ proto.add_repeated_fixed32(0);
+ proto.add_repeated_fixed32(50);
+ proto.add_repeated_fixed64(-60);
+ proto.add_repeated_fixed64(0);
+ proto.add_repeated_fixed64(60);
+ proto.add_repeated_sfixed32(-70);
+ proto.add_repeated_sfixed32(0);
+ proto.add_repeated_sfixed32(70);
+ proto.add_repeated_sfixed64(-80);
+ proto.add_repeated_sfixed64(0);
+ proto.add_repeated_sfixed64(80);
+ proto.add_repeated_float(-1.2345);
+ proto.add_repeated_float(0);
+ proto.add_repeated_float(-2.3456);
+ proto.add_repeated_double(-10.2345);
+ proto.add_repeated_double(0);
+ proto.add_repeated_double(-20.3456);
+ proto.add_repeated_bool(false);
+ proto.add_repeated_bool(true);
+ proto.add_repeated_bool(false);
+ proto.add_repeated_string("abc");
+ proto.add_repeated_string("");
+ proto.add_repeated_string("def");
+ proto.add_repeated_cord("abc");
+ proto.add_repeated_cord("");
+ proto.add_repeated_cord("def");
+ proto.add_packed_repeated_int64(-1000);
+ proto.add_packed_repeated_int64(0);
+ proto.add_packed_repeated_int64(1000);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Proto supports [] for list values as well.
+ EXPECT_PARSE_SUCCESS("repeated_int32: 1 repeated_int32: 2 repeated_int32: 3",
+ "repeated_int32: [1, 2 , 3]");
+
+ // Test [] and also interesting bool values.
+ EXPECT_PARSE_SUCCESS(("repeated_bool: false repeated_bool: false "
+ "repeated_bool: true repeated_bool: true "
+ "repeated_bool: false repeated_bool: true"),
+ "repeated_bool: [false, 0, 1, true, False, True]");
+
+ EXPECT_PARSE_SUCCESS(("repeated_string: \"a,b\" "
+ "repeated_string: \"cdef\""),
+ "repeated_string: [ 'a,b', 'cdef' ] ");
+
+ // Proto supports ' as quote character.
+ EXPECT_PARSE_SUCCESS("optional_string: \"123\\\" \\'xyz\"",
+ "optional_string: '123\\\" \\'xyz' ");
+
+ EXPECT_PARSE_SUCCESS("optional_double: 10000", "optional_double: 1e4");
+
+ // Error cases.
+ EXPECT_PARSE_FAILURE("optional_string: '1' optional_string: '2'");
+ EXPECT_PARSE_FAILURE("optional_double: 123 optional_double: 456");
+ EXPECT_PARSE_FAILURE("optional_double: 0001");
+ EXPECT_PARSE_FAILURE("optional_double: 000.1");
+ EXPECT_PARSE_FAILURE("optional_double: a");
+ EXPECT_PARSE_FAILURE("optional_double: x123");
+ EXPECT_PARSE_FAILURE("optional_double: '123'");
+ EXPECT_PARSE_FAILURE("optional_double: --111");
+ EXPECT_PARSE_FAILURE("optional_string: 'abc\"");
+ EXPECT_PARSE_FAILURE("optional_bool: truE");
+ EXPECT_PARSE_FAILURE("optional_bool: FALSE");
+}
+
+TEST(CreateProtoDebugStringLibTest, NestedMessages) {
+ TestAllTypes proto;
+
+ proto.Clear();
+ // Test empty message.
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_nested_message();
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_foreign_message();
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Empty messages.
+ proto.Clear();
+ proto.mutable_optional_nested_message();
+ proto.mutable_optional_foreign_message();
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_nested_message()->set_optional_int32(1);
+ proto.mutable_optional_foreign_message()->set_c(-1234);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_nested_message()->set_optional_int32(1234);
+ proto.mutable_optional_nested_message()
+ ->mutable_msg(); // empty double-nested
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_nested_message()->set_optional_int32(1234);
+ proto.mutable_optional_nested_message()->mutable_msg()->set_optional_string(
+ "abc");
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.mutable_optional_nested_message()->mutable_msg()->set_optional_string(
+ "abc");
+ proto.mutable_optional_nested_message()->set_optional_int64(1234);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ auto* nested = proto.add_repeated_nested_message();
+ nested = proto.add_repeated_nested_message();
+ nested->set_optional_int32(123);
+ nested->mutable_msg();
+ nested = proto.add_repeated_nested_message();
+ nested->mutable_msg();
+ nested->mutable_msg()->set_optional_string("abc");
+ nested->set_optional_int64(1234);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // text format allows use of <> for messages.
+ EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }",
+ "optional_nested_message: < optional_int32: 123 >");
+
+ // <> and {} must use same style for closing.
+ EXPECT_PARSE_FAILURE("optional_nested_message: < optional_int32: 123 }");
+ EXPECT_PARSE_FAILURE("optional_nested_message: { optional_int32: 123 >");
+
+ // colon after identifier is optional for messages.
+ EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }",
+ "optional_nested_message < optional_int32: 123 >");
+
+ EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }",
+ "optional_nested_message{ optional_int32: 123 } ");
+
+ // Proto supports [] for list values as well.
+ EXPECT_PARSE_SUCCESS(
+ ("repeated_nested_message { } "
+ "repeated_nested_message { optional_int32: 123 }"),
+ "repeated_nested_message: [ { }, { optional_int32: 123 } ]");
+
+ // Colon after repeated_nested_message is optional.
+ EXPECT_PARSE_SUCCESS(
+ ("repeated_nested_message { } "
+ "repeated_nested_message { optional_int32: 123 }"),
+ "repeated_nested_message [ { }, { optional_int32: 123 } ]");
+
+ // Using the list format a:[..] twice, like a:[..] a:[..] joins the two
+ // arrays.
+ EXPECT_PARSE_SUCCESS(
+ ("repeated_nested_message { } "
+ "repeated_nested_message { optional_int32: 123 } "
+ "repeated_nested_message { optional_int32: 456 }"),
+ ("repeated_nested_message [ { }, { optional_int32: 123 } ]"
+ "repeated_nested_message [ { optional_int32: 456 } ]"));
+
+ // Parse errors on nested messages.
+ EXPECT_PARSE_FAILURE("optional_nested_message: {optional_int32: 'abc' }");
+
+ // Optional_nested_message appearing twice is an error.
+ EXPECT_PARSE_FAILURE(
+ ("optional_nested_message { optional_int32: 123 } "
+ "optional_nested_message { optional_int64: 456 }"));
+}
+
+TEST(CreateProtoDebugStringLibTest, RecursiveMessage) {
+ NestedTestAllTypes proto;
+
+ NestedTestAllTypes* cur = &proto;
+ for (int depth = 0; depth < 20; ++depth) {
+ cur->mutable_payload()->set_optional_int32(1000 + depth);
+ cur = cur->mutable_child();
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+}
+
+template <typename T>
+T ParseProto(const string& value_text_proto) {
+ T value;
+ EXPECT_TRUE(protobuf::TextFormat::ParseFromString(value_text_proto, &value))
+ << value_text_proto;
+ return value;
+}
+
+TestAllTypes::NestedMessage ParseNestedMessage(const string& value_text_proto) {
+ return ParseProto<TestAllTypes::NestedMessage>(value_text_proto);
+}
+
+TEST(CreateProtoDebugStringLibTest, Map) {
+ TestAllTypes proto;
+
+ std::vector<TestAllTypes::NestedMessage> msg_values;
+ msg_values.push_back(ParseNestedMessage("optional_int32: 345"));
+ msg_values.push_back(ParseNestedMessage("optional_int32: 123"));
+ msg_values.push_back(ParseNestedMessage("optional_int32: 234"));
+ msg_values.push_back(ParseNestedMessage("optional_int32: 0"));
+
+ // string->message map
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_string_to_message();
+ map["def"] = msg_values[0];
+ map["abc"] = msg_values[1];
+ map["cde"] = msg_values[2];
+ map[""] = msg_values[3];
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // int32->message map.
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_int32_to_message();
+ map[20] = msg_values[0];
+ map[10] = msg_values[1];
+ map[15] = msg_values[2];
+ map[0] = msg_values[3];
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // int64->message map.
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_int64_to_message();
+ map[20] = msg_values[0];
+ map[10] = msg_values[1];
+ map[15] = msg_values[2];
+ map[0] = msg_values[3];
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // bool->message map.
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_int64_to_message();
+ map[true] = msg_values[0];
+ map[false] = msg_values[1];
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // string->int64 map.
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_string_to_int64();
+ map["def"] = 0;
+ map["abc"] = std::numeric_limits<int64>::max();
+ map[""] = 20;
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // int64->string map.
+ proto.Clear();
+ {
+ auto& map = *proto.mutable_map_int64_to_string();
+ map[0] = "def";
+ map[std::numeric_limits<int64>::max()] = "";
+ map[20] = "abc";
+ }
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Test a map with the same key multiple times.
+ EXPECT_PARSE_SUCCESS(("map_string_to_int64 { key: \"abc\" value: 5 } "
+ "map_string_to_int64 { key: \"def\" value: 2 } "
+ "map_string_to_int64 { key: \"ghi\" value: 4 }"),
+ ("map_string_to_int64: { key: 'abc' value: 1 } "
+ "map_string_to_int64: { key: 'def' value: 2 } "
+ "map_string_to_int64: { key: 'ghi' value: 3 } "
+ "map_string_to_int64: { key: 'ghi' value: 4 } "
+ "map_string_to_int64: { key: 'abc' value: 5 } "));
+}
+
+TEST(CreateProtoDebugStringLibTest, Enums) {
+ TestAllTypes proto;
+
+ proto.Clear();
+ proto.set_optional_nested_enum(TestAllTypes::ZERO);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.set_optional_nested_enum(TestAllTypes::FOO);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.add_repeated_nested_enum(TestAllTypes::FOO);
+ proto.add_repeated_nested_enum(TestAllTypes::ZERO);
+ proto.add_repeated_nested_enum(TestAllTypes::BAR);
+ proto.add_repeated_nested_enum(TestAllTypes::NEG);
+ proto.add_repeated_nested_enum(TestAllTypes::ZERO);
+ proto.set_optional_foreign_enum(ForeignEnum::FOREIGN_BAR);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Parsing from numbers works as well.
+ EXPECT_PARSE_SUCCESS(
+ "optional_nested_enum: BAR " // 2
+ "repeated_nested_enum: BAR " // 2
+ "repeated_nested_enum: ZERO " // 0
+ "repeated_nested_enum: FOO", // 1
+ ("repeated_nested_enum: 2 "
+ "repeated_nested_enum: 0 "
+ "optional_nested_enum: 2 "
+ "repeated_nested_enum: 1"));
+
+ EXPECT_PARSE_SUCCESS("", "optional_nested_enum: -0");
+ EXPECT_PARSE_FAILURE("optional_nested_enum: 6");
+ EXPECT_PARSE_FAILURE("optional_nested_enum: BARNONE");
+ EXPECT_PARSE_FAILURE("optional_nested_enum: 'BAR'");
+ EXPECT_PARSE_FAILURE("optional_nested_enum: \"BAR\" ");
+
+ EXPECT_EQ(string("BAR"),
+ string(EnumName_TestAllTypes_NestedEnum(TestAllTypes::BAR)));
+ // out of range - returns empty string (see NameOfEnum in proto library).
+ EXPECT_EQ(string(""), string(EnumName_TestAllTypes_NestedEnum(
+ static_cast<TestAllTypes_NestedEnum>(123))));
+}
+
+TEST(CreateProtoDebugStringLibTest, Oneof) {
+ TestAllTypes proto;
+
+ proto.Clear();
+ proto.set_oneof_string("abc");
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ proto.Clear();
+ proto.set_oneof_string("abc");
+ proto.set_oneof_uint32(123);
+ EXPECT_TEXT_TRANSFORMS_MATCH();
+
+ // Parse a text format that lists multiple members of the oneof.
+ EXPECT_PARSE_FAILURE("oneof_string: \"abc\" oneof_uint32: 13 ");
+ EXPECT_PARSE_FAILURE("oneof_string: \"abc\" oneof_string: \"def\" ");
+}
+
+TEST(CreateProtoDebugStringLibTest, Comments) {
+ TestAllTypes proto;
+
+ EXPECT_PARSE_SUCCESS("optional_int64: 123 optional_string: \"#text\"",
+ ("#leading comment \n"
+ "optional_int64# comment\n"
+ ":# comment\n"
+ "123# comment\n"
+ "optional_string # comment\n"
+ ": # comment\n"
+ "\"#text\"#comment####\n"));
+
+ EXPECT_PARSE_FAILURE("optional_int64:// not a valid comment\n123");
+ EXPECT_PARSE_FAILURE("optional_int64:/* not a valid comment */\n123");
+}
+
+} // namespace
+} // namespace test
+} // namespace tensorflow
diff --git a/tensorflow/tools/proto_text/test.proto b/tensorflow/tools/proto_text/test.proto
new file mode 100644
index 0000000000..9dd83e6393
--- /dev/null
+++ b/tensorflow/tools/proto_text/test.proto
@@ -0,0 +1,107 @@
+syntax = "proto3";
+
+package tensorflow.test;
+
+message TestAllTypes {
+ message NestedMessage {
+ message DoubleNestedMessage {
+ string optional_string = 1;
+ }
+
+ int32 optional_int32 = 1;
+ repeated int32 repeated_int32 = 2;
+ DoubleNestedMessage msg = 3;
+ int64 optional_int64 = 4;
+ }
+
+ enum NestedEnum {
+ ZERO = 0;
+ FOO = 1;
+ BAR = 2;
+ BAZ = 3;
+ NEG = -1; // Intentionally negative.
+ }
+
+ // Singular
+ int32 optional_int32 = 1;
+ int64 optional_int64 = 2;
+ uint32 optional_uint32 = 3;
+ uint64 optional_uint64 = 4;
+ sint32 optional_sint32 = 5;
+ sint64 optional_sint64 = 6;
+ fixed32 optional_fixed32 = 7;
+ fixed64 optional_fixed64 = 8;
+ sfixed32 optional_sfixed32 = 9;
+ sfixed64 optional_sfixed64 = 10;
+ float optional_float = 11;
+ double optional_double = 12;
+ bool optional_bool = 13;
+ string optional_string = 14;
+ bytes optional_bytes = 15;
+
+ NestedMessage optional_nested_message = 18;
+ ForeignMessage optional_foreign_message = 19;
+
+ NestedEnum optional_nested_enum = 21;
+ ForeignEnum optional_foreign_enum = 22;
+
+ string optional_cord = 25;
+
+ // Repeated
+ repeated int32 repeated_int32 = 31;
+ repeated int64 repeated_int64 = 32;
+ repeated uint32 repeated_uint32 = 33;
+ repeated uint64 repeated_uint64 = 34;
+ repeated sint32 repeated_sint32 = 35;
+ repeated sint64 repeated_sint64 = 36;
+ repeated fixed32 repeated_fixed32 = 37;
+ repeated fixed64 repeated_fixed64 = 38;
+ repeated sfixed32 repeated_sfixed32 = 39;
+ repeated sfixed64 repeated_sfixed64 = 40;
+ repeated float repeated_float = 41;
+ repeated double repeated_double = 42;
+ repeated bool repeated_bool = 43;
+ repeated string repeated_string = 44;
+ repeated bytes repeated_bytes = 45;
+
+ repeated NestedMessage repeated_nested_message = 48;
+ repeated NestedEnum repeated_nested_enum = 51;
+
+ repeated string repeated_cord = 55;
+
+ oneof oneof_field {
+ uint32 oneof_uint32 = 111;
+ NestedMessage oneof_nested_message = 112;
+ string oneof_string = 113;
+ bytes oneof_bytes = 114;
+ }
+
+ map<string, NestedMessage> map_string_to_message = 58;
+ map<int32, NestedMessage> map_int32_to_message = 59;
+ map<int64, NestedMessage> map_int64_to_message = 60;
+ map<bool, NestedMessage> map_bool_to_message = 61;
+ map<string, int64> map_string_to_int64 = 62;
+ map<int64, string> map_int64_to_string = 63;
+
+ repeated int64 packed_repeated_int64 = 64 [packed = true];
+}
+
+// A recursive message.
+message NestedTestAllTypes {
+ NestedTestAllTypes child = 1;
+ TestAllTypes payload = 2;
+}
+
+message ForeignMessage {
+ int32 c = 1;
+}
+
+enum ForeignEnum {
+ FOREIGN_ZERO = 0;
+ FOREIGN_FOO = 4;
+ FOREIGN_BAR = 5;
+ FOREIGN_BAZ = 6;
+}
+
+message TestEmptyMessage {
+}