mirror of
https://github.com/Buriburizaem0n/nezha_domains.git
synced 2026-03-22 02:51:50 +00:00
refactor agent auth & server api
This commit is contained in:
@@ -2,20 +2,23 @@ package rpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/naiba/nezha/model"
|
||||
"github.com/naiba/nezha/service/singleton"
|
||||
)
|
||||
|
||||
type authHandler struct {
|
||||
ClientSecret string
|
||||
ClientUUID string
|
||||
}
|
||||
|
||||
func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
|
||||
return map[string]string{"client_secret": a.ClientSecret}, nil
|
||||
return map[string]string{"client_secret": a.ClientSecret, "client_uuid": a.ClientUUID}, nil
|
||||
}
|
||||
|
||||
func (a *authHandler) RequireTransportSecurity() bool {
|
||||
@@ -33,15 +36,29 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
|
||||
clientSecret = value[0]
|
||||
}
|
||||
|
||||
if clientSecret != singleton.Conf.AgentSecretKey {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
|
||||
}
|
||||
|
||||
var clientUUID string
|
||||
if value, ok := md["client_uuid"]; ok {
|
||||
clientUUID = value[0]
|
||||
}
|
||||
|
||||
singleton.ServerLock.RLock()
|
||||
defer singleton.ServerLock.RUnlock()
|
||||
clientID, hasID := singleton.SecretToID[clientSecret]
|
||||
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
|
||||
if !hasID {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
|
||||
}
|
||||
_, hasServer := singleton.ServerList[clientID]
|
||||
if !hasServer {
|
||||
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
|
||||
s := model.Server{UUID: clientUUID}
|
||||
if err := singleton.DB.Create(&s).Error; err != nil {
|
||||
return 0, status.Errorf(codes.Unauthenticated, err.Error())
|
||||
}
|
||||
s.Host = &model.Host{}
|
||||
s.State = &model.HostState{}
|
||||
s.TaskCloseLock = new(sync.Mutex)
|
||||
singleton.ServerList[s.ID] = &s
|
||||
singleton.ServerUUIDToID[clientUUID] = s.ID
|
||||
}
|
||||
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
@@ -103,9 +103,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
|
||||
}
|
||||
ipv4, ipv6, validIP := utils.SplitIPAddr(server.Host.IP)
|
||||
info := CommonServerInfo{
|
||||
ID: server.ID,
|
||||
Name: server.Name,
|
||||
Tag: server.Tag,
|
||||
ID: server.ID,
|
||||
Name: server.Name,
|
||||
// Tag: server.Tag,
|
||||
LastActive: server.LastActive.Unix(),
|
||||
IPV4: ipv4,
|
||||
IPV6: ipv6,
|
||||
@@ -125,9 +125,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
|
||||
}
|
||||
|
||||
// GetStatusByTag 获取传入分组的所有服务器状态信息
|
||||
func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
|
||||
return s.GetStatusByIDList(ServerTagToIDList[tag])
|
||||
}
|
||||
// func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
|
||||
// return s.GetStatusByIDList(ServerTagToIDList[tag])
|
||||
// }
|
||||
|
||||
// GetAllStatus 获取所有服务器状态信息
|
||||
func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
|
||||
@@ -143,9 +143,9 @@ func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
|
||||
}
|
||||
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
|
||||
info := CommonServerInfo{
|
||||
ID: v.ID,
|
||||
Name: v.Name,
|
||||
Tag: v.Tag,
|
||||
ID: v.ID,
|
||||
Name: v.Name,
|
||||
// Tag: v.Tag,
|
||||
LastActive: v.LastActive.Unix(),
|
||||
IPV4: ipv4,
|
||||
IPV6: ipv6,
|
||||
@@ -173,23 +173,23 @@ func (s *ServerAPIService) GetListByTag(tag string) *ServerInfoResponse {
|
||||
|
||||
ServerLock.RLock()
|
||||
defer ServerLock.RUnlock()
|
||||
for _, v := range ServerTagToIDList[tag] {
|
||||
host := ServerList[v].Host
|
||||
if host == nil {
|
||||
continue
|
||||
}
|
||||
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
|
||||
info := &CommonServerInfo{
|
||||
ID: v,
|
||||
Name: ServerList[v].Name,
|
||||
Tag: ServerList[v].Tag,
|
||||
LastActive: ServerList[v].LastActive.Unix(),
|
||||
IPV4: ipv4,
|
||||
IPV6: ipv6,
|
||||
ValidIP: validIP,
|
||||
}
|
||||
res.Result = append(res.Result, info)
|
||||
}
|
||||
// for _, v := range ServerTagToIDList[tag] {
|
||||
// host := ServerList[v].Host
|
||||
// if host == nil {
|
||||
// continue
|
||||
// }
|
||||
// ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
|
||||
// info := &CommonServerInfo{
|
||||
// ID: v,
|
||||
// Name: ServerList[v].Name,
|
||||
// Tag: ServerList[v].Tag,
|
||||
// LastActive: ServerList[v].LastActive.Unix(),
|
||||
// IPV4: ipv4,
|
||||
// IPV6: ipv6,
|
||||
// ValidIP: validIP,
|
||||
// }
|
||||
// res.Result = append(res.Result, info)
|
||||
// }
|
||||
res.CommonResponse = CommonResponse{
|
||||
Code: 0,
|
||||
Message: "success",
|
||||
@@ -211,9 +211,9 @@ func (s *ServerAPIService) GetAllList() *ServerInfoResponse {
|
||||
}
|
||||
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
|
||||
info := &CommonServerInfo{
|
||||
ID: v.ID,
|
||||
Name: v.Name,
|
||||
Tag: v.Tag,
|
||||
ID: v.ID,
|
||||
Name: v.Name,
|
||||
// Tag: v.Tag,
|
||||
LastActive: v.LastActive.Unix(),
|
||||
IPV4: ipv4,
|
||||
IPV6: ipv6,
|
||||
|
||||
@@ -8,21 +8,18 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
|
||||
SecretToID map[string]uint64 // [ServerSecret] -> ServerID
|
||||
ServerTagToIDList map[string][]uint64 // [ServerTag] -> ServerID
|
||||
ServerLock sync.RWMutex
|
||||
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
|
||||
ServerUUIDToID map[string]uint64 // [ServerUUID] -> ServerID
|
||||
ServerLock sync.RWMutex
|
||||
|
||||
SortedServerList []*model.Server // 用于存储服务器列表的 slice,按照服务器 ID 排序
|
||||
SortedServerListForGuest []*model.Server
|
||||
SortedServerLock sync.RWMutex
|
||||
)
|
||||
|
||||
// InitServer 初始化 ServerID <-> Secret 的映射
|
||||
func InitServer() {
|
||||
ServerList = make(map[uint64]*model.Server)
|
||||
SecretToID = make(map[string]uint64)
|
||||
ServerTagToIDList = make(map[string][]uint64)
|
||||
ServerUUIDToID = make(map[string]uint64)
|
||||
}
|
||||
|
||||
// loadServers 加载服务器列表并根据ID排序
|
||||
@@ -36,8 +33,7 @@ func loadServers() {
|
||||
innerS.State = &model.HostState{}
|
||||
innerS.TaskCloseLock = new(sync.Mutex)
|
||||
ServerList[innerS.ID] = &innerS
|
||||
SecretToID[innerS.Secret] = innerS.ID
|
||||
ServerTagToIDList[innerS.Tag] = append(ServerTagToIDList[innerS.Tag], innerS.ID)
|
||||
ServerUUIDToID[innerS.UUID] = innerS.ID
|
||||
}
|
||||
ReSortServer()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user