Skip to main content

lychee_lib/
retry.rs

1use std::io;
2
3use http::StatusCode;
4
5use crate::{ErrorKind, Status};
6
7/// An extension trait to help determine if a given HTTP request
8/// is retryable.
9///
10/// Modified from `Retryable` in [reqwest-middleware].
11/// We vendor this code to avoid a dependency on `reqwest-middleware` and
12/// to easily customize the logic.
13///
14/// [reqwest-middleware]: https://github.com/TrueLayer/reqwest-middleware/blob/f854725791ccf4a02c401a26cab3d9db753f468c/reqwest-retry/src/retryable.rs
15pub(crate) trait RetryExt {
16    fn should_retry(&self) -> bool;
17}
18
19impl RetryExt for reqwest::StatusCode {
20    /// Try to map a `reqwest` response into `Retryable`.
21    fn should_retry(&self) -> bool {
22        self.is_server_error()
23            || self == &StatusCode::REQUEST_TIMEOUT
24            || self == &StatusCode::TOO_MANY_REQUESTS
25    }
26}
27
28impl RetryExt for reqwest::Error {
29    #[allow(clippy::if_same_then_else)]
30    fn should_retry(&self) -> bool {
31        if self.is_timeout() {
32            true
33        } else if self.is_connect() {
34            false
35        } else if self.is_body() || self.is_decode() || self.is_builder() || self.is_redirect() {
36            false
37        } else if self.is_request() {
38            // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
39            // Here we check if the Reqwest error was originated by hyper and map it consistently.
40            if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&self) {
41                // The hyper::Error(IncompleteMessage) is raised if the HTTP
42                // response is well formatted but does not contain all the
43                // bytes. This can happen when the server has started sending
44                // back the response but the connection is cut halfway through.
45                // We can safely retry the call, hence marking this error as
46                // transient.
47                //
48                // Instead hyper::Error(Canceled) is raised when the connection is
49                // gracefully closed on the server side.
50                if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
51                    true
52
53                // Try and downcast the hyper error to [`io::Error`] if that is the
54                // underlying error, and try and classify it.
55                } else if let Some(io_error) = get_source_error_type::<io::Error>(hyper_error) {
56                    should_retry_io(io_error)
57                } else {
58                    false
59                }
60            } else {
61                false
62            }
63        } else if let Some(status) = self.status() {
64            status.should_retry()
65        } else {
66            // We omit checking if error.is_status() since we check that already.
67            // However, if Response::error_for_status is used the status will still
68            // remain in the response object.
69            false
70        }
71    }
72}
73
74impl RetryExt for http::Error {
75    fn should_retry(&self) -> bool {
76        let inner = self.get_ref();
77        inner
78            .source()
79            .and_then(<dyn std::error::Error + 'static>::downcast_ref)
80            .is_some_and(should_retry_io)
81    }
82}
83
84impl RetryExt for ErrorKind {
85    fn should_retry(&self) -> bool {
86        // If the error is a `reqwest::Error`, delegate to that
87        if let Some(r) = self.reqwest_error() {
88            r.should_retry()
89        // GitHub errors sometimes wrap `reqwest` errors.
90        // In that case, delegate to the underlying error.
91        } else if let Some(octocrab::Error::Http {
92            source,
93            backtrace: _,
94        }) = self.github_error()
95        {
96            source.should_retry()
97        } else {
98            matches!(
99                self,
100                Self::RejectedStatusCode(StatusCode::TOO_MANY_REQUESTS)
101            )
102        }
103    }
104}
105
106impl RetryExt for Status {
107    fn should_retry(&self) -> bool {
108        match self {
109            Status::Timeout(_) => true,
110            Status::Error(err) => err.should_retry(),
111            Status::Ok(_)
112            | Status::RequestError(_)
113            | Status::Redirected(_, _)
114            | Status::UnknownStatusCode(_)
115            | Status::UnknownMailStatus(_)
116            | Status::Excluded
117            | Status::Unsupported(_)
118            | Status::Cached(_) => false,
119        }
120    }
121}
122
123/// Classifies an `io::Error` into retryable or not.
124fn should_retry_io(error: &io::Error) -> bool {
125    matches!(
126        error.kind(),
127        io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted | io::ErrorKind::TimedOut
128    )
129}
130
131/// Downcasts the given err source into T.
132fn get_source_error_type<T: std::error::Error + 'static>(
133    err: &dyn std::error::Error,
134) -> Option<&T> {
135    let mut source = err.source();
136
137    while let Some(err) = source {
138        if let Some(hyper_err) = err.downcast_ref::<T>() {
139            return Some(hyper_err);
140        }
141
142        source = err.source();
143    }
144    None
145}
146
147#[cfg(test)]
148mod tests {
149    use http::StatusCode;
150
151    use super::RetryExt;
152
153    #[test]
154    fn test_should_retry() {
155        assert!(StatusCode::REQUEST_TIMEOUT.should_retry());
156        assert!(StatusCode::TOO_MANY_REQUESTS.should_retry());
157        assert!(!StatusCode::FORBIDDEN.should_retry());
158        assert!(StatusCode::INTERNAL_SERVER_ERROR.should_retry());
159    }
160}