lychee_lib/ratelimit/
config.rs1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::Iter;
5use std::time::Duration;
6
7use crate::ratelimit::HostKey;
8
9const DEFAULT_CONCURRENCY: usize = 10;
11
12const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
14
15#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
17pub struct RateLimitConfig {
18 #[serde(default = "default_concurrency")]
20 pub concurrency: usize,
21
22 #[serde(default = "default_request_interval", with = "humantime_serde")]
24 pub request_interval: Duration,
25}
26
27impl Default for RateLimitConfig {
28 fn default() -> Self {
29 Self {
30 concurrency: default_concurrency(),
31 request_interval: default_request_interval(),
32 }
33 }
34}
35
36const fn default_concurrency() -> usize {
38 DEFAULT_CONCURRENCY
39}
40
41const fn default_request_interval() -> Duration {
43 DEFAULT_REQUEST_INTERVAL
44}
45
46impl RateLimitConfig {
47 #[must_use]
49 pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
50 Self {
51 concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
52 request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
59pub struct HostConfigs(HashMap<HostKey, HostConfig>);
60
61impl HostConfigs {
62 pub(crate) fn get(&self, key: &HostKey) -> Option<&HostConfig> {
64 self.0.get(key)
65 }
66
67 #[cfg(test)]
69 pub(crate) fn len(&self) -> usize {
70 self.0.len()
71 }
72
73 pub(crate) fn iter(&self) -> Iter<'_, HostKey, HostConfig> {
75 self.0.iter()
76 }
77
78 #[must_use]
80 pub fn merge(mut self, other: HostConfigs) -> HostConfigs {
81 for (key, value) in other.0 {
82 let value = if let Some(s) = self.0.remove(&key) {
83 s.merge(value)
84 } else {
85 value
86 };
87
88 self.0.insert(key, value);
89 }
90
91 self
92 }
93}
94
95impl<'a> IntoIterator for &'a HostConfigs {
96 type Item = (&'a HostKey, &'a HostConfig);
97 type IntoIter = Iter<'a, HostKey, HostConfig>;
98 fn into_iter(self) -> Self::IntoIter {
99 self.0.iter()
100 }
101}
102
103impl<const N: usize> From<[(HostKey, HostConfig); N]> for HostConfigs {
104 fn from(arr: [(HostKey, HostConfig); N]) -> Self {
105 HostConfigs(HashMap::<HostKey, HostConfig>::from_iter(arr))
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111#[serde(deny_unknown_fields)]
112pub struct HostConfig {
113 pub concurrency: Option<usize>,
115
116 #[serde(default, with = "humantime_serde")]
118 pub request_interval: Option<Duration>,
119
120 #[serde(default)]
122 #[serde(deserialize_with = "deserialize_headers")]
123 #[serde(serialize_with = "serialize_headers")]
124 pub headers: HeaderMap,
125}
126
127impl Default for HostConfig {
128 fn default() -> Self {
129 Self {
130 concurrency: None,
131 request_interval: None,
132 headers: HeaderMap::new(),
133 }
134 }
135}
136
137impl HostConfig {
138 #[must_use]
140 pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
141 self.concurrency.unwrap_or(global_config.concurrency)
142 }
143
144 #[must_use]
146 pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
147 self.request_interval
148 .unwrap_or(global_config.request_interval)
149 }
150
151 #[must_use]
152 pub(crate) fn merge(mut self, other: Self) -> Self {
153 for (k, v) in other.headers {
154 if let Some(k) = k {
155 self.headers.append(k, v);
156 }
157 }
158
159 Self {
160 concurrency: self.concurrency.or(other.concurrency),
161 request_interval: self.request_interval.or(other.request_interval),
162 headers: self.headers,
163 }
164 }
165}
166
167fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
169where
170 D: serde::Deserializer<'de>,
171{
172 let map = HashMap::<String, String>::deserialize(deserializer)?;
173 let mut header_map = HeaderMap::new();
174
175 for (name, value) in map {
176 let header_name = HeaderName::from_bytes(name.as_bytes())
177 .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
178 let header_value = HeaderValue::from_str(&value).map_err(|e| {
179 serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
180 })?;
181 header_map.insert(header_name, header_value);
182 }
183
184 Ok(header_map)
185}
186
187fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
189where
190 S: serde::Serializer,
191{
192 let map: HashMap<String, String> = headers
193 .iter()
194 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
195 .collect();
196 map.serialize(serializer)
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_default_rate_limit_config() {
205 let config = RateLimitConfig::default();
206 assert_eq!(config.concurrency, 10);
207 assert_eq!(config.request_interval, Duration::from_millis(50));
208 }
209
210 #[test]
211 fn test_host_config_effective_values() {
212 let global_config = RateLimitConfig::default();
213
214 let host_config = HostConfig::default();
216 assert_eq!(host_config.effective_concurrency(&global_config), 10);
217 assert_eq!(
218 host_config.effective_request_interval(&global_config),
219 Duration::from_millis(50)
220 );
221
222 let host_config = HostConfig {
224 concurrency: Some(5),
225 request_interval: Some(Duration::from_millis(500)),
226 headers: HeaderMap::new(),
227 };
228 assert_eq!(host_config.effective_concurrency(&global_config), 5);
229 assert_eq!(
230 host_config.effective_request_interval(&global_config),
231 Duration::from_millis(500)
232 );
233 }
234
235 #[test]
236 fn test_config_serialization() {
237 let config = RateLimitConfig {
238 concurrency: 15,
239 request_interval: Duration::from_millis(200),
240 };
241
242 let toml = toml::to_string(&config).unwrap();
243 let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
244
245 assert_eq!(config.concurrency, deserialized.concurrency);
246 assert_eq!(config.request_interval, deserialized.request_interval);
247 }
248
249 #[test]
250 fn test_headers_serialization() {
251 let mut headers = HeaderMap::new();
252 headers.insert("Authorization", "Bearer token123".parse().unwrap());
253 headers.insert("User-Agent", "test-agent".parse().unwrap());
254
255 let host_config = HostConfig {
256 concurrency: Some(5),
257 request_interval: Some(Duration::from_millis(500)),
258 headers,
259 };
260
261 let toml = toml::to_string(&host_config).unwrap();
262 let deserialized: HostConfig = toml::from_str(&toml).unwrap();
263
264 assert_eq!(deserialized.concurrency, Some(5));
265 assert_eq!(
266 deserialized.request_interval,
267 Some(Duration::from_millis(500))
268 );
269 assert_eq!(deserialized.headers.len(), 2);
270 assert!(deserialized.headers.contains_key("authorization"));
271 assert!(deserialized.headers.contains_key("user-agent"));
272 }
273}