feat: 绑定 oauth2

This commit is contained in:
naiba
2024-12-28 23:50:59 +08:00
parent 8554f3eba7
commit 18020939da
15 changed files with 360 additions and 24 deletions

View File

@@ -56,6 +56,8 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
}
api := r.Group("api/v1")
api.POST("/login", authMiddleware.LoginHandler)
api.GET("/oauth2/:provider", commonHandler(oauth2redirect))
api.POST("/oauth2/:provider/callback", commonHandler(oauth2callback(authMiddleware)))
optionalAuth := api.Group("", optionalAuthMiddleware(authMiddleware))
optionalAuth.GET("/ws/server", commonHandler(serverStream))
@@ -79,6 +81,9 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
auth.GET("/profile", commonHandler(getProfile))
auth.POST("/profile", commonHandler(updateProfile))
auth.POST("/oauth2/:provider/bind", commonHandler(bindOauth2))
auth.POST("/oauth2/:provider/unbind", commonHandler(unbindOauth2))
auth.GET("/user", adminHandler(listUser))
auth.POST("/user", adminHandler(createUser))
auth.POST("/batch-delete/user", adminHandler(batchDeleteUser))

View File

@@ -0,0 +1,240 @@
package controller
import (
"context"
"fmt"
"io"
"strconv"
"strings"
"time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"
"gorm.io/gorm"
"github.com/nezhahq/nezha/model"
"github.com/nezhahq/nezha/pkg/utils"
"github.com/nezhahq/nezha/service/singleton"
)
type Oauth2LoginType uint8
const (
_ Oauth2LoginType = iota
rTypeLogin
rTypeBind
)
func getRedirectURL(c *gin.Context, provider string, rType Oauth2LoginType) string {
scheme := "http://"
referer := c.Request.Referer()
if forwardedProto := c.Request.Header.Get("X-Forwarded-Proto"); forwardedProto == "https" || strings.HasPrefix(referer, "https://") {
scheme = "https://"
}
var suffix string
if rType == rTypeLogin {
suffix = "/dashboard/login?provider=" + provider
} else if rType == rTypeBind {
suffix = "/dashboard/profile?provider=" + provider
}
return scheme + c.Request.Host + suffix
}
// @Summary Get Oauth2 Redirect URL
// @Description Get Oauth2 Redirect URL
// @Produce json
// @Param provider path string true "provider"
// @Param type query int false "type" Enums(1, 2) default(1)
// @Success 200 {object} model.Oauth2LoginResponse
// @Router /api/v1/oauth2/{provider} [get]
func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) {
provider := c.Param("provider")
if provider == "" {
return nil, singleton.Localizer.ErrorT("provider is required")
}
rTypeInt, err := strconv.Atoi(c.Query("type"))
if err != nil {
return nil, err
}
o2confRaw, has := singleton.Conf.Oauth2[provider]
if !has {
return nil, singleton.Localizer.ErrorT("provider not found")
}
o2conf := o2confRaw.Setup(getRedirectURL(c, provider, Oauth2LoginType(rTypeInt)))
randomString, err := utils.GenerateRandomString(32)
if err != nil {
return nil, err
}
state, stateKey := randomString[:16], randomString[16:]
singleton.Cache.Set(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey), state, cache.DefaultExpiration)
url := o2conf.AuthCodeURL(state, oauth2.AccessTypeOnline)
c.SetCookie("nz-o2s", stateKey, 60*5, "", "", false, false)
return &model.Oauth2LoginResponse{Redirect: url}, nil
}
func exchangeOpenId(c *gin.Context, o2confRaw *model.Oauth2Config, provider string, callbackData model.Oauth2Callback) (string, error) {
// 验证登录跳转时的 State
stateKey, err := c.Cookie("nz-o2s")
if err != nil {
return "", singleton.Localizer.ErrorT("invalid state key")
}
state, ok := singleton.Cache.Get(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey))
if !ok || state.(string) != callbackData.State {
return "", singleton.Localizer.ErrorT("invalid state key")
}
o2conf := o2confRaw.Setup(getRedirectURL(c, provider, rTypeLogin))
ctx := context.Background()
otk, err := o2conf.Exchange(ctx, callbackData.Code)
if err != nil {
return "", err
}
oauth2client := o2conf.Client(ctx, otk)
resp, err := oauth2client.Get(o2confRaw.UserInfoURL)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return gjson.Get(string(body), o2confRaw.UserIDPath).String(), nil
}
// @Summary Oauth2 Callback
// @Description Oauth2 Callback
// @Accept json
// @Produce json
// @Param provider path string true "provider"
// @Param body body model.Oauth2Callback true "body"
// @Success 200 {object} model.LoginResponse
// @Router /api/v1/oauth2/{provider}/callback [post]
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) {
return func(c *gin.Context) (*model.LoginResponse, error) {
provider := c.Param("provider")
if provider == "" {
return nil, singleton.Localizer.ErrorT("provider is required")
}
o2confRaw, has := singleton.Conf.Oauth2[provider]
if !has {
return nil, singleton.Localizer.ErrorT("provider not found")
}
provider = strings.ToLower(provider)
var callbackData model.Oauth2Callback
if err := c.ShouldBind(&callbackData); err != nil {
return nil, err
}
realip := c.GetString(model.CtxKeyRealIPStr)
if callbackData.Code == "" {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
return nil, singleton.Localizer.ErrorT("code is required")
}
openId, err := exchangeOpenId(c, o2confRaw, provider, callbackData)
if err != nil {
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
return nil, err
}
var bind model.Oauth2Bind
if err := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).First(&bind).Error; err != nil {
return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet")
}
tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
if err != nil {
return nil, err
}
jwtConfig.SetCookie(c, tokenString)
return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil
}
}
// @Summary Bind Oauth2
// @Description Bind Oauth2
// @Accept json
// @Produce json
// @Param provider path string true "provider"
// @Param body body model.Oauth2Callback true "body"
// @Success 200 {object} any
// @Router /api/v1/oauth2/{provider}/bind [post]
func bindOauth2(c *gin.Context) (any, error) {
var bindData model.Oauth2Callback
if err := c.ShouldBind(&bindData); err != nil {
return nil, err
}
provider := c.Param("provider")
o2conf, has := singleton.Conf.Oauth2[provider]
if !has {
return nil, singleton.Localizer.ErrorT("provider not found")
}
provider = strings.ToLower(provider)
openId, err := exchangeOpenId(c, o2conf, provider, bindData)
if err != nil {
return nil, err
}
u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User)
var bind model.Oauth2Bind
result := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).Limit(1).Find(&bind)
if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
return nil, newGormError("%v", result.Error)
}
bind.UserID = u.ID
bind.Provider = provider
bind.OpenID = openId
if result.Error == gorm.ErrRecordNotFound {
result = singleton.DB.Create(&bind)
} else {
result = singleton.DB.Save(&bind)
}
if result.Error != nil {
return nil, newGormError("%v", result.Error)
}
return nil, nil
}
// @Summary Unbind Oauth2
// @Description Unbind Oauth2
// @Accept json
// @Produce json
// @Param provider path string true "provider"
// @Success 200 {object} any
// @Router /api/v1/oauth2/{provider}/unbind [post]
func unbindOauth2(c *gin.Context) (any, error) {
provider := c.Param("provider")
if provider == "" {
return nil, singleton.Localizer.ErrorT("provider is required")
}
_, has := singleton.Conf.Oauth2[provider]
if !has {
return nil, singleton.Localizer.ErrorT("provider not found")
}
provider = strings.ToLower(provider)
u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User)
if err := singleton.DB.Where("provider = ? AND user_id = ?", provider, u.ID).Delete(&model.Oauth2Bind{}).Error; err != nil {
return nil, newGormError("%v", err)
}
return nil, nil
}

View File

@@ -17,9 +17,9 @@ import (
// @Security BearerAuth
// @Tags common
// @Produce json
// @Success 200 {object} model.CommonResponse[model.SettingResponse]
// @Success 200 {object} model.CommonResponse[model.SettingResponse[model.Config]]
// @Router /setting [get]
func listConfig(c *gin.Context) (model.SettingResponse, error) {
func listConfig(c *gin.Context) (model.SettingResponse[any], error) {
u, authorized := c.Get(model.CtxKeyAuthorizedUser)
var isAdmin bool
if authorized {
@@ -27,30 +27,32 @@ func listConfig(c *gin.Context) (model.SettingResponse, error) {
isAdmin = user.Role == model.RoleAdmin
}
conf := model.SettingResponse{
Config: *singleton.Conf,
config := *singleton.Conf
config.Language = strings.Replace(config.Language, "_", "-", -1)
conf := model.SettingResponse[any]{
Config: config,
Version: singleton.Version,
FrontendTemplates: singleton.FrontendTemplates,
}
if !authorized || !isAdmin {
conf = model.SettingResponse{
Config: model.Config{
SiteName: conf.SiteName,
Language: conf.Language,
CustomCode: conf.CustomCode,
CustomCodeDashboard: conf.CustomCodeDashboard,
},
configForGuests := model.ConfigForGuests{
Language: config.Language,
SiteName: config.SiteName,
CustomCode: config.CustomCode,
CustomCodeDashboard: config.CustomCodeDashboard,
Oauth2Providers: config.Oauth2Providers,
}
if authorized {
config.TLS = singleton.Conf.TLS
config.InstallHost = singleton.Conf.InstallHost
}
conf = model.SettingResponse[any]{
Config: configForGuests,
}
}
if !isAdmin {
conf.Config.TLS = singleton.Conf.TLS
conf.Config.InstallHost = singleton.Conf.InstallHost
}
conf.Config.Language = strings.Replace(conf.Config.Language, "_", "-", -1)
return conf, nil
}

View File

@@ -26,9 +26,18 @@ func getProfile(c *gin.Context) (*model.Profile, error) {
if !ok {
return nil, singleton.Localizer.ErrorT("unauthorized")
}
var ob []model.Oauth2Bind
if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Find(&ob).Error; err != nil {
return nil, newGormError("%v", err)
}
var obMap = make(map[string]string)
for _, v := range ob {
obMap[v.Provider] = v.OpenID
}
return &model.Profile{
User: *auth.(*model.User),
LoginIP: c.GetString(model.CtxKeyRealIPStr),
User: *auth.(*model.User),
LoginIP: c.GetString(model.CtxKeyRealIPStr),
Oauth2Bind: obMap,
}, nil
}