From 49af3e709936e2ce6266039fcebb4d5326e05bde Mon Sep 17 00:00:00 2001 From: shuaiplus <2327005759@qq.com> Date: Thu, 5 Mar 2026 02:26:05 +0800 Subject: [PATCH] feat: enhance rate limiting with new public request budgets and client IP validation --- src/config/limits.ts | 12 +++ src/handlers/accounts.ts | 6 +- src/handlers/identity.ts | 24 +++++- src/handlers/sends.ts | 137 ++++++++++++++++++++++++++++++---- src/router.ts | 52 +++++++++---- src/services/ratelimit.ts | 150 ++++++++++++++++++++++++++++++++++++-- 6 files changed, 345 insertions(+), 36 deletions(-) diff --git a/src/config/limits.ts b/src/config/limits.ts index 58c78b4..01a50c2 100644 --- a/src/config/limits.ts +++ b/src/config/limits.ts @@ -38,6 +38,18 @@ // Public (unauthenticated) request budget per IP per minute. // 公开(未认证)接口每 IP 每分钟请求配额。 publicRequestsPerMinute: 60, + // Public read-only request budget per IP per minute. + // 公开只读接口每 IP 每分钟请求配额。 + publicReadRequestsPerMinute: 120, + // Sensitive public/auth request budget per IP per minute. + // 敏感公开/认证接口每 IP 每分钟请求配额。 + sensitivePublicRequestsPerMinute: 30, + // Register endpoint budget per IP per minute. + // 注册接口每 IP 每分钟请求配额。 + registerRequestsPerMinute: 5, + // Refresh-token grant budget per IP per minute. + // refresh_token 授权每 IP 每分钟请求配额。 + refreshTokenRequestsPerMinute: 30, // Fixed window size for API rate limiting in seconds. // API 限流固定窗口大小(秒)。 apiWindowSeconds: 60, diff --git a/src/handlers/accounts.ts b/src/handlers/accounts.ts index 604f63d..1cfaf8c 100644 --- a/src/handlers/accounts.ts +++ b/src/handlers/accounts.ts @@ -527,7 +527,11 @@ export async function handleRecoverTwoFactor(request: Request, env: Env): Promis const email = String(body.email || body.username || '').trim().toLowerCase(); const masterPasswordHash = String(body.masterPasswordHash || body.password || '').trim(); const recoveryCode = normalizeRecoveryCodeInput(String(body.recoveryCode || body.twoFactorToken || body.recovery_code || '')); - const recoverLimitKey = `${getClientIdentifier(request)}:recover-2fa:${email || 'unknown'}`; + const clientIdentifier = getClientIdentifier(request); + if (!clientIdentifier) { + return errorResponse('Client IP is required', 403); + } + const recoverLimitKey = `${clientIdentifier}:recover-2fa:${email || 'unknown'}`; const recoverAttemptCheck = await rateLimit.checkLoginAttempt(recoverLimitKey); if (!recoverAttemptCheck.allowed) { diff --git a/src/handlers/identity.ts b/src/handlers/identity.ts index a704b8f..907dbe2 100644 --- a/src/handlers/identity.ts +++ b/src/handlers/identity.ts @@ -107,6 +107,9 @@ export async function handleToken(request: Request, env: Env): Promise const grantType = body.grant_type; const clientIdentifier = getClientIdentifier(request); + if (!clientIdentifier) { + return identityErrorResponse('Client IP is required', 'invalid_request', 403); + } if (grantType === 'password') { // Login with password @@ -297,7 +300,14 @@ export async function handleToken(request: Request, env: Env): Promise ).trim() || null; const password = String(body.password || '').trim() || null; - const result = await issueSendAccessToken(env, sendId, passwordHashB64, password); + const result = await issueSendAccessToken( + env, + sendId, + passwordHashB64, + password, + rateLimit, + `${clientIdentifier}:send-password` + ); if ('error' in result) { return result.error; } @@ -310,6 +320,18 @@ export async function handleToken(request: Request, env: Env): Promise unofficialServer: true, }); } else if (grantType === 'refresh_token') { + const refreshLimit = await rateLimit.consumeBudget( + `${clientIdentifier}:identity-refresh`, + LIMITS.rateLimit.refreshTokenRequestsPerMinute + ); + if (!refreshLimit.allowed) { + return identityErrorResponse( + `Rate limit exceeded. Try again in ${refreshLimit.retryAfterSeconds} seconds.`, + 'TooManyRequests', + 429 + ); + } + // Refresh token const refreshToken = body.refresh_token; if (!refreshToken) { diff --git a/src/handlers/sends.ts b/src/handlers/sends.ts index 582114c..34f837f 100644 --- a/src/handlers/sends.ts +++ b/src/handlers/sends.ts @@ -1,5 +1,6 @@ import { Env, Send, SendAuthType, SendResponse, SendType, DEFAULT_DEV_SECRET } from '../types'; import { StorageService } from '../services/storage'; +import { RateLimitService, getClientIdentifier } from '../services/ratelimit'; import { jsonResponse, errorResponse } from '../utils/response'; import { generateUUID } from '../utils/uuid'; import { parsePagination, encodeContinuationToken } from '../utils/pagination'; @@ -13,6 +14,7 @@ import { const SEND_INACCESSIBLE_MSG = 'Send does not exist or is no longer available'; const SEND_PASSWORD_ITERATIONS = 100_000; +const SEND_PASSWORD_LIMIT_SCOPE = 'send-password'; function getAliasedProp(source: unknown, aliases: string[]): { present: boolean; value: unknown } { if (!source || typeof source !== 'object') return { present: false, value: undefined }; @@ -383,12 +385,44 @@ async function getCreatorIdentifier(storage: StorageService, send: Send): Promis return owner?.email ?? null; } -async function validatePublicSendAccess(send: Send, body: unknown): Promise { +type PublicSendAccessValidationResult = + | { ok: true } + | { ok: false; response: Response; reason: 'email_auth_unsupported' | 'password_missing' | 'invalid_password' }; + +function sendPasswordLimitKey(clientIdentifier: string): string { + return `${clientIdentifier}:${SEND_PASSWORD_LIMIT_SCOPE}`; +} + +function sendPasswordLockMessage(retryAfterSeconds: number): string { + return `Too many failed send password attempts. Try again in ${Math.ceil(retryAfterSeconds / 60)} minutes.`; +} + +function sendPasswordLockedErrorResponse(retryAfterSeconds: number): Response { + return errorResponse(sendPasswordLockMessage(retryAfterSeconds), 429); +} + +function sendPasswordLockedOAuthResponse(retryAfterSeconds: number): Response { + const message = sendPasswordLockMessage(retryAfterSeconds); + return jsonResponse( + { + error: 'invalid_grant', + error_description: message, + send_access_error_type: 'too_many_password_attempts', + ErrorModel: { + Message: message, + Object: 'error', + }, + }, + 429 + ); +} + +async function validatePublicSendAccess(send: Send, body: unknown): Promise { if (hasEmailAuth(send)) { - return errorResponse(SEND_INACCESSIBLE_MSG, 404); + return { ok: false, response: errorResponse(SEND_INACCESSIBLE_MSG, 404), reason: 'email_auth_unsupported' }; } - if (!send.passwordHash) return null; + if (!send.passwordHash) return { ok: true }; const passwordRaw = getAliasedProp(body, ['password', 'Password']); const passwordHashB64Raw = getAliasedProp(body, [ @@ -401,7 +435,7 @@ async function validatePublicSendAccess(send: Send, body: unknown): Promise { const jwt = getSafeJwtSecret(env); if (!jwt.ok) { @@ -1267,6 +1355,15 @@ export async function issueSendAccessToken( } if (send.passwordHash) { + if (rateLimit && sendPasswordLimitIpKey) { + const sendPasswordCheck = await rateLimit.checkLoginAttempt(sendPasswordLimitIpKey); + if (!sendPasswordCheck.allowed) { + return { + error: sendPasswordLockedOAuthResponse(sendPasswordCheck.retryAfterSeconds || 60), + }; + } + } + let ok = false; if (passwordHashB64) { ok = verifySendPasswordHashB64(send, passwordHashB64); @@ -1275,6 +1372,14 @@ export async function issueSendAccessToken( } if (!ok) { + if (rateLimit && sendPasswordLimitIpKey) { + const failed = await rateLimit.recordFailedLogin(sendPasswordLimitIpKey); + if (failed.locked) { + return { + error: sendPasswordLockedOAuthResponse(failed.retryAfterSeconds || 60), + }; + } + } return { error: jsonResponse( { @@ -1290,6 +1395,10 @@ export async function issueSendAccessToken( ), }; } + + if (rateLimit && sendPasswordLimitIpKey) { + await rateLimit.clearLoginAttempts(sendPasswordLimitIpKey); + } } const token = await createSendAccessToken(send.id, jwt.secret); diff --git a/src/router.ts b/src/router.ts index 87cf327..14a4df7 100644 --- a/src/router.ts +++ b/src/router.ts @@ -215,9 +215,23 @@ export async function handleRequest(request: Request, env: Env): Promise { + async function enforcePublicRateLimit( + category: string = 'public', + maxRequests: number = LIMITS.rateLimit.publicRequestsPerMinute + ): Promise { + if (!clientId) { + return new Response(JSON.stringify({ + error: 'Forbidden', + error_description: 'Client IP is required', + }), { + status: 403, + headers: { + 'Content-Type': 'application/json', + }, + }); + } const rateLimit = new RateLimitService(env.DB); - const check = await rateLimit.consumeBudget(`${clientId}:public`, LIMITS.rateLimit.publicRequestsPerMinute); + const check = await rateLimit.consumeBudget(`${clientId}:${category}`, maxRequests); if (check.allowed) return null; return new Response(JSON.stringify({ error: 'Too many requests', @@ -254,11 +268,15 @@ export async function handleRequest(request: Request, env: Env): Promise 255) return null; + octets.push(value); + } + return octets; +} + +function parseIpv6Hextets(input: string): number[] | null { + let value = input.trim().toLowerCase(); + if (!value) return null; + + if (value.startsWith('[') && value.endsWith(']')) { + value = value.slice(1, -1); + } + const zoneIndex = value.indexOf('%'); + if (zoneIndex >= 0) { + value = value.slice(0, zoneIndex); + } + if (!value.includes(':')) return null; + + // Handle IPv4-mapped tail (e.g. ::ffff:192.0.2.1). + if (value.includes('.')) { + const lastColon = value.lastIndexOf(':'); + if (lastColon < 0) return null; + const ipv4Tail = value.slice(lastColon + 1); + const octets = parseIpv4Octets(ipv4Tail); + if (!octets) return null; + const high = ((octets[0] << 8) | octets[1]).toString(16); + const low = ((octets[2] << 8) | octets[3]).toString(16); + value = `${value.slice(0, lastColon)}:${high}:${low}`; + } + + const doubleColon = value.indexOf('::'); + if (doubleColon !== value.lastIndexOf('::')) return null; + + const parsePart = (part: string): number | null => { + if (!/^[0-9a-f]{1,4}$/.test(part)) return null; + const n = parseInt(part, 16); + return Number.isNaN(n) ? null : n; + }; + + const parseParts = (parts: string[]): number[] | null => { + const out: number[] = []; + for (const p of parts) { + if (!p) return null; + const n = parsePart(p); + if (n === null) return null; + out.push(n); + } + return out; + }; + + if (doubleColon >= 0) { + const [headRaw, tailRaw] = value.split('::'); + const head = headRaw ? headRaw.split(':') : []; + const tail = tailRaw ? tailRaw.split(':') : []; + + const headNums = parseParts(head); + const tailNums = parseParts(tail); + if (!headNums || !tailNums) return null; + + const missing = 8 - (headNums.length + tailNums.length); + if (missing < 1) return null; + + return [...headNums, ...new Array(missing).fill(0), ...tailNums]; + } + + const all = parseParts(value.split(':')); + if (!all || all.length !== 8) return null; + return all; +} + +function normalizeClientIpForRateLimit(rawIp: string): string | null { + const input = rawIp.trim(); + if (!input) return null; + + const ipv4 = parseIpv4Octets(input); + if (ipv4) { + return `ip4:${ipv4.join('.')}`; + } + + const ipv6 = parseIpv6Hextets(input); + if (!ipv6) return null; + + // Handle IPv4-mapped / IPv4-compatible IPv6 as IPv4 identity. + // Examples: ::ffff:192.0.2.1, ::192.0.2.1 + if ( + ipv6[0] === 0 && + ipv6[1] === 0 && + ipv6[2] === 0 && + ipv6[3] === 0 && + ipv6[4] === 0 && + (ipv6[5] === 0xffff || ipv6[5] === 0) + ) { + const octets = [ipv6[6] >> 8, ipv6[6] & 0xff, ipv6[7] >> 8, ipv6[7] & 0xff]; + return `ip4:${octets.join('.')}`; + } + + // Collapse to /64 to reduce brute-force bypass via IPv6 address rotation. + const prefix64 = ipv6 + .slice(0, 4) + .map(part => part.toString(16).padStart(4, '0')) + .join(':'); + return `ip6:${prefix64}`; +} + +export function getClientIdentifier(request: Request): string | null { + const cfIp = request.headers.get('CF-Connecting-IP'); + if (cfIp) { + return normalizeClientIpForRateLimit(cfIp); + } + + // Local development fallback: + // wrangler dev may not provide CF-Connecting-IP. Allow localhost requests + // to resolve an identifier from X-Forwarded-For or loopback. + try { + const hostname = new URL(request.url).hostname.toLowerCase(); + const isLocalHost = + hostname === 'localhost' || + hostname === '127.0.0.1' || + hostname === '::1' || + hostname === '[::1]'; + if (!isLocalHost) { + return null; + } + + const forwardedFor = request.headers.get('X-Forwarded-For'); + if (forwardedFor) { + const first = forwardedFor.split(',')[0].trim(); + const normalized = normalizeClientIpForRateLimit(first); + if (normalized) return normalized; + } + + return 'ip4:127.0.0.1'; + } catch { + return null; + } }