客户端授权验证

This commit is contained in:
奶爸
2019-12-09 16:02:49 +08:00
parent 986dc6114a
commit 58277ba0b6
13 changed files with 211 additions and 73 deletions

View File

@@ -2,7 +2,10 @@ package rpc
import (
"context"
"fmt"
"github.com/p14yground/nezha/model"
"github.com/p14yground/nezha/service/dao"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
@@ -10,51 +13,41 @@ import (
// AuthHandler ..
type AuthHandler struct {
AppKey string
AppSecret string
ClientID string
ClientSecret string
}
// GetRequestMetadata ..
func (a *AuthHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
return map[string]string{"app_key": a.AppKey, "app_secret": a.AppSecret}, nil
return map[string]string{"app_key": a.ClientID, "app_secret": a.ClientSecret}, nil
}
// RequireTransportSecurity ..
func (a *AuthHandler) RequireTransportSecurity() bool {
return false
return !dao.Conf.Debug
}
// Check ..
func (a *AuthHandler) Check(ctx context.Context) error {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return status.Errorf(codes.Unauthenticated, "metadata.FromIncomingContext err")
return status.Errorf(codes.Unauthenticated, "获取 metaData 失败")
}
var (
AppKey string
AppSecret string
ClientID string
ClientSecret string
)
if value, ok := md["app_key"]; ok {
AppKey = value[0]
ClientID = value[0]
}
if value, ok := md["app_secret"]; ok {
AppSecret = value[0]
ClientSecret = value[0]
}
if AppKey != a.GetAppKey() || AppSecret != a.GetAppSecret() {
return status.Errorf(codes.Unauthenticated, "invalid token")
if _, ok := dao.Cache.Get(fmt.Sprintf("%s%s%s", model.CtxKeyServer, ClientID, ClientSecret)); !ok {
return status.Errorf(codes.Unauthenticated, "客户端认证失败")
}
return nil
}
// GetAppKey ..
func (a *AuthHandler) GetAppKey() string {
return a.AppKey
}
// GetAppSecret ..
func (a *AuthHandler) GetAppSecret() string {
return a.AppSecret
}

View File

@@ -3,7 +3,9 @@ package rpc
import (
"context"
"fmt"
"log"
"github.com/p14yground/nezha/model"
pb "github.com/p14yground/nezha/proto"
)
@@ -23,11 +25,17 @@ func (s *NezhaHandler) ReportState(c context.Context, r *pb.State) (*pb.Receipt,
// Heartbeat ..
func (s *NezhaHandler) Heartbeat(r *pb.Beat, stream pb.NezhaService_HeartbeatServer) error {
defer log.Println("Heartbeat exit")
if err := s.Auth.Check(stream.Context()); err != nil {
return err
}
fmt.Printf("ReportState receive: %s\n", r)
return nil
err := stream.Send(&pb.Command{
Type: model.MTReportState,
})
if err != nil {
log.Printf("Heartbeat stream.Send err:%v", err)
}
select {}
}
// Register ..