diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-04-14 19:15:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-14 20:22:35 -0700 |
commit | df15baa9b10a0b2d194181dff7ee14bff70d9b8f (patch) | |
tree | a66e1fcb309a7404a2749140d8bac991a6524ac0 /tensorflow/tools/proto_text | |
parent | 104fe2822b419c4154d11c401ffd4a3a6e8f24c6 (diff) |
Add tools/proto_text for generating ProtoDebugString,
ProtoShortDebugString, and ProtoParseFromString methods from protos. This will
allow changing code used on mobile to use the proto LITE_RUNTIME, to reduce
code size.
This change is only for the tool itself. A future change will add a better
genrule and use it the generated code in tensorflow.
Change: 119919087
Diffstat (limited to 'tensorflow/tools/proto_text')
-rw-r--r-- | tensorflow/tools/proto_text/BUILD | 90 | ||||
-rw-r--r-- | tensorflow/tools/proto_text/gen_proto_text_functions.cc | 117 | ||||
-rw-r--r-- | tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc | 779 | ||||
-rw-r--r-- | tensorflow/tools/proto_text/gen_proto_text_functions_lib.h | 42 | ||||
-rw-r--r-- | tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc | 500 | ||||
-rw-r--r-- | tensorflow/tools/proto_text/test.proto | 107 |
6 files changed, 1635 insertions, 0 deletions
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD new file mode 100644 index 0000000000..c677721672 --- /dev/null +++ b/tensorflow/tools/proto_text/BUILD @@ -0,0 +1,90 @@ +# Description: +# This package provides build-time generation of proto3 text format functions +# (ProtoDebugString, ProtoShortDebugString, and ProtoParseFromString) which +# provide equivalent functionality as proto.DebugString, proto.ShortDebugString, +# and TextFormat parsing, but can be used with protos generated with +# LITE_RUNTIME. +# +# Note that proto3 well-known types (e.g. Any) are not handled in a special way +# by the generated code. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_cc", +) + +cc_binary( + name = "gen_proto_text_functions", + srcs = ["gen_proto_text_functions.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + ":gen_proto_text_functions_lib", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gen_proto_text_functions_lib", + srcs = ["gen_proto_text_functions_lib.cc"], + hdrs = ["gen_proto_text_functions_lib.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_proto_library_cc( + name = "test_proto", + srcs = ["test.proto"], +) + +genrule( + name = "test_proto_text_srcs", + srcs = ["test.proto"], + outs = [ + "test.pb_text-impl.h", + "test.pb_text.h", + "test.pb_text.cc", + ], + cmd = "$(location :gen_proto_text_functions) " + + "$(@D) $(location test.proto)", + tools = [":gen_proto_text_functions"], +) + +cc_test( + name = "gen_proto_text_functions_lib_test", + size = "small", + srcs = [ + "gen_proto_text_functions_lib_test.cc", + ":test_proto_text_srcs", + ], + deps = [ + ":gen_proto_text_functions_lib", + ":test_proto_cc", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. These must be at the end for syncrepo. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc new file mode 100644 index 0000000000..604ae432cf --- /dev/null +++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc @@ -0,0 +1,117 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <stdio.h> +#include <set> + +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" + +namespace tensorflow { + +namespace { +class CrashOnErrorCollector + : public tensorflow::protobuf::compiler::MultiFileErrorCollector { + public: + ~CrashOnErrorCollector() override {} + + void AddError(const string& filename, int line, int column, + const string& message) override { + LOG(FATAL) << "Unexpected error at " << filename << "@" << line << ":" + << column << " - " << message; + } +}; +} // namespace + +static const char kTensorflowHeaderPrefix[] = ""; + +// Main program to take input protos and write output pb_text source files that +// contain generated proto text input and output functions. +// +// Main expects the first argument to give the output path. This is followed by +// pairs of arguments: <proto_name_relative_to_root, proto_file_path>. +// +// Note that this code doesn't use tensorflow's command line parsing, because of +// circular dependencies between libraries if that were done. +// +// This is meant to be invoked by a genrule. See BUILD for more information. +int MainImpl(int argc, char** argv) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + + if (argc < 3) { + LOG(ERROR) << "Pass output path and at least proto file"; + return -1; + } + + const string output_root = argv[1]; + + tensorflow::protobuf::compiler::DiskSourceTree source_tree; + + // This requires all protos to be relative to the directory from which the + // genrule is invoked. If protos are generated in some other directory, + // then they may not be found. + source_tree.MapPath("", "."); + CrashOnErrorCollector crash_on_error; + tensorflow::protobuf::compiler::Importer importer(&source_tree, + &crash_on_error); + + for (int i = 2; i < argc; i++) { + const string proto_path = argv[i]; + const tensorflow::protobuf::FileDescriptor* fd = + importer.Import(proto_path); + + string proto_name = proto_path; + int index = proto_name.find_last_of("/"); + if (index != string::npos) proto_name = proto_name.substr(index + 1); + index = proto_name.find_last_of("."); + if (index != string::npos) proto_name = proto_name.substr(0, index); + + const auto code = + tensorflow::GetProtoTextFunctionCode(*fd, kTensorflowHeaderPrefix); + + // Three passes, one for each output file. + for (int pass = 0; pass < 3; ++pass) { + string suffix; + string data; + if (pass == 0) { + suffix = ".pb_text.h"; + data = code.header; + } else if (pass == 1) { + suffix = ".pb_text-impl.h"; + data = code.header_impl; + } else { + suffix = ".pb_text.cc"; + data = code.cc; + } + + const string path = output_root + "/" + proto_name + suffix; + FILE* f = fopen(path.c_str(), "w"); + if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) { + return -1; + } + if (fclose(f) != 0) { + return -1; + } + } + } + return 0; +} + +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::MainImpl(argc, argv); } diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc new file mode 100644 index 0000000000..a8e3437296 --- /dev/null +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc @@ -0,0 +1,779 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" + +#include <algorithm> +#include <set> + +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +using ::tensorflow::protobuf::Descriptor; +using ::tensorflow::protobuf::EnumDescriptor; +using ::tensorflow::protobuf::FieldDescriptor; +using ::tensorflow::protobuf::FieldOptions; +using ::tensorflow::protobuf::FileDescriptor; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +namespace tensorflow { + +namespace { + +// Class used to generate the code for proto text functions. One of these should +// be created for each FileDescriptor whose code should be generated. +// +// This class has a notion of the current output Section. The Print, Nested, +// and Unnest functions apply their operations to the current output section, +// which can be toggled with SetOutput. +// +// Note that on the generated code, various pieces are not optimized - for +// example: map input and output, Cord input and output, comparisons against +// the field names (it's a loop over all names), and tracking of has_seen. +// +// The generated API has, for enums and messages defined in the proto file: +// 1. For each message: +// * ProtoDebugString(m): same as msg.DebugString() +// * ProtoShortDebugString(m): same as msg.ShorDebugString() +// * ProtoParseFromString(s, m): same as TextFormat.ParseFromString(s, &m); +// 2. For each enum: +// * EnumName_<EnumTypeName>(enum_value): same as <EnumTypeName>(enum_value) +// in proto. +class Generator { + public: + Generator(const string& tf_header_prefix) + : tf_header_prefix_(tf_header_prefix), + header_(&code_.header), + header_impl_(&code_.header_impl), + cc_(&code_.cc) {} + + void Generate(const FileDescriptor& fd); + + // The generated code; valid after Generate has been called. + ProtoTextFunctionCode code() const { return code_; } + + private: + struct Section { + Section(string* str) : str(str) {} + string* str; + string indent; + }; + + // Switches the currently active section to <section>. + Generator& SetOutput(Section* section) { + cur_ = section; + return *this; + } + + // Increases indent level. Returns <*this>, to allow chaining. + Generator& Nest() { + StrAppend(&cur_->indent, " "); + return *this; + } + + // Decreases indent level. Returns <*this>, to allow chaining. + Generator& Unnest() { + cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2); + return *this; + } + + // Appends the concatenated args, with a trailing newline. Returns <*this>, to + // allow chaining. + template <typename... Args> + Generator& Print(Args... args) { + StrAppend(cur_->str, cur_->indent, args..., "\n"); + return *this; + } + + // Appends the print code for a single field's value. + // If <omit_default> is true, then the emitted code will not print zero-valued + // values. + // <field_expr> is code that when emitted yields the field's value. + void AppendFieldValueAppend(const FieldDescriptor& field, + const bool omit_default, + const string& field_expr); + + // Appends the print code for as single field. + void AppendFieldAppend(const FieldDescriptor& field); + + // Appends the print code for a message. May change which section is currently + // active. + void AppendDebugStringFunctions(const Descriptor& md); + + // Appends the print and parse functions for an enum. May change which + // section is currently active. + void AppendEnumFunctions(const EnumDescriptor& enum_d); + + // Appends the parse functions for a message. May change which section is + // currently active. + void AppendParseMessageFunction(const Descriptor& md); + + // Appends all functions for a message and its nested message and enum types. + // May change which section is currently active. + void AppendMessageFunctions(const Descriptor& md); + + // Appends lines to open or close namespace declarations. + void AddNamespaceToCurrentSection(const string& package, bool open); + + // Appends the given headers as sorted #include lines. + void AddHeadersToCurrentSection(const std::vector<string>& headers); + + // When adding #includes for tensorflow headers, prefix them with this. + const string tf_header_prefix_; + ProtoTextFunctionCode code_; + Section* cur_ = nullptr; + Section header_; + Section header_impl_; + Section cc_; + + TF_DISALLOW_COPY_AND_ASSIGN(Generator); +}; + +// Returns the prefix needed to reference objects defined in <fd>. E.g. +// "::tensorflow::test". +string GetPackageReferencePrefix(const FileDescriptor* fd) { + return StrCat("::", str_util::Join(str_util::Split(fd->package(), '.'), "::"), + "::"); +} + +// Returns the name of the class generated by proto to represent <d>. +string GetClassName(const Descriptor& d) { + if (d.containing_type() == nullptr) return d.name(); + return StrCat(GetClassName(*d.containing_type()), "_", d.name()); +} + +// Returns the name of the class generated by proto to represent <ed>. +string GetClassName(const EnumDescriptor& ed) { + if (ed.containing_type() == nullptr) return ed.name(); + return StrCat(GetClassName(*ed.containing_type()), "_", ed.name()); +} + +// Returns the qualified name that refers to the class generated by proto to +// represent <d>. +string GetQualifiedName(const Descriptor& d) { + return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); +} + +// Returns the qualified name that refers to the class generated by proto to +// represent <ed>. +string GetQualifiedName(const EnumDescriptor& d) { + return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); +} + +// Returns the qualified name that refers to the generated +// AppendProtoDebugString function for <d>. +string GetQualifiedAppendFn(const Descriptor& d) { + return StrCat(GetPackageReferencePrefix(d.file()), + "internal::AppendProtoDebugString"); +} + +// Returns the name of the generated function that returns an enum value's +// string value. +string GetEnumNameFn(const EnumDescriptor& enum_d) { + return StrCat("EnumName_", GetClassName(enum_d)); +} + +// Returns the qualified name of the function returned by GetEnumNameFn(). +string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) { + return StrCat(GetPackageReferencePrefix(enum_d.file()), + GetEnumNameFn(enum_d)); +} + +// Returns the name of a generated header file, either the public api (if impl +// is false) or the internal implementation header (if impl is true). +string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) { + const int dot_index = fd.name().find_last_of('.'); + return fd.name().substr(0, dot_index) + + (impl ? ".pb_text-impl.h" : ".pb_text.h"); +} + +// Returns the name of the header generated by the proto library for <fd>. +string GetProtoHeaderName(const FileDescriptor& fd) { + const int dot_index = fd.name().find_last_of('.'); + return fd.name().substr(0, dot_index) + ".pb.h"; +} + +// Returns the C++ class name for the given proto field. +string GetCppClass(const FieldDescriptor& d) { + return d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE + ? GetQualifiedName(*d.message_type()) + : d.cpp_type_name(); +} + +// Returns the string that can be used for a header guard for the generated +// headers for <fd>, either for the public api (if impl is false) or the +// internal implementation header (if impl is true). +string GetHeaderGuard(const FileDescriptor& fd, bool impl) { + string s = fd.name(); + std::replace(s.begin(), s.end(), '/', '_'); + std::replace(s.begin(), s.end(), '.', '_'); + return s + (impl ? "_IMPL_H_" : "_H_"); +} + +void Generator::AppendFieldValueAppend(const FieldDescriptor& field, + const bool omit_default, + const string& field_expr) { + SetOutput(&cc_); + switch (field.cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + case FieldDescriptor::CPPTYPE_INT64: + case FieldDescriptor::CPPTYPE_UINT32: + case FieldDescriptor::CPPTYPE_UINT64: + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric", + "(\"", field.name(), "\", ", field_expr, ");"); + break; + case FieldDescriptor::CPPTYPE_BOOL: + Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"", + field.name(), "\", ", field_expr, ");"); + break; + case FieldDescriptor::CPPTYPE_STRING: { + const auto ctype = field.options().ctype(); + CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING) + << "Unsupported ctype " << ctype; + + string str_expr = field_expr; + if (ctype == FieldOptions::CORD) { + str_expr = StrCat("(", field_expr, ").ToString()"); + } + Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString", + "(\"", field.name(), "\", ", str_expr, ");"); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: + if (omit_default) { + Print("if (", field_expr, " != 0) {").Nest(); + } + Print("o->AppendEnumName(\"", field.name(), "\", ", + GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, "));"); + if (omit_default) { + Unnest().Print("}"); + } + break; + case FieldDescriptor::CPPTYPE_MESSAGE: + CHECK(!field.message_type()->options().map_entry()); + if (omit_default) { + Print("if (msg.has_", field.name(), "()) {").Nest(); + } + Print("o->OpenNestedMessage(\"", field.name(), "\");"); + Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr, + ");"); + Print("o->CloseNestedMessage();"); + if (omit_default) { + Unnest().Print("}"); + } + break; + } +} + +void Generator::AppendFieldAppend(const FieldDescriptor& field) { + const string name = field.name(); + + if (field.is_map()) { + Print("{").Nest(); + const auto& key_type = *field.message_type()->FindFieldByName("key"); + const auto& value_type = *field.message_type()->FindFieldByName("value"); + + Print("std::vector<", key_type.cpp_type_name(), "> keys;"); + Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);"); + Print("std::stable_sort(keys.begin(), keys.end());"); + Print("for (const auto& key : keys) {").Nest(); + Print("o->OpenNestedMessage(\"", name, "\");"); + AppendFieldValueAppend(key_type, false /* omit_default */, "key"); + AppendFieldValueAppend(value_type, false /* omit_default */, + StrCat("msg.", name, "().at(key)")); + Print("o->CloseNestedMessage();"); + Unnest().Print("}"); + + Unnest().Print("}"); + } else if (field.is_repeated()) { + Print("for (int i = 0; i < msg.", name, "_size(); ++i) {"); + Nest(); + AppendFieldValueAppend(field, false /* omit_default */, + "msg." + name + "(i)"); + Unnest().Print("}"); + } else { + AppendFieldValueAppend(field, true /* omit_default */, + "msg." + name + "()"); + } +} + +void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) { + const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n ", + GetQualifiedName(enum_d), " value)"); + SetOutput(&header_); + Print().Print("// Enum text output for ", string(enum_d.full_name())); + Print(sig, ";"); + + SetOutput(&cc_); + Print().Print(sig, " {"); + Nest().Print("switch (value) {").Nest(); + for (int i = 0; i < enum_d.value_count(); ++i) { + const auto& value = *enum_d.value(i); + Print("case ", value.number(), ": return \"", value.name(), "\";"); + } + Print("default: return \"\";"); + Unnest().Print("}"); + Unnest().Print("}"); +} + +void Generator::AppendParseMessageFunction(const Descriptor& md) { + const bool map_append = (md.options().map_entry()); + string sig; + if (!map_append) { + sig = StrCat("bool ProtoParseFromString(\n const string& s,\n ", + GetQualifiedName(md), "* msg)"); + SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;"); + + SetOutput(&cc_); + Print().Print(sig, " {").Nest(); + Print("msg->Clear();"); + Print("Scanner scanner(s);"); + Print("if (!internal::ProtoParseFromScanner(", + "&scanner, false, false, msg)) return false;"); + Print("scanner.Eos();"); + Print("return scanner.GetResult();"); + Unnest().Print("}"); + } + + // Parse from scanner - the real work here. + sig = StrCat(map_append ? "inline " : "", "bool ProtoParseFromScanner(", + "\n ::tensorflow::strings::Scanner* scanner, bool nested, " + "bool close_curly,\n "); + const FieldDescriptor* key_type = nullptr; + const FieldDescriptor* value_type = nullptr; + if (map_append) { + key_type = md.FindFieldByName("key"); + value_type = md.FindFieldByName("value"); + StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type), + ", ", GetCppClass(*value_type), ">* map)"); + } else { + StrAppend(&sig, GetQualifiedName(md), "* msg)"); + } + SetOutput(&header_impl_).Print(sig, ";"); + + SetOutput(&cc_); + Print().Print("namespace internal {"); + Print().Print(sig, " {").Nest(); + if (map_append) { + Print(GetCppClass(*key_type), " map_key;"); + Print("bool set_map_key = false;"); + Print(GetCppClass(*value_type), " map_value;"); + Print("bool set_map_value = false;"); + } + Print("std::vector<bool> has_seen(", md.field_count(), ", false);"); + Print("while(true) {").Nest(); + Print("ProtoSpaceAndComments(scanner);"); + + // Emit success case + Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {") + .Nest(); + Print("scanner->One(Scanner::ALL);"); + Print("ProtoSpaceAndComments(scanner);"); + if (map_append) { + Print("if (!set_map_key || !set_map_value) return false;"); + Print("(*map)[map_key] = map_value;"); + } + Print("return true;"); + Unnest().Print("}"); + + Print("if (!nested && scanner->empty()) { return true; }"); + Print("scanner->RestartCapture()"); + Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)"); + Print(" .StopCapture();"); + Print("StringPiece identifier;"); + Print("if (!scanner->GetResult(nullptr, &identifier)) return false;"); + Print("bool parsed_colon = false;"); + Print("ProtoSpaceAndComments(scanner);"); + Print("if (scanner->Peek() == ':') {"); + Nest().Print("parsed_colon = true;"); + Print("scanner->One(Scanner::ALL);"); + Print("ProtoSpaceAndComments(scanner);"); + Unnest().Print("}"); + for (int i = 0; i < md.field_count(); ++i) { + const FieldDescriptor* field = md.field(i); + const string field_name = field->name(); + string mutable_value_expr; + string set_value_prefix; + if (map_append) { + mutable_value_expr = StrCat("&map_", field_name); + set_value_prefix = StrCat("map_", field_name, " = "); + } else if (field->is_repeated()) { + if (field->is_map()) { + mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); + set_value_prefix = + "UNREACHABLE"; // generator will never use this value. + } else { + mutable_value_expr = StrCat("msg->add_", field_name, "()"); + set_value_prefix = StrCat("msg->add_", field_name); + } + } else { + mutable_value_expr = StrCat("msg->mutable_", field_name, "()"); + set_value_prefix = StrCat("msg->set_", field_name); + } + + Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {"); + Nest(); + + if (field->is_repeated()) { + CHECK(!map_append); + + // Check to see if this is an array assignment, like a: [1, 2, 3] + Print("const bool is_list = (scanner->Peek() == '[');"); + Print("do {"); + // [ or , // skip + Nest().Print("if (is_list) {"); + Nest().Print("scanner->One(Scanner::ALL);"); + Print("ProtoSpaceAndComments(scanner);"); + Unnest().Print("}"); + } else if (field->containing_oneof() != nullptr) { + CHECK(!map_append); + + // Detect duplicate oneof value. + const string oneof_name = field->containing_oneof()->name(); + Print("if (msg->", oneof_name, "_case() != 0) return false;"); + } + + if (!field->is_repeated() && !map_append) { + // Detect duplicate nested repeated message. + Print("if (has_seen[", i, "]) return false;"); + Print("has_seen[", i, "] = true;"); + } + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + Print("const char open_char = scanner->Peek();"); + Print("if (open_char != '{' && open_char != '<') return false;"); + Print("scanner->One(Scanner::ALL);"); + Print("ProtoSpaceAndComments(scanner);"); + Print("if (!", GetPackageReferencePrefix(field->message_type()->file()), + "internal::ProtoParseFromScanner("); + Print(" scanner, true, open_char == '{', ", mutable_value_expr, + ")) return false;"); + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) { + if (field->options().ctype() == FieldOptions::CORD) { + Print("string str_value;"); + Print( + "if (!parsed_colon || " + "!::tensorflow::strings::ProtoParseStringLiteralFromScanner("); + Print(" scanner, &str_value)) return false;"); + Print(mutable_value_expr, "->CopyFrom(str_value);"); + } else { + Print( + "if (!parsed_colon || " + "!::tensorflow::strings::ProtoParseStringLiteralFromScanner("); + Print(" scanner, ", mutable_value_expr, ")) return false;"); + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + Print("StringPiece value;"); + Print( + "if (!parsed_colon || " + "!scanner->RestartCapture().Many(" + "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)." + "GetResult(nullptr, &value)) return false;"); + const auto* enum_d = field->enum_type(); + string value_prefix; + if (enum_d->containing_type() == nullptr) { + value_prefix = GetPackageReferencePrefix(enum_d->file()); + } else { + value_prefix = StrCat(GetQualifiedName(*enum_d), "_"); + } + + for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) { + const auto* value_d = enum_d->value(enum_i); + const string value_name = value_d->name(); + string condition = StrCat("value == \"", value_name, + "\" || value == \"", value_d->number(), "\""); + if (value_d->number() == 0) { + StrAppend(&condition, " || value == \"-0\""); + } + + Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {"); + Nest(); + Print(set_value_prefix, "(", value_prefix, value_name, ");"); + Unnest(); + } + Print("} else {").Nest().Print("return false;").Unnest().Print("}"); + } else { + Print(field->cpp_type_name(), " value;"); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + case FieldDescriptor::CPPTYPE_INT64: + case FieldDescriptor::CPPTYPE_UINT32: + case FieldDescriptor::CPPTYPE_UINT64: + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + Print( + "if (!parsed_colon || " + "!::tensorflow::strings::ProtoParseNumericFromScanner(", + "scanner, &value)) return false;"); + break; + case FieldDescriptor::CPPTYPE_BOOL: + Print( + "if (!parsed_colon || " + "!::tensorflow::strings::ProtoParseBoolFromScanner(", + "scanner, &value)) return false;"); + break; + default: + LOG(FATAL) << "handled earlier"; + } + Print(set_value_prefix, "(value);"); + } + + if (field->is_repeated()) { + Unnest().Print("} while (is_list && scanner->Peek() == ',');"); + Print( + "if (is_list && " + "!scanner->OneLiteral(\"]\").GetResult()) return false;"); + } + if (map_append) { + Print("set_map_", field_name, " = true;"); + } + Unnest().Print("}"); + } + Unnest().Print("}"); + Unnest().Print("}"); + Unnest().Print().Print("} // namespace internal"); +} + +void Generator::AppendDebugStringFunctions(const Descriptor& md) { + SetOutput(&header_impl_).Print(); + SetOutput(&header_).Print().Print("// Message-text conversion for ", + string(md.full_name())); + + // Append the two debug string functions for <md>. + for (int short_pass = 0; short_pass < 2; ++short_pass) { + const bool short_debug = (short_pass == 1); + + // Make the Get functions. + const string sig = StrCat( + "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString", + "(\n const ", GetQualifiedName(md), "& msg)"); + SetOutput(&header_).Print(sig, ";"); + + SetOutput(&cc_); + Print().Print(sig, " {").Nest(); + Print("string s;"); + Print("::tensorflow::strings::ProtoTextOutput o(&s, ", + short_debug ? "true" : "false", ");"); + Print("internal::AppendProtoDebugString(&o, msg);"); + Print("o.CloseTopMessage();"); + Print("return s;"); + Unnest().Print("}"); + } + + // Make the Append function. + const string sig = + StrCat("void AppendProtoDebugString(\n", + " ::tensorflow::strings::ProtoTextOutput* o,\n const ", + GetQualifiedName(md), "& msg)"); + SetOutput(&header_impl_).Print(sig, ";"); + SetOutput(&cc_); + Print().Print("namespace internal {").Print(); + Print(sig, " {").Nest(); + for (int i = 0; i < md.field_count(); ++i) { + SetOutput(&cc_); + AppendFieldAppend(*md.field(i)); + } + Unnest().Print("}").Print().Print("} // namespace internal"); +} + +void Generator::AppendMessageFunctions(const Descriptor& md) { + if (md.options().map_entry()) { + // The 'map entry' Message is not a user-visible message type. Only its + // parse function is created (and that actually parsed the whole Map, not + // just the map entry). Printing of a map is done in the code generated for + // the containing message. + AppendParseMessageFunction(md); + return; + } + + AppendDebugStringFunctions(md); + AppendParseMessageFunction(md); + for (int i = 0; i < md.enum_type_count(); ++i) { + AppendEnumFunctions(*md.enum_type(i)); + } + for (int i = 0; i < md.nested_type_count(); ++i) { + AppendMessageFunctions(*md.nested_type(i)); + } +} + +void Generator::AddNamespaceToCurrentSection(const string& package, bool open) { + Print(); + const std::vector<string> parts = str_util::Split(package, '.'); + if (open) { + for (const auto& p : parts) { + Print("namespace ", p, " {"); + } + } else { + for (auto it = parts.rbegin(); it != parts.rend(); ++it) { + Print("} // namespace ", *it); + } + } +} + +void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) { + std::vector<string> sorted = headers; + std::sort(sorted.begin(), sorted.end()); + for (const auto& h : sorted) { + Print("#include \"", h, "\""); + } +} + +// Adds to <all_fd> and <all_d> with all descriptors recursively +// reachable from the given descriptor. +void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, + std::set<const FileDescriptor*>* all_fd, + std::set<const Descriptor*>* all_d); + +// Adds to <all_fd> and <all_d> with all descriptors recursively +// reachable from the given descriptor. +void GetAllFileDescriptorsFromMessage(const Descriptor* d, + std::set<const FileDescriptor*>* all_fd, + std::set<const Descriptor*>* all_d) { + if (!all_d->insert(d).second) return; + GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d); + for (int i = 0; i < d->field_count(); ++i) { + auto* f = d->field(i); + switch (f->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + case FieldDescriptor::CPPTYPE_INT64: + case FieldDescriptor::CPPTYPE_UINT32: + case FieldDescriptor::CPPTYPE_UINT64: + case FieldDescriptor::CPPTYPE_DOUBLE: + case FieldDescriptor::CPPTYPE_FLOAT: + case FieldDescriptor::CPPTYPE_BOOL: + case FieldDescriptor::CPPTYPE_STRING: + break; + case FieldDescriptor::CPPTYPE_MESSAGE: + GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d); + break; + case FieldDescriptor::CPPTYPE_ENUM: + GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d); + break; + } + } + for (int i = 0; i < d->nested_type_count(); ++i) { + GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d); + } +} + +void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, + std::set<const FileDescriptor*>* all_fd, + std::set<const Descriptor*>* all_d) { + if (!all_fd->insert(fd).second) return; + for (int i = 0; i < fd->message_type_count(); ++i) { + GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d); + } +} + +void Generator::Generate(const FileDescriptor& fd) { + // This does not emit code with proper proto2 semantics (e.g. it doesn't check + // 'has' fields on non-messages), so check that only proto3 is passed. + CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name(); + + const string package = fd.package(); + std::set<const FileDescriptor*> all_fd; + std::set<const Descriptor*> all_d; + GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d); + + std::vector<string> headers; + + // Add header to header file. + SetOutput(&header_); + Print("// GENERATED FILE - DO NOT MODIFY"); + Print("#ifndef ", GetHeaderGuard(fd, false /* impl */)); + Print("#define ", GetHeaderGuard(fd, false /* impl */)); + Print(); + headers = { + GetProtoHeaderName(fd), + StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"), + StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"), + StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"), + }; + for (const auto& h : headers) { + Print("#include \"", h, "\""); + } + AddNamespaceToCurrentSection(package, true /* is_open */); + + // Add header to impl file. + SetOutput(&header_impl_); + Print("// GENERATED FILE - DO NOT MODIFY"); + Print("#ifndef ", GetHeaderGuard(fd, true /* impl */)); + Print("#define ", GetHeaderGuard(fd, true /* impl */)); + Print(); + headers = { + GetProtoTextHeaderName(fd, false /* impl */), + StrCat(tf_header_prefix_, + "tensorflow/core/lib/strings/proto_text_util.h"), + StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"), + }; + for (const FileDescriptor* d : all_fd) { + if (d != &fd) { + headers.push_back(GetProtoTextHeaderName(*d, true /* impl */)); + } + headers.push_back(GetProtoHeaderName(*d)); + } + for (const auto& h : headers) { + Print("#include \"", h, "\""); + } + AddNamespaceToCurrentSection(package, true /* is_open */); + SetOutput(&header_impl_).Print().Print("namespace internal {"); + + // Add header to cc file. + SetOutput(&cc_); + Print("// GENERATED FILE - DO NOT MODIFY"); + headers = {GetProtoTextHeaderName(fd, true /* impl */)}; + AddHeadersToCurrentSection(headers); + Print(); + Print("using ::tensorflow::strings::Scanner;"); + Print("using ::tensorflow::strings::StrCat;"); + AddNamespaceToCurrentSection(package, true /* is_open */); + + // Add declarations and definitions. + for (int i = 0; i < fd.enum_type_count(); ++i) { + AppendEnumFunctions(*fd.enum_type(i)); + } + for (int i = 0; i < fd.message_type_count(); ++i) { + AppendMessageFunctions(*fd.message_type(i)); + } + + // Add footer to header file. + SetOutput(&header_); + AddNamespaceToCurrentSection(package, false /* is_open */); + Print().Print("#endif // ", GetHeaderGuard(fd, false /* impl */)); + + // Add footer to header impl file. + SetOutput(&header_impl_).Print().Print("} // namespace internal"); + AddNamespaceToCurrentSection(package, false /* is_open */); + Print().Print("#endif // ", GetHeaderGuard(fd, true /* impl */)); + + // Add footer to cc file. + SetOutput(&cc_); + AddNamespaceToCurrentSection(package, false /* is_open */); +} + +} // namespace + +ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd, + const string& tf_header_prefix) { + Generator gen(tf_header_prefix); + gen.Generate(fd); + return gen.code(); +} + +} // namespace tensorflow diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h new file mode 100644 index 0000000000..f538275ec8 --- /dev/null +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib.h @@ -0,0 +1,42 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ + +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +struct ProtoTextFunctionCode { + string header; // for a file named proto_name + ".pb_text.h" + string header_impl; // for a file named proto_name + ".pb_text-impl.h" + string cc; // for a file named proto_name + ".pb_text.cc" +}; + +// Returns the generated source code for a proto file descriptor. +// +// <tf_header_prefix> is used as the prefix for #include paths, when including +// tensorflow library headers. +// +// Only works for proto3 messages. +ProtoTextFunctionCode GetProtoTextFunctionCode( + const tensorflow::protobuf::FileDescriptor& fd, + const string& tf_header_prefix); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_UTIL_CREATE_PROTO_DEBUG_STRING_LIB_H_ diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc new file mode 100644 index 0000000000..b491295967 --- /dev/null +++ b/tensorflow/tools/proto_text/gen_proto_text_functions_lib_test.cc @@ -0,0 +1,500 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" + +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/proto_text/test.pb.h" +#include "tensorflow/tools/proto_text/test.pb_text.h" + +namespace tensorflow { +namespace test { +namespace { + +// Convert <input> to text depending on <short_debug>, then parse that into a +// new message using the generated parse function. Return the new message. +template <typename T> +T RoundtripParseProtoOrDie(const T& input, bool short_debug) { + const string s = short_debug ? input.ShortDebugString() : input.DebugString(); + T t; + EXPECT_TRUE(ProtoParseFromString(s, &t)) << "Failed to parse " << s; + return t; +} + +// Macro that takes <proto> and verifies the proto text string output +// matches DebugString calls on the proto, and verifies parsing the +// DebugString output works. It does this for regular and short +// debug strings. +#define EXPECT_TEXT_TRANSFORMS_MATCH() \ + EXPECT_EQ(proto.DebugString(), ProtoDebugString(proto)); \ + EXPECT_EQ(proto.ShortDebugString(), ProtoShortDebugString(proto)); \ + EXPECT_EQ(proto.DebugString(), \ + RoundtripParseProtoOrDie(proto, true).DebugString()); \ + EXPECT_EQ(proto.DebugString(), \ + RoundtripParseProtoOrDie(proto, false).DebugString()); + +// Macro for failure cases. Verifies both protobuf and proto_text to +// make sure they match. +#define EXPECT_PARSE_FAILURE(str) \ + EXPECT_FALSE(ProtoParseFromString(str, &proto)); \ + EXPECT_FALSE(protobuf::TextFormat::ParseFromString(str, &proto)) + +// Macro for success cases parsing from a string. Verifies protobuf and +// proto_text cases match. +#define EXPECT_PARSE_SUCCESS(expected, str) \ + do { \ + EXPECT_TRUE(ProtoParseFromString(str, &proto)); \ + string proto_text_str = ProtoShortDebugString(proto); \ + EXPECT_TRUE(protobuf::TextFormat::ParseFromString(str, &proto)); \ + string protobuf_str = ProtoShortDebugString(proto); \ + EXPECT_EQ(proto_text_str, protobuf_str); \ + EXPECT_EQ(expected, proto_text_str); \ + } while (false) + +// Test different cases of numeric values, including repeated values. +TEST(CreateProtoDebugStringLibTest, ValidSimpleTypes) { + TestAllTypes proto; + + proto.Clear(); + proto.set_optional_int32(-1); + proto.set_optional_int64(-2); + proto.set_optional_uint32(3); + proto.set_optional_uint64(4); + proto.set_optional_sint32(-5); + proto.set_optional_sint64(-6); + proto.set_optional_fixed32(-7); + proto.set_optional_fixed64(-8); + proto.set_optional_sfixed32(-9); + proto.set_optional_sfixed64(-10); + proto.set_optional_float(-12.34); + proto.set_optional_double(-5.678); + proto.set_optional_bool(true); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Max numeric values. + proto.Clear(); + proto.set_optional_int32(std::numeric_limits<int32>::max()); + proto.set_optional_int64(std::numeric_limits<int64>::max()); + proto.set_optional_uint32(std::numeric_limits<uint32>::max()); + proto.set_optional_uint64(std::numeric_limits<uint64>::max()); + proto.set_optional_float(std::numeric_limits<float>::max()); + proto.set_optional_double(std::numeric_limits<double>::max()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Least positive numeric values. + proto.Clear(); + proto.set_optional_float(std::numeric_limits<float>::min()); + proto.set_optional_double(std::numeric_limits<double>::min()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Lowest numeric values. + proto.Clear(); + proto.set_optional_int32(std::numeric_limits<int32>::lowest()); + proto.set_optional_int64(std::numeric_limits<int64>::lowest()); + proto.set_optional_float(std::numeric_limits<float>::lowest()); + proto.set_optional_double(std::numeric_limits<double>::lowest()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // inf and -inf for float and double. + proto.Clear(); + proto.set_optional_double(std::numeric_limits<double>::infinity()); + proto.set_optional_float(std::numeric_limits<float>::infinity()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + proto.set_optional_double(-1 * std::numeric_limits<double>::infinity()); + proto.set_optional_float(-1 * std::numeric_limits<float>::infinity()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // String and bytes values. + proto.Clear(); + for (int i = 0; i < 256; ++i) { + proto.mutable_optional_string()->push_back(static_cast<char>(i)); + proto.mutable_optional_bytes()->push_back(static_cast<char>(i)); + } + strings::StrAppend(proto.mutable_optional_string(), "¢€𐍈"); + proto.set_optional_cord(proto.optional_string()); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Repeated values. Include zero values to show they are retained in + // repeateds. + proto.Clear(); + proto.add_repeated_int32(-1); + proto.add_repeated_int32(0); + proto.add_repeated_int64(0); + proto.add_repeated_int64(1); + proto.add_repeated_uint32(-10); + proto.add_repeated_uint32(0); + proto.add_repeated_uint32(10); + proto.add_repeated_uint64(-20); + proto.add_repeated_uint64(0); + proto.add_repeated_uint64(20); + proto.add_repeated_sint32(-30); + proto.add_repeated_sint32(0); + proto.add_repeated_sint32(30); + proto.add_repeated_sint64(-40); + proto.add_repeated_sint64(0); + proto.add_repeated_sint64(40); + proto.add_repeated_fixed32(-50); + proto.add_repeated_fixed32(0); + proto.add_repeated_fixed32(50); + proto.add_repeated_fixed64(-60); + proto.add_repeated_fixed64(0); + proto.add_repeated_fixed64(60); + proto.add_repeated_sfixed32(-70); + proto.add_repeated_sfixed32(0); + proto.add_repeated_sfixed32(70); + proto.add_repeated_sfixed64(-80); + proto.add_repeated_sfixed64(0); + proto.add_repeated_sfixed64(80); + proto.add_repeated_float(-1.2345); + proto.add_repeated_float(0); + proto.add_repeated_float(-2.3456); + proto.add_repeated_double(-10.2345); + proto.add_repeated_double(0); + proto.add_repeated_double(-20.3456); + proto.add_repeated_bool(false); + proto.add_repeated_bool(true); + proto.add_repeated_bool(false); + proto.add_repeated_string("abc"); + proto.add_repeated_string(""); + proto.add_repeated_string("def"); + proto.add_repeated_cord("abc"); + proto.add_repeated_cord(""); + proto.add_repeated_cord("def"); + proto.add_packed_repeated_int64(-1000); + proto.add_packed_repeated_int64(0); + proto.add_packed_repeated_int64(1000); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Proto supports [] for list values as well. + EXPECT_PARSE_SUCCESS("repeated_int32: 1 repeated_int32: 2 repeated_int32: 3", + "repeated_int32: [1, 2 , 3]"); + + // Test [] and also interesting bool values. + EXPECT_PARSE_SUCCESS(("repeated_bool: false repeated_bool: false " + "repeated_bool: true repeated_bool: true " + "repeated_bool: false repeated_bool: true"), + "repeated_bool: [false, 0, 1, true, False, True]"); + + EXPECT_PARSE_SUCCESS(("repeated_string: \"a,b\" " + "repeated_string: \"cdef\""), + "repeated_string: [ 'a,b', 'cdef' ] "); + + // Proto supports ' as quote character. + EXPECT_PARSE_SUCCESS("optional_string: \"123\\\" \\'xyz\"", + "optional_string: '123\\\" \\'xyz' "); + + EXPECT_PARSE_SUCCESS("optional_double: 10000", "optional_double: 1e4"); + + // Error cases. + EXPECT_PARSE_FAILURE("optional_string: '1' optional_string: '2'"); + EXPECT_PARSE_FAILURE("optional_double: 123 optional_double: 456"); + EXPECT_PARSE_FAILURE("optional_double: 0001"); + EXPECT_PARSE_FAILURE("optional_double: 000.1"); + EXPECT_PARSE_FAILURE("optional_double: a"); + EXPECT_PARSE_FAILURE("optional_double: x123"); + EXPECT_PARSE_FAILURE("optional_double: '123'"); + EXPECT_PARSE_FAILURE("optional_double: --111"); + EXPECT_PARSE_FAILURE("optional_string: 'abc\""); + EXPECT_PARSE_FAILURE("optional_bool: truE"); + EXPECT_PARSE_FAILURE("optional_bool: FALSE"); +} + +TEST(CreateProtoDebugStringLibTest, NestedMessages) { + TestAllTypes proto; + + proto.Clear(); + // Test empty message. + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_nested_message(); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_foreign_message(); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Empty messages. + proto.Clear(); + proto.mutable_optional_nested_message(); + proto.mutable_optional_foreign_message(); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_nested_message()->set_optional_int32(1); + proto.mutable_optional_foreign_message()->set_c(-1234); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_nested_message()->set_optional_int32(1234); + proto.mutable_optional_nested_message() + ->mutable_msg(); // empty double-nested + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_nested_message()->set_optional_int32(1234); + proto.mutable_optional_nested_message()->mutable_msg()->set_optional_string( + "abc"); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.mutable_optional_nested_message()->mutable_msg()->set_optional_string( + "abc"); + proto.mutable_optional_nested_message()->set_optional_int64(1234); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + auto* nested = proto.add_repeated_nested_message(); + nested = proto.add_repeated_nested_message(); + nested->set_optional_int32(123); + nested->mutable_msg(); + nested = proto.add_repeated_nested_message(); + nested->mutable_msg(); + nested->mutable_msg()->set_optional_string("abc"); + nested->set_optional_int64(1234); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // text format allows use of <> for messages. + EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }", + "optional_nested_message: < optional_int32: 123 >"); + + // <> and {} must use same style for closing. + EXPECT_PARSE_FAILURE("optional_nested_message: < optional_int32: 123 }"); + EXPECT_PARSE_FAILURE("optional_nested_message: { optional_int32: 123 >"); + + // colon after identifier is optional for messages. + EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }", + "optional_nested_message < optional_int32: 123 >"); + + EXPECT_PARSE_SUCCESS("optional_nested_message { optional_int32: 123 }", + "optional_nested_message{ optional_int32: 123 } "); + + // Proto supports [] for list values as well. + EXPECT_PARSE_SUCCESS( + ("repeated_nested_message { } " + "repeated_nested_message { optional_int32: 123 }"), + "repeated_nested_message: [ { }, { optional_int32: 123 } ]"); + + // Colon after repeated_nested_message is optional. + EXPECT_PARSE_SUCCESS( + ("repeated_nested_message { } " + "repeated_nested_message { optional_int32: 123 }"), + "repeated_nested_message [ { }, { optional_int32: 123 } ]"); + + // Using the list format a:[..] twice, like a:[..] a:[..] joins the two + // arrays. + EXPECT_PARSE_SUCCESS( + ("repeated_nested_message { } " + "repeated_nested_message { optional_int32: 123 } " + "repeated_nested_message { optional_int32: 456 }"), + ("repeated_nested_message [ { }, { optional_int32: 123 } ]" + "repeated_nested_message [ { optional_int32: 456 } ]")); + + // Parse errors on nested messages. + EXPECT_PARSE_FAILURE("optional_nested_message: {optional_int32: 'abc' }"); + + // Optional_nested_message appearing twice is an error. + EXPECT_PARSE_FAILURE( + ("optional_nested_message { optional_int32: 123 } " + "optional_nested_message { optional_int64: 456 }")); +} + +TEST(CreateProtoDebugStringLibTest, RecursiveMessage) { + NestedTestAllTypes proto; + + NestedTestAllTypes* cur = &proto; + for (int depth = 0; depth < 20; ++depth) { + cur->mutable_payload()->set_optional_int32(1000 + depth); + cur = cur->mutable_child(); + } + EXPECT_TEXT_TRANSFORMS_MATCH(); +} + +template <typename T> +T ParseProto(const string& value_text_proto) { + T value; + EXPECT_TRUE(protobuf::TextFormat::ParseFromString(value_text_proto, &value)) + << value_text_proto; + return value; +} + +TestAllTypes::NestedMessage ParseNestedMessage(const string& value_text_proto) { + return ParseProto<TestAllTypes::NestedMessage>(value_text_proto); +} + +TEST(CreateProtoDebugStringLibTest, Map) { + TestAllTypes proto; + + std::vector<TestAllTypes::NestedMessage> msg_values; + msg_values.push_back(ParseNestedMessage("optional_int32: 345")); + msg_values.push_back(ParseNestedMessage("optional_int32: 123")); + msg_values.push_back(ParseNestedMessage("optional_int32: 234")); + msg_values.push_back(ParseNestedMessage("optional_int32: 0")); + + // string->message map + proto.Clear(); + { + auto& map = *proto.mutable_map_string_to_message(); + map["def"] = msg_values[0]; + map["abc"] = msg_values[1]; + map["cde"] = msg_values[2]; + map[""] = msg_values[3]; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // int32->message map. + proto.Clear(); + { + auto& map = *proto.mutable_map_int32_to_message(); + map[20] = msg_values[0]; + map[10] = msg_values[1]; + map[15] = msg_values[2]; + map[0] = msg_values[3]; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // int64->message map. + proto.Clear(); + { + auto& map = *proto.mutable_map_int64_to_message(); + map[20] = msg_values[0]; + map[10] = msg_values[1]; + map[15] = msg_values[2]; + map[0] = msg_values[3]; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // bool->message map. + proto.Clear(); + { + auto& map = *proto.mutable_map_int64_to_message(); + map[true] = msg_values[0]; + map[false] = msg_values[1]; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // string->int64 map. + proto.Clear(); + { + auto& map = *proto.mutable_map_string_to_int64(); + map["def"] = 0; + map["abc"] = std::numeric_limits<int64>::max(); + map[""] = 20; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // int64->string map. + proto.Clear(); + { + auto& map = *proto.mutable_map_int64_to_string(); + map[0] = "def"; + map[std::numeric_limits<int64>::max()] = ""; + map[20] = "abc"; + } + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Test a map with the same key multiple times. + EXPECT_PARSE_SUCCESS(("map_string_to_int64 { key: \"abc\" value: 5 } " + "map_string_to_int64 { key: \"def\" value: 2 } " + "map_string_to_int64 { key: \"ghi\" value: 4 }"), + ("map_string_to_int64: { key: 'abc' value: 1 } " + "map_string_to_int64: { key: 'def' value: 2 } " + "map_string_to_int64: { key: 'ghi' value: 3 } " + "map_string_to_int64: { key: 'ghi' value: 4 } " + "map_string_to_int64: { key: 'abc' value: 5 } ")); +} + +TEST(CreateProtoDebugStringLibTest, Enums) { + TestAllTypes proto; + + proto.Clear(); + proto.set_optional_nested_enum(TestAllTypes::ZERO); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.set_optional_nested_enum(TestAllTypes::FOO); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.add_repeated_nested_enum(TestAllTypes::FOO); + proto.add_repeated_nested_enum(TestAllTypes::ZERO); + proto.add_repeated_nested_enum(TestAllTypes::BAR); + proto.add_repeated_nested_enum(TestAllTypes::NEG); + proto.add_repeated_nested_enum(TestAllTypes::ZERO); + proto.set_optional_foreign_enum(ForeignEnum::FOREIGN_BAR); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Parsing from numbers works as well. + EXPECT_PARSE_SUCCESS( + "optional_nested_enum: BAR " // 2 + "repeated_nested_enum: BAR " // 2 + "repeated_nested_enum: ZERO " // 0 + "repeated_nested_enum: FOO", // 1 + ("repeated_nested_enum: 2 " + "repeated_nested_enum: 0 " + "optional_nested_enum: 2 " + "repeated_nested_enum: 1")); + + EXPECT_PARSE_SUCCESS("", "optional_nested_enum: -0"); + EXPECT_PARSE_FAILURE("optional_nested_enum: 6"); + EXPECT_PARSE_FAILURE("optional_nested_enum: BARNONE"); + EXPECT_PARSE_FAILURE("optional_nested_enum: 'BAR'"); + EXPECT_PARSE_FAILURE("optional_nested_enum: \"BAR\" "); + + EXPECT_EQ(string("BAR"), + string(EnumName_TestAllTypes_NestedEnum(TestAllTypes::BAR))); + // out of range - returns empty string (see NameOfEnum in proto library). + EXPECT_EQ(string(""), string(EnumName_TestAllTypes_NestedEnum( + static_cast<TestAllTypes_NestedEnum>(123)))); +} + +TEST(CreateProtoDebugStringLibTest, Oneof) { + TestAllTypes proto; + + proto.Clear(); + proto.set_oneof_string("abc"); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + proto.Clear(); + proto.set_oneof_string("abc"); + proto.set_oneof_uint32(123); + EXPECT_TEXT_TRANSFORMS_MATCH(); + + // Parse a text format that lists multiple members of the oneof. + EXPECT_PARSE_FAILURE("oneof_string: \"abc\" oneof_uint32: 13 "); + EXPECT_PARSE_FAILURE("oneof_string: \"abc\" oneof_string: \"def\" "); +} + +TEST(CreateProtoDebugStringLibTest, Comments) { + TestAllTypes proto; + + EXPECT_PARSE_SUCCESS("optional_int64: 123 optional_string: \"#text\"", + ("#leading comment \n" + "optional_int64# comment\n" + ":# comment\n" + "123# comment\n" + "optional_string # comment\n" + ": # comment\n" + "\"#text\"#comment####\n")); + + EXPECT_PARSE_FAILURE("optional_int64:// not a valid comment\n123"); + EXPECT_PARSE_FAILURE("optional_int64:/* not a valid comment */\n123"); +} + +} // namespace +} // namespace test +} // namespace tensorflow diff --git a/tensorflow/tools/proto_text/test.proto b/tensorflow/tools/proto_text/test.proto new file mode 100644 index 0000000000..9dd83e6393 --- /dev/null +++ b/tensorflow/tools/proto_text/test.proto @@ -0,0 +1,107 @@ +syntax = "proto3"; + +package tensorflow.test; + +message TestAllTypes { + message NestedMessage { + message DoubleNestedMessage { + string optional_string = 1; + } + + int32 optional_int32 = 1; + repeated int32 repeated_int32 = 2; + DoubleNestedMessage msg = 3; + int64 optional_int64 = 4; + } + + enum NestedEnum { + ZERO = 0; + FOO = 1; + BAR = 2; + BAZ = 3; + NEG = -1; // Intentionally negative. + } + + // Singular + int32 optional_int32 = 1; + int64 optional_int64 = 2; + uint32 optional_uint32 = 3; + uint64 optional_uint64 = 4; + sint32 optional_sint32 = 5; + sint64 optional_sint64 = 6; + fixed32 optional_fixed32 = 7; + fixed64 optional_fixed64 = 8; + sfixed32 optional_sfixed32 = 9; + sfixed64 optional_sfixed64 = 10; + float optional_float = 11; + double optional_double = 12; + bool optional_bool = 13; + string optional_string = 14; + bytes optional_bytes = 15; + + NestedMessage optional_nested_message = 18; + ForeignMessage optional_foreign_message = 19; + + NestedEnum optional_nested_enum = 21; + ForeignEnum optional_foreign_enum = 22; + + string optional_cord = 25; + + // Repeated + repeated int32 repeated_int32 = 31; + repeated int64 repeated_int64 = 32; + repeated uint32 repeated_uint32 = 33; + repeated uint64 repeated_uint64 = 34; + repeated sint32 repeated_sint32 = 35; + repeated sint64 repeated_sint64 = 36; + repeated fixed32 repeated_fixed32 = 37; + repeated fixed64 repeated_fixed64 = 38; + repeated sfixed32 repeated_sfixed32 = 39; + repeated sfixed64 repeated_sfixed64 = 40; + repeated float repeated_float = 41; + repeated double repeated_double = 42; + repeated bool repeated_bool = 43; + repeated string repeated_string = 44; + repeated bytes repeated_bytes = 45; + + repeated NestedMessage repeated_nested_message = 48; + repeated NestedEnum repeated_nested_enum = 51; + + repeated string repeated_cord = 55; + + oneof oneof_field { + uint32 oneof_uint32 = 111; + NestedMessage oneof_nested_message = 112; + string oneof_string = 113; + bytes oneof_bytes = 114; + } + + map<string, NestedMessage> map_string_to_message = 58; + map<int32, NestedMessage> map_int32_to_message = 59; + map<int64, NestedMessage> map_int64_to_message = 60; + map<bool, NestedMessage> map_bool_to_message = 61; + map<string, int64> map_string_to_int64 = 62; + map<int64, string> map_int64_to_string = 63; + + repeated int64 packed_repeated_int64 = 64 [packed = true]; +} + +// A recursive message. +message NestedTestAllTypes { + NestedTestAllTypes child = 1; + TestAllTypes payload = 2; +} + +message ForeignMessage { + int32 c = 1; +} + +enum ForeignEnum { + FOREIGN_ZERO = 0; + FOREIGN_FOO = 4; + FOREIGN_BAR = 5; + FOREIGN_BAZ = 6; +} + +message TestEmptyMessage { +} |