init
This commit is contained in:
84
backend/internal/server/ratelimit.go
Normal file
84
backend/internal/server/ratelimit.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter provides per-IP rate limiting
|
||||
type RateLimiter struct {
|
||||
limiters map[string]*rate.Limiter
|
||||
mu sync.RWMutex
|
||||
rate rate.Limit
|
||||
burst int
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
// rate is requests per second, burst is max burst size
|
||||
func NewRateLimiter(r float64, burst int) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
rate: rate.Limit(r),
|
||||
burst: burst,
|
||||
}
|
||||
}
|
||||
|
||||
// getLimiter returns the rate limiter for a given IP, creating one if needed
|
||||
func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if limiter, exists = rl.limiters[ip]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.limiters[ip] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Middleware returns a middleware handler for rate limiting
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get client IP (chi's RealIP middleware should have set this)
|
||||
ip := r.RemoteAddr
|
||||
|
||||
limiter := rl.getLimiter(ip)
|
||||
if !limiter.Allow() {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CleanupOldEntries removes stale IP entries periodically
|
||||
// Call this in a goroutine to prevent memory growth
|
||||
func (rl *RateLimiter) CleanupOldEntries(interval time.Duration, maxAge time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
// Simple cleanup: just reset the map periodically
|
||||
// In a more sophisticated implementation, you'd track last access time
|
||||
if len(rl.limiters) > 10000 {
|
||||
rl.limiters = make(map[string]*rate.Limiter)
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user