mirror of
https://github.com/Buriburizaem0n/nezha_domains.git
synced 2026-02-04 04:30:05 +00:00
fix: oauth2 redirect url not consistent (#930)
* fix: oauth2 redirect url not consistent * only use one redirect uri * feat: allow to disable password authentication * generate translation template * update error * redirect * query
This commit is contained in:
@@ -57,7 +57,6 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
|
||||
api := r.Group("api/v1")
|
||||
api.POST("/login", authMiddleware.LoginHandler)
|
||||
api.GET("/oauth2/:provider", commonHandler(oauth2redirect))
|
||||
api.POST("/oauth2/:provider/callback", commonHandler(oauth2callback(authMiddleware)))
|
||||
|
||||
optionalAuth := api.Group("", optionalAuthMiddleware(authMiddleware))
|
||||
optionalAuth.GET("/ws/server", commonHandler(serverStream))
|
||||
@@ -67,6 +66,8 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
|
||||
optionalAuth.GET("/service/:id", commonHandler(listServiceHistory))
|
||||
optionalAuth.GET("/service/server", commonHandler(listServerWithServices))
|
||||
|
||||
optionalAuth.GET("/oauth2/callback", commonHandler(oauth2callback(authMiddleware)))
|
||||
|
||||
optionalAuth.GET("/setting", commonHandler(listConfig))
|
||||
|
||||
auth := api.Group("", authMiddleware.MiddlewareFunc())
|
||||
@@ -81,7 +82,6 @@ func routers(r *gin.Engine, frontendDist fs.FS) {
|
||||
|
||||
auth.GET("/profile", commonHandler(getProfile))
|
||||
auth.POST("/profile", commonHandler(updateProfile))
|
||||
auth.POST("/oauth2/:provider/bind", commonHandler(bindOauth2))
|
||||
auth.POST("/oauth2/:provider/unbind", commonHandler(unbindOauth2))
|
||||
|
||||
auth.GET("/user", adminHandler(listUser))
|
||||
|
||||
@@ -89,6 +89,7 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
|
||||
|
||||
var user model.User
|
||||
realip := c.GetString(model.CtxKeyRealIPStr)
|
||||
|
||||
if err := singleton.DB.Select("id", "password").Where("username = ?", loginVals.Username).First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, model.BlockIDUnknownUser)
|
||||
@@ -96,6 +97,11 @@ func authenticator() func(c *gin.Context) (interface{}, error) {
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
if user.RejectPassword {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, int64(user.ID))
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(loginVals.Password)); err != nil {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeLoginFail, int64(user.ID))
|
||||
return nil, jwt.ErrFailedAuthentication
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -20,27 +21,13 @@ import (
|
||||
"github.com/nezhahq/nezha/service/singleton"
|
||||
)
|
||||
|
||||
type Oauth2LoginType uint8
|
||||
|
||||
const (
|
||||
_ Oauth2LoginType = iota
|
||||
rTypeLogin
|
||||
rTypeBind
|
||||
)
|
||||
|
||||
func getRedirectURL(c *gin.Context, provider string, rType Oauth2LoginType) string {
|
||||
func getRedirectURL(c *gin.Context) string {
|
||||
scheme := "http://"
|
||||
referer := c.Request.Referer()
|
||||
if forwardedProto := c.Request.Header.Get("X-Forwarded-Proto"); forwardedProto == "https" || strings.HasPrefix(referer, "https://") {
|
||||
scheme = "https://"
|
||||
}
|
||||
var suffix string
|
||||
if rType == rTypeLogin {
|
||||
suffix = "/dashboard/login?provider=" + provider
|
||||
} else if rType == rTypeBind {
|
||||
suffix = "/dashboard/profile?provider=" + provider
|
||||
}
|
||||
return scheme + c.Request.Host + suffix
|
||||
return scheme + c.Request.Host + "/api/v1/oauth2/callback"
|
||||
}
|
||||
|
||||
// @Summary Get Oauth2 Redirect URL
|
||||
@@ -56,7 +43,7 @@ func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) {
|
||||
return nil, singleton.Localizer.ErrorT("provider is required")
|
||||
}
|
||||
|
||||
rTypeInt, err := strconv.Atoi(c.Query("type"))
|
||||
rTypeInt, err := strconv.ParseUint(c.Query("type"), 10, 8)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -65,14 +52,18 @@ func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) {
|
||||
if !has {
|
||||
return nil, singleton.Localizer.ErrorT("provider not found")
|
||||
}
|
||||
o2conf := o2confRaw.Setup(getRedirectURL(c, provider, Oauth2LoginType(rTypeInt)))
|
||||
o2conf := o2confRaw.Setup(getRedirectURL(c))
|
||||
|
||||
randomString, err := utils.GenerateRandomString(32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, stateKey := randomString[:16], randomString[16:]
|
||||
singleton.Cache.Set(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey), state, cache.DefaultExpiration)
|
||||
singleton.Cache.Set(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey), &model.Oauth2State{
|
||||
Action: model.Oauth2LoginType(rTypeInt),
|
||||
Provider: provider,
|
||||
State: state,
|
||||
}, cache.DefaultExpiration)
|
||||
|
||||
url := o2conf.AuthCodeURL(state, oauth2.AccessTypeOnline)
|
||||
c.SetCookie("nz-o2s", stateKey, 60*5, "", "", false, false)
|
||||
@@ -80,141 +71,6 @@ func oauth2redirect(c *gin.Context) (*model.Oauth2LoginResponse, error) {
|
||||
return &model.Oauth2LoginResponse{Redirect: url}, nil
|
||||
}
|
||||
|
||||
func exchangeOpenId(c *gin.Context, o2confRaw *model.Oauth2Config, provider string, callbackData model.Oauth2Callback) (string, error) {
|
||||
// 验证登录跳转时的 State
|
||||
stateKey, err := c.Cookie("nz-o2s")
|
||||
if err != nil {
|
||||
return "", singleton.Localizer.ErrorT("invalid state key")
|
||||
}
|
||||
state, ok := singleton.Cache.Get(fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey))
|
||||
if !ok || state.(string) != callbackData.State {
|
||||
return "", singleton.Localizer.ErrorT("invalid state key")
|
||||
}
|
||||
|
||||
o2conf := o2confRaw.Setup(getRedirectURL(c, provider, rTypeLogin))
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
otk, err := o2conf.Exchange(ctx, callbackData.Code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
oauth2client := o2conf.Client(ctx, otk)
|
||||
resp, err := oauth2client.Get(o2confRaw.UserInfoURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gjson.Get(string(body), o2confRaw.UserIDPath).String(), nil
|
||||
}
|
||||
|
||||
// @Summary Oauth2 Callback
|
||||
// @Description Oauth2 Callback
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param provider path string true "provider"
|
||||
// @Param body body model.Oauth2Callback true "body"
|
||||
// @Success 200 {object} model.LoginResponse
|
||||
// @Router /api/v1/oauth2/{provider}/callback [post]
|
||||
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) {
|
||||
return func(c *gin.Context) (*model.LoginResponse, error) {
|
||||
provider := c.Param("provider")
|
||||
if provider == "" {
|
||||
return nil, singleton.Localizer.ErrorT("provider is required")
|
||||
}
|
||||
|
||||
o2confRaw, has := singleton.Conf.Oauth2[provider]
|
||||
if !has {
|
||||
return nil, singleton.Localizer.ErrorT("provider not found")
|
||||
}
|
||||
provider = strings.ToLower(provider)
|
||||
|
||||
var callbackData model.Oauth2Callback
|
||||
if err := c.ShouldBind(&callbackData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
realip := c.GetString(model.CtxKeyRealIPStr)
|
||||
|
||||
if callbackData.Code == "" {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
|
||||
return nil, singleton.Localizer.ErrorT("code is required")
|
||||
}
|
||||
|
||||
openId, err := exchangeOpenId(c, o2confRaw, provider, callbackData)
|
||||
if err != nil {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bind model.Oauth2Bind
|
||||
if err := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).First(&bind).Error; err != nil {
|
||||
return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet")
|
||||
}
|
||||
|
||||
tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jwtConfig.SetCookie(c, tokenString)
|
||||
|
||||
return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// @Summary Bind Oauth2
|
||||
// @Description Bind Oauth2
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param provider path string true "provider"
|
||||
// @Param body body model.Oauth2Callback true "body"
|
||||
// @Success 200 {object} any
|
||||
// @Router /api/v1/oauth2/{provider}/bind [post]
|
||||
func bindOauth2(c *gin.Context) (any, error) {
|
||||
var bindData model.Oauth2Callback
|
||||
if err := c.ShouldBind(&bindData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
provider := c.Param("provider")
|
||||
o2conf, has := singleton.Conf.Oauth2[provider]
|
||||
if !has {
|
||||
return nil, singleton.Localizer.ErrorT("provider not found")
|
||||
}
|
||||
provider = strings.ToLower(provider)
|
||||
|
||||
openId, err := exchangeOpenId(c, o2conf, provider, bindData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User)
|
||||
|
||||
var bind model.Oauth2Bind
|
||||
result := singleton.DB.Where("provider = ? AND open_id = ?", provider, openId).Limit(1).Find(&bind)
|
||||
if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
|
||||
return nil, newGormError("%v", result.Error)
|
||||
}
|
||||
bind.UserID = u.ID
|
||||
bind.Provider = provider
|
||||
bind.OpenID = openId
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
result = singleton.DB.Create(&bind)
|
||||
} else {
|
||||
result = singleton.DB.Save(&bind)
|
||||
}
|
||||
if result.Error != nil {
|
||||
return nil, newGormError("%v", result.Error)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// @Summary Unbind Oauth2
|
||||
// @Description Unbind Oauth2
|
||||
// @Accept json
|
||||
@@ -232,9 +88,145 @@ func unbindOauth2(c *gin.Context) (any, error) {
|
||||
return nil, singleton.Localizer.ErrorT("provider not found")
|
||||
}
|
||||
provider = strings.ToLower(provider)
|
||||
|
||||
u := c.MustGet(model.CtxKeyAuthorizedUser).(*model.User)
|
||||
if err := singleton.DB.Where("provider = ? AND user_id = ?", provider, u.ID).Delete(&model.Oauth2Bind{}).Error; err != nil {
|
||||
query := singleton.DB.Where("provider = ? AND user_id = ?", provider, u.ID)
|
||||
|
||||
var bindCount int64
|
||||
if err := query.Model(&model.Oauth2Bind{}).Count(&bindCount).Error; err != nil {
|
||||
return nil, newGormError("%v", err)
|
||||
}
|
||||
|
||||
if bindCount < 2 && u.RejectPassword {
|
||||
return nil, singleton.Localizer.ErrorT("operation not permitted")
|
||||
}
|
||||
|
||||
if err := query.Delete(&model.Oauth2Bind{}).Error; err != nil {
|
||||
return nil, newGormError("%v", err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// @Summary Oauth2 Callback
|
||||
// @Description Oauth2 Callback
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param state query string true "state"
|
||||
// @Param code query string true "code"
|
||||
// @Success 200 {object} model.LoginResponse
|
||||
// @Router /api/v1/oauth2/callback [get]
|
||||
func oauth2callback(jwtConfig *jwt.GinJWTMiddleware) func(c *gin.Context) (*model.LoginResponse, error) {
|
||||
return func(c *gin.Context) (*model.LoginResponse, error) {
|
||||
callbackData := &model.Oauth2Callback{
|
||||
State: c.Query("state"),
|
||||
Code: c.Query("code"),
|
||||
}
|
||||
|
||||
state, err := verifyState(c, callbackData.State)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o2confRaw, has := singleton.Conf.Oauth2[state.Provider]
|
||||
if !has {
|
||||
return nil, singleton.Localizer.ErrorT("provider not found")
|
||||
}
|
||||
|
||||
realip := c.GetString(model.CtxKeyRealIPStr)
|
||||
if callbackData.Code == "" {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
|
||||
return nil, singleton.Localizer.ErrorT("code is required")
|
||||
}
|
||||
|
||||
openId, err := exchangeOpenId(c, o2confRaw, callbackData)
|
||||
if err != nil {
|
||||
model.BlockIP(singleton.DB, realip, model.WAFBlockReasonTypeBruteForceOauth2, model.BlockIDToken)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bind model.Oauth2Bind
|
||||
switch state.Action {
|
||||
case model.RTypeBind:
|
||||
u, authorized := c.Get(model.CtxKeyAuthorizedUser)
|
||||
if !authorized {
|
||||
return nil, singleton.Localizer.ErrorT("unauthorized")
|
||||
}
|
||||
user := u.(*model.User)
|
||||
|
||||
result := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).Limit(1).Find(&bind)
|
||||
if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
|
||||
return nil, newGormError("%v", result.Error)
|
||||
}
|
||||
bind.UserID = user.ID
|
||||
bind.Provider = state.Provider
|
||||
bind.OpenID = openId
|
||||
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
result = singleton.DB.Create(&bind)
|
||||
} else {
|
||||
result = singleton.DB.Save(&bind)
|
||||
}
|
||||
if result.Error != nil {
|
||||
return nil, newGormError("%v", result.Error)
|
||||
}
|
||||
default:
|
||||
if err := singleton.DB.Where("provider = ? AND open_id = ?", strings.ToLower(state.Provider), openId).First(&bind).Error; err != nil {
|
||||
return nil, singleton.Localizer.ErrorT("oauth2 user not binded yet")
|
||||
}
|
||||
}
|
||||
|
||||
tokenString, expire, err := jwtConfig.TokenGenerator(fmt.Sprintf("%d", bind.UserID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
jwtConfig.SetCookie(c, tokenString)
|
||||
c.Redirect(http.StatusFound, utils.IfOr(state.Action == model.RTypeBind, "/dashboard/profile?oauth2=true", "/dashboard/login?oauth2=true"))
|
||||
|
||||
return &model.LoginResponse{Token: tokenString, Expire: expire.Format(time.RFC3339)}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func exchangeOpenId(c *gin.Context, o2confRaw *model.Oauth2Config, callbackData *model.Oauth2Callback) (string, error) {
|
||||
o2conf := o2confRaw.Setup(getRedirectURL(c))
|
||||
ctx := context.Background()
|
||||
|
||||
otk, err := o2conf.Exchange(ctx, callbackData.Code)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
oauth2client := o2conf.Client(ctx, otk)
|
||||
resp, err := oauth2client.Get(o2confRaw.UserInfoURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return gjson.GetBytes(body, o2confRaw.UserIDPath).String(), nil
|
||||
}
|
||||
|
||||
func verifyState(c *gin.Context, state string) (*model.Oauth2State, error) {
|
||||
// 验证登录跳转时的 State
|
||||
stateKey, err := c.Cookie("nz-o2s")
|
||||
if err != nil {
|
||||
return nil, singleton.Localizer.ErrorT("invalid state key")
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s%s", model.CacheKeyOauth2State, stateKey)
|
||||
istate, ok := singleton.Cache.Get(cacheKey)
|
||||
if !ok {
|
||||
return nil, singleton.Localizer.ErrorT("invalid state key")
|
||||
}
|
||||
|
||||
oauth2State, ok := istate.(*model.Oauth2State)
|
||||
if !ok || oauth2State.State != state {
|
||||
return nil, singleton.Localizer.ErrorT("invalid state key")
|
||||
}
|
||||
|
||||
return oauth2State, nil
|
||||
}
|
||||
|
||||
@@ -73,8 +73,18 @@ func updateProfile(c *gin.Context) (any, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var bindCount int64
|
||||
if err := singleton.DB.Where("user_id = ?", auth.(*model.User).ID).Count(&bindCount).Error; err != nil {
|
||||
return nil, newGormError("%v", err)
|
||||
}
|
||||
|
||||
if pf.RejectPassword && bindCount < 1 {
|
||||
return nil, singleton.Localizer.ErrorT("you don't have any oauth2 bindings")
|
||||
}
|
||||
|
||||
user.Username = pf.NewUsername
|
||||
user.Password = string(hash)
|
||||
user.RejectPassword = pf.RejectPassword
|
||||
if err := singleton.DB.Save(&user).Error; err != nil {
|
||||
return nil, newGormError("%v", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user