mirror of
https://github.com/Buriburizaem0n/nezha_domains.git
synced 2026-02-04 12:40:07 +00:00
improve: use stream reduce auth check time
This commit is contained in:
@@ -20,17 +20,26 @@ import (
|
||||
)
|
||||
|
||||
func ServeRPC() *grpc.Server {
|
||||
server := grpc.NewServer(grpc.UnaryInterceptor(getRealIp))
|
||||
server := grpc.NewServer(grpc.ChainUnaryInterceptor(getRealIp, waf))
|
||||
rpcService.NezhaHandlerSingleton = rpcService.NewNezhaHandler()
|
||||
proto.RegisterNezhaServiceServer(server, rpcService.NezhaHandlerSingleton)
|
||||
return server
|
||||
}
|
||||
|
||||
func waf(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if err := model.CheckIP(singleton.DB, ctx.Value(model.CtxKeyRealIP{}).(string)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
if singleton.Conf.RealIPHeader == "" {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
var ip string
|
||||
|
||||
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
@@ -40,18 +49,19 @@ func getRealIp(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, addrPort.Addr().String())
|
||||
return handler(ctx, req)
|
||||
ip = addrPort.Addr().String()
|
||||
} else {
|
||||
vals := metadata.ValueFromIncomingContext(ctx, singleton.Conf.RealIPHeader)
|
||||
if len(vals) == 0 {
|
||||
return nil, fmt.Errorf("real ip header not found")
|
||||
}
|
||||
var err error
|
||||
ip, err = utils.GetIPFromHeader(vals[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
vals := metadata.ValueFromIncomingContext(ctx, singleton.Conf.RealIPHeader)
|
||||
if len(vals) == 0 {
|
||||
return nil, fmt.Errorf("real ip header not found")
|
||||
}
|
||||
ip, err := utils.GetIPFromHeader(vals[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = context.WithValue(ctx, model.CtxKeyRealIP{}, ip)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user