feat: user roles (#852)

* [WIP] feat: user roles

* update

* update

* admin handler

* update

* feat: user-specific connection secret

* simplify some logics

* cleanup

* update waf

* update user api error handling

* update waf api

* fix codeql

* update waf table

* fix several problems

* add pagination for waf api

* update permission checks

* switch to runtime check

* 1

* cover?

* some changes
This commit is contained in:
UUBulb
2024-12-22 00:05:41 +08:00
committed by GitHub
parent 50ee62172f
commit 653d0cf2e9
35 changed files with 841 additions and 180 deletions

View File

@@ -36,12 +36,16 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
ip, _ := ctx.Value(model.CtxKeyRealIP{}).(string)
if clientSecret != singleton.Conf.AgentSecretKey {
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail)
singleton.UserLock.RLock()
userId, ok := singleton.AgentSecretToUserId[clientSecret]
if !ok && clientSecret != singleton.Conf.AgentSecretKey {
singleton.UserLock.RUnlock()
model.BlockIP(singleton.DB, ip, model.WAFBlockReasonTypeAgentAuthFail, model.BlockIDgRPC)
return 0, status.Error(codes.Unauthenticated, "客户端认证失败")
}
singleton.UserLock.RUnlock()
model.ClearIP(singleton.DB, ip)
model.ClearIP(singleton.DB, ip, model.BlockIDgRPC)
var clientUUID string
if value, ok := md["client_uuid"]; ok {
@@ -53,21 +57,26 @@ func (a *authHandler) Check(ctx context.Context) (uint64, error) {
}
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
clientID, hasID := singleton.ServerUUIDToID[clientUUID]
singleton.ServerLock.RUnlock()
if !hasID {
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-")}
s := model.Server{UUID: clientUUID, Name: petname.Generate(2, "-"), Common: model.Common{
UserID: userId,
}}
if err := singleton.DB.Create(&s).Error; err != nil {
return 0, status.Error(codes.Unauthenticated, err.Error())
}
s.Host = &model.Host{}
s.State = &model.HostState{}
s.GeoIP = &model.GeoIP{}
// generate a random silly server name
singleton.ServerLock.Lock()
singleton.ServerList[s.ID] = &s
singleton.ServerUUIDToID[clientUUID] = s.ID
singleton.ServerLock.Unlock()
singleton.ReSortServer()
clientID = s.ID
}

View File

@@ -143,8 +143,16 @@ func checkStatus() {
}
for _, server := range ServerList {
// 监测点
UserLock.RLock()
var role uint8
if u, ok := UserInfoMap[server.UserID]; !ok {
role = model.RoleMember
} else {
role = u.Role
}
UserLock.RUnlock()
alertsStore[alert.ID][server.ID] = append(alertsStore[alert.
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB))
ID][server.ID], alert.Snapshot(AlertsCycleTransferStatsStore[alert.ID], server, DB, role))
// 发送通知,分为触发报警和恢复通知
max, passed := alert.Check(alertsStore[alert.ID][server.ID])
// 保存当前服务器状态信息

View File

@@ -12,6 +12,7 @@ import (
"github.com/robfig/cron/v3"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@@ -79,10 +80,7 @@ func UpdateCronList() {
CronLock.RLock()
defer CronLock.RUnlock()
CronList = make([]*model.Cron, 0, len(Crons))
for _, c := range Crons {
CronList = append(CronList, c)
}
CronList = utils.MapValuesToSlice(Crons)
slices.SortFunc(CronList, func(a, b *model.Cron) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@@ -13,6 +13,7 @@ import (
ddns2 "github.com/nezhahq/nezha/pkg/ddns"
"github.com/nezhahq/nezha/pkg/ddns/dummy"
"github.com/nezhahq/nezha/pkg/ddns/webhook"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@@ -24,12 +25,10 @@ var (
func initDDNS() {
DB.Find(&DDNSList)
DDNSCacheLock.Lock()
DDNSCache = make(map[uint64]*model.DDNSProfile)
for i := 0; i < len(DDNSList); i++ {
DDNSCache[DDNSList[i].ID] = DDNSList[i]
}
DDNSCacheLock.Unlock()
OnNameserverUpdate()
}
@@ -56,10 +55,7 @@ func UpdateDDNSList() {
DDNSListLock.Lock()
defer DDNSListLock.Unlock()
DDNSList = make([]*model.DDNSProfile, 0, len(DDNSCache))
for _, p := range DDNSCache {
DDNSList = append(DDNSList, p)
}
DDNSList = utils.MapValuesToSlice(DDNSCache)
slices.SortFunc(DDNSList, func(a, b *model.DDNSProfile) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@@ -6,6 +6,7 @@ import (
"sync"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@@ -19,8 +20,6 @@ var (
func initNAT() {
DB.Find(&NATList)
NATCacheRwLock.Lock()
defer NATCacheRwLock.Unlock()
NATCache = make(map[string]*model.NAT)
for i := 0; i < len(NATList); i++ {
NATCache[NATList[i].Domain] = NATList[i]
@@ -59,10 +58,7 @@ func UpdateNATList() {
NATListLock.Lock()
defer NATListLock.Unlock()
NATList = make([]*model.NAT, 0, len(NATCache))
for _, n := range NATCache {
NATList = append(NATList, n)
}
NATList = utils.MapValuesToSlice(NATCache)
slices.SortFunc(NATList, func(a, b *model.NAT) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
const (
@@ -30,7 +31,7 @@ var (
)
// InitNotification 初始化 GroupID <-> ID <-> Notification 的映射
func InitNotification() {
func initNotification() {
NotificationList = make(map[uint64]map[uint64]*model.Notification)
NotificationIDToGroups = make(map[uint64]map[uint64]struct{})
NotificationGroup = make(map[uint64]string)
@@ -38,9 +39,7 @@ func InitNotification() {
// loadNotifications 从 DB 初始化通知方式相关参数
func loadNotifications() {
InitNotification()
NotificationsLock.Lock()
initNotification()
groupNotifications := make(map[uint64][]uint64)
var ngn []model.NotificationGroupNotification
if err := DB.Find(&ngn).Error; err != nil {
@@ -74,8 +73,6 @@ func loadNotifications() {
}
}
}
NotificationsLock.Unlock()
}
func UpdateNotificationList() {
@@ -85,10 +82,7 @@ func UpdateNotificationList() {
NotificationSortedLock.Lock()
defer NotificationSortedLock.Unlock()
NotificationListSorted = make([]*model.Notification, 0, len(NotificationMap))
for _, n := range NotificationMap {
NotificationListSorted = append(NotificationListSorted, n)
}
NotificationListSorted = utils.MapValuesToSlice(NotificationMap)
slices.SortFunc(NotificationListSorted, func(a, b *model.Notification) int {
return cmp.Compare(a.ID, b.ID)
})

View File

@@ -1,10 +1,12 @@
package singleton
import (
"sort"
"cmp"
"slices"
"sync"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
)
var (
@@ -45,29 +47,21 @@ func ReSortServer() {
SortedServerLock.Lock()
defer SortedServerLock.Unlock()
SortedServerList = make([]*model.Server, 0, len(ServerList))
SortedServerListForGuest = make([]*model.Server, 0)
for _, s := range ServerList {
SortedServerList = append(SortedServerList, s)
SortedServerList = utils.MapValuesToSlice(ServerList)
// 按照服务器 ID 排序的具体实现ID越大越靠前
slices.SortStableFunc(SortedServerList, func(a, b *model.Server) int {
if a.DisplayIndex == b.DisplayIndex {
return cmp.Compare(a.ID, b.ID)
}
return cmp.Compare(b.DisplayIndex, a.DisplayIndex)
})
SortedServerListForGuest = make([]*model.Server, 0, len(SortedServerList))
for _, s := range SortedServerList {
if !s.HideForGuest {
SortedServerListForGuest = append(SortedServerListForGuest, s)
}
}
// 按照服务器 ID 排序的具体实现ID越大越靠前
sort.SliceStable(SortedServerList, func(i, j int) bool {
if SortedServerList[i].DisplayIndex == SortedServerList[j].DisplayIndex {
return SortedServerList[i].ID < SortedServerList[j].ID
}
return SortedServerList[i].DisplayIndex > SortedServerList[j].DisplayIndex
})
sort.SliceStable(SortedServerListForGuest, func(i, j int) bool {
if SortedServerListForGuest[i].DisplayIndex == SortedServerListForGuest[j].DisplayIndex {
return SortedServerListForGuest[i].ID < SortedServerListForGuest[j].ID
}
return SortedServerListForGuest[i].DisplayIndex > SortedServerListForGuest[j].DisplayIndex
})
}
func OnServerDelete(sid []uint64) {

View File

@@ -11,6 +11,7 @@ import (
"github.com/jinzhu/copier"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
pb "github.com/nezhahq/nezha/proto"
)
@@ -174,11 +175,7 @@ func (ss *ServiceSentinel) UpdateServiceList() {
ss.ServiceListLock.Lock()
defer ss.ServiceListLock.Unlock()
ss.ServiceList = make([]*model.Service, 0, len(ss.Services))
for _, v := range ss.Services {
ss.ServiceList = append(ss.ServiceList, v)
}
ss.ServiceList = utils.MapValuesToSlice(ss.Services)
slices.SortFunc(ss.ServiceList, func(a, b *model.Service) int {
return cmp.Compare(a.ID, b.ID)
})
@@ -192,13 +189,6 @@ func (ss *ServiceSentinel) loadServiceHistory() {
panic(err)
}
ss.serviceResponseDataStoreLock.Lock()
defer ss.serviceResponseDataStoreLock.Unlock()
ss.monthlyStatusLock.Lock()
defer ss.monthlyStatusLock.Unlock()
ss.ServicesLock.Lock()
defer ss.ServicesLock.Unlock()
for i := 0; i < len(services); i++ {
task := *services[i]
// 通过cron定时将服务监控任务传递给任务调度管道

View File

@@ -42,6 +42,7 @@ func InitTimezoneAndCache() {
// LoadSingleton 加载子服务并执行
func LoadSingleton() {
initUser() // 加载用户ID绑定表
initI18n() // 加载本地化服务
loadNotifications() // 加载通知服务
loadServers() // 加载服务器列表
@@ -81,8 +82,8 @@ func InitDBFromPath(path string) {
}
err = DB.AutoMigrate(model.Server{}, model.User{}, model.ServerGroup{}, model.NotificationGroup{},
model.Notification{}, model.AlertRule{}, model.Service{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{}, model.UserGroup{},
model.UserGroupUser{}, model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.ServiceHistory{}, model.Cron{}, model.Transfer{}, model.ServerGroupServer{},
model.NAT{}, model.DDNSProfile{}, model.NotificationGroupNotification{},
model.WAF{})
if err != nil {
panic(err)

135
service/singleton/user.go Normal file
View File

@@ -0,0 +1,135 @@
package singleton
import (
"sync"
"github.com/nezhahq/nezha/model"
"gorm.io/gorm"
)
var (
UserInfoMap map[uint64]model.UserInfo
AgentSecretToUserId map[string]uint64
UserLock sync.RWMutex
)
func initUser() {
UserInfoMap = make(map[uint64]model.UserInfo)
AgentSecretToUserId = make(map[string]uint64)
var users []model.User
DB.Find(&users)
for _, u := range users {
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
}
func OnUserUpdate(u *model.User) {
UserLock.Lock()
defer UserLock.Unlock()
if u == nil {
return
}
UserInfoMap[u.ID] = model.UserInfo{
Role: u.Role,
AgentSecret: u.AgentSecret,
}
AgentSecretToUserId[u.AgentSecret] = u.ID
}
func OnUserDelete(id []uint64, errorFunc func(string, ...interface{}) error) error {
UserLock.Lock()
defer UserLock.Unlock()
if len(id) < 1 {
return Localizer.ErrorT("user id not specified")
}
var (
cron, server bool
crons, servers []uint64
)
for _, uid := range id {
err := DB.Transaction(func(tx *gorm.DB) error {
CronLock.RLock()
crons = model.FindByUserID(CronList, uid)
CronLock.RUnlock()
cron = len(crons) > 0
if cron {
if err := tx.Unscoped().Delete(&model.Cron{}, "id in (?)", crons).Error; err != nil {
return err
}
}
SortedServerLock.RLock()
servers = model.FindByUserID(SortedServerList, uid)
SortedServerLock.RUnlock()
server = len(servers) > 0
if server {
if err := tx.Unscoped().Delete(&model.Server{}, "id in (?)", servers).Error; err != nil {
return err
}
if err := tx.Unscoped().Delete(&model.ServerGroupServer{}, "server_id in (?)", servers).Error; err != nil {
return err
}
}
if err := tx.Unscoped().Delete(&model.Transfer{}, "server_id in (?)", servers).Error; err != nil {
return err
}
if err := tx.Where("id IN (?)", id).Delete(&model.User{}).Error; err != nil {
return err
}
return nil
})
if err != nil {
return errorFunc("%v", err)
}
if cron {
OnDeleteCron(crons)
}
if server {
AlertsLock.Lock()
for _, sid := range servers {
for _, alert := range Alerts {
if AlertsCycleTransferStatsStore[alert.ID] != nil {
delete(AlertsCycleTransferStatsStore[alert.ID].ServerName, sid)
delete(AlertsCycleTransferStatsStore[alert.ID].Transfer, sid)
delete(AlertsCycleTransferStatsStore[alert.ID].NextUpdate, sid)
}
}
}
AlertsLock.Unlock()
OnServerDelete(servers)
}
secret := UserInfoMap[uid].AgentSecret
delete(AgentSecretToUserId, secret)
delete(UserInfoMap, uid)
}
if cron {
UpdateCronList()
}
if server {
ReSortServer()
}
return nil
}