package server import ( "net/http" "sync" "time" "golang.org/x/time/rate" ) // RateLimiter provides per-IP rate limiting type RateLimiter struct { limiters map[string]*rate.Limiter mu sync.RWMutex rate rate.Limit burst int } // NewRateLimiter creates a new rate limiter // rate is requests per second, burst is max burst size func NewRateLimiter(r float64, burst int) *RateLimiter { return &RateLimiter{ limiters: make(map[string]*rate.Limiter), rate: rate.Limit(r), burst: burst, } } // getLimiter returns the rate limiter for a given IP, creating one if needed func (rl *RateLimiter) getLimiter(ip string) *rate.Limiter { rl.mu.RLock() limiter, exists := rl.limiters[ip] rl.mu.RUnlock() if exists { return limiter } rl.mu.Lock() defer rl.mu.Unlock() // Double-check after acquiring write lock if limiter, exists = rl.limiters[ip]; exists { return limiter } limiter = rate.NewLimiter(rl.rate, rl.burst) rl.limiters[ip] = limiter return limiter } // Middleware returns a middleware handler for rate limiting func (rl *RateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Get client IP (chi's RealIP middleware should have set this) ip := r.RemoteAddr limiter := rl.getLimiter(ip) if !limiter.Allow() { w.Header().Set("Retry-After", "1") http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } // CleanupOldEntries removes stale IP entries periodically // Call this in a goroutine to prevent memory growth func (rl *RateLimiter) CleanupOldEntries(interval time.Duration, maxAge time.Duration) { ticker := time.NewTicker(interval) defer ticker.Stop() for range ticker.C { rl.mu.Lock() // Simple cleanup: just reset the map periodically // In a more sophisticated implementation, you'd track last access time if len(rl.limiters) > 10000 { rl.limiters = make(map[string]*rate.Limiter) } rl.mu.Unlock() } }