Refactor UmpTransformStream to use io::Error instead of ReqwestError

- Replaced all occurrences of `ReqwestError` with `io::Error` in the code.
- Updated the type signature and implementation of `UmpTransformStream` to use `io::Error`.
- Handled errors when calling `read_variable_integer` by returning an error result.
This commit is contained in:
Kavin 2023-12-11 22:51:12 +00:00
parent e158c3aef8
commit 7586ae314b
No known key found for this signature in database
GPG key ID: 6E4598CA5C92C41F

View file

@ -3,7 +3,6 @@ use actix_web::{web, App, HttpRequest, HttpResponse, HttpResponseBuilder, HttpSe
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use qstring::QString; use qstring::QString;
use regex::Regex; use regex::Regex;
use reqwest::Error as ReqwestError;
use reqwest::{Body, Client, Request, Url}; use reqwest::{Body, Client, Request, Url};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::error::Error; use std::error::Error;
@ -16,7 +15,7 @@ use std::{env, io};
compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support"); compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support");
use bytes::{Bytes, BytesMut}; use bytes::{Bytes, BytesMut};
use futures_util::Stream; use futures_util::{Stream, TryStreamExt};
#[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))] #[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))]
use tokio::task::spawn_blocking; use tokio::task::spawn_blocking;
@ -431,6 +430,7 @@ async fn index(req: HttpRequest) -> Result<HttpResponse, Box<dyn Error>> {
if req.headers().contains_key("range") { if req.headers().contains_key("range") {
response.status(StatusCode::PARTIAL_CONTENT); response.status(StatusCode::PARTIAL_CONTENT);
} }
let resp = resp.map_err(|e| io::Error::new(ErrorKind::Other, e));
let transformed_stream = UmpTransformStream::new(resp); let transformed_stream = UmpTransformStream::new(resp);
return Ok(response.streaming(transformed_stream)); return Ok(response.streaming(transformed_stream));
} }
@ -494,7 +494,7 @@ fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)>
struct UmpTransformStream<S> struct UmpTransformStream<S>
where where
S: Stream<Item = Result<Bytes, ReqwestError>> + Unpin, S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{ {
inner: S, inner: S,
buffer: BytesMut, buffer: BytesMut,
@ -504,7 +504,7 @@ where
impl<S> UmpTransformStream<S> impl<S> UmpTransformStream<S>
where where
S: Stream<Item = Result<Bytes, ReqwestError>> + Unpin, S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{ {
pub fn new(stream: S) -> Self { pub fn new(stream: S) -> Self {
UmpTransformStream { UmpTransformStream {
@ -518,9 +518,9 @@ where
impl<S> Stream for UmpTransformStream<S> impl<S> Stream for UmpTransformStream<S>
where where
S: Stream<Item = Result<Bytes, ReqwestError>> + Unpin, S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{ {
type Item = Result<Bytes, ReqwestError>; type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut(); let this = self.get_mut();
@ -555,8 +555,14 @@ where
} }
if !this.found_stream && !this.buffer.is_empty() { if !this.found_stream && !this.buffer.is_empty() {
let (segment_type, s1) = read_variable_integer(&this.buffer, 0).unwrap(); let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) {
let (segment_length, s2) = read_variable_integer(&this.buffer, s1).unwrap(); Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) {
Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
if segment_type != 21 { if segment_type != 21 {
// Not the stream // Not the stream
if this.buffer.len() > s1 + s2 + segment_length as usize { if this.buffer.len() > s1 + s2 + segment_length as usize {