diff options
Diffstat (limited to 'src/google/protobuf/wire_format.cc')
-rw-r--r-- | src/google/protobuf/wire_format.cc | 309 |
1 files changed, 306 insertions, 3 deletions
diff --git a/src/google/protobuf/wire_format.cc b/src/google/protobuf/wire_format.cc index bd5b2489..01704c94 100644 --- a/src/google/protobuf/wire_format.cc +++ b/src/google/protobuf/wire_format.cc @@ -54,9 +54,17 @@ namespace google { +const size_t kMapEntryTagByteSize = 2; + namespace protobuf { namespace internal { +// Forward declare static functions +static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field, + const MapKey& value); +static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field, + const MapValueRef& value); + // =================================================================== bool UnknownFieldSetFieldSkipper::SkipField( @@ -76,6 +84,8 @@ void UnknownFieldSetFieldSkipper::SkipUnknownEnum( bool WireFormat::SkipField(io::CodedInputStream* input, uint32 tag, UnknownFieldSet* unknown_fields) { int number = WireFormatLite::GetTagFieldNumber(tag); + // Field number 0 is illegal. + if (number == 0) return false; switch (WireFormatLite::GetTagWireType(tag)) { case WireFormatLite::WIRETYPE_VARINT: { @@ -795,7 +805,16 @@ void WireFormat::SerializeWithCachedSizes( int expected_endpoint = output->ByteCount() + size; std::vector<const FieldDescriptor*> fields; - message_reflection->ListFields(message, &fields); + + // Fields of map entry should always be serialized. + if (descriptor->options().map_entry()) { + for (int i = 0; i < descriptor->field_count(); i++) { + fields.push_back(descriptor->field(i)); + } + } else { + message_reflection->ListFields(message, &fields); + } + for (int i = 0; i < fields.size(); i++) { SerializeFieldWithCachedSizes(fields[i], message, output); } @@ -814,6 +833,129 @@ void WireFormat::SerializeWithCachedSizes( "during serialization?"; } +static void SerializeMapKeyWithCachedSizes(const FieldDescriptor* field, + const MapKey& value, + io::CodedOutputStream* output) { + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + case FieldDescriptor::TYPE_FLOAT: + case FieldDescriptor::TYPE_GROUP: + case FieldDescriptor::TYPE_MESSAGE: + case FieldDescriptor::TYPE_BYTES: + case FieldDescriptor::TYPE_ENUM: + GOOGLE_LOG(FATAL) << "Unsupported"; + break; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + WireFormatLite::Write##CamelFieldType(1, value.Get##CamelCppType##Value(), \ + output); \ + break; + CASE_TYPE(INT64, Int64, Int64) + CASE_TYPE(UINT64, UInt64, UInt64) + CASE_TYPE(INT32, Int32, Int32) + CASE_TYPE(FIXED64, Fixed64, UInt64) + CASE_TYPE(FIXED32, Fixed32, UInt32) + CASE_TYPE(BOOL, Bool, Bool) + CASE_TYPE(UINT32, UInt32, UInt32) + CASE_TYPE(SFIXED32, SFixed32, Int32) + CASE_TYPE(SFIXED64, SFixed64, Int64) + CASE_TYPE(SINT32, SInt32, Int32) + CASE_TYPE(SINT64, SInt64, Int64) + CASE_TYPE(STRING, String, String) +#undef CASE_TYPE + } +} + +static void SerializeMapValueRefWithCachedSizes(const FieldDescriptor* field, + const MapValueRef& value, + io::CodedOutputStream* output) { + switch (field->type()) { +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + WireFormatLite::Write##CamelFieldType(2, value.Get##CamelCppType##Value(), \ + output); \ + break; + CASE_TYPE(INT64, Int64, Int64) + CASE_TYPE(UINT64, UInt64, UInt64) + CASE_TYPE(INT32, Int32, Int32) + CASE_TYPE(FIXED64, Fixed64, UInt64) + CASE_TYPE(FIXED32, Fixed32, UInt32) + CASE_TYPE(BOOL, Bool, Bool) + CASE_TYPE(UINT32, UInt32, UInt32) + CASE_TYPE(SFIXED32, SFixed32, Int32) + CASE_TYPE(SFIXED64, SFixed64, Int64) + CASE_TYPE(SINT32, SInt32, Int32) + CASE_TYPE(SINT64, SInt64, Int64) + CASE_TYPE(ENUM, Enum, Enum) + CASE_TYPE(DOUBLE, Double, Double) + CASE_TYPE(FLOAT, Float, Float) + CASE_TYPE(STRING, String, String) + CASE_TYPE(BYTES, Bytes, String) + CASE_TYPE(MESSAGE, Message, Message) + CASE_TYPE(GROUP, Group, Message) +#undef CASE_TYPE + } +} + +class MapKeySorter { + public: + static std::vector<MapKey> SortKey(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + std::vector<MapKey> sorted_key_list; + for (MapIterator it = + reflection->MapBegin(const_cast<Message*>(&message), field); + it != reflection->MapEnd(const_cast<Message*>(&message), field); + ++it) { + sorted_key_list.push_back(it.GetKey()); + } + MapKeyComparator comparator; + std::sort(sorted_key_list.begin(), sorted_key_list.end(), comparator); + return sorted_key_list; + } + + private: + class MapKeyComparator { + public: + bool operator()(const MapKey& a, const MapKey& b) const { + GOOGLE_DCHECK(a.type() == b.type()); + switch (a.type()) { +#define CASE_TYPE(CppType, CamelCppType) \ + case FieldDescriptor::CPPTYPE_##CppType: { \ + return a.Get##CamelCppType##Value() < b.Get##CamelCppType##Value(); \ + } + CASE_TYPE(STRING, String) + CASE_TYPE(INT64, Int64) + CASE_TYPE(INT32, Int32) + CASE_TYPE(UINT64, UInt64) + CASE_TYPE(UINT32, UInt32) + CASE_TYPE(BOOL, Bool) +#undef CASE_TYPE + + default: + GOOGLE_LOG(DFATAL) << "Invalid key for map field."; + return true; + } + } + }; +}; + +static void SerializeMapEntry(const FieldDescriptor* field, const MapKey& key, + const MapValueRef& value, + io::CodedOutputStream* output) { + const FieldDescriptor* key_field = field->message_type()->field(0); + const FieldDescriptor* value_field = field->message_type()->field(1); + + WireFormatLite::WriteTag(field->number(), + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output); + size_t size = kMapEntryTagByteSize; + size += MapKeyDataOnlyByteSize(key_field, key); + size += MapValueRefDataOnlyByteSize(value_field, value); + output->WriteVarint32(size); + SerializeMapKeyWithCachedSizes(key_field, key, output); + SerializeMapValueRefWithCachedSizes(value_field, value, output); +} + void WireFormat::SerializeFieldWithCachedSizes( const FieldDescriptor* field, const Message& message, @@ -828,10 +970,55 @@ void WireFormat::SerializeFieldWithCachedSizes( return; } + // For map fields, we can use either repeated field reflection or map + // reflection. Our choice has some subtle effects. If we use repeated field + // reflection here, then the repeated field representation becomes + // authoritative for this field: any existing references that came from map + // reflection remain valid for reading, but mutations to them are lost and + // will be overwritten next time we call map reflection! + // + // So far this mainly affects Python, which keeps long-term references to map + // values around, and always uses map reflection. See: b/35918691 + // + // Here we choose to use map reflection API as long as the internal + // map is valid. In this way, the serialization doesn't change map field's + // internal state and existing references that came from map reflection remain + // valid for both reading and writing. + if (field->is_map()) { + MapFieldBase* map_field = + message_reflection->MapData(const_cast<Message*>(&message), field); + if (map_field->IsMapValid()) { + if (output->IsSerializationDeterministic()) { + std::vector<MapKey> sorted_key_list = + MapKeySorter::SortKey(message, message_reflection, field); + for (std::vector<MapKey>::iterator it = sorted_key_list.begin(); + it != sorted_key_list.end(); ++it) { + MapValueRef map_value; + message_reflection->InsertOrLookupMapValue( + const_cast<Message*>(&message), field, *it, &map_value); + SerializeMapEntry(field, *it, map_value, output); + } + } else { + for (MapIterator it = message_reflection->MapBegin( + const_cast<Message*>(&message), field); + it != + message_reflection->MapEnd(const_cast<Message*>(&message), field); + ++it) { + SerializeMapEntry(field, it.GetKey(), it.GetValueRef(), output); + } + } + + return; + } + } + int count = 0; if (field->is_repeated()) { count = message_reflection->FieldSize(message, field); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } @@ -982,7 +1169,16 @@ size_t WireFormat::ByteSize(const Message& message) { size_t our_size = 0; std::vector<const FieldDescriptor*> fields; - message_reflection->ListFields(message, &fields); + + // Fields of map entry should always be serialized. + if (descriptor->options().map_entry()) { + for (int i = 0; i < descriptor->field_count(); i++) { + fields.push_back(descriptor->field(i)); + } + } else { + message_reflection->ListFields(message, &fields); + } + for (int i = 0; i < fields.size(); i++) { our_size += FieldByteSize(fields[i], message); } @@ -1013,6 +1209,9 @@ size_t WireFormat::FieldByteSize( size_t count = 0; if (field->is_repeated()) { count = FromIntSize(message_reflection->FieldSize(message, field)); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } @@ -1033,20 +1232,124 @@ size_t WireFormat::FieldByteSize( return our_size; } +static size_t MapKeyDataOnlyByteSize(const FieldDescriptor* field, + const MapKey& value) { + GOOGLE_DCHECK_EQ(FieldDescriptor::TypeToCppType(field->type()), value.type()); + switch (field->type()) { + case FieldDescriptor::TYPE_DOUBLE: + case FieldDescriptor::TYPE_FLOAT: + case FieldDescriptor::TYPE_GROUP: + case FieldDescriptor::TYPE_MESSAGE: + case FieldDescriptor::TYPE_BYTES: + case FieldDescriptor::TYPE_ENUM: + GOOGLE_LOG(FATAL) << "Unsupported"; + return 0; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::CamelFieldType##Size( \ + value.Get##CamelCppType##Value()); + +#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::k##CamelFieldType##Size; + + CASE_TYPE(INT32, Int32, Int32); + CASE_TYPE(INT64, Int64, Int64); + CASE_TYPE(UINT32, UInt32, UInt32); + CASE_TYPE(UINT64, UInt64, UInt64); + CASE_TYPE(SINT32, SInt32, Int32); + CASE_TYPE(SINT64, SInt64, Int64); + CASE_TYPE(STRING, String, String); + FIXED_CASE_TYPE(FIXED32, Fixed32); + FIXED_CASE_TYPE(FIXED64, Fixed64); + FIXED_CASE_TYPE(SFIXED32, SFixed32); + FIXED_CASE_TYPE(SFIXED64, SFixed64); + FIXED_CASE_TYPE(BOOL, Bool); + +#undef CASE_TYPE +#undef FIXED_CASE_TYPE + } + GOOGLE_LOG(FATAL) << "Cannot get here"; + return 0; +} + +static size_t MapValueRefDataOnlyByteSize(const FieldDescriptor* field, + const MapValueRef& value) { + switch (field->type()) { + case FieldDescriptor::TYPE_GROUP: + GOOGLE_LOG(FATAL) << "Unsupported"; + return 0; +#define CASE_TYPE(FieldType, CamelFieldType, CamelCppType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::CamelFieldType##Size( \ + value.Get##CamelCppType##Value()); + +#define FIXED_CASE_TYPE(FieldType, CamelFieldType) \ + case FieldDescriptor::TYPE_##FieldType: \ + return WireFormatLite::k##CamelFieldType##Size; + + CASE_TYPE(INT32, Int32, Int32); + CASE_TYPE(INT64, Int64, Int64); + CASE_TYPE(UINT32, UInt32, UInt32); + CASE_TYPE(UINT64, UInt64, UInt64); + CASE_TYPE(SINT32, SInt32, Int32); + CASE_TYPE(SINT64, SInt64, Int64); + CASE_TYPE(STRING, String, String); + CASE_TYPE(BYTES, Bytes, String); + CASE_TYPE(ENUM, Enum, Enum); + CASE_TYPE(MESSAGE, Message, Message); + FIXED_CASE_TYPE(FIXED32, Fixed32); + FIXED_CASE_TYPE(FIXED64, Fixed64); + FIXED_CASE_TYPE(SFIXED32, SFixed32); + FIXED_CASE_TYPE(SFIXED64, SFixed64); + FIXED_CASE_TYPE(DOUBLE, Double); + FIXED_CASE_TYPE(FLOAT, Float); + FIXED_CASE_TYPE(BOOL, Bool); + +#undef CASE_TYPE +#undef FIXED_CASE_TYPE + } + GOOGLE_LOG(FATAL) << "Cannot get here"; + return 0; +} + size_t WireFormat::FieldDataOnlyByteSize( const FieldDescriptor* field, const Message& message) { const Reflection* message_reflection = message.GetReflection(); + size_t data_size = 0; + + if (field->is_map()) { + MapFieldBase* map_field = + message_reflection->MapData(const_cast<Message*>(&message), field); + if (map_field->IsMapValid()) { + MapIterator iter(const_cast<Message*>(&message), field); + MapIterator end(const_cast<Message*>(&message), field); + const FieldDescriptor* key_field = field->message_type()->field(0); + const FieldDescriptor* value_field = field->message_type()->field(1); + for (map_field->MapBegin(&iter), map_field->MapEnd(&end); iter != end; + ++iter) { + size_t size = kMapEntryTagByteSize; + size += MapKeyDataOnlyByteSize(key_field, iter.GetKey()); + size += MapValueRefDataOnlyByteSize(value_field, iter.GetValueRef()); + data_size += WireFormatLite::LengthDelimitedSize(size); + } + return data_size; + } + } + size_t count = 0; if (field->is_repeated()) { count = internal::FromIntSize(message_reflection->FieldSize(message, field)); + } else if (field->containing_type()->options().map_entry()) { + // Map entry fields always need to be serialized. + count = 1; } else if (message_reflection->HasField(message, field)) { count = 1; } - size_t data_size = 0; switch (field->type()) { #define HANDLE_TYPE(TYPE, TYPE_METHOD, CPPTYPE_METHOD) \ case FieldDescriptor::TYPE_##TYPE: \ |