diff options
Diffstat (limited to 'vendor/github.com/golang/protobuf/proto/extensions_test.go')
-rw-r--r-- | vendor/github.com/golang/protobuf/proto/extensions_test.go | 192 |
1 files changed, 172 insertions, 20 deletions
diff --git a/vendor/github.com/golang/protobuf/proto/extensions_test.go b/vendor/github.com/golang/protobuf/proto/extensions_test.go index b6d9114..dc69fe9 100644 --- a/vendor/github.com/golang/protobuf/proto/extensions_test.go +++ b/vendor/github.com/golang/protobuf/proto/extensions_test.go @@ -34,12 +34,14 @@ package proto_test import ( "bytes" "fmt" + "io" "reflect" "sort" + "strings" "testing" "github.com/golang/protobuf/proto" - pb "github.com/golang/protobuf/proto/testdata" + pb "github.com/golang/protobuf/proto/test_proto" "golang.org/x/sync/errgroup" ) @@ -64,7 +66,107 @@ func TestGetExtensionsWithMissingExtensions(t *testing.T) { } } -func TestExtensionDescsWithMissingExtensions(t *testing.T) { +func TestGetExtensionWithEmptyBuffer(t *testing.T) { + // Make sure that GetExtension returns an error if its + // undecoded buffer is empty. + msg := &pb.MyMessage{} + proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{}) + _, err := proto.GetExtension(msg, pb.E_Ext_More) + if want := io.ErrUnexpectedEOF; err != want { + t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want) + } +} + +func TestGetExtensionForIncompleteDesc(t *testing.T) { + msg := &pb.MyMessage{Count: proto.Int32(0)} + extdesc1 := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*bool)(nil), + Field: 123456789, + Name: "a.b", + Tag: "varint,123456789,opt", + } + ext1 := proto.Bool(true) + if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { + t.Fatalf("Could not set ext1: %s", err) + } + extdesc2 := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: ([]byte)(nil), + Field: 123456790, + Name: "a.c", + Tag: "bytes,123456790,opt", + } + ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7} + if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { + t.Fatalf("Could not set ext2: %s", err) + } + extdesc3 := &proto.ExtensionDesc{ + ExtendedType: (*pb.MyMessage)(nil), + ExtensionType: (*pb.Ext)(nil), + Field: 123456791, + Name: "a.d", + Tag: "bytes,123456791,opt", + } + ext3 := &pb.Ext{Data: proto.String("foo")} + if err := proto.SetExtension(msg, extdesc3, ext3); err != nil { + t.Fatalf("Could not set ext3: %s", err) + } + + b, err := proto.Marshal(msg) + if err != nil { + t.Fatalf("Could not marshal msg: %v", err) + } + if err := proto.Unmarshal(b, msg); err != nil { + t.Fatalf("Could not unmarshal into msg: %v", err) + } + + var expected proto.Buffer + if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil { + t.Fatalf("failed to compute expected prefix for ext1: %s", err) + } + if err := expected.EncodeVarint(1 /* bool true */); err != nil { + t.Fatalf("failed to compute expected value for ext1: %s", err) + } + + if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil { + t.Fatalf("Failed to get raw value for ext1: %s", err) + } else if !reflect.DeepEqual(b, expected.Bytes()) { + t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes()) + } + + expected = proto.Buffer{} // reset + if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil { + t.Fatalf("failed to compute expected prefix for ext2: %s", err) + } + if err := expected.EncodeRawBytes(ext2); err != nil { + t.Fatalf("failed to compute expected value for ext2: %s", err) + } + + if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil { + t.Fatalf("Failed to get raw value for ext2: %s", err) + } else if !reflect.DeepEqual(b, expected.Bytes()) { + t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes()) + } + + expected = proto.Buffer{} // reset + if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil { + t.Fatalf("failed to compute expected prefix for ext3: %s", err) + } + if b, err := proto.Marshal(ext3); err != nil { + t.Fatalf("failed to compute expected value for ext3: %s", err) + } else if err := expected.EncodeRawBytes(b); err != nil { + t.Fatalf("failed to compute expected value for ext3: %s", err) + } + + if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil { + t.Fatalf("Failed to get raw value for ext3: %s", err) + } else if !reflect.DeepEqual(b, expected.Bytes()) { + t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes()) + } +} + +func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) { msg := &pb.MyMessage{Count: proto.Int32(0)} extdesc1 := pb.E_Ext_More if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { @@ -100,7 +202,7 @@ func TestExtensionDescsWithMissingExtensions(t *testing.T) { t.Fatalf("proto.ExtensionDescs: got error %v", err) } sortExtDescs(descs) - wantDescs := []*proto.ExtensionDesc{extdesc1, &proto.ExtensionDesc{Field: extdesc2.Field}} + wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}} if !reflect.DeepEqual(descs, wantDescs) { t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) } @@ -200,7 +302,7 @@ func TestGetExtensionDefaults(t *testing.T) { {pb.E_DefaultSfixed64, setInt64, int64(51)}, {pb.E_DefaultBool, setBool, true}, {pb.E_DefaultBool, setBool2, true}, - {pb.E_DefaultString, setString, "Hello, string"}, + {pb.E_DefaultString, setString, "Hello, string,def=foo"}, {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, } @@ -287,6 +389,44 @@ func TestGetExtensionDefaults(t *testing.T) { } } +func TestNilMessage(t *testing.T) { + name := "nil interface" + if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil { + t.Errorf("%s: got %T %v, expected to fail", name, got, got) + } else if !strings.Contains(err.Error(), "extendable") { + t.Errorf("%s: got error %v, expected not-extendable error", name, err) + } + + // Regression tests: all functions of the Extension API + // used to panic when passed (*M)(nil), where M is a concrete message + // type. Now they handle this gracefully as a no-op or reported error. + var nilMsg *pb.MyMessage + desc := pb.E_Ext_More + + isNotExtendable := func(err error) bool { + return strings.Contains(fmt.Sprint(err), "not extendable") + } + + if proto.HasExtension(nilMsg, desc) { + t.Error("HasExtension(nil) = true") + } + + if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) { + t.Errorf("GetExtensions(nil) = %q (wrong error)", err) + } + + if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) { + t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err) + } + + if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) { + t.Errorf("SetExtension(nil) = %q (wrong error)", err) + } + + proto.ClearExtension(nilMsg, desc) // no-op + proto.ClearAllExtensions(nilMsg) // no-op +} + func TestExtensionsRoundTrip(t *testing.T) { msg := &pb.MyMessage{} ext1 := &pb.Ext{ @@ -311,7 +451,7 @@ func TestExtensionsRoundTrip(t *testing.T) { } x, ok := e.(*pb.Ext) if !ok { - t.Errorf("e has type %T, expected testdata.Ext", e) + t.Errorf("e has type %T, expected test_proto.Ext", e) } else if *x.Data != "there" { t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) } @@ -339,7 +479,7 @@ func TestNilExtension(t *testing.T) { } if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil { t.Error("expected SetExtension to fail due to a nil extension") - } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want { + } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want { t.Errorf("expected error %v, got %v", want, err) } // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update @@ -402,8 +542,13 @@ func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { if ext == nil { t.Fatalf("[%s] Invalid extension", test.name) } - if !reflect.DeepEqual(ext, test.ext) { - t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, test.ext) + if len(ext) != len(test.ext) { + t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext)) + } + for i := range test.ext { + if !proto.Equal(ext[i], test.ext[i]) { + t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i]) + } } } } @@ -477,8 +622,8 @@ func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { if ext == nil { t.Fatalf("[%s] Invalid extension", test.name) } - if !reflect.DeepEqual(*ext, want) { - t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, want) + if !proto.Equal(ext, &want) { + t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want) } } } @@ -509,19 +654,22 @@ func TestClearAllExtensions(t *testing.T) { } func TestMarshalRace(t *testing.T) { - // unregistered extension - desc := &proto.ExtensionDesc{ - ExtendedType: (*pb.MyMessage)(nil), - ExtensionType: (*bool)(nil), - Field: 101010100, - Name: "emptyextension", - Tag: "varint,0,opt", + ext := &pb.Ext{} + m := &pb.MyMessage{Count: proto.Int32(4)} + if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil { + t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err) } - m := &pb.MyMessage{Count: proto.Int32(4)} - if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { - t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) + b, err := proto.Marshal(m) + if err != nil { + t.Fatalf("Could not marshal message: %v", err) + } + if err := proto.Unmarshal(b, m); err != nil { + t.Fatalf("Could not unmarshal message: %v", err) } + // after Unmarshal, the extension is in undecoded form. + // GetExtension will decode it lazily. Make sure this does + // not race against Marshal. var g errgroup.Group for n := 3; n > 0; n-- { @@ -529,6 +677,10 @@ func TestMarshalRace(t *testing.T) { _, err := proto.Marshal(m) return err }) + g.Go(func() error { + _, err := proto.GetExtension(m, pb.E_Ext_More) + return err + }) } if err := g.Wait(); err != nil { t.Fatal(err) |