feat: waf 🤡

This commit is contained in:
naiba
2024-11-22 23:57:25 +08:00
parent d699d0ee87
commit 17b02640a9
9 changed files with 214 additions and 35 deletions

View File

@@ -0,0 +1,90 @@
package waf
import (
_ "embed"
"errors"
"log"
"math/big"
"net/http"
"net/netip"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/service/singleton"
"gorm.io/gorm"
)
//go:embed waf.html
var errorPageTemplate string
func RealIp(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" {
c.Next()
return
}
if singleton.Conf.RealIPHeader == model.ConfigUsePeerIP {
c.Set(model.CtxKeyRealIPStr, c.RemoteIP())
c.Next()
return
}
vals := c.Request.Header.Get(singleton.Conf.RealIPHeader)
if vals == "" {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: "real ip header not found"})
return
}
ip, err := netip.ParseAddr(vals)
if err != nil {
c.AbortWithStatusJSON(http.StatusOK, model.CommonResponse[any]{Success: false, Error: err.Error()})
return
}
c.Set(model.CtxKeyRealIPStr, ip.String())
c.Next()
}
func Waf(c *gin.Context) {
if singleton.Conf.RealIPHeader == "" {
c.Next()
return
}
realipAddr := c.GetString(model.CtxKeyRealIPStr)
if realipAddr == "" {
c.Next()
return
}
var w model.WAF
if err := singleton.DB.First(&w, "ip = ?", realipAddr).Error; err != nil {
if err != gorm.ErrRecordNotFound {
ShowBlockPage(c, err)
return
}
}
now := time.Now().Unix()
if w.LastBlockTimestamp+pow(w.Count, 4) > uint64(now) {
log.Println(w.Count, w.LastBlockTimestamp+pow(w.Count, 4)-uint64(now))
ShowBlockPage(c, errors.New("you are blocked by nezha WAF"))
return
}
c.Next()
}
func pow(x, y uint64) uint64 {
base := big.NewInt(0).SetUint64(x)
exp := big.NewInt(0).SetUint64(y)
result := big.NewInt(1)
result.Exp(base, exp, nil)
if !result.IsUint64() {
return ^uint64(0) // return max uint64 value on overflow
}
return result.Uint64()
}
func ShowBlockPage(c *gin.Context, err error) {
c.Writer.WriteHeader(http.StatusForbidden)
c.Header("Content-Type", "text/html; charset=utf-8")
c.Writer.WriteString(strings.Replace(errorPageTemplate, "{error}", err.Error(), 1))
c.Abort()
}

View File

@@ -0,0 +1,39 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Blocked</title>
<style>
body {
display: flex;
justify-content: center;
align-items: center;
height: 90vh;
font-weight: bolder;
font-family: 'Courier New', Courier, monospace;
}
main {
text-align: center;
}
.emoji {
font-size: 200px;
}
p.secondary {
font-size: 12px;
color: #888;
}
</style>
</head>
<body>
<main>
<div class="emoji">🤡</div>
<h1>Blocked</h1>
<p>{error}</p>
<p class="secondary">nezha WAF</p>
</main>
</body>
</html>

View File

@@ -0,0 +1,29 @@
package waf
import (
"math"
"testing"
)
func TestPow(t *testing.T) {
tests := []struct {
x,
y,
expect uint64
}{
{2, 64, math.MaxUint64}, // 2 的 64 次方,超过 uint64 最大值
{uint64(1 << 63), 2, math.MaxUint64}, // 大数平方,可能溢出
{uint64(^uint64(0)), 2, math.MaxUint64}, // uint64 最大值的平方,溢出
{2, 3, 8},
{5, 0, 1},
{3, 1, 3},
{0, 5, 0},
}
for _, tt := range tests {
result := pow(tt.x, tt.y)
if result != tt.expect {
t.Errorf("pow(%d, %d) = %d; expect %d", tt.x, tt.y, result, tt.expect)
}
}
}