Compare commits
1 Commits
dependabot
...
socks5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0b1c4d27f |
86
internal/socks5/constants.go
Normal file
86
internal/socks5/constants.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
||||||
|
type authMethod byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
authNotRequired authMethod = 0
|
||||||
|
authGssapi authMethod = 1
|
||||||
|
authUsernamePassword authMethod = 2
|
||||||
|
authNotAcceptable authMethod = 255
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a authMethod) String() string {
|
||||||
|
switch a {
|
||||||
|
case authNotRequired:
|
||||||
|
return "no authentication required"
|
||||||
|
case authGssapi:
|
||||||
|
return "GSSAPI"
|
||||||
|
case authUsernamePassword:
|
||||||
|
return "username/password"
|
||||||
|
case authNotAcceptable:
|
||||||
|
return "no acceptable methods"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown method (%d)", a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subnegotiation version
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
const (
|
||||||
|
authUsernamePasswordSubNegotiation1 byte = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// SOCKS versions.
|
||||||
|
const (
|
||||||
|
socks5Version byte = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
|
||||||
|
type cmdType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
connect cmdType = 1
|
||||||
|
bind cmdType = 2
|
||||||
|
udpAssociate cmdType = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c cmdType) String() string {
|
||||||
|
switch c {
|
||||||
|
case connect:
|
||||||
|
return "connect"
|
||||||
|
case bind:
|
||||||
|
return "bind"
|
||||||
|
case udpAssociate:
|
||||||
|
return "UDP associate"
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("unknown command (%d)", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4 and
|
||||||
|
// https://datatracker.ietf.org/doc/html/rfc1928#section-5
|
||||||
|
type addrType byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
ipv4 addrType = 1
|
||||||
|
domainName addrType = 3
|
||||||
|
ipv6 addrType = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
type replyCode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
succeeded replyCode = iota
|
||||||
|
generalServerFailure
|
||||||
|
connectionNotAllowedByRuleset
|
||||||
|
networkUnreachable
|
||||||
|
hostUnreachable
|
||||||
|
connectionRefused
|
||||||
|
ttlExpired
|
||||||
|
commandNotSupported
|
||||||
|
addressTypeNotSupported
|
||||||
|
)
|
||||||
6
internal/socks5/interfaces.go
Normal file
6
internal/socks5/interfaces.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
type Logger interface {
|
||||||
|
Infof(format string, a ...interface{})
|
||||||
|
Warnf(format string, a ...interface{})
|
||||||
|
}
|
||||||
103
internal/socks5/response.go
Normal file
103
internal/socks5/response.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
func (c *socksConn) encodeFailedResponse(writer io.Writer, socksVersion byte, reply replyCode) {
|
||||||
|
_, err := writer.Write([]byte{
|
||||||
|
socksVersion,
|
||||||
|
byte(reply),
|
||||||
|
0, // RSV byte
|
||||||
|
// TODO do we need to set the bind addr type to 0??
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Warnf("failed writing failed response: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-6
|
||||||
|
func (c *socksConn) encodeSuccessResponse(writer io.Writer, socksVersion byte,
|
||||||
|
reply replyCode, bindAddrType addrType, bindAddress string,
|
||||||
|
bindPort uint16) (err error) {
|
||||||
|
bindData, err := encodeBindData(bindAddrType, bindAddress, bindPort)
|
||||||
|
if err != nil { // TODO encode with below block if this changes
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
const initialPacketLength = 3
|
||||||
|
capacity := initialPacketLength + len(bindData)
|
||||||
|
packet := make([]byte, initialPacketLength, capacity)
|
||||||
|
packet[0] = socksVersion
|
||||||
|
packet[1] = byte(reply)
|
||||||
|
packet[2] = 0 // RSV byte
|
||||||
|
packet = append(packet, bindData...)
|
||||||
|
|
||||||
|
_, err = writer.Write(packet)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Warnf("failed writing success response: %s", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrIPVersionUnexpected = errors.New("ip version is unexpected")
|
||||||
|
ErrDomainNameTooLong = errors.New("domain name is too long")
|
||||||
|
)
|
||||||
|
|
||||||
|
func encodeBindData(addrType addrType, address string, port uint16) (
|
||||||
|
data []byte, err error) {
|
||||||
|
capacity := bindDataLength(addrType, address)
|
||||||
|
data = make([]byte, 0, capacity)
|
||||||
|
|
||||||
|
data = append(data, byte(addrType))
|
||||||
|
switch addrType {
|
||||||
|
case ipv4, ipv6:
|
||||||
|
ip, err := netip.ParseAddr(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing IP address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case addrType == ipv4 && !ip.Is4():
|
||||||
|
return nil, fmt.Errorf("%w: expected IPv4 for %s", ErrIPVersionUnexpected, ip)
|
||||||
|
case addrType == ipv6 && !ip.Is6():
|
||||||
|
return nil, fmt.Errorf("%w: expected IPv6 for %s", ErrIPVersionUnexpected, ip)
|
||||||
|
}
|
||||||
|
data = append(data, ip.AsSlice()...)
|
||||||
|
case domainName:
|
||||||
|
const maxDomainNameLength = 255
|
||||||
|
if len(address) > maxDomainNameLength {
|
||||||
|
return nil, fmt.Errorf("%w: %s", ErrDomainNameTooLong, address)
|
||||||
|
}
|
||||||
|
data = append(data, byte(len(address)))
|
||||||
|
data = append(data, []byte(address)...)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unsupported address type %d", addrType))
|
||||||
|
}
|
||||||
|
data = binary.BigEndian.AppendUint16(data, port)
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindDataLength(addrType addrType, address string) (maxLength int) {
|
||||||
|
maxLength++ // address type
|
||||||
|
switch addrType {
|
||||||
|
case ipv4:
|
||||||
|
maxLength += net.IPv4len
|
||||||
|
case domainName:
|
||||||
|
maxLength++ // domain name length
|
||||||
|
maxLength += len([]byte(address))
|
||||||
|
case ipv6:
|
||||||
|
maxLength += net.IPv6len
|
||||||
|
default:
|
||||||
|
panic("unsupported address type: " + fmt.Sprint(addrType))
|
||||||
|
}
|
||||||
|
maxLength += 2 // port
|
||||||
|
return maxLength
|
||||||
|
}
|
||||||
105
internal/socks5/server.go
Normal file
105
internal/socks5/server.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
address string
|
||||||
|
logger Logger
|
||||||
|
|
||||||
|
// internal fields
|
||||||
|
listener net.Listener
|
||||||
|
listening atomic.Bool
|
||||||
|
socksConnCtx context.Context //nolint:containedctx
|
||||||
|
socksConnCancel context.CancelFunc
|
||||||
|
done <-chan struct{}
|
||||||
|
stopping atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(settings Settings) *Server {
|
||||||
|
return &Server{
|
||||||
|
username: settings.Username,
|
||||||
|
password: settings.Password,
|
||||||
|
address: settings.Address,
|
||||||
|
logger: settings.Logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Start(_ context.Context) (runErr <-chan error, err error) {
|
||||||
|
s.listener, err = net.Listen("tcp", s.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listening on %s: %w", s.address, err)
|
||||||
|
}
|
||||||
|
s.listening.Store(true)
|
||||||
|
|
||||||
|
s.socksConnCtx, s.socksConnCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ready := make(chan struct{})
|
||||||
|
runErrCh := make(chan error)
|
||||||
|
runErr = runErrCh
|
||||||
|
done := make(chan struct{})
|
||||||
|
s.done = done
|
||||||
|
go s.runServer(ready, runErrCh, done)
|
||||||
|
<-ready
|
||||||
|
return runErr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) runServer(ready chan<- struct{},
|
||||||
|
runErrCh chan<- error, done chan<- struct{}) {
|
||||||
|
close(ready)
|
||||||
|
defer close(done)
|
||||||
|
wg := new(sync.WaitGroup)
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
dialer := &net.Dialer{}
|
||||||
|
for {
|
||||||
|
connection, err := s.listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !s.stopping.Load() {
|
||||||
|
_ = s.Stop()
|
||||||
|
runErrCh <- fmt.Errorf("accepting connection: %w", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(ctx context.Context, connection net.Conn,
|
||||||
|
dialer *net.Dialer, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
socksConn := &socksConn{
|
||||||
|
dialer: dialer,
|
||||||
|
username: s.username,
|
||||||
|
password: s.password,
|
||||||
|
clientConn: connection,
|
||||||
|
logger: s.logger,
|
||||||
|
}
|
||||||
|
err := socksConn.run(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Infof("running socks connection: %s", err)
|
||||||
|
}
|
||||||
|
}(s.socksConnCtx, connection, dialer, wg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Stop() (err error) {
|
||||||
|
s.stopping.Store(true)
|
||||||
|
s.listening.Store(false)
|
||||||
|
err = s.listener.Close()
|
||||||
|
s.socksConnCancel() // stop ongoing socks connections
|
||||||
|
<-s.done // wait for run goroutine to finish
|
||||||
|
s.stopping.Store(false)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) listeningAddress() net.Addr {
|
||||||
|
if s.listening.Load() {
|
||||||
|
return s.listener.Addr()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
8
internal/socks5/settings.go
Normal file
8
internal/socks5/settings.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
type Settings struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
Address string
|
||||||
|
Logger Logger
|
||||||
|
}
|
||||||
283
internal/socks5/socks5.go
Normal file
283
internal/socks5/socks5.go
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type socksConn struct {
|
||||||
|
// Injected fields
|
||||||
|
dialer *net.Dialer
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
clientConn net.Conn
|
||||||
|
logger Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socksConn) closeClientConn(ctxErr error) {
|
||||||
|
err := c.clientConn.Close()
|
||||||
|
if err != nil && ctxErr == nil {
|
||||||
|
c.logger.Warnf("closing client connection: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *socksConn) run(ctx context.Context) error {
|
||||||
|
authMethod := authNotRequired
|
||||||
|
if c.username != "" || c.password != "" {
|
||||||
|
authMethod = authUsernamePassword
|
||||||
|
}
|
||||||
|
|
||||||
|
err := verifyFirstNegotiation(c.clientConn, authMethod)
|
||||||
|
if err != nil {
|
||||||
|
replyMethod := authMethod
|
||||||
|
if errors.Is(err, ErrNoMethodIdentifiers) || errors.Is(err, ErrNoValidMethodIdentifier) {
|
||||||
|
replyMethod = authNotAcceptable
|
||||||
|
}
|
||||||
|
_, writeErr := c.clientConn.Write([]byte{socks5Version, byte(replyMethod)})
|
||||||
|
if writeErr != nil {
|
||||||
|
c.logger.Warnf("failed writing first negotiation reply: %s", writeErr)
|
||||||
|
}
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("verifying first negotiation: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.clientConn.Write([]byte{socks5Version, byte(authMethod)})
|
||||||
|
if err != nil {
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("writing first negotiation reply: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch authMethod {
|
||||||
|
case authNotRequired, authNotAcceptable:
|
||||||
|
case authGssapi:
|
||||||
|
panic("not implemented")
|
||||||
|
// TODO
|
||||||
|
case authUsernamePassword:
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
err = usernamePasswordSubnegotiate(c.clientConn, c.username, c.password)
|
||||||
|
if err != nil {
|
||||||
|
// If the server returns a `failure' (STATUS value other than X'00') status,
|
||||||
|
// it MUST close the connection.
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
return fmt.Errorf("subnegotiating username and password: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unimplemented auth method %d", authMethod))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.handleRequest(ctx)
|
||||||
|
c.closeClientConn(ctx.Err())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("handling request: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrCommandNotSupported = errors.New("command not supported")
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *socksConn) handleRequest(ctx context.Context) error {
|
||||||
|
const socksVersion = socks5Version
|
||||||
|
request, err := decodeRequest(c.clientConn, socksVersion)
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if request.command != connect {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, commandNotSupported)
|
||||||
|
return fmt.Errorf("%w: %s", ErrCommandNotSupported, request.command)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinationAddress := net.JoinHostPort(request.destination, fmt.Sprint(request.port))
|
||||||
|
destinationConn, err := c.dialer.DialContext(ctx, "tcp", destinationAddress)
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer destinationConn.Close()
|
||||||
|
|
||||||
|
destinationServerAddress := destinationConn.LocalAddr().String()
|
||||||
|
destinationAddr, destinationPortStr, err := net.SplitHostPort(destinationServerAddress)
|
||||||
|
fmt.Println("===", destinationServerAddress)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
destinationPort, err := strconv.Atoi(destinationPortStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var bindAddrType addrType
|
||||||
|
if ip := net.ParseIP(destinationAddr); ip != nil {
|
||||||
|
if ip.To4() != nil {
|
||||||
|
bindAddrType = ipv4
|
||||||
|
} else {
|
||||||
|
bindAddrType = ipv6
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bindAddrType = domainName
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.encodeSuccessResponse(c.clientConn, socksVersion, succeeded, bindAddrType,
|
||||||
|
destinationAddr, uint16(destinationPort))
|
||||||
|
if err != nil {
|
||||||
|
c.encodeFailedResponse(c.clientConn, socksVersion, generalServerFailure)
|
||||||
|
return fmt.Errorf("writing successful %s response: %w", request.command, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
errc := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(c.clientConn, destinationConn)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("from backend to client: %w", err)
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(destinationConn, c.clientConn)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("from client to backend: %w", err)
|
||||||
|
}
|
||||||
|
errc <- err
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case err := <-errc:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = destinationConn.Close()
|
||||||
|
_ = c.clientConn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrVersionNotSupported = errors.New("version not supported")
|
||||||
|
ErrNoMethodIdentifiers = errors.New("no method identifiers")
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-3
|
||||||
|
func verifyFirstNegotiation(reader io.Reader, requiredMethod authMethod) error {
|
||||||
|
const headerLength = 2 // version + nMethods bytes
|
||||||
|
header := make([]byte, headerLength)
|
||||||
|
_, err := io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header[0] != socks5Version {
|
||||||
|
return fmt.Errorf("%w: %d", ErrVersionNotSupported, header[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
nMethods := header[1]
|
||||||
|
if nMethods == 0 {
|
||||||
|
return fmt.Errorf("%w", ErrNoMethodIdentifiers)
|
||||||
|
}
|
||||||
|
|
||||||
|
methodIdentifiers := make([]byte, nMethods)
|
||||||
|
_, err = io.ReadFull(reader, methodIdentifiers)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("reading method identifiers: %w", err)
|
||||||
|
}
|
||||||
|
for _, methodIdentifier := range methodIdentifiers {
|
||||||
|
if methodIdentifier == byte(requiredMethod) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return makeNoAcceptableMethodError(requiredMethod, methodIdentifiers)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoValidMethodIdentifier = errors.New("no valid method identifier")
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeNoAcceptableMethodError(requiredAuthMethod authMethod, methodIdentifiers []byte) error {
|
||||||
|
methodNames := make([]string, len(methodIdentifiers))
|
||||||
|
for i, methodIdentifier := range methodIdentifiers {
|
||||||
|
methodNames[i] = fmt.Sprintf("%q", authMethod(methodIdentifier))
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("%w: none of %s matches %s",
|
||||||
|
ErrNoValidMethodIdentifier, strings.Join(methodNames, ", "),
|
||||||
|
requiredAuthMethod)
|
||||||
|
}
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1928#section-4
|
||||||
|
type request struct {
|
||||||
|
command cmdType
|
||||||
|
destination string
|
||||||
|
port uint16
|
||||||
|
addressType addrType
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRequestSocksVersionMismatch = errors.New("request SOCKS version mismatch")
|
||||||
|
ErrAddressTypeNotSupported = errors.New("address type not supported")
|
||||||
|
)
|
||||||
|
|
||||||
|
func decodeRequest(reader io.Reader, expectedVersion byte) (req request, err error) {
|
||||||
|
const headerLength = 4
|
||||||
|
header := [headerLength]byte{}
|
||||||
|
_, err = io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
version := header[0]
|
||||||
|
if header[0] != expectedVersion {
|
||||||
|
return request{}, fmt.Errorf("%w: expected %d and got %d",
|
||||||
|
ErrRequestSocksVersionMismatch, expectedVersion, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.command = cmdType(header[1])
|
||||||
|
// header[2] is RSV byte
|
||||||
|
req.addressType = addrType(header[3])
|
||||||
|
|
||||||
|
switch req.addressType {
|
||||||
|
case ipv4:
|
||||||
|
var ip [4]byte
|
||||||
|
_, err = io.ReadFull(reader, ip[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading IPv4 address: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = netip.AddrFrom4(ip).String()
|
||||||
|
case ipv6:
|
||||||
|
var ip [16]byte
|
||||||
|
_, err = io.ReadFull(reader, ip[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading IPv6 address: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = netip.AddrFrom16(ip).String()
|
||||||
|
case domainName:
|
||||||
|
var header [1]byte
|
||||||
|
_, err = io.ReadFull(reader, header[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading domain name header: %w", err)
|
||||||
|
}
|
||||||
|
domainName := make([]byte, header[0])
|
||||||
|
_, err = io.ReadFull(reader, domainName)
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading domain name bytes: %w", err)
|
||||||
|
}
|
||||||
|
req.destination = string(domainName)
|
||||||
|
default:
|
||||||
|
return request{}, fmt.Errorf("%w: %d", ErrAddressTypeNotSupported, req.addressType)
|
||||||
|
}
|
||||||
|
|
||||||
|
var portBytes [2]byte
|
||||||
|
_, err = io.ReadFull(reader, portBytes[:])
|
||||||
|
if err != nil {
|
||||||
|
return request{}, fmt.Errorf("reading port: %w", err)
|
||||||
|
}
|
||||||
|
req.port = binary.BigEndian.Uint16(portBytes[:])
|
||||||
|
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
175
internal/socks5/socks5_test.go
Normal file
175
internal/socks5/socks5_test.go
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/qdm12/log"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test(t *testing.T) {
|
||||||
|
server := New(Settings{
|
||||||
|
Username: "test",
|
||||||
|
Password: "test",
|
||||||
|
Address: ":8000",
|
||||||
|
Logger: log.New(),
|
||||||
|
})
|
||||||
|
|
||||||
|
runErr, startErr := server.Start(context.Background())
|
||||||
|
require.NoError(t, startErr)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-runErr:
|
||||||
|
require.NoError(t, err)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("SlEEPING")
|
||||||
|
time.Sleep(15 * time.Second)
|
||||||
|
t.Log("Done sleeping")
|
||||||
|
|
||||||
|
err := server.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func backendServer(listener net.Listener) {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
conn.Write([]byte("Test"))
|
||||||
|
conn.Close()
|
||||||
|
listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRead(t *testing.T) {
|
||||||
|
// backend server which we'll use SOCKS5 to connect to
|
||||||
|
listener, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
backendServerPort := listener.Addr().(*net.TCPAddr).Port
|
||||||
|
go backendServer(listener)
|
||||||
|
|
||||||
|
// SOCKS5 server
|
||||||
|
server := New(Settings{
|
||||||
|
Address: ":0",
|
||||||
|
})
|
||||||
|
_, err = server.Start(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = server.Stop()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
socks5Port := server.listeningAddress().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("localhost:%d", socks5Port)
|
||||||
|
socksDialer, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr = fmt.Sprintf("localhost:%d", backendServerPort)
|
||||||
|
conn, err := socksDialer.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
_, err = io.ReadFull(conn, buf)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(buf) != "Test" {
|
||||||
|
t.Fatalf("got: %q want: Test", buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadPassword(t *testing.T) {
|
||||||
|
// backend server which we'll use SOCKS5 to connect to
|
||||||
|
ln, err := net.Listen("tcp", ":0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
backendServerPort := ln.Addr().(*net.TCPAddr).Port
|
||||||
|
go backendServer(ln)
|
||||||
|
|
||||||
|
auth := &proxy.Auth{User: "foo", Password: "bar"}
|
||||||
|
|
||||||
|
server := Server{
|
||||||
|
logger: log.New(),
|
||||||
|
username: auth.User,
|
||||||
|
password: auth.Password,
|
||||||
|
address: ":0",
|
||||||
|
}
|
||||||
|
_, err = server.Start(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err = server.Stop()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := fmt.Sprintf("localhost:%d", server.listeningAddress().(*net.TCPAddr).Port)
|
||||||
|
|
||||||
|
if d, err := proxy.SOCKS5("tcp", addr, nil, proxy.Direct); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else {
|
||||||
|
if _, err := d.Dial("tcp", addr); err == nil {
|
||||||
|
t.Fatal("expected no-auth dial error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
badPwd := &proxy.Auth{User: "foo", Password: "not right"}
|
||||||
|
if d, err := proxy.SOCKS5("tcp", addr, badPwd, proxy.Direct); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else {
|
||||||
|
if _, err := d.Dial("tcp", addr); err == nil {
|
||||||
|
t.Fatal("expected bad password dial error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
badUsr := &proxy.Auth{User: "not right", Password: "bar"}
|
||||||
|
if d, err := proxy.SOCKS5("tcp", addr, badUsr, proxy.Direct); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else {
|
||||||
|
if _, err := d.Dial("tcp", addr); err == nil {
|
||||||
|
t.Fatal("expected bad username dial error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
addr = fmt.Sprintf("localhost:%d", backendServerPort)
|
||||||
|
conn, err := socksDialer.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 4)
|
||||||
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if string(buf) != "Test" {
|
||||||
|
t.Fatalf("got: %q want: Test", buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
69
internal/socks5/usernamepassword.go
Normal file
69
internal/socks5/usernamepassword.go
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
package socks5
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSubnegotiationVersionNotSupported = errors.New("subnegotiation version not supported")
|
||||||
|
ErrUsernameNotValid = errors.New("username not valid")
|
||||||
|
ErrPasswordNotValid = errors.New("password not valid")
|
||||||
|
)
|
||||||
|
|
||||||
|
// See https://datatracker.ietf.org/doc/html/rfc1929#section-2
|
||||||
|
func usernamePasswordSubnegotiate(conn io.ReadWriter, username, password string) (err error) {
|
||||||
|
status := byte(1)
|
||||||
|
const defaultVersion = byte(1)
|
||||||
|
|
||||||
|
const headerLength = 2
|
||||||
|
var header [headerLength]byte
|
||||||
|
_, err = io.ReadFull(conn, header[:])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{defaultVersion, status})
|
||||||
|
return fmt.Errorf("reading header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header[0] != authUsernamePasswordSubNegotiation1 {
|
||||||
|
_, _ = conn.Write([]byte{defaultVersion, status})
|
||||||
|
return fmt.Errorf("%w: %d", ErrSubnegotiationVersionNotSupported, header[0])
|
||||||
|
}
|
||||||
|
version := header[0]
|
||||||
|
|
||||||
|
usernameBytes := make([]byte, header[1])
|
||||||
|
_, err = io.ReadFull(conn, usernameBytes)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading username bytes: %w", err)
|
||||||
|
} else if username != string(usernameBytes) {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("%w: %s", ErrUsernameNotValid, string(usernameBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
const passwordHeaderLength = 1
|
||||||
|
passwordHeader := make([]byte, passwordHeaderLength)
|
||||||
|
_, err = io.ReadFull(conn, passwordHeader[:])
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading password length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
passwordBytes := make([]byte, passwordHeader[0])
|
||||||
|
_, err = io.ReadFull(conn, passwordBytes)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("reading password bytes: %w", err)
|
||||||
|
} else if password != string(passwordBytes) {
|
||||||
|
_, _ = conn.Write([]byte{version, status})
|
||||||
|
return fmt.Errorf("%w: %s", ErrPasswordNotValid, string(passwordBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
status = 0
|
||||||
|
_, err = conn.Write([]byte{version, status})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("writing success status: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user