Merge pull request #19 from InazumaV/dev_new

Dev new
This commit is contained in:
Yuzuki
2023-09-25 16:24:07 +08:00
committed by GitHub

View File

@@ -164,12 +164,12 @@ func (d *DefaultDispatcher) getLink(ctx context.Context, network net.Network) (*
if user != nil && len(user.Email) > 0 { if user != nil && len(user.Email) > 0 {
limit, err = limiter.GetLimiter(sessionInbound.Tag) limit, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil { if err != nil {
newError("Get limit info error: ", err).AtError().WriteToLog() newError("get limiter ", sessionInbound.Tag, " error: ", err).AtError().WriteToLog()
common.Close(outboundLink.Writer) common.Close(outboundLink.Writer)
common.Close(inboundLink.Writer) common.Close(inboundLink.Writer)
common.Interrupt(outboundLink.Reader) common.Interrupt(outboundLink.Reader)
common.Interrupt(inboundLink.Reader) common.Interrupt(inboundLink.Reader)
return nil, nil, nil, newError("Get limit info error: ", err) return nil, nil, nil, newError("get limiter ", sessionInbound.Tag, " error: ", err)
} }
// Speed Limit and Device Limit // Speed Limit and Device Limit
w, reject := limit.CheckLimit(user.Email, w, reject := limit.CheckLimit(user.Email,
@@ -414,49 +414,48 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
} }
sessionInbound := session.InboundFromContext(ctx) sessionInbound := session.InboundFromContext(ctx)
if l != nil { if sessionInbound.User != nil {
// del connect count if l != nil {
if sessionInbound.User != nil { // del connect count
if destination.Network == net.Network_TCP { if destination.Network == net.Network_TCP {
defer func() { defer func() {
l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String()) l.ConnLimiter.DelConnCount(sessionInbound.User.Email, sessionInbound.Source.Address.IP().String())
}() }()
} }
} else {
var err error
l, err = limiter.GetLimiter(sessionInbound.Tag)
if err != nil {
newError("get limiter ", sessionInbound.Tag, " error: ", err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
}
} }
} else { if l != nil {
var err error var destStr string
l, err = limiter.GetLimiter(sessionInbound.Tag) if destination.Address.Family().IsDomain() {
if err != nil { destStr = destination.Address.Domain()
newError("get limiter error: ", err).AtError().WriteToLog(session.ExportIDToError(ctx)) } else {
common.Close(link.Writer) destStr = destination.Address.IP().String()
common.Interrupt(link.Reader) }
return if l.CheckDomainRule(destStr) {
} newError(fmt.Sprintf(
} "User %s access domain %s reject by rule",
var destStr string sessionInbound.User.Email,
if destination.Address.Family().IsDomain() { destStr)).AtWarning().WriteToLog(session.ExportIDToError(ctx))
destStr = destination.Address.Domain() common.Close(link.Writer)
} else { common.Interrupt(link.Reader)
destStr = destination.Address.IP().String() return
} }
if l.CheckDomainRule(destStr) { if len(protocol) != 0 {
newError(fmt.Sprintf( if l.CheckProtocolRule(protocol) {
"User %s access domain %s reject by rule", newError(fmt.Sprintf(
sessionInbound.User.Email, "User %s access protocol %s reject by rule",
destStr)).AtWarning().WriteToLog(session.ExportIDToError(ctx)) sessionInbound.User.Email,
common.Close(link.Writer) protocol)).AtWarning().WriteToLog(session.ExportIDToError(ctx))
common.Interrupt(link.Reader) common.Close(link.Writer)
return common.Interrupt(link.Reader)
} return
if len(protocol) != 0 { }
if l.CheckProtocolRule(protocol) { }
newError(fmt.Sprintf(
"User %s access protocol %s reject by rule",
sessionInbound.User.Email,
protocol)).AtWarning().WriteToLog(session.ExportIDToError(ctx))
common.Close(link.Writer)
common.Interrupt(link.Reader)
return
} }
} }