1use std::io;
2
3use http::StatusCode;
4
5use crate::{ErrorKind, Status};
6
7pub(crate) trait RetryExt {
16 fn should_retry(&self) -> bool;
17}
18
19impl RetryExt for reqwest::StatusCode {
20 fn should_retry(&self) -> bool {
22 self.is_server_error()
23 || self == &StatusCode::REQUEST_TIMEOUT
24 || self == &StatusCode::TOO_MANY_REQUESTS
25 }
26}
27
28impl RetryExt for reqwest::Error {
29 #[allow(clippy::if_same_then_else)]
30 fn should_retry(&self) -> bool {
31 if self.is_timeout() {
32 true
33 } else if self.is_connect() {
34 false
35 } else if self.is_body() || self.is_decode() || self.is_builder() || self.is_redirect() {
36 false
37 } else if self.is_request() {
38 if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&self) {
41 if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
51 true
52
53 } else if let Some(io_error) = get_source_error_type::<io::Error>(hyper_error) {
56 should_retry_io(io_error)
57 } else {
58 false
59 }
60 } else {
61 false
62 }
63 } else if let Some(status) = self.status() {
64 status.should_retry()
65 } else {
66 false
70 }
71 }
72}
73
74impl RetryExt for http::Error {
75 fn should_retry(&self) -> bool {
76 let inner = self.get_ref();
77 inner
78 .source()
79 .and_then(<dyn std::error::Error + 'static>::downcast_ref)
80 .is_some_and(should_retry_io)
81 }
82}
83
84impl RetryExt for ErrorKind {
85 fn should_retry(&self) -> bool {
86 if let Some(r) = self.reqwest_error() {
88 r.should_retry()
89 } else if let Some(octocrab::Error::Http {
92 source,
93 backtrace: _,
94 }) = self.github_error()
95 {
96 source.should_retry()
97 } else {
98 matches!(
99 self,
100 Self::RejectedStatusCode(StatusCode::TOO_MANY_REQUESTS)
101 )
102 }
103 }
104}
105
106impl RetryExt for Status {
107 fn should_retry(&self) -> bool {
108 match self {
109 Status::Timeout(_) => true,
110 Status::Error(err) => err.should_retry(),
111 Status::Ok(_)
112 | Status::RequestError(_)
113 | Status::Redirected(_, _)
114 | Status::UnknownStatusCode(_)
115 | Status::UnknownMailStatus(_)
116 | Status::Excluded
117 | Status::Unsupported(_)
118 | Status::Cached(_) => false,
119 }
120 }
121}
122
123fn should_retry_io(error: &io::Error) -> bool {
125 matches!(
126 error.kind(),
127 io::ErrorKind::ConnectionReset | io::ErrorKind::ConnectionAborted | io::ErrorKind::TimedOut
128 )
129}
130
131fn get_source_error_type<T: std::error::Error + 'static>(
133 err: &dyn std::error::Error,
134) -> Option<&T> {
135 let mut source = err.source();
136
137 while let Some(err) = source {
138 if let Some(hyper_err) = err.downcast_ref::<T>() {
139 return Some(hyper_err);
140 }
141
142 source = err.source();
143 }
144 None
145}
146
147#[cfg(test)]
148mod tests {
149 use http::StatusCode;
150
151 use super::RetryExt;
152
153 #[test]
154 fn test_should_retry() {
155 assert!(StatusCode::REQUEST_TIMEOUT.should_retry());
156 assert!(StatusCode::TOO_MANY_REQUESTS.should_retry());
157 assert!(!StatusCode::FORBIDDEN.should_retry());
158 assert!(StatusCode::INTERNAL_SERVER_ERROR.should_retry());
159 }
160}