Skip to main content

lychee_lib/ratelimit/
config.rs

1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::Iter;
5use std::time::Duration;
6
7use crate::ratelimit::HostKey;
8
9/// Default number of concurrent requests per host
10const DEFAULT_CONCURRENCY: usize = 10;
11
12/// Default interval between requests to the same host
13const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
14
15/// Global rate limiting configuration that applies as defaults to all hosts
16#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
17pub struct RateLimitConfig {
18    /// Default maximum concurrent requests per host
19    #[serde(default = "default_concurrency")]
20    pub concurrency: usize,
21
22    /// Default minimum interval between requests to the same host
23    #[serde(default = "default_request_interval", with = "humantime_serde")]
24    pub request_interval: Duration,
25}
26
27impl Default for RateLimitConfig {
28    fn default() -> Self {
29        Self {
30            concurrency: default_concurrency(),
31            request_interval: default_request_interval(),
32        }
33    }
34}
35
36/// Default number of concurrent requests per host
37const fn default_concurrency() -> usize {
38    DEFAULT_CONCURRENCY
39}
40
41/// Default interval between requests to the same host
42const fn default_request_interval() -> Duration {
43    DEFAULT_REQUEST_INTERVAL
44}
45
46impl RateLimitConfig {
47    /// Create a `RateLimitConfig` from CLI options, using defaults for missing values
48    #[must_use]
49    pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
50        Self {
51            concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
52            request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
53        }
54    }
55}
56
57/// Per-host configuration overrides
58#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
59pub struct HostConfigs(HashMap<HostKey, HostConfig>);
60
61impl HostConfigs {
62    /// Get a reference to the [`HostConfig`] associated to the [`HostKey`]
63    pub(crate) fn get(&self, key: &HostKey) -> Option<&HostConfig> {
64        self.0.get(key)
65    }
66
67    /// Get the number of [`HostConfig`]s
68    #[cfg(test)]
69    pub(crate) fn len(&self) -> usize {
70        self.0.len()
71    }
72
73    /// Get the iterator over all elements
74    pub(crate) fn iter(&self) -> Iter<'_, HostKey, HostConfig> {
75        self.0.iter()
76    }
77
78    /// Merge `self` with another `HostConfigs`
79    #[must_use]
80    pub fn merge(mut self, other: HostConfigs) -> HostConfigs {
81        for (key, value) in other.0 {
82            let value = if let Some(s) = self.0.remove(&key) {
83                s.merge(value)
84            } else {
85                value
86            };
87
88            self.0.insert(key, value);
89        }
90
91        self
92    }
93}
94
95impl<'a> IntoIterator for &'a HostConfigs {
96    type Item = (&'a HostKey, &'a HostConfig);
97    type IntoIter = Iter<'a, HostKey, HostConfig>;
98    fn into_iter(self) -> Self::IntoIter {
99        self.0.iter()
100    }
101}
102
103impl<const N: usize> From<[(HostKey, HostConfig); N]> for HostConfigs {
104    fn from(arr: [(HostKey, HostConfig); N]) -> Self {
105        HostConfigs(HashMap::<HostKey, HostConfig>::from_iter(arr))
106    }
107}
108
109/// Configuration for a specific host's rate limiting behavior
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111#[serde(deny_unknown_fields)]
112pub struct HostConfig {
113    /// Maximum concurrent requests allowed to this host
114    pub concurrency: Option<usize>,
115
116    /// Minimum interval between requests to this host
117    #[serde(default, with = "humantime_serde")]
118    pub request_interval: Option<Duration>,
119
120    /// Custom headers to send with requests to this host
121    #[serde(default)]
122    #[serde(deserialize_with = "deserialize_headers")]
123    #[serde(serialize_with = "serialize_headers")]
124    pub headers: HeaderMap,
125}
126
127impl Default for HostConfig {
128    fn default() -> Self {
129        Self {
130            concurrency: None,
131            request_interval: None,
132            headers: HeaderMap::new(),
133        }
134    }
135}
136
137impl HostConfig {
138    /// Get the effective maximum concurrency, falling back to the global default
139    #[must_use]
140    pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
141        self.concurrency.unwrap_or(global_config.concurrency)
142    }
143
144    /// Get the effective request interval, falling back to the global default
145    #[must_use]
146    pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
147        self.request_interval
148            .unwrap_or(global_config.request_interval)
149    }
150
151    #[must_use]
152    pub(crate) fn merge(mut self, other: Self) -> Self {
153        for (k, v) in other.headers {
154            if let Some(k) = k {
155                self.headers.append(k, v);
156            }
157        }
158
159        Self {
160            concurrency: self.concurrency.or(other.concurrency),
161            request_interval: self.request_interval.or(other.request_interval),
162            headers: self.headers,
163        }
164    }
165}
166
167/// Custom deserializer for headers from TOML config format
168fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
169where
170    D: serde::Deserializer<'de>,
171{
172    let map = HashMap::<String, String>::deserialize(deserializer)?;
173    let mut header_map = HeaderMap::new();
174
175    for (name, value) in map {
176        let header_name = HeaderName::from_bytes(name.as_bytes())
177            .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
178        let header_value = HeaderValue::from_str(&value).map_err(|e| {
179            serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
180        })?;
181        header_map.insert(header_name, header_value);
182    }
183
184    Ok(header_map)
185}
186
187/// Custom serializer for headers to TOML config format
188fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
189where
190    S: serde::Serializer,
191{
192    let map: HashMap<String, String> = headers
193        .iter()
194        .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
195        .collect();
196    map.serialize(serializer)
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_default_rate_limit_config() {
205        let config = RateLimitConfig::default();
206        assert_eq!(config.concurrency, 10);
207        assert_eq!(config.request_interval, Duration::from_millis(50));
208    }
209
210    #[test]
211    fn test_host_config_effective_values() {
212        let global_config = RateLimitConfig::default();
213
214        // Test with no overrides
215        let host_config = HostConfig::default();
216        assert_eq!(host_config.effective_concurrency(&global_config), 10);
217        assert_eq!(
218            host_config.effective_request_interval(&global_config),
219            Duration::from_millis(50)
220        );
221
222        // Test with overrides
223        let host_config = HostConfig {
224            concurrency: Some(5),
225            request_interval: Some(Duration::from_millis(500)),
226            headers: HeaderMap::new(),
227        };
228        assert_eq!(host_config.effective_concurrency(&global_config), 5);
229        assert_eq!(
230            host_config.effective_request_interval(&global_config),
231            Duration::from_millis(500)
232        );
233    }
234
235    #[test]
236    fn test_config_serialization() {
237        let config = RateLimitConfig {
238            concurrency: 15,
239            request_interval: Duration::from_millis(200),
240        };
241
242        let toml = toml::to_string(&config).unwrap();
243        let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
244
245        assert_eq!(config.concurrency, deserialized.concurrency);
246        assert_eq!(config.request_interval, deserialized.request_interval);
247    }
248
249    #[test]
250    fn test_headers_serialization() {
251        let mut headers = HeaderMap::new();
252        headers.insert("Authorization", "Bearer token123".parse().unwrap());
253        headers.insert("User-Agent", "test-agent".parse().unwrap());
254
255        let host_config = HostConfig {
256            concurrency: Some(5),
257            request_interval: Some(Duration::from_millis(500)),
258            headers,
259        };
260
261        let toml = toml::to_string(&host_config).unwrap();
262        let deserialized: HostConfig = toml::from_str(&toml).unwrap();
263
264        assert_eq!(deserialized.concurrency, Some(5));
265        assert_eq!(
266            deserialized.request_interval,
267            Some(Duration::from_millis(500))
268        );
269        assert_eq!(deserialized.headers.len(), 2);
270        assert!(deserialized.headers.contains_key("authorization"));
271        assert!(deserialized.headers.contains_key("user-agent"));
272    }
273}