1use std::{collections::HashSet, fmt::Display, hash::BuildHasher, str::FromStr, sync::LazyLock};
2
3use http::StatusCode;
4use serde::{Deserialize, de::Visitor};
5use thiserror::Error;
6
7use crate::{StatusRangeError, types::accept::StatusRange};
8
9pub static DEFAULT_ACCEPTED_STATUS_CODES: LazyLock<HashSet<StatusCode>> =
11 LazyLock::new(|| <HashSet<StatusCode>>::from(StatusCodeSelector::default_accepted()));
12
13#[derive(Debug, Error, PartialEq)]
14pub enum StatusCodeSelectorError {
15 #[error("invalid/empty input")]
16 InvalidInput,
17
18 #[error("failed to parse range: {0}")]
19 RangeError(#[from] StatusRangeError),
20}
21
22#[derive(Clone, Debug, PartialEq)]
25pub struct StatusCodeSelector {
26 ranges: Vec<StatusRange>,
27}
28
29impl FromStr for StatusCodeSelector {
30 type Err = StatusCodeSelectorError;
31
32 fn from_str(input: &str) -> Result<Self, Self::Err> {
33 let input = input.trim();
34
35 if input.is_empty() {
36 return Ok(Self::empty());
37 }
38
39 let ranges = input
40 .split(',')
41 .map(|part| StatusRange::from_str(part.trim()))
42 .collect::<Result<Vec<StatusRange>, StatusRangeError>>()?;
43
44 Ok(Self::new_from(ranges))
45 }
46}
47
48impl Display for StatusCodeSelector {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 let ranges: Vec<_> = self.ranges.iter().map(ToString::to_string).collect();
51 write!(f, "{}", ranges.join(","))
52 }
53}
54
55impl StatusCodeSelector {
56 #[must_use]
58 pub const fn empty() -> Self {
59 Self { ranges: Vec::new() }
60 }
61
62 #[must_use]
65 pub fn default_accepted() -> Self {
66 #[expect(clippy::missing_panics_doc, reason = "infallible")]
67 Self::new_from(vec![
68 StatusRange::new(100, 103).unwrap(),
69 StatusRange::new(200, 299).unwrap(),
70 ])
71 }
72
73 #[must_use]
75 pub fn new_from(ranges: Vec<StatusRange>) -> Self {
76 let mut selector = Self::empty();
77
78 for range in ranges {
79 selector.add_range(range);
80 }
81
82 selector
83 }
84
85 pub fn add_range(&mut self, range: StatusRange) -> &mut Self {
88 if let Some(last) = self.ranges.last_mut()
90 && last.merge(&range)
91 {
92 return self;
93 }
94
95 self.ranges.push(range);
98 self
99 }
100
101 #[must_use]
103 pub fn contains(&self, value: u16) -> bool {
104 self.ranges.iter().any(|range| range.contains(value))
105 }
106
107 #[cfg(test)]
108 pub(crate) const fn len(&self) -> usize {
109 self.ranges.len()
110 }
111}
112
113impl<S: BuildHasher + Default> From<StatusCodeSelector> for HashSet<StatusCode, S> {
114 fn from(value: StatusCodeSelector) -> Self {
115 value
116 .ranges
117 .into_iter()
118 .flat_map(<HashSet<StatusCode>>::from)
119 .collect()
120 }
121}
122
123struct StatusCodeSelectorVisitor;
124
125impl<'de> Visitor<'de> for StatusCodeSelectorVisitor {
126 type Value = StatusCodeSelector;
127
128 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
129 formatter.write_str("a string or a sequence of strings")
130 }
131
132 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
133 where
134 E: serde::de::Error,
135 {
136 StatusCodeSelector::from_str(v).map_err(serde::de::Error::custom)
137 }
138
139 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
140 where
141 E: serde::de::Error,
142 {
143 let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
144 Ok(StatusCodeSelector::new_from(vec![
145 StatusRange::new(value, value).map_err(serde::de::Error::custom)?,
146 ]))
147 }
148
149 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
150 where
151 A: serde::de::SeqAccess<'de>,
152 {
153 let mut selector = StatusCodeSelector::empty();
154 while let Some(v) = seq.next_element::<toml::Value>()? {
155 if let Some(v) = v.as_integer() {
156 let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
157 selector
158 .add_range(StatusRange::new(value, value).map_err(serde::de::Error::custom)?);
159 continue;
160 }
161
162 if let Some(s) = v.as_str() {
163 let range = StatusRange::from_str(s).map_err(serde::de::Error::custom)?;
164 selector.add_range(range);
165 continue;
166 }
167
168 return Err(serde::de::Error::custom(
169 "failed to parse sequence of accept ranges",
170 ));
171 }
172 Ok(selector)
173 }
174}
175
176impl<'de> Deserialize<'de> for StatusCodeSelector {
177 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
178 where
179 D: serde::Deserializer<'de>,
180 {
181 deserializer.deserialize_any(StatusCodeSelectorVisitor)
182 }
183}
184
185#[cfg(test)]
186mod test {
187 use super::*;
188 use rstest::rstest;
189
190 #[rstest]
191 #[case("", vec![], vec![100, 110, 150, 200, 300, 175, 350], 0)]
192 #[case("100..=150,200..=300", vec![100, 110, 150, 200, 300], vec![175, 350], 2)]
193 #[case("200..=300,100..=250", vec![100, 150, 200, 250, 300], vec![350], 1)]
194 #[case("100..=200,150..=200", vec![100, 150, 200], vec![250, 300], 1)]
195 #[case("100..=200,300", vec![100, 110, 200, 300], vec![250, 350], 2)]
196 fn test_from_str(
197 #[case] input: &str,
198 #[case] valid_values: Vec<u16>,
199 #[case] invalid_values: Vec<u16>,
200 #[case] length: usize,
201 ) {
202 let selector = StatusCodeSelector::from_str(input).unwrap();
203 assert_eq!(selector.len(), length);
204
205 for valid in valid_values {
206 assert!(selector.contains(valid));
207 }
208
209 for invalid in invalid_values {
210 assert!(!selector.contains(invalid));
211 }
212 }
213
214 #[rstest]
215 #[case(r"accept = ['200..204', '429']", vec![200, 203, 429], vec![204, 404], 2)]
216 #[case(r"accept = '200..204, 429'", vec![200, 203, 429], vec![204, 404], 2)]
217 #[case(r"accept = ['200', '429']", vec![200, 429], vec![404], 2)]
218 #[case(r"accept = '200, 429'", vec![200, 429], vec![404], 2)]
219 #[case(r"accept = [200, 429]", vec![200, 429], vec![404], 2)]
220 #[case(r"accept = '200'", vec![200], vec![404], 1)]
221 #[case(r"accept = 200", vec![200], vec![404], 1)]
222 fn test_deserialize(
223 #[case] input: &str,
224 #[case] valid_values: Vec<u16>,
225 #[case] invalid_values: Vec<u16>,
226 #[case] length: usize,
227 ) {
228 #[derive(Deserialize)]
229 struct Config {
230 accept: StatusCodeSelector,
231 }
232
233 let config: Config = toml::from_str(input).unwrap();
234 assert_eq!(config.accept.len(), length);
235
236 for valid in valid_values {
237 assert!(config.accept.contains(valid));
238 }
239
240 for invalid in invalid_values {
241 assert!(!config.accept.contains(invalid));
242 }
243 }
244
245 #[rstest]
246 #[case("100..=150,200..=300", "100..=150,200..=300")]
247 #[case("100..=150,300", "100..=150,300..=300")]
248 fn test_display(#[case] input: &str, #[case] display: &str) {
249 let selector = StatusCodeSelector::from_str(input).unwrap();
250 assert_eq!(selector.to_string(), display);
251 }
252
253 #[rstest]
254 #[case("..=102,200..202,999..", HashSet::from([100, 101, 102, 200, 201,999]))]
255 fn test_into_u16_set(#[case] input: &str, #[case] expected: HashSet<u16>) {
256 let actual: HashSet<StatusCode> = StatusCodeSelector::from_str(input).unwrap().into();
257 let expected = expected
258 .into_iter()
259 .map(|v| StatusCode::from_u16(v).unwrap())
260 .collect();
261 assert_eq!(actual, expected);
262 }
263
264 #[test]
265 fn test_default_accepted_values() {
266 let _ = LazyLock::force(&DEFAULT_ACCEPTED_STATUS_CODES);
268 }
269}