refactor agent auth & server api

This commit is contained in:
naiba
2024-10-20 23:23:04 +08:00
parent d3f907b5c3
commit aa20c97312
19 changed files with 488 additions and 330 deletions

View File

@@ -2,20 +2,23 @@ package rpc
import (
"context"
"sync"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
)
type authHandler struct {
ClientSecret string
ClientUUID string
}
func (a *authHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{"client_secret": a.ClientSecret}, nil
return map[string]string{"client_secret": a.ClientSecret, "client_uuid": a.ClientUUID}, nil
}
func (a *authHandler) RequireTransportSecurity() bool {
@@ -33,15 +36,29 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
clientSecret = value[0]
}
if clientSecret != singleton.Conf.AgentSecretKey {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
}
var clientUUID string
if value, ok := md["client_uuid"]; ok {
clientUUID = value[0]
}
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
clientID, hasID := singleton.SecretToID[clientSecret]
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
if !hasID {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
}
_, hasServer := singleton.ServerList[clientID]
if !hasServer {
return 0, status.Errorf(codes.Unauthenticated, "客户端认证失败")
s := model.Server{UUID: clientUUID}
if err := singleton.DB.Create(&s).Error; err != nil {
return 0, status.Errorf(codes.Unauthenticated, err.Error())
}
s.Host = &model.Host{}
s.State = &model.HostState{}
s.TaskCloseLock = new(sync.Mutex)
singleton.ServerList[s.ID] = &s
singleton.ServerUUIDToID[clientUUID] = s.ID
}
return clientID, nil
}

View File

@@ -103,9 +103,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
}
ipv4, ipv6, validIP := utils.SplitIPAddr(server.Host.IP)
info := CommonServerInfo{
ID: server.ID,
Name: server.Name,
Tag: server.Tag,
ID: server.ID,
Name: server.Name,
// Tag: server.Tag,
LastActive: server.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
@@ -125,9 +125,9 @@ func (s *ServerAPIService) GetStatusByIDList(idList []uint64) *ServerStatusRespo
}
// GetStatusByTag 获取传入分组的所有服务器状态信息
func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
return s.GetStatusByIDList(ServerTagToIDList[tag])
}
// func (s *ServerAPIService) GetStatusByTag(tag string) *ServerStatusResponse {
// return s.GetStatusByIDList(ServerTagToIDList[tag])
// }
// GetAllStatus 获取所有服务器状态信息
func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
@@ -143,9 +143,9 @@ func (s *ServerAPIService) GetAllStatus() *ServerStatusResponse {
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := CommonServerInfo{
ID: v.ID,
Name: v.Name,
Tag: v.Tag,
ID: v.ID,
Name: v.Name,
// Tag: v.Tag,
LastActive: v.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
@@ -173,23 +173,23 @@ func (s *ServerAPIService) GetListByTag(tag string) *ServerInfoResponse {
ServerLock.RLock()
defer ServerLock.RUnlock()
for _, v := range ServerTagToIDList[tag] {
host := ServerList[v].Host
if host == nil {
continue
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := &CommonServerInfo{
ID: v,
Name: ServerList[v].Name,
Tag: ServerList[v].Tag,
LastActive: ServerList[v].LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,
ValidIP: validIP,
}
res.Result = append(res.Result, info)
}
// for _, v := range ServerTagToIDList[tag] {
// host := ServerList[v].Host
// if host == nil {
// continue
// }
// ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
// info := &CommonServerInfo{
// ID: v,
// Name: ServerList[v].Name,
// Tag: ServerList[v].Tag,
// LastActive: ServerList[v].LastActive.Unix(),
// IPV4: ipv4,
// IPV6: ipv6,
// ValidIP: validIP,
// }
// res.Result = append(res.Result, info)
// }
res.CommonResponse = CommonResponse{
Code: 0,
Message: "success",
@@ -211,9 +211,9 @@ func (s *ServerAPIService) GetAllList() *ServerInfoResponse {
}
ipv4, ipv6, validIP := utils.SplitIPAddr(host.IP)
info := &CommonServerInfo{
ID: v.ID,
Name: v.Name,
Tag: v.Tag,
ID: v.ID,
Name: v.Name,
// Tag: v.Tag,
LastActive: v.LastActive.Unix(),
IPV4: ipv4,
IPV6: ipv6,

View File

@@ -8,21 +8,18 @@ import (
)
var (
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
SecretToID map[string]uint64 // [ServerSecret] -> ServerID
ServerTagToIDList map[string][]uint64 // [ServerTag] -> ServerID
ServerLock sync.RWMutex
ServerList map[uint64]*model.Server // [ServerID] -> model.Server
ServerUUIDToID map[string]uint64 // [ServerUUID] -> ServerID
ServerLock sync.RWMutex
SortedServerList []*model.Server // 用于存储服务器列表的 slice按照服务器 ID 排序
SortedServerListForGuest []*model.Server
SortedServerLock sync.RWMutex
)
// InitServer 初始化 ServerID <-> Secret 的映射
func InitServer() {
ServerList = make(map[uint64]*model.Server)
SecretToID = make(map[string]uint64)
ServerTagToIDList = make(map[string][]uint64)
ServerUUIDToID = make(map[string]uint64)
}
// loadServers 加载服务器列表并根据ID排序
@@ -36,8 +33,7 @@ func loadServers() {
innerS.State = &model.HostState{}
innerS.TaskCloseLock = new(sync.Mutex)
ServerList[innerS.ID] = &innerS
SecretToID[innerS.Secret] = innerS.ID
ServerTagToIDList[innerS.Tag] = append(ServerTagToIDList[innerS.Tag], innerS.ID)
ServerUUIDToID[innerS.UUID] = innerS.ID
}
ReSortServer()
}