aboutsummaryrefslogtreecommitdiffhomepage
path: root/vendor/github.com/golang/protobuf/proto/table_marshal.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/golang/protobuf/proto/table_marshal.go')
-rw-r--r--vendor/github.com/golang/protobuf/proto/table_marshal.go229
1 files changed, 162 insertions, 67 deletions
diff --git a/vendor/github.com/golang/protobuf/proto/table_marshal.go b/vendor/github.com/golang/protobuf/proto/table_marshal.go
index 0f212b3..5cb11fa 100644
--- a/vendor/github.com/golang/protobuf/proto/table_marshal.go
+++ b/vendor/github.com/golang/protobuf/proto/table_marshal.go
@@ -87,6 +87,7 @@ type marshalElemInfo struct {
sizer sizer
marshaler marshaler
isptr bool // elem is pointer typed, thus interface of this type is a direct interface (extension only)
+ deref bool // dereference the pointer before operating on it; implies isptr
}
var (
@@ -231,7 +232,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
return b, err
}
- var err, errreq error
+ var err, errLater error
// The old marshaler encodes extensions at beginning.
if u.extensions.IsValid() {
e := ptr.offset(u.extensions).toExtensions()
@@ -252,11 +253,13 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
}
}
for _, f := range u.fields {
- if f.required && errreq == nil {
+ if f.required {
if ptr.offset(f.field).getPointer().isNil() {
// Required field is not set.
// We record the error but keep going, to give a complete marshaling.
- errreq = &RequiredNotSetError{f.name}
+ if errLater == nil {
+ errLater = &RequiredNotSetError{f.name}
+ }
continue
}
}
@@ -269,14 +272,21 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
if err1, ok := err.(*RequiredNotSetError); ok {
// Required field in submessage is not set.
// We record the error but keep going, to give a complete marshaling.
- if errreq == nil {
- errreq = &RequiredNotSetError{f.name + "." + err1.field}
+ if errLater == nil {
+ errLater = &RequiredNotSetError{f.name + "." + err1.field}
}
continue
}
if err == errRepeatedHasNil {
err = errors.New("proto: repeated field " + f.name + " has nil element")
}
+ if err == errInvalidUTF8 {
+ if errLater == nil {
+ fullName := revProtoTypes[reflect.PtrTo(u.typ)] + "." + f.name
+ errLater = &invalidUTF8Error{fullName}
+ }
+ continue
+ }
return b, err
}
}
@@ -284,7 +294,7 @@ func (u *marshalInfo) marshal(b []byte, ptr pointer, deterministic bool) ([]byte
s := *ptr.offset(u.unrecognized).toBytes()
b = append(b, s...)
}
- return b, errreq
+ return b, errLater
}
// computeMarshalInfo initializes the marshal info.
@@ -311,8 +321,11 @@ func (u *marshalInfo) computeMarshalInfo() {
// get oneof implementers
var oneofImplementers []interface{}
- if m, ok := reflect.Zero(reflect.PtrTo(t)).Interface().(oneofMessage); ok {
+ switch m := reflect.Zero(reflect.PtrTo(t)).Interface().(type) {
+ case oneofFuncsIface:
_, _, _, oneofImplementers = m.XXX_OneofFuncs()
+ case oneofWrappersIface:
+ oneofImplementers = m.XXX_OneofWrappers()
}
n := t.NumField()
@@ -398,13 +411,22 @@ func (u *marshalInfo) getExtElemInfo(desc *ExtensionDesc) *marshalElemInfo {
panic("tag is not an integer")
}
wt := wiretype(tags[0])
+ if t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct {
+ t = t.Elem()
+ }
sizer, marshaler := typeMarshaler(t, tags, false, false)
+ var deref bool
+ if t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 {
+ t = reflect.PtrTo(t)
+ deref = true
+ }
e = &marshalElemInfo{
wiretag: uint64(tag)<<3 | wt,
tagsize: SizeVarint(uint64(tag) << 3),
sizer: sizer,
marshaler: marshaler,
isptr: t.Kind() == reflect.Ptr,
+ deref: deref,
}
// update cache
@@ -439,7 +461,7 @@ func (fi *marshalFieldInfo) computeMarshalFieldInfo(f *reflect.StructField) {
func (fi *marshalFieldInfo) computeOneofFieldInfo(f *reflect.StructField, oneofImplementers []interface{}) {
fi.field = toField(f)
- fi.wiretag = 1<<31 - 1 // Use a large tag number, make oneofs sorted at the end. This tag will not appear on the wire.
+ fi.wiretag = math.MaxInt32 // Use a large tag number, make oneofs sorted at the end. This tag will not appear on the wire.
fi.isPointer = true
fi.sizer, fi.marshaler = makeOneOfMarshaler(fi, f)
fi.oneofElems = make(map[reflect.Type]*marshalElemInfo)
@@ -467,10 +489,6 @@ func (fi *marshalFieldInfo) computeOneofFieldInfo(f *reflect.StructField, oneofI
}
}
-type oneofMessage interface {
- XXX_OneofFuncs() (func(Message, *Buffer) error, func(Message, int, int, *Buffer) (bool, error), func(Message) int, []interface{})
-}
-
// wiretype returns the wire encoding of the type.
func wiretype(encoding string) uint64 {
switch encoding {
@@ -530,6 +548,7 @@ func typeMarshaler(t reflect.Type, tags []string, nozero, oneof bool) (sizer, ma
packed := false
proto3 := false
+ validateUTF8 := true
for i := 2; i < len(tags); i++ {
if tags[i] == "packed" {
packed = true
@@ -538,6 +557,7 @@ func typeMarshaler(t reflect.Type, tags []string, nozero, oneof bool) (sizer, ma
proto3 = true
}
}
+ validateUTF8 = validateUTF8 && proto3
switch t.Kind() {
case reflect.Bool:
@@ -735,6 +755,18 @@ func typeMarshaler(t reflect.Type, tags []string, nozero, oneof bool) (sizer, ma
}
return sizeFloat64Value, appendFloat64Value
case reflect.String:
+ if validateUTF8 {
+ if pointer {
+ return sizeStringPtr, appendUTF8StringPtr
+ }
+ if slice {
+ return sizeStringSlice, appendUTF8StringSlice
+ }
+ if nozero {
+ return sizeStringValueNoZero, appendUTF8StringValueNoZero
+ }
+ return sizeStringValue, appendUTF8StringValue
+ }
if pointer {
return sizeStringPtr, appendStringPtr
}
@@ -1984,51 +2016,104 @@ func appendBoolPackedSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byt
}
func appendStringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
v := *ptr.toString()
+ b = appendVarint(b, wiretag)
+ b = appendVarint(b, uint64(len(v)))
+ b = append(b, v...)
+ return b, nil
+}
+func appendStringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ v := *ptr.toString()
+ if v == "" {
+ return b, nil
+ }
+ b = appendVarint(b, wiretag)
+ b = appendVarint(b, uint64(len(v)))
+ b = append(b, v...)
+ return b, nil
+}
+func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ p := *ptr.toStringPtr()
+ if p == nil {
+ return b, nil
+ }
+ v := *p
+ b = appendVarint(b, wiretag)
+ b = appendVarint(b, uint64(len(v)))
+ b = append(b, v...)
+ return b, nil
+}
+func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ s := *ptr.toStringSlice()
+ for _, v := range s {
+ b = appendVarint(b, wiretag)
+ b = appendVarint(b, uint64(len(v)))
+ b = append(b, v...)
+ }
+ return b, nil
+}
+func appendUTF8StringValue(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ var invalidUTF8 bool
+ v := *ptr.toString()
if !utf8.ValidString(v) {
- return nil, errInvalidUTF8
+ invalidUTF8 = true
}
b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v)))
b = append(b, v...)
+ if invalidUTF8 {
+ return b, errInvalidUTF8
+ }
return b, nil
}
-func appendStringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+func appendUTF8StringValueNoZero(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ var invalidUTF8 bool
v := *ptr.toString()
if v == "" {
return b, nil
}
if !utf8.ValidString(v) {
- return nil, errInvalidUTF8
+ invalidUTF8 = true
}
b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v)))
b = append(b, v...)
+ if invalidUTF8 {
+ return b, errInvalidUTF8
+ }
return b, nil
}
-func appendStringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+func appendUTF8StringPtr(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ var invalidUTF8 bool
p := *ptr.toStringPtr()
if p == nil {
return b, nil
}
v := *p
if !utf8.ValidString(v) {
- return nil, errInvalidUTF8
+ invalidUTF8 = true
}
b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v)))
b = append(b, v...)
+ if invalidUTF8 {
+ return b, errInvalidUTF8
+ }
return b, nil
}
-func appendStringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+func appendUTF8StringSlice(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
+ var invalidUTF8 bool
s := *ptr.toStringSlice()
for _, v := range s {
if !utf8.ValidString(v) {
- return nil, errInvalidUTF8
+ invalidUTF8 = true
}
b = appendVarint(b, wiretag)
b = appendVarint(b, uint64(len(v)))
b = append(b, v...)
}
+ if invalidUTF8 {
+ return b, errInvalidUTF8
+ }
return b, nil
}
func appendBytes(b []byte, ptr pointer, wiretag uint64, _ bool) ([]byte, error) {
@@ -2107,7 +2192,8 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
},
func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
s := ptr.getPointerSlice()
- var err, errreq error
+ var err error
+ var nerr nonFatal
for _, v := range s {
if v.isNil() {
return b, errRepeatedHasNil
@@ -2115,22 +2201,14 @@ func makeGroupSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
b = appendVarint(b, wiretag) // start group
b, err = u.marshal(b, v, deterministic)
b = appendVarint(b, wiretag+(WireEndGroup-WireStartGroup)) // end group
- if err != nil {
- if _, ok := err.(*RequiredNotSetError); ok {
- // Required field in submessage is not set.
- // We record the error but keep going, to give a complete marshaling.
- if errreq == nil {
- errreq = err
- }
- continue
- }
+ if !nerr.Merge(err) {
if err == ErrNil {
err = errRepeatedHasNil
}
return b, err
}
}
- return b, errreq
+ return b, nerr.E
}
}
@@ -2174,7 +2252,8 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
},
func(b []byte, ptr pointer, wiretag uint64, deterministic bool) ([]byte, error) {
s := ptr.getPointerSlice()
- var err, errreq error
+ var err error
+ var nerr nonFatal
for _, v := range s {
if v.isNil() {
return b, errRepeatedHasNil
@@ -2184,22 +2263,14 @@ func makeMessageSliceMarshaler(u *marshalInfo) (sizer, marshaler) {
b = appendVarint(b, uint64(siz))
b, err = u.marshal(b, v, deterministic)
- if err != nil {
- if _, ok := err.(*RequiredNotSetError); ok {
- // Required field in submessage is not set.
- // We record the error but keep going, to give a complete marshaling.
- if errreq == nil {
- errreq = err
- }
- continue
- }
+ if !nerr.Merge(err) {
if err == ErrNil {
err = errRepeatedHasNil
}
return b, err
}
}
- return b, errreq
+ return b, nerr.E
}
}
@@ -2223,14 +2294,33 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
// value.
// Key cannot be pointer-typed.
valIsPtr := valType.Kind() == reflect.Ptr
+
+ // If value is a message with nested maps, calling
+ // valSizer in marshal may be quadratic. We should use
+ // cached version in marshal (but not in size).
+ // If value is not message type, we don't have size cache,
+ // but it cannot be nested either. Just use valSizer.
+ valCachedSizer := valSizer
+ if valIsPtr && valType.Elem().Kind() == reflect.Struct {
+ u := getMarshalInfo(valType.Elem())
+ valCachedSizer = func(ptr pointer, tagsize int) int {
+ // Same as message sizer, but use cache.
+ p := ptr.getPointer()
+ if p.isNil() {
+ return 0
+ }
+ siz := u.cachedsize(p)
+ return siz + SizeVarint(uint64(siz)) + tagsize
+ }
+ }
return func(ptr pointer, tagsize int) int {
m := ptr.asPointerTo(t).Elem() // the map
n := 0
for _, k := range m.MapKeys() {
ki := k.Interface()
vi := m.MapIndex(k).Interface()
- kaddr := toAddrPointer(&ki, false) // pointer to key
- vaddr := toAddrPointer(&vi, valIsPtr) // pointer to value
+ kaddr := toAddrPointer(&ki, false, false) // pointer to key
+ vaddr := toAddrPointer(&vi, valIsPtr, false) // pointer to value
siz := keySizer(kaddr, 1) + valSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
n += siz + SizeVarint(uint64(siz)) + tagsize
}
@@ -2243,24 +2333,26 @@ func makeMapMarshaler(f *reflect.StructField) (sizer, marshaler) {
if len(keys) > 1 && deterministic {
sort.Sort(mapKeys(keys))
}
+
+ var nerr nonFatal
for _, k := range keys {
ki := k.Interface()
vi := m.MapIndex(k).Interface()
- kaddr := toAddrPointer(&ki, false) // pointer to key
- vaddr := toAddrPointer(&vi, valIsPtr) // pointer to value
+ kaddr := toAddrPointer(&ki, false, false) // pointer to key
+ vaddr := toAddrPointer(&vi, valIsPtr, false) // pointer to value
b = appendVarint(b, tag)
- siz := keySizer(kaddr, 1) + valSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
+ siz := keySizer(kaddr, 1) + valCachedSizer(vaddr, 1) // tag of key = 1 (size=1), tag of val = 2 (size=1)
b = appendVarint(b, uint64(siz))
b, err = keyMarshaler(b, kaddr, keyWireTag, deterministic)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
b, err = valMarshaler(b, vaddr, valWireTag, deterministic)
- if err != nil && err != ErrNil { // allow nil value in map
+ if err != ErrNil && !nerr.Merge(err) { // allow nil value in map
return b, err
}
}
- return b, nil
+ return b, nerr.E
}
}
@@ -2316,7 +2408,7 @@ func (u *marshalInfo) sizeExtensions(ext *XXX_InternalExtensions) int {
// the last time this function was called.
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
n += ei.sizer(p, ei.tagsize)
}
mu.Unlock()
@@ -2333,6 +2425,7 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
defer mu.Unlock()
var err error
+ var nerr nonFatal
// Fast-path for common cases: zero or one extensions.
// Don't bother sorting the keys.
@@ -2350,13 +2443,13 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
}
- return b, nil
+ return b, nerr.E
}
// Sort the keys to provide a deterministic encoding.
@@ -2381,13 +2474,13 @@ func (u *marshalInfo) appendExtensions(b []byte, ext *XXX_InternalExtensions, de
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
}
- return b, nil
+ return b, nerr.E
}
// message set format is:
@@ -2426,7 +2519,7 @@ func (u *marshalInfo) sizeMessageSet(ext *XXX_InternalExtensions) int {
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
n += ei.sizer(p, 1) // message, tag = 3 (size=1)
}
mu.Unlock()
@@ -2444,6 +2537,7 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
defer mu.Unlock()
var err error
+ var nerr nonFatal
// Fast-path for common cases: zero or one extensions.
// Don't bother sorting the keys.
@@ -2468,14 +2562,14 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
b = append(b, 1<<3|WireEndGroup)
}
- return b, nil
+ return b, nerr.E
}
// Sort the keys to provide a deterministic encoding.
@@ -2506,14 +2600,14 @@ func (u *marshalInfo) appendMessageSet(b []byte, ext *XXX_InternalExtensions, de
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
b, err = ei.marshaler(b, p, 3<<3|WireBytes, deterministic)
b = append(b, 1<<3|WireEndGroup)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
}
- return b, nil
+ return b, nerr.E
}
// sizeV1Extensions computes the size of encoded data for a V1-API extension field.
@@ -2536,7 +2630,7 @@ func (u *marshalInfo) sizeV1Extensions(m map[int32]Extension) int {
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
n += ei.sizer(p, ei.tagsize)
}
return n
@@ -2556,6 +2650,7 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
sort.Ints(keys)
var err error
+ var nerr nonFatal
for _, k := range keys {
e := m[int32(k)]
if e.value == nil || e.desc == nil {
@@ -2570,13 +2665,13 @@ func (u *marshalInfo) appendV1Extensions(b []byte, m map[int32]Extension, determ
ei := u.getExtElemInfo(e.desc)
v := e.value
- p := toAddrPointer(&v, ei.isptr)
+ p := toAddrPointer(&v, ei.isptr, ei.deref)
b, err = ei.marshaler(b, p, ei.wiretag, deterministic)
- if err != nil {
+ if !nerr.Merge(err) {
return b, err
}
}
- return b, nil
+ return b, nerr.E
}
// newMarshaler is the interface representing objects that can marshal themselves.