feat: binding ip with session

🛡️staying safe even your frontend was hacked
This commit is contained in:
naiba
2025-10-09 21:36:59 +08:00
parent 1db4fe4679
commit 540c1cb1e7
7 changed files with 109 additions and 17 deletions
+1 -1
View File
@@ -48,7 +48,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.24.x" go-version: "1.25.x"
- name: generate swagger docs - name: generate swagger docs
run: | run: |
+1 -1
View File
@@ -27,7 +27,7 @@ jobs:
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "1.24.x" go-version: "1.25.x"
- name: generate swagger docs - name: generate swagger docs
run: | run: |
+21 -8
View File
@@ -50,10 +50,8 @@ func initParams() *jwt.GinJWTMiddleware {
func payloadFunc() func(data any) jwt.MapClaims { func payloadFunc() func(data any) jwt.MapClaims {
return func(data any) jwt.MapClaims { return func(data any) jwt.MapClaims {
if v, ok := data.(string); ok { if v, ok := data.(map[string]interface{}); ok {
return jwt.MapClaims{ return v
model.CtxKeyAuthorizedUser: v,
}
} }
return jwt.MapClaims{} return jwt.MapClaims{}
} }
@@ -62,7 +60,15 @@ func payloadFunc() func(data any) jwt.MapClaims {
func identityHandler() func(c *gin.Context) any { func identityHandler() func(c *gin.Context) any {
return func(c *gin.Context) any { return func(c *gin.Context) any {
claims := jwt.ExtractClaims(c) claims := jwt.ExtractClaims(c)
userId := claims[model.CtxKeyAuthorizedUser].(string) userId := claims["user_id"].(string)
tokenIP := claims["ip"].(string)
currentIP := c.GetString(model.CtxKeyRealIPStr)
if tokenIP != currentIP {
// IP地址不匹配,token无效
return nil
}
var user model.User var user model.User
if err := singleton.DB.First(&user, userId).Error; err != nil { if err := singleton.DB.First(&user, userId).Error; err != nil {
return nil return nil
@@ -109,7 +115,12 @@ func authenticator() func(c *gin.Context) (any, error) {
model.UnblockIP(singleton.DB, realip, model.BlockIDUnknownUser) model.UnblockIP(singleton.DB, realip, model.BlockIDUnknownUser)
model.UnblockIP(singleton.DB, realip, int64(user.ID)) model.UnblockIP(singleton.DB, realip, int64(user.ID))
return utils.Itoa(user.ID), nil
// 返回用户ID和IP地址的组合,用于在payloadFunc中设置JWT claims
return map[string]interface{}{
"user_id": utils.Itoa(user.ID),
"ip": realip,
}, nil
} }
} }
@@ -174,14 +185,16 @@ func fallbackAuthMiddleware(mw *jwt.GinJWTMiddleware) func(c *gin.Context) {
return return
} }
realIP := c.GetString(model.CtxKeyRealIPStr)
c.Set("JWT_PAYLOAD", claims) c.Set("JWT_PAYLOAD", claims)
identity := mw.IdentityHandler(c) identity := mw.IdentityHandler(c)
if identity != nil { if identity != nil {
model.UnblockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.BlockIDToken) model.UnblockIP(singleton.DB, realIP, model.BlockIDToken)
c.Set(mw.IdentityKey, identity) c.Set(mw.IdentityKey, identity)
} else { } else {
if err := model.BlockIP(singleton.DB, c.GetString(model.CtxKeyRealIPStr), model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil { if err := model.BlockIP(singleton.DB, realIP, model.WAFBlockReasonTypeBruteForceToken, model.BlockIDToken); err != nil {
waf.ShowBlockPage(c, err) waf.ShowBlockPage(c, err)
return return
} }
+79
View File
@@ -0,0 +1,79 @@
package controller
import (
"testing"
"time"
jwt "github.com/appleboy/gin-jwt/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)
func TestPayloadFunc(t *testing.T) {
payloadFn := payloadFunc()
// 测试包含IP的格式
t.Run("format with IP", func(t *testing.T) {
data := map[string]interface{}{
"user_id": "123",
"ip": "192.168.1.1",
}
claims := payloadFn(data)
assert.Equal(t, "123", claims["user_id"])
assert.Equal(t, "192.168.1.1", claims["ip"])
})
// 测试不包含IP的格式
t.Run("format without IP", func(t *testing.T) {
data := map[string]interface{}{
"user_id": "123",
}
claims := payloadFn(data)
assert.Equal(t, "123", claims["user_id"])
assert.Nil(t, claims["ip"])
})
// 测试无效数据格式
t.Run("invalid data format", func(t *testing.T) {
claims := payloadFn("123") // 字符串类型不再支持
assert.Empty(t, claims)
})
// 测试空的map
t.Run("empty map", func(t *testing.T) {
data := map[string]interface{}{}
claims := payloadFn(data)
assert.Empty(t, claims)
})
}
func TestIPBinding(t *testing.T) {
// 创建测试用的gin context
gin.SetMode(gin.TestMode)
t.Run("IP mismatch should invalidate token", func(t *testing.T) {
// 模拟JWT claims包含IP绑定
claims := jwt.MapClaims{
"user_id": "123",
"ip": "192.168.1.1",
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// 这里需要实际的数据库和用户设置来完全测试
// 但可以测试claims的基本结构
assert.Equal(t, "123", claims["user_id"])
assert.Equal(t, "192.168.1.1", claims["ip"])
})
t.Run("no IP in token should deny access", func(t *testing.T) {
// 没有IP绑定的token应该被拒绝
claims := jwt.MapClaims{
"user_id": "123",
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// 验证token结构
assert.Equal(t, "123", claims["user_id"])
assert.Nil(t, claims["ip"])
})
}
+3 -2
View File
@@ -1,6 +1,6 @@
module github.com/nezhahq/nezha module github.com/nezhahq/nezha
go 1.24.0 go 1.25
require ( require (
github.com/appleboy/gin-jwt/v2 v2.10.3 github.com/appleboy/gin-jwt/v2 v2.10.3
@@ -26,8 +26,10 @@ require (
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/robfig/cron/v3 v3.0.1 github.com/robfig/cron/v3 v3.0.1
github.com/stretchr/testify v1.10.0
github.com/swaggo/files v1.0.1 github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0 github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
golang.org/x/crypto v0.37.0 golang.org/x/crypto v0.37.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
@@ -74,7 +76,6 @@ require (
github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/swaggo/swag v1.16.4 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
+4
View File
@@ -114,6 +114,10 @@ func (provider *Provider) splitDomainSOA(ctx context.Context, domain string) (pr
if soa, ok := r.Answer[0].(*dns.SOA); ok { if soa, ok := r.Answer[0].(*dns.SOA); ok {
zone := soa.Hdr.Name zone := soa.Hdr.Name
prefix := libdns.RelativeName(domain, zone) prefix := libdns.RelativeName(domain, zone)
// Convert "@" to empty string for zone apex
if prefix == "@" {
prefix = ""
}
return prefix, zone, nil return prefix, zone, nil
} }
} }
-5
View File
@@ -2,7 +2,6 @@ package ddns
import ( import (
"context" "context"
"os"
"testing" "testing"
) )
@@ -13,10 +12,6 @@ type testSt struct {
} }
func TestSplitDomainSOA(t *testing.T) { func TestSplitDomainSOA(t *testing.T) {
if ci := os.Getenv("CI"); ci != "" { // skip if test on CI
return
}
cases := []testSt{ cases := []testSt{
{ {
domain: "www.example.co.uk", domain: "www.example.co.uk",