From 540c1cb1e7ee9ff12fadc5133a238d99ae1b149d Mon Sep 17 00:00:00 2001 From: naiba Date: Thu, 9 Oct 2025 21:36:59 +0800 Subject: [PATCH] feat: binding ip with session MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🛡️staying safe even your frontend was hacked --- .github/workflows/release.yml | 2 +- .github/workflows/test.yml | 2 +- cmd/dashboard/controller/jwt.go | 29 +++++++--- cmd/dashboard/controller/jwt_test.go | 79 ++++++++++++++++++++++++++++ go.mod | 5 +- pkg/ddns/ddns.go | 4 ++ pkg/ddns/ddns_test.go | 5 -- 7 files changed, 109 insertions(+), 17 deletions(-) create mode 100644 cmd/dashboard/controller/jwt_test.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 84639a8..d206721 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -48,7 +48,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.24.x" + go-version: "1.25.x" - name: generate swagger docs run: | diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f6e0989..2b5a6f2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,7 +27,7 @@ jobs: - uses: actions/setup-go@v5 with: - go-version: "1.24.x" + go-version: "1.25.x" - name: generate swagger docs run: | diff --git a/cmd/dashboard/controller/jwt.go b/cmd/dashboard/controller/jwt.go index 476b594..9054c40 100644 --- a/cmd/dashboard/controller/jwt.go +++ b/cmd/dashboard/controller/jwt.go @@ -50,10 +50,8 @@ func initParams() *jwt.GinJWTMiddleware { func payloadFunc() func(data any) jwt.MapClaims { return func(data any) jwt.MapClaims { - if v, ok := data.(string); ok { - return jwt.MapClaims{ - model.CtxKeyAuthorizedUser: v, - } + if v, ok := data.(map[string]interface{}); ok { + return v } return jwt.MapClaims{} } @@ -62,7 +60,15 @@ func payloadFunc() func(data any) jwt.MapClaims { func identityHandler() func(c *gin.Context) any { return func(c *gin.Context) any { claims := jwt.ExtractClaims(c) - userId := claims[model.CtxKeyAuthorizedUser].(string) + userId := claims["user_id"].(string) + tokenIP := claims["ip"].(string) + currentIP := c.GetString(model.CtxKeyRealIPStr) + + if tokenIP != currentIP { + // IP地址不匹配,token无效 + return nil + } + var user model.User if err := singleton.DB.First(&user, userId).Error; err != nil { return nil @@ -109,7 +115,12 @@ func authenticator() func(c *gin.Context) (any, error) { model.UnblockIP(singleton.DB, realip, model.BlockIDUnknownUser) model.UnblockIP(singleton.DB, realip, int64(user.ID)) - return utils.Itoa(user.ID), nil + + // 返回用户ID和IP地址的组合,用于在payloadFunc中设置JWT claims + return map[string]interface{}{ + "user_id": utils.Itoa(user.ID), + "ip": realip, + }, nil } } @@ -174,14 +185,16 @@ func fallbackAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) { return } + realIP := c.GetString(model.CtxKeyRealIPStr) + c.Set("JWT_PAYLOAD", claims) identity := mw.IdentityHandler(c) if identity != nil { - model.UnblockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.BlockIDToken) + model.UnblockIP(singleton.DB, realIP, model.BlockIDToken) c.Set(mw.IdentityKey, identity) } else { - if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil { + if err := model.BlockIP(singleton.DB, realIP, model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil { waf.ShowBlockPage(c, err) return } diff --git a/cmd/dashboard/controller/jwt_test.go b/cmd/dashboard/controller/jwt_test.go new file mode 100644 index 0000000..f2b58ea --- /dev/null +++ b/cmd/dashboard/controller/jwt_test.go @@ -0,0 +1,79 @@ +package controller + +import ( + "testing" + "time" + + jwt "github.com/appleboy/gin-jwt/v2" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func TestPayloadFunc(t *testing.T) { + payloadFn := payloadFunc() + + // 测试包含IP的格式 + t.Run("format with IP", func(t *testing.T) { + data := map[string]interface{}{ + "user_id": "123", + "ip": "192.168.1.1", + } + claims := payloadFn(data) + assert.Equal(t, "123", claims["user_id"]) + assert.Equal(t, "192.168.1.1", claims["ip"]) + }) + + // 测试不包含IP的格式 + t.Run("format without IP", func(t *testing.T) { + data := map[string]interface{}{ + "user_id": "123", + } + claims := payloadFn(data) + assert.Equal(t, "123", claims["user_id"]) + assert.Nil(t, claims["ip"]) + }) + + // 测试无效数据格式 + t.Run("invalid data format", func(t *testing.T) { + claims := payloadFn("123") // 字符串类型不再支持 + assert.Empty(t, claims) + }) + + // 测试空的map + t.Run("empty map", func(t *testing.T) { + data := map[string]interface{}{} + claims := payloadFn(data) + assert.Empty(t, claims) + }) +} + +func TestIPBinding(t *testing.T) { + // 创建测试用的gin context + gin.SetMode(gin.TestMode) + + t.Run("IP mismatch should invalidate token", func(t *testing.T) { + // 模拟JWT claims包含IP绑定 + claims := jwt.MapClaims{ + "user_id": "123", + "ip": "192.168.1.1", + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + + // 这里需要实际的数据库和用户设置来完全测试 + // 但可以测试claims的基本结构 + assert.Equal(t, "123", claims["user_id"]) + assert.Equal(t, "192.168.1.1", claims["ip"]) + }) + + t.Run("no IP in token should deny access", func(t *testing.T) { + // 没有IP绑定的token应该被拒绝 + claims := jwt.MapClaims{ + "user_id": "123", + "exp": float64(time.Now().Add(time.Hour).Unix()), + } + + // 验证token结构 + assert.Equal(t, "123", claims["user_id"]) + assert.Nil(t, claims["ip"]) + }) +} diff --git a/go.mod b/go.mod index 60041b6..48d186f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/nezhahq/nezha -go 1.24.0 +go 1.25 require ( github.com/appleboy/gin-jwt/v2 v2.10.3 @@ -26,8 +26,10 @@ require ( github.com/oschwald/maxminddb-golang v1.13.1 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/robfig/cron/v3 v3.0.1 + github.com/stretchr/testify v1.10.0 github.com/swaggo/files v1.0.1 github.com/swaggo/gin-swagger v1.6.0 + github.com/swaggo/swag v1.16.4 github.com/tidwall/gjson v1.18.0 golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 @@ -74,7 +76,6 @@ require ( github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/swaggo/swag v1.16.4 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect diff --git a/pkg/ddns/ddns.go b/pkg/ddns/ddns.go index 99b1d4d..1dc5c2e 100644 --- a/pkg/ddns/ddns.go +++ b/pkg/ddns/ddns.go @@ -114,6 +114,10 @@ func (provider *Provider) splitDomainSOA(ctx context.Context, domain string) (pr if soa, ok := r.Answer[0].(*dns.SOA); ok { zone := soa.Hdr.Name prefix := libdns.RelativeName(domain, zone) + // Convert "@" to empty string for zone apex + if prefix == "@" { + prefix = "" + } return prefix, zone, nil } } diff --git a/pkg/ddns/ddns_test.go b/pkg/ddns/ddns_test.go index c802d2b..70d7cb9 100644 --- a/pkg/ddns/ddns_test.go +++ b/pkg/ddns/ddns_test.go @@ -2,7 +2,6 @@ package ddns import ( "context" - "os" "testing" ) @@ -13,10 +12,6 @@ type testSt struct { } func TestSplitDomainSOA(t *testing.T) { - if ci := os.Getenv("CI"); ci != "" { // skip if test on CI - return - } - cases := []testSt{ { domain: "www.example.co.uk",