// Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package sockstest provides utilities for SOCKS testing. package sockstest import ( "errors" "io" "net" "golang.org/x/net/internal/nettest" "golang.org/x/net/internal/socks" ) // An AuthRequest represents an authentication request. type AuthRequest struct { Version int Methods []socks.AuthMethod } // ParseAuthRequest parses an authentication request. func ParseAuthRequest(b []byte) (*AuthRequest, error) { if len(b) < 2 { return nil, errors.New("short auth request") } if b[0] != socks.Version5 { return nil, errors.New("unexpected protocol version") } if len(b)-2 < int(b[1]) { return nil, errors.New("short auth request") } req := &AuthRequest{Version: int(b[0])} if b[1] > 0 { req.Methods = make([]socks.AuthMethod, b[1]) for i, m := range b[2 : 2+b[1]] { req.Methods[i] = socks.AuthMethod(m) } } return req, nil } // MarshalAuthReply returns an authentication reply in wire format. func MarshalAuthReply(ver int, m socks.AuthMethod) ([]byte, error) { return []byte{byte(ver), byte(m)}, nil } // A CmdRequest repesents a command request. type CmdRequest struct { Version int Cmd socks.Command Addr socks.Addr } // ParseCmdRequest parses a command request. func ParseCmdRequest(b []byte) (*CmdRequest, error) { if len(b) < 7 { return nil, errors.New("short cmd request") } if b[0] != socks.Version5 { return nil, errors.New("unexpected protocol version") } if socks.Command(b[1]) != socks.CmdConnect { return nil, errors.New("unexpected command") } if b[2] != 0 { return nil, errors.New("non-zero reserved field") } req := &CmdRequest{Version: int(b[0]), Cmd: socks.Command(b[1])} l := 2 off := 4 switch b[3] { case socks.AddrTypeIPv4: l += net.IPv4len req.Addr.IP = make(net.IP, net.IPv4len) case socks.AddrTypeIPv6: l += net.IPv6len req.Addr.IP = make(net.IP, net.IPv6len) case socks.AddrTypeFQDN: l += int(b[4]) off = 5 default: return nil, errors.New("unknown address type") } if len(b[off:]) < l { return nil, errors.New("short cmd request") } if req.Addr.IP != nil { copy(req.Addr.IP, b[off:]) } else { req.Addr.Name = string(b[off : off+l-2]) } req.Addr.Port = int(b[off+l-2])<<8 | int(b[off+l-1]) return req, nil } // MarshalCmdReply returns a command reply in wire format. func MarshalCmdReply(ver int, reply socks.Reply, a *socks.Addr) ([]byte, error) { b := make([]byte, 4) b[0] = byte(ver) b[1] = byte(reply) if a.Name != "" { if len(a.Name) > 255 { return nil, errors.New("fqdn too long") } b[3] = socks.AddrTypeFQDN b = append(b, byte(len(a.Name))) b = append(b, a.Name...) } else if ip4 := a.IP.To4(); ip4 != nil { b[3] = socks.AddrTypeIPv4 b = append(b, ip4...) } else if ip6 := a.IP.To16(); ip6 != nil { b[3] = socks.AddrTypeIPv6 b = append(b, ip6...) } else { return nil, errors.New("unknown address type") } b = append(b, byte(a.Port>>8), byte(a.Port)) return b, nil } // A Server repesents a server for handshake testing. type Server struct { ln net.Listener } // Addr rerurns a server address. func (s *Server) Addr() net.Addr { return s.ln.Addr() } // TargetAddr returns a fake final destination address. // // The returned address is only valid for testing with Server. func (s *Server) TargetAddr() net.Addr { a := s.ln.Addr() switch a := a.(type) { case *net.TCPAddr: if a.IP.To4() != nil { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963} } if a.IP.To16() != nil && a.IP.To4() == nil { return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963} } } return nil } // Close closes the server. func (s *Server) Close() error { return s.ln.Close() } func (s *Server) serve(authFunc, cmdFunc func(io.ReadWriter, []byte) error) { c, err := s.ln.Accept() if err != nil { return } defer c.Close() go s.serve(authFunc, cmdFunc) b := make([]byte, 512) n, err := c.Read(b) if err != nil { return } if err := authFunc(c, b[:n]); err != nil { return } n, err = c.Read(b) if err != nil { return } if err := cmdFunc(c, b[:n]); err != nil { return } } // NewServer returns a new server. // // The provided authFunc and cmdFunc must parse requests and return // appropriate replies to clients. func NewServer(authFunc, cmdFunc func(io.ReadWriter, []byte) error) (*Server, error) { var err error s := new(Server) s.ln, err = nettest.NewLocalListener("tcp") if err != nil { return nil, err } go s.serve(authFunc, cmdFunc) return s, nil } // NoAuthRequired handles a no-authentication-required signaling. func NoAuthRequired(rw io.ReadWriter, b []byte) error { req, err := ParseAuthRequest(b) if err != nil { return err } b, err = MarshalAuthReply(req.Version, socks.AuthMethodNotRequired) if err != nil { return err } n, err := rw.Write(b) if err != nil { return err } if n != len(b) { return errors.New("short write") } return nil } // NoProxyRequired handles a command signaling without constructing a // proxy connection to the final destination. func NoProxyRequired(rw io.ReadWriter, b []byte) error { req, err := ParseCmdRequest(b) if err != nil { return err } req.Addr.Port += 1 if req.Addr.Name != "" { req.Addr.Name = "boundaddr.doesnotexist" } else if req.Addr.IP.To4() != nil { req.Addr.IP = net.IPv4(127, 0, 0, 1) } else { req.Addr.IP = net.IPv6loopback } b, err = MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &req.Addr) if err != nil { return err } n, err := rw.Write(b) if err != nil { return err } if n != len(b) { return errors.New("short write") } return nil }