support protocol rule

This commit is contained in:
Yuzuki616
2023-07-20 21:14:18 +08:00
parent 7cc57fe2ba
commit adf98fbc81
6 changed files with 92 additions and 48 deletions

View File

@@ -337,7 +337,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, l)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil {
content.Protocol = result.Protocol()
}
@@ -380,7 +386,13 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
reader: outbound.Reader.(*pipe.Reader),
}
outbound.Reader = cReader
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network)
result, err := sniffer(ctx, cReader, sniffingRequest.MetadataOnly, destination.Network, nil)
if _, ok := err.(limitedError); ok {
newError(err).AtInfo().WriteToLog()
common.Close(outbound.Writer)
common.Interrupt(outbound.Reader)
return
}
if err == nil {
content.Protocol = result.Protocol()
}
@@ -400,18 +412,50 @@ func (d *DefaultDispatcher) DispatchLink(ctx context.Context, destination net.De
return nil
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network) (SniffResult, error) {
type limitedError string
func (l limitedError) Error() string {
return string(l)
}
func sniffer(ctx context.Context, cReader *cachedReader, metadataOnly bool, network net.Network, l *limiter.Limiter) (result SniffResult, err error) {
payload := buf.New()
defer payload.Release()
defer func() {
if err != nil {
return
}
// Check if domain and protocol hit the rule
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if l == nil {
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
return
}
}
if l.CheckDomainRule(result.Domain()) {
err = limitedError(fmt.Sprintf(
"User %s access domain %s reject by rule",
sessionInbound.User.Email,
result.Domain()))
}
if l.CheckProtocolRule(result.Protocol()) {
err = limitedError(fmt.Sprintf(
"User %s access protocol %s reject by rule",
sessionInbound.User.Email,
result.Protocol()))
}
}
}()
sniffer := NewSniffer(ctx)
metaresult, metadataErr := sniffer.SniffMetadata(ctx)
if metadataOnly {
return metaresult, metadataErr
}
contentResult, contentErr := func() (SniffResult, error) {
totalAttempt := 0
for {
@@ -460,32 +504,17 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
}
}
}
var handler outbound.Handler
// Check if domain and protocol hit the rule
sessionInbound := session.InboundFromContext(ctx)
// Whether the inbound connection contains a user
if sessionInbound.User != nil {
if l == nil {
var err error
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
newError("Get limiter error: ", err).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
// del connect count
if l != nil {
sessionInbound := session.InboundFromContext(ctx)
if sessionInbound.User != nil {
if destination.Network == net.Network_TCP {
defer func() {
l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
}()
}
} else if destination.Network == net.Network_TCP {
defer func() {
l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
}()
}
if l.CheckDomainRule(destination.Address.String()) {
newError(fmt.Sprintf("User %s access %s reject by rule", sessionInbound.User.Email, destination.String())).AtError().WriteToLog()
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
}
}