Files
paragliding/backend/internal/server/ratelimit.go
2026-01-03 14:16:16 -08:00

85 lines
2.0 KiB
Go

package server
import (
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiter provides per-IP rate limiting
type RateLimiter struct {
limiters map[string]*rate.Limiter
mu sync.RWMutex
rate rate.Limit
burst int
}
// NewRateLimiter creates a new rate limiter
// rate is requests per second, burst is max burst size
func NewRateLimiter(r float64, burst int) *RateLimiter {
return &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(r),
burst: burst,
}
}
// getLimiter returns the rate limiter for a given IP, creating one if needed
func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[ip]
rl.mu.RUnlock()
if exists {
return limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
// Double-check after acquiring write lock
if limiter, exists = rl.limiters[ip]; exists {
return limiter
}
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[ip] = limiter
return limiter
}
// Middleware returns a middleware handler for rate limiting
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get client IP (chi's RealIP middleware should have set this)
ip := r.RemoteAddr
limiter := rl.getLimiter(ip)
if !limiter.Allow() {
w.Header().Set("Retry-After", "1")
http.Error(w, "rate limit exceeded", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// CleanupOldEntries removes stale IP entries periodically
// Call this in a goroutine to prevent memory growth
func (rl *RateLimiter) CleanupOldEntries(interval time.Duration, maxAge time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// Simple cleanup: just reset the map periodically
// In a more sophisticated implementation, you'd track last access time
if len(rl.limiters) > 10000 {
rl.limiters = make(map[string]*rate.Limiter)
}
rl.mu.Unlock()
}
}