Skip to main content

lychee_lib/types/
status_code_selector.rs

1use std::{collections::HashSet, fmt::Display, hash::BuildHasher, str::FromStr, sync::LazyLock};
2
3use http::StatusCode;
4use serde::{Deserialize, de::Visitor};
5use thiserror::Error;
6
7use crate::{StatusRangeError, types::accept::StatusRange};
8
9/// These values are the default status codes which are accepted by lychee.
10pub static DEFAULT_ACCEPTED_STATUS_CODES: LazyLock<HashSet<StatusCode>> =
11    LazyLock::new(|| <HashSet<StatusCode>>::from(StatusCodeSelector::default_accepted()));
12
13#[derive(Debug, Error, PartialEq)]
14pub enum StatusCodeSelectorError {
15    #[error("invalid/empty input")]
16    InvalidInput,
17
18    #[error("failed to parse range: {0}")]
19    RangeError(#[from] StatusRangeError),
20}
21
22/// A [`StatusCodeSelector`] holds ranges of HTTP status codes, and determines
23/// whether a specific code is matched.
24#[derive(Clone, Debug, PartialEq)]
25pub struct StatusCodeSelector {
26    ranges: Vec<StatusRange>,
27}
28
29impl FromStr for StatusCodeSelector {
30    type Err = StatusCodeSelectorError;
31
32    fn from_str(input: &str) -> Result<Self, Self::Err> {
33        let input = input.trim();
34
35        if input.is_empty() {
36            return Ok(Self::empty());
37        }
38
39        let ranges = input
40            .split(',')
41            .map(|part| StatusRange::from_str(part.trim()))
42            .collect::<Result<Vec<StatusRange>, StatusRangeError>>()?;
43
44        Ok(Self::new_from(ranges))
45    }
46}
47
48impl Display for StatusCodeSelector {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        let ranges: Vec<_> = self.ranges.iter().map(ToString::to_string).collect();
51        write!(f, "{}", ranges.join(","))
52    }
53}
54
55impl StatusCodeSelector {
56    /// Creates a new empty selector
57    #[must_use]
58    pub const fn empty() -> Self {
59        Self { ranges: Vec::new() }
60    }
61
62    /// Create a new selector where 100..=103 and 200..300 are selected.
63    /// These are the status codes which are accepted by default.
64    #[must_use]
65    pub fn default_accepted() -> Self {
66        #[expect(clippy::missing_panics_doc, reason = "infallible")]
67        Self::new_from(vec![
68            StatusRange::new(100, 103).unwrap(),
69            StatusRange::new(200, 299).unwrap(),
70        ])
71    }
72
73    /// Creates a new [`StatusCodeSelector`] prefilled with `ranges`.
74    #[must_use]
75    pub fn new_from(ranges: Vec<StatusRange>) -> Self {
76        let mut selector = Self::empty();
77
78        for range in ranges {
79            selector.add_range(range);
80        }
81
82        selector
83    }
84
85    /// Adds a range of HTTP status codes to this [`StatusCodeSelector`].
86    /// This method merges the new and existing ranges if they overlap.
87    pub fn add_range(&mut self, range: StatusRange) -> &mut Self {
88        // Merge with previous range if possible
89        if let Some(last) = self.ranges.last_mut()
90            && last.merge(&range)
91        {
92            return self;
93        }
94
95        // If neither is the case, the ranges have no overlap at all. Just add
96        // to the list of ranges.
97        self.ranges.push(range);
98        self
99    }
100
101    /// Returns whether this [`StatusCodeSelector`] contains `value`.
102    #[must_use]
103    pub fn contains(&self, value: u16) -> bool {
104        self.ranges.iter().any(|range| range.contains(value))
105    }
106
107    #[cfg(test)]
108    pub(crate) const fn len(&self) -> usize {
109        self.ranges.len()
110    }
111}
112
113impl<S: BuildHasher + Default> From<StatusCodeSelector> for HashSet<StatusCode, S> {
114    fn from(value: StatusCodeSelector) -> Self {
115        value
116            .ranges
117            .into_iter()
118            .flat_map(<HashSet<StatusCode>>::from)
119            .collect()
120    }
121}
122
123struct StatusCodeSelectorVisitor;
124
125impl<'de> Visitor<'de> for StatusCodeSelectorVisitor {
126    type Value = StatusCodeSelector;
127
128    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
129        formatter.write_str("a string or a sequence of strings")
130    }
131
132    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
133    where
134        E: serde::de::Error,
135    {
136        StatusCodeSelector::from_str(v).map_err(serde::de::Error::custom)
137    }
138
139    fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
140    where
141        E: serde::de::Error,
142    {
143        let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
144        Ok(StatusCodeSelector::new_from(vec![
145            StatusRange::new(value, value).map_err(serde::de::Error::custom)?,
146        ]))
147    }
148
149    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
150    where
151        A: serde::de::SeqAccess<'de>,
152    {
153        let mut selector = StatusCodeSelector::empty();
154        while let Some(v) = seq.next_element::<toml::Value>()? {
155            if let Some(v) = v.as_integer() {
156                let value = u16::try_from(v).map_err(serde::de::Error::custom)?;
157                selector
158                    .add_range(StatusRange::new(value, value).map_err(serde::de::Error::custom)?);
159                continue;
160            }
161
162            if let Some(s) = v.as_str() {
163                let range = StatusRange::from_str(s).map_err(serde::de::Error::custom)?;
164                selector.add_range(range);
165                continue;
166            }
167
168            return Err(serde::de::Error::custom(
169                "failed to parse sequence of accept ranges",
170            ));
171        }
172        Ok(selector)
173    }
174}
175
176impl<'de> Deserialize<'de> for StatusCodeSelector {
177    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
178    where
179        D: serde::Deserializer<'de>,
180    {
181        deserializer.deserialize_any(StatusCodeSelectorVisitor)
182    }
183}
184
185#[cfg(test)]
186mod test {
187    use super::*;
188    use rstest::rstest;
189
190    #[rstest]
191    #[case("", vec![], vec![100, 110, 150, 200, 300, 175, 350], 0)]
192    #[case("100..=150,200..=300", vec![100, 110, 150, 200, 300], vec![175, 350], 2)]
193    #[case("200..=300,100..=250", vec![100, 150, 200, 250, 300], vec![350], 1)]
194    #[case("100..=200,150..=200", vec![100, 150, 200], vec![250, 300], 1)]
195    #[case("100..=200,300", vec![100, 110, 200, 300], vec![250, 350], 2)]
196    fn test_from_str(
197        #[case] input: &str,
198        #[case] valid_values: Vec<u16>,
199        #[case] invalid_values: Vec<u16>,
200        #[case] length: usize,
201    ) {
202        let selector = StatusCodeSelector::from_str(input).unwrap();
203        assert_eq!(selector.len(), length);
204
205        for valid in valid_values {
206            assert!(selector.contains(valid));
207        }
208
209        for invalid in invalid_values {
210            assert!(!selector.contains(invalid));
211        }
212    }
213
214    #[rstest]
215    #[case(r"accept = ['200..204', '429']", vec![200, 203, 429], vec![204, 404], 2)]
216    #[case(r"accept = '200..204, 429'", vec![200, 203, 429], vec![204, 404], 2)]
217    #[case(r"accept = ['200', '429']", vec![200, 429], vec![404], 2)]
218    #[case(r"accept = '200, 429'", vec![200, 429], vec![404], 2)]
219    #[case(r"accept = [200, 429]", vec![200, 429], vec![404], 2)]
220    #[case(r"accept = '200'", vec![200], vec![404], 1)]
221    #[case(r"accept = 200", vec![200], vec![404], 1)]
222    fn test_deserialize(
223        #[case] input: &str,
224        #[case] valid_values: Vec<u16>,
225        #[case] invalid_values: Vec<u16>,
226        #[case] length: usize,
227    ) {
228        #[derive(Deserialize)]
229        struct Config {
230            accept: StatusCodeSelector,
231        }
232
233        let config: Config = toml::from_str(input).unwrap();
234        assert_eq!(config.accept.len(), length);
235
236        for valid in valid_values {
237            assert!(config.accept.contains(valid));
238        }
239
240        for invalid in invalid_values {
241            assert!(!config.accept.contains(invalid));
242        }
243    }
244
245    #[rstest]
246    #[case("100..=150,200..=300", "100..=150,200..=300")]
247    #[case("100..=150,300", "100..=150,300..=300")]
248    fn test_display(#[case] input: &str, #[case] display: &str) {
249        let selector = StatusCodeSelector::from_str(input).unwrap();
250        assert_eq!(selector.to_string(), display);
251    }
252
253    #[rstest]
254    #[case("..=102,200..202,999..", HashSet::from([100, 101, 102, 200, 201,999]))]
255    fn test_into_u16_set(#[case] input: &str, #[case] expected: HashSet<u16>) {
256        let actual: HashSet<StatusCode> = StatusCodeSelector::from_str(input).unwrap().into();
257        let expected = expected
258            .into_iter()
259            .map(|v| StatusCode::from_u16(v).unwrap())
260            .collect();
261        assert_eq!(actual, expected);
262    }
263
264    #[test]
265    fn test_default_accepted_values() {
266        // assert that accessing the value does not panic
267        let _ = LazyLock::force(&DEFAULT_ACCEPTED_STATUS_CODES);
268    }
269}