From e291ed4e190e0f1d29c58c7d1abaa89369b3ab90 Mon Sep 17 00:00:00 2001 From: Kavin <20838718+FireMasterK@users.noreply.github.com> Date: Thu, 14 Dec 2023 13:37:44 +0000 Subject: [PATCH] Refactor utils and UMP transforming code to new files. --- src/main.rs | 260 +++------------------------------------------- src/ump_stream.rs | 154 +++++++++++++++++++++++++++ src/utils.rs | 96 +++++++++++++++++ 3 files changed, 262 insertions(+), 248 deletions(-) create mode 100644 src/ump_stream.rs create mode 100644 src/utils.rs diff --git a/src/main.rs b/src/main.rs index 44ba06a..02f9940 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,24 +1,23 @@ +mod ump_stream; +mod utils; + use actix_web::http::{Method, StatusCode}; use actix_web::{web, App, HttpRequest, HttpResponse, HttpResponseBuilder, HttpServer}; use once_cell::sync::Lazy; use qstring::QString; use regex::Regex; use reqwest::{Body, Client, Request, Url}; -use std::borrow::Cow; -use std::collections::BTreeMap; use std::error::Error; use std::io::ErrorKind; -use std::pin::Pin; -use std::task::{Context, Poll}; use std::{env, io}; #[cfg(not(any(feature = "reqwest-native-tls", feature = "reqwest-rustls")))] compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support"); -use bytes::{Bytes, BytesMut}; -use futures_util::{Stream, TryStreamExt}; +use futures_util::{TryStreamExt}; #[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))] use tokio::task::spawn_blocking; +use ump_stream::UmpTransformStream; #[cfg(feature = "mimalloc")] #[global_allocator] @@ -406,11 +405,13 @@ async fn index(req: HttpRequest) -> Result> { if let Some(captures) = captures { let url = captures.get(1).unwrap().as_str(); if url.starts_with("https://") { - return line - .replace(url, localize_url(url, host.as_str()).as_str()); + return line.replace( + url, + utils::localize_url(url, host.as_str()).as_str(), + ); } } - localize_url(line, host.as_str()) + utils::localize_url(line, host.as_str()) }) .collect::>() .join("\n"); @@ -423,8 +424,8 @@ async fn index(req: HttpRequest) -> Result> { let captures = RE_DASH_MANIFEST.captures_iter(&resp_str); for capture in captures { let url = capture.get(1).unwrap().as_str(); - let new_url = localize_url(url, host.as_str()); - let new_url = escape_xml(new_url.as_str()); + let new_url = utils::localize_url(url, host.as_str()); + let new_url = utils::escape_xml(new_url.as_str()); new_resp = new_resp.replace(url, new_url.as_ref()); } return Ok(response.body(new_resp)); @@ -453,240 +454,3 @@ async fn index(req: HttpRequest) -> Result> { // Stream response Ok(response.streaming(resp)) } - -fn read_buf(buf: &[u8], pos: &mut usize) -> u8 { - let byte = buf[*pos]; - *pos += 1; - byte -} - -fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> { - let mut pos = offset; - let prefix = read_buf(buf, &mut pos); - let mut size = 0; - for shift in 1..=5 { - if prefix & (128 >> (shift - 1)) == 0 { - size = shift; - break; - } - } - if !(1..=5).contains(&size) { - return Err(io::Error::new( - ErrorKind::InvalidData, - format!("Invalid integer size {} at position {}", size, offset), - )); - } - - match size { - 1 => Ok((prefix as i32, size)), - 2 => { - let value = ((read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111); - Ok((value, size)) - } - 3 => { - let value = - (((read_buf(buf, &mut pos) as i32) | ((read_buf(buf, &mut pos) as i32) << 8)) << 5) - | (prefix as i32 & 0b11111); - Ok((value, size)) - } - 4 => { - let value = (((read_buf(buf, &mut pos) as i32) - | ((read_buf(buf, &mut pos) as i32) << 8) - | ((read_buf(buf, &mut pos) as i32) << 16)) - << 4) - | (prefix as i32 & 0b1111); - Ok((value, size)) - } - _ => { - let value = (read_buf(buf, &mut pos) as i32) - | ((read_buf(buf, &mut pos) as i32) << 8) - | ((read_buf(buf, &mut pos) as i32) << 16) - | ((read_buf(buf, &mut pos) as i32) << 24); - Ok((value, size)) - } - } -} - -struct UmpTransformStream -where - S: Stream> + Unpin, -{ - inner: S, - buffer: BytesMut, - found_stream: bool, - remaining: usize, -} - -impl UmpTransformStream -where - S: Stream> + Unpin, -{ - pub fn new(stream: S) -> Self { - UmpTransformStream { - inner: stream, - buffer: BytesMut::new(), - found_stream: false, - remaining: 0, - } - } -} - -impl Stream for UmpTransformStream -where - S: Stream> + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.get_mut(); - - while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) { - match item { - Some(Ok(bytes)) => { - if this.found_stream { - if this.remaining > 0 { - let len = std::cmp::min(this.remaining, bytes.len()); - this.remaining -= len; - if this.remaining == 0 { - this.buffer.clear(); - this.buffer.extend_from_slice(&bytes[len..]); - this.found_stream = false; - } - return Poll::Ready(Some(Ok(bytes.slice(0..len)))); - } else { - this.found_stream = false; - this.buffer.clear(); - this.buffer.extend_from_slice(&bytes); - }; - } else { - this.buffer.extend_from_slice(&bytes); - } - } - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - None => { - return Poll::Ready(None); - } - } - } - - if !this.found_stream && !this.buffer.is_empty() { - let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) { - Ok(result) => result, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) { - Ok(result) => result, - Err(e) => return Poll::Ready(Some(Err(e))), - }; - if segment_type != 21 { - // Not the stream - if this.buffer.len() > s1 + s2 + segment_length as usize { - let _ = this.buffer.split_to(s1 + s2 + segment_length as usize); - } - } else { - this.remaining = segment_length as usize - 1; - - let _ = this.buffer.split_to(s1 + s2 + 1); - - if this.buffer.len() > segment_length as usize { - let len = std::cmp::min(this.remaining, this.buffer.len()); - this.remaining -= len; - - return Poll::Ready(Some(Ok(this.buffer.split_to(len).into()))); - } else { - this.remaining -= this.buffer.len(); - this.found_stream = true; - - return Poll::Ready(Some(Ok(this.buffer.to_vec().into()))); - } - } - } - - Poll::Pending - } -} - -fn finalize_url(path: &str, query: BTreeMap) -> String { - #[cfg(feature = "qhash")] - { - use std::collections::BTreeSet; - - let qhash = { - let secret = env::var("HASH_SECRET"); - if let Ok(secret) = secret { - let set = query - .iter() - .filter(|(key, _)| !matches!(key.as_str(), "qhash" | "range" | "rewrite")) - .map(|(key, value)| (key.as_bytes().to_owned(), value.as_bytes().to_owned())) - .collect::>(); - - let mut hasher = blake3::Hasher::new(); - - for (key, value) in set { - hasher.update(&key); - hasher.update(&value); - } - - hasher.update(path.as_bytes()); - - hasher.update(secret.as_bytes()); - - let hash = hasher.finalize().to_hex(); - - Some(hash[..8].to_owned()) - } else { - None - } - }; - - if qhash.is_some() { - let mut query = QString::new(query.into_iter().collect::>()); - query.add_pair(("qhash", qhash.unwrap())); - return format!("{}?{}", path, query); - } - } - - let query = QString::new(query.into_iter().collect::>()); - format!("{}?{}", path, query) -} - -pub fn escape_xml(raw: &str) -> Cow<'_, str> { - if !raw.contains(&['<', '>', '&', '\'', '"'][..]) { - // If there are no characters to escape, return the original string. - Cow::Borrowed(raw) - } else { - // If there are characters to escape, build a new string with the replacements. - let mut escaped = String::with_capacity(raw.len()); - for c in raw.chars() { - match c { - '<' => escaped.push_str("<"), - '>' => escaped.push_str(">"), - '&' => escaped.push_str("&"), - '\'' => escaped.push_str("'"), - '"' => escaped.push_str("""), - _ => escaped.push(c), - } - } - Cow::Owned(escaped) - } -} - -fn localize_url(url: &str, host: &str) -> String { - if url.starts_with("https://") { - let url = Url::parse(url).unwrap(); - let host = url.host().unwrap().to_string(); - - let mut query = url.query_pairs().into_owned().collect::>(); - - query.insert("host".to_string(), host.clone()); - - return finalize_url(url.path(), query); - } else if url.ends_with(".m3u8") || url.ends_with(".ts") { - let mut query = BTreeMap::new(); - query.insert("host".to_string(), host.to_string()); - - return finalize_url(url, query); - } - - url.to_string() -} diff --git a/src/ump_stream.rs b/src/ump_stream.rs new file mode 100644 index 0000000..9e87b72 --- /dev/null +++ b/src/ump_stream.rs @@ -0,0 +1,154 @@ +use crate::utils; +use bytes::{Bytes, BytesMut}; +use futures_util::Stream; +use std::io; +use std::io::ErrorKind; +use std::pin::Pin; +use std::task::{Context, Poll}; + +fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> { + let mut pos = offset; + let prefix = utils::read_buf(buf, &mut pos); + let mut size = 0; + for shift in 1..=5 { + if prefix & (128 >> (shift - 1)) == 0 { + size = shift; + break; + } + } + if !(1..=5).contains(&size) { + return Err(io::Error::new( + ErrorKind::InvalidData, + format!("Invalid integer size {} at position {}", size, offset), + )); + } + + match size { + 1 => Ok((prefix as i32, size)), + 2 => { + let value = ((utils::read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111); + Ok((value, size)) + } + 3 => { + let value = (((utils::read_buf(buf, &mut pos) as i32) + | ((utils::read_buf(buf, &mut pos) as i32) << 8)) + << 5) + | (prefix as i32 & 0b11111); + Ok((value, size)) + } + 4 => { + let value = (((utils::read_buf(buf, &mut pos) as i32) + | ((utils::read_buf(buf, &mut pos) as i32) << 8) + | ((utils::read_buf(buf, &mut pos) as i32) << 16)) + << 4) + | (prefix as i32 & 0b1111); + Ok((value, size)) + } + _ => { + let value = (utils::read_buf(buf, &mut pos) as i32) + | ((utils::read_buf(buf, &mut pos) as i32) << 8) + | ((utils::read_buf(buf, &mut pos) as i32) << 16) + | ((utils::read_buf(buf, &mut pos) as i32) << 24); + Ok((value, size)) + } + } +} + +pub struct UmpTransformStream +where + S: Stream> + Unpin, +{ + inner: S, + buffer: BytesMut, + found_stream: bool, + remaining: usize, +} + +impl UmpTransformStream +where + S: Stream> + Unpin, +{ + pub fn new(stream: S) -> Self { + UmpTransformStream { + inner: stream, + buffer: BytesMut::new(), + found_stream: false, + remaining: 0, + } + } +} + +impl Stream for UmpTransformStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) { + match item { + Some(Ok(bytes)) => { + if this.found_stream { + if this.remaining > 0 { + let len = std::cmp::min(this.remaining, bytes.len()); + this.remaining -= len; + if this.remaining == 0 { + this.buffer.clear(); + this.buffer.extend_from_slice(&bytes[len..]); + this.found_stream = false; + } + return Poll::Ready(Some(Ok(bytes.slice(0..len)))); + } else { + this.found_stream = false; + this.buffer.clear(); + this.buffer.extend_from_slice(&bytes); + }; + } else { + this.buffer.extend_from_slice(&bytes); + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => { + return Poll::Ready(None); + } + } + } + + if !this.found_stream && !this.buffer.is_empty() { + let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) { + Ok(result) => result, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) { + Ok(result) => result, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + if segment_type != 21 { + // Not the stream + if this.buffer.len() > s1 + s2 + segment_length as usize { + let _ = this.buffer.split_to(s1 + s2 + segment_length as usize); + } + } else { + this.remaining = segment_length as usize - 1; + + let _ = this.buffer.split_to(s1 + s2 + 1); + + if this.buffer.len() > segment_length as usize { + let len = std::cmp::min(this.remaining, this.buffer.len()); + this.remaining -= len; + + return Poll::Ready(Some(Ok(this.buffer.split_to(len).into()))); + } else { + this.remaining -= this.buffer.len(); + this.found_stream = true; + + return Poll::Ready(Some(Ok(this.buffer.to_vec().into()))); + } + } + } + + Poll::Pending + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..8c415c3 --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,96 @@ +use qstring::QString; +use reqwest::Url; +use std::borrow::Cow; +use std::collections::BTreeMap; + +pub fn read_buf(buf: &[u8], pos: &mut usize) -> u8 { + let byte = buf[*pos]; + *pos += 1; + byte +} + +fn finalize_url(path: &str, query: BTreeMap) -> String { + #[cfg(feature = "qhash")] + { + use std::collections::BTreeSet; + use std::env; + + let qhash = { + let secret = env::var("HASH_SECRET"); + if let Ok(secret) = secret { + let set = query + .iter() + .filter(|(key, _)| !matches!(key.as_str(), "qhash" | "range" | "rewrite")) + .map(|(key, value)| (key.as_bytes().to_owned(), value.as_bytes().to_owned())) + .collect::>(); + + let mut hasher = blake3::Hasher::new(); + + for (key, value) in set { + hasher.update(&key); + hasher.update(&value); + } + + hasher.update(path.as_bytes()); + + hasher.update(secret.as_bytes()); + + let hash = hasher.finalize().to_hex(); + + Some(hash[..8].to_owned()) + } else { + None + } + }; + + if qhash.is_some() { + let mut query = QString::new(query.into_iter().collect::>()); + query.add_pair(("qhash", qhash.unwrap())); + return format!("{}?{}", path, query); + } + } + + let query = QString::new(query.into_iter().collect::>()); + format!("{}?{}", path, query) +} + +pub fn localize_url(url: &str, host: &str) -> String { + if url.starts_with("https://") { + let url = Url::parse(url).unwrap(); + let host = url.host().unwrap().to_string(); + + let mut query = url.query_pairs().into_owned().collect::>(); + + query.insert("host".to_string(), host.clone()); + + return finalize_url(url.path(), query); + } else if url.ends_with(".m3u8") || url.ends_with(".ts") { + let mut query = BTreeMap::new(); + query.insert("host".to_string(), host.to_string()); + + return finalize_url(url, query); + } + + url.to_string() +} + +pub fn escape_xml(raw: &str) -> Cow<'_, str> { + if !raw.contains(&['<', '>', '&', '\'', '"'][..]) { + // If there are no characters to escape, return the original string. + Cow::Borrowed(raw) + } else { + // If there are characters to escape, build a new string with the replacements. + let mut escaped = String::with_capacity(raw.len()); + for c in raw.chars() { + match c { + '<' => escaped.push_str("<"), + '>' => escaped.push_str(">"), + '&' => escaped.push_str("&"), + '\'' => escaped.push_str("'"), + '"' => escaped.push_str("""), + _ => escaped.push(c), + } + } + Cow::Owned(escaped) + } +}