feat: refactor grpc keepalive

This commit is contained in:
naiba
2024-12-05 00:11:34 +08:00
parent 68e3bb00e4
commit be51c5a1e6
12 changed files with 102 additions and 152 deletions

View File

@@ -8,12 +8,11 @@ import (
"sync"
"time"
"github.com/jinzhu/copier"
"github.com/nezhahq/nezha/pkg/ddns"
geoipx "github.com/nezhahq/nezha/pkg/geoip"
"github.com/nezhahq/nezha/pkg/grpcx"
"github.com/jinzhu/copier"
"github.com/nezhahq/nezha/model"
pb "github.com/nezhahq/nezha/proto"
"github.com/nezhahq/nezha/service/singleton"
@@ -37,63 +36,55 @@ func NewNezhaHandler() *NezhaHandler {
}
}
func (s *NezhaHandler) ReportTask(c context.Context, r *pb.TaskResult) (*pb.Receipt, error) {
var err error
var clientID uint64
if clientID, err = s.Auth.Check(c); err != nil {
return nil, err
}
if r.GetType() == model.TaskTypeCommand {
// 处理上报的计划任务
singleton.CronLock.RLock()
defer singleton.CronLock.RUnlock()
cr := singleton.Crons[r.GetId()]
if cr != nil {
singleton.ServerLock.RLock()
defer singleton.ServerLock.RUnlock()
// 保存当前服务器状态信息
curServer := model.Server{}
copier.Copy(&curServer, singleton.ServerList[clientID])
if cr.PushSuccessful && r.GetSuccessful() {
singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Successfully"),
cr.Name, singleton.ServerList[clientID].Name, r.GetData()), nil, &curServer)
}
if !r.GetSuccessful() {
singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Failed"),
cr.Name, singleton.ServerList[clientID].Name, r.GetData()), nil, &curServer)
}
singleton.DB.Model(cr).Updates(model.Cron{
LastExecutedAt: time.Now().Add(time.Second * -1 * time.Duration(r.GetDelay())),
LastResult: r.GetSuccessful(),
})
}
} else if model.IsServiceSentinelNeeded(r.GetType()) {
singleton.ServiceSentinelShared.Dispatch(singleton.ReportData{
Data: r,
Reporter: clientID,
})
}
return &pb.Receipt{Proced: true}, nil
}
func (s *NezhaHandler) RequestTask(h *pb.Host, stream pb.NezhaService_RequestTaskServer) error {
func (s *NezhaHandler) RequestTask(stream pb.NezhaService_RequestTaskServer) error {
var clientID uint64
var err error
if clientID, err = s.Auth.Check(stream.Context()); err != nil {
return err
}
closeCh := make(chan error)
singleton.ServerLock.RLock()
singleton.ServerList[clientID].TaskCloseLock.Lock()
// 修复不断的请求 task 但是没有 return 导致内存泄漏
if singleton.ServerList[clientID].TaskClose != nil {
close(singleton.ServerList[clientID].TaskClose)
}
singleton.ServerList[clientID].TaskStream = stream
singleton.ServerList[clientID].TaskClose = closeCh
singleton.ServerList[clientID].TaskCloseLock.Unlock()
singleton.ServerLock.RUnlock()
return <-closeCh
var result *pb.TaskResult
for {
result, err = stream.Recv()
if err != nil {
log.Printf("NEZHA>> RequestTask error: %v, clientID: %d\n", err, clientID)
return nil
}
if result.GetType() == model.TaskTypeCommand {
// 处理上报的计划任务
singleton.CronLock.RLock()
cr := singleton.Crons[result.GetId()]
singleton.CronLock.RUnlock()
if cr != nil {
// 保存当前服务器状态信息
var curServer model.Server
singleton.ServerLock.RLock()
copier.Copy(&curServer, singleton.ServerList[clientID])
singleton.ServerLock.RUnlock()
if cr.PushSuccessful && result.GetSuccessful() {
singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Successfully"),
cr.Name, singleton.ServerList[clientID].Name, result.GetData()), nil, &curServer)
}
if !result.GetSuccessful() {
singleton.SendNotification(cr.NotificationGroupID, fmt.Sprintf("[%s] %s, %s\n%s", singleton.Localizer.T("Scheduled Task Executed Failed"),
cr.Name, singleton.ServerList[clientID].Name, result.GetData()), nil, &curServer)
}
singleton.DB.Model(cr).Updates(model.Cron{
LastExecutedAt: time.Now().Add(time.Second * -1 * time.Duration(result.GetDelay())),
LastResult: result.GetSuccessful(),
})
}
} else if model.IsServiceSentinelNeeded(result.GetType()) {
singleton.ServiceSentinelShared.Dispatch(singleton.ReportData{
Data: result,
Reporter: clientID,
})
}
}
}
func (s *NezhaHandler) ReportSystemState(stream pb.NezhaService_ReportSystemStateServer) error {
@@ -157,10 +148,22 @@ func (s *NezhaHandler) IOStream(stream pb.NezhaService_IOStreamServer) error {
if err != nil {
return err
}
// ff05ff05 是 Nezha 的魔数,用于标识流 ID
if id == nil || len(id.Data) < 4 || (id.Data[0] != 0xff && id.Data[1] != 0x05 && id.Data[2] != 0xff && id.Data[3] == 0x05) {
return fmt.Errorf("invalid stream id")
}
go func() {
for {
if err := stream.Send(&pb.IOStreamData{Data: []byte{}}); err != nil {
log.Printf("NEZHA>> IOStream keepAlive error: %v\n", err)
return
}
time.Sleep(time.Second * 30)
}
}()
streamId := string(id.Data[4:])
if _, err := s.GetStream(streamId); err != nil {