Skip to content

Commit

Permalink
Connection not allowed by ruleset (#11)
Browse files Browse the repository at this point in the history
* RuleSet interface

* server rules

* with allow commands

* test not allowed command

* package comment

* bind command

* white list IPs

* RemoteAddressFromContext return net.Addr type

* test IP rules

* not allowed destination

* rename
  • Loading branch information
TuanKiri authored May 24, 2024
1 parent e64444b commit 9acaf24
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 21 deletions.
8 changes: 8 additions & 0 deletions address.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 26,14 @@ func (a address) String() string {
return fmt.Sprintf("%s:%s", a.Domain, a.Port)
}

func (a address) getDomainOrIP() string {
if a.IP != nil {
return a.IP.String()
}

return string(a.Domain)
}

type port []byte

func (p port) String() string {
Expand Down
13 changes: 8 additions & 5 deletions context.go
Original file line number Diff line number Diff line change
@@ -1,6 1,9 @@
package socks5

import "context"
import (
"context"
"net"
)

type ctxKey int

Expand All @@ -9,12 12,12 @@ const (
usernameKey
)

func contextWithRemoteAddress(ctx context.Context, address string) context.Context {
return context.WithValue(ctx, remoteAddressKey, address)
func contextWithRemoteAddress(ctx context.Context, addr net.Addr) context.Context {
return context.WithValue(ctx, remoteAddressKey, addr)
}

func RemoteAddressFromContext(ctx context.Context) (string, bool) {
value, ok := ctx.Value(remoteAddressKey).(string)
func RemoteAddressFromContext(ctx context.Context) (net.Addr, bool) {
value, ok := ctx.Value(remoteAddressKey).(net.Addr)
return value, ok
}

Expand Down
60 changes: 60 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 8,14 @@ import (
"time"
)

const (
Connect Command = iota 1
Bind
UDPAssociate
)

type Command int

type Option func(*options)

type options struct {
Expand All @@ -20,10 28,14 @@ type options struct {
getPasswordTimeout time.Duration
passwordAuthentication bool
staticCredentials map[string]string
allowCommands map[byte]struct{}
blockListHosts map[string]struct{}
allowIPs []net.IP
logger Logger
store Store
driver Driver
metrics Metrics
rules Rules
}

func (o options) authMethods() map[byte]struct{} {
Expand Down Expand Up @@ -80,6 92,18 @@ func optsWithDefaults(opts *options) *options {
opts.metrics = &nopMetrics{}
}

if opts.rules == nil {
if opts.allowCommands == nil {
opts.allowCommands = permitAllCommands()
}

opts.rules = &serverRules{
allowCommands: opts.allowCommands,
blockListHosts: opts.blockListHosts,
allowIPs: opts.allowIPs,
}
}

return opts
}

Expand Down Expand Up @@ -160,3 184,39 @@ func WithMetrics(val Metrics) Option {
o.metrics = val
}
}

func WithRules(val Rules) Option {
return func(o *options) {
o.rules = val
}
}

func WithAllowCommands(commands ...Command) Option {
allowCommands := map[byte]struct{}{}

for _, command := range commands {
allowCommands[byte(command)] = struct{}{}
}

return func(o *options) {
o.allowCommands = allowCommands
}
}

func WithWhiteListIPs(IPs ...net.IP) Option {
return func(o *options) {
o.allowIPs = IPs
}
}

func WithBlockListHosts(hosts ...string) Option {
blockListHosts := map[string]struct{}{}

for _, host := range hosts {
blockListHosts[host] = struct{}{}
}

return func(o *options) {
o.blockListHosts = blockListHosts
}
}
59 changes: 59 additions & 0 deletions rules.go
Original file line number Diff line number Diff line change
@@ -0,0 1,59 @@
package socks5

import (
"context"
"net"
)

type Rules interface {
IsAllowCommand(ctx context.Context, cmd byte) bool
IsAllowConnection(addr net.Addr) bool
IsAllowDestination(ctx context.Context, host string) bool
}

type serverRules struct {
allowCommands map[byte]struct{}
blockListHosts map[string]struct{}
allowIPs []net.IP
}

func (r *serverRules) IsAllowCommand(ctx context.Context, cmd byte) bool {
_, ok := r.allowCommands[cmd]
return ok
}

func (r *serverRules) IsAllowConnection(addr net.Addr) bool {
if r.allowIPs == nil {
return true
}

tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false
}

for _, allowIP := range r.allowIPs {
if allowIP.Equal(tcpAddr.IP) {
return true
}
}

return false
}

func (r *serverRules) IsAllowDestination(ctx context.Context, host string) bool {
if r.blockListHosts == nil {
return true
}

_, ok := r.blockListHosts[host]
return !ok
}

func permitAllCommands() map[byte]struct{} {
return map[byte]struct{}{
connect: {},
bind: {},
udpAssociate: {},
}
}
51 changes: 51 additions & 0 deletions rules_test.go
Original file line number Diff line number Diff line change
@@ -0,0 1,51 @@
package socks5

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestIPRules(t *testing.T) {
cases := map[string]struct {
allowIPs []net.IP
address net.Addr
allow bool
}{
"allow_connection": {
allowIPs: []net.IP{
net.ParseIP("192.168.0.100"),
},
address: &net.TCPAddr{
IP: net.ParseIP("192.168.0.100"),
},
allow: true,
},
"not_allow_connection": {
allowIPs: []net.IP{},
address: &net.TCPAddr{
IP: net.ParseIP("192.168.0.101"),
},
allow: false,
},
"incorrect_address_type": {
allowIPs: []net.IP{},
address: &net.UDPAddr{},
allow: false,
},
"without_slice_allow_IPs": {
allow: true,
},
}

for name, tc := range cases {
t.Run(name, func(t *testing.T) {
rules := &serverRules{
allowIPs: tc.allowIPs,
}

assert.Equal(t, rules.IsAllowConnection(tc.address), tc.allow)
})
}
}
10 changes: 9 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 22,7 @@ type Server struct {
store Store
driver Driver
metrics Metrics
rules Rules
active chan struct{}
done chan struct{}
closeListener func() error
Expand Down Expand Up @@ -50,6 51,7 @@ func New(opts ...Option) *Server {
store: options.store,
driver: options.driver,
metrics: options.metrics,
rules: options.rules,
active: make(chan struct{}),
done: make(chan struct{}),
}
Expand Down Expand Up @@ -104,9 106,15 @@ func (s *Server) Shutdown() error {
func (s *Server) serve(conn net.Conn) {
defer conn.Close()

remoteAddr := conn.RemoteAddr()

if !s.rules.IsAllowConnection(remoteAddr) {
return
}

s.setConnDeadline(conn)

ctx := contextWithRemoteAddress(context.Background(), conn.RemoteAddr().String())
ctx := contextWithRemoteAddress(context.Background(), remoteAddr)

s.handshake(ctx, newConnection(conn))
}
Expand Down
38 changes: 29 additions & 9 deletions socks5.go
Original file line number Diff line number Diff line change
@@ -1,3 1,4 @@
// Package socks5 a fully featured implementation of the SOCKS 5 protocol in golang.
package socks5

import (
Expand All @@ -23,15 24,17 @@ const (
addressTypeIPv6 byte = 0x04

connect byte = 0x01
bind byte = 0x02
udpAssociate byte = 0x03

connectionSuccessful byte = 0x00
generalSOCKSserverFailure byte = 0x01
networkUnreachable byte = 0x03
hostUnreachable byte = 0x04
connectionRefused byte = 0x05
commandNotSupported byte = 0x07
addressTypeNotSupported byte = 0x08
connectionSuccessful byte = 0x00
generalSOCKSserverFailure byte = 0x01
connectionNotAllowedByRuleSet byte = 0x02
networkUnreachable byte = 0x03
hostUnreachable byte = 0x04
connectionRefused byte = 0x05
commandNotSupported byte = 0x07
addressTypeNotSupported byte = 0x08
)

func (s *Server) handshake(ctx context.Context, conn *connection) {
Expand Down Expand Up @@ -69,7 72,6 @@ func (s *Server) handshake(ctx context.Context, conn *connection) {
s.usernamePasswordAuthenticate(ctx, conn)
default:
s.response(ctx, conn, version5, noAcceptableMethods)
return
}
}

Expand Down Expand Up @@ -140,14 142,28 @@ func (s *Server) acceptRequest(ctx context.Context, conn *connection) {
return
}

if !s.rules.IsAllowDestination(ctx, addr.getDomainOrIP()) {
s.replyRequest(ctx, conn, connectionNotAllowedByRuleSet, &addr)
return
}

switch command {
case connect:
if !s.rules.IsAllowCommand(ctx, connect) {
s.replyRequest(ctx, conn, connectionNotAllowedByRuleSet, &addr)
return
}

s.connect(ctx, conn, &addr)
case udpAssociate:
if !s.rules.IsAllowCommand(ctx, udpAssociate) {
s.replyRequest(ctx, conn, connectionNotAllowedByRuleSet, &addr)
return
}

s.udpAssociate(ctx, conn, &addr)
default:
s.replyRequest(ctx, conn, commandNotSupported, &addr)
return
}
}

Expand Down Expand Up @@ -303,6 319,10 @@ func (s *Server) servePacketConn(ctx context.Context, conn *packetConn) {
return
}

if !s.rules.IsAllowDestination(ctx, addr.getDomainOrIP()) {
return
}

res, err := s.sendPacket(ctx, conn.bytes(), &addr)
if err != nil {
s.logger.Error(ctx, "failed to send packet: " err.Error())
Expand Down
Loading

0 comments on commit 9acaf24

Please sign in to comment.