Refactor utils and UMP transforming code to new files.

This commit is contained in:
Kavin 2023-12-14 13:37:44 +00:00
parent e90eebcbd7
commit e291ed4e19
No known key found for this signature in database
GPG key ID: 6E4598CA5C92C41F
3 changed files with 262 additions and 248 deletions

View file

@ -1,24 +1,23 @@
mod ump_stream;
mod utils;
use actix_web::http::{Method, StatusCode}; use actix_web::http::{Method, StatusCode};
use actix_web::{web, App, HttpRequest, HttpResponse, HttpResponseBuilder, HttpServer}; use actix_web::{web, App, HttpRequest, HttpResponse, HttpResponseBuilder, HttpServer};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use qstring::QString; use qstring::QString;
use regex::Regex; use regex::Regex;
use reqwest::{Body, Client, Request, Url}; use reqwest::{Body, Client, Request, Url};
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::error::Error; use std::error::Error;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{env, io}; use std::{env, io};
#[cfg(not(any(feature = "reqwest-native-tls", feature = "reqwest-rustls")))] #[cfg(not(any(feature = "reqwest-native-tls", feature = "reqwest-rustls")))]
compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support"); compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support");
use bytes::{Bytes, BytesMut}; use futures_util::{TryStreamExt};
use futures_util::{Stream, TryStreamExt};
#[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))] #[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))]
use tokio::task::spawn_blocking; use tokio::task::spawn_blocking;
use ump_stream::UmpTransformStream;
#[cfg(feature = "mimalloc")] #[cfg(feature = "mimalloc")]
#[global_allocator] #[global_allocator]
@ -406,11 +405,13 @@ async fn index(req: HttpRequest) -> Result<HttpResponse, Box<dyn Error>> {
if let Some(captures) = captures { if let Some(captures) = captures {
let url = captures.get(1).unwrap().as_str(); let url = captures.get(1).unwrap().as_str();
if url.starts_with("https://") { if url.starts_with("https://") {
return line return line.replace(
.replace(url, localize_url(url, host.as_str()).as_str()); url,
utils::localize_url(url, host.as_str()).as_str(),
);
} }
} }
localize_url(line, host.as_str()) utils::localize_url(line, host.as_str())
}) })
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("\n"); .join("\n");
@ -423,8 +424,8 @@ async fn index(req: HttpRequest) -> Result<HttpResponse, Box<dyn Error>> {
let captures = RE_DASH_MANIFEST.captures_iter(&resp_str); let captures = RE_DASH_MANIFEST.captures_iter(&resp_str);
for capture in captures { for capture in captures {
let url = capture.get(1).unwrap().as_str(); let url = capture.get(1).unwrap().as_str();
let new_url = localize_url(url, host.as_str()); let new_url = utils::localize_url(url, host.as_str());
let new_url = escape_xml(new_url.as_str()); let new_url = utils::escape_xml(new_url.as_str());
new_resp = new_resp.replace(url, new_url.as_ref()); new_resp = new_resp.replace(url, new_url.as_ref());
} }
return Ok(response.body(new_resp)); return Ok(response.body(new_resp));
@ -453,240 +454,3 @@ async fn index(req: HttpRequest) -> Result<HttpResponse, Box<dyn Error>> {
// Stream response // Stream response
Ok(response.streaming(resp)) Ok(response.streaming(resp))
} }
fn read_buf(buf: &[u8], pos: &mut usize) -> u8 {
let byte = buf[*pos];
*pos += 1;
byte
}
fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> {
let mut pos = offset;
let prefix = read_buf(buf, &mut pos);
let mut size = 0;
for shift in 1..=5 {
if prefix & (128 >> (shift - 1)) == 0 {
size = shift;
break;
}
}
if !(1..=5).contains(&size) {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!("Invalid integer size {} at position {}", size, offset),
));
}
match size {
1 => Ok((prefix as i32, size)),
2 => {
let value = ((read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111);
Ok((value, size))
}
3 => {
let value =
(((read_buf(buf, &mut pos) as i32) | ((read_buf(buf, &mut pos) as i32) << 8)) << 5)
| (prefix as i32 & 0b11111);
Ok((value, size))
}
4 => {
let value = (((read_buf(buf, &mut pos) as i32)
| ((read_buf(buf, &mut pos) as i32) << 8)
| ((read_buf(buf, &mut pos) as i32) << 16))
<< 4)
| (prefix as i32 & 0b1111);
Ok((value, size))
}
_ => {
let value = (read_buf(buf, &mut pos) as i32)
| ((read_buf(buf, &mut pos) as i32) << 8)
| ((read_buf(buf, &mut pos) as i32) << 16)
| ((read_buf(buf, &mut pos) as i32) << 24);
Ok((value, size))
}
}
}
struct UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
inner: S,
buffer: BytesMut,
found_stream: bool,
remaining: usize,
}
impl<S> UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
pub fn new(stream: S) -> Self {
UmpTransformStream {
inner: stream,
buffer: BytesMut::new(),
found_stream: false,
remaining: 0,
}
}
}
impl<S> Stream for UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) {
match item {
Some(Ok(bytes)) => {
if this.found_stream {
if this.remaining > 0 {
let len = std::cmp::min(this.remaining, bytes.len());
this.remaining -= len;
if this.remaining == 0 {
this.buffer.clear();
this.buffer.extend_from_slice(&bytes[len..]);
this.found_stream = false;
}
return Poll::Ready(Some(Ok(bytes.slice(0..len))));
} else {
this.found_stream = false;
this.buffer.clear();
this.buffer.extend_from_slice(&bytes);
};
} else {
this.buffer.extend_from_slice(&bytes);
}
}
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => {
return Poll::Ready(None);
}
}
}
if !this.found_stream && !this.buffer.is_empty() {
let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) {
Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) {
Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
if segment_type != 21 {
// Not the stream
if this.buffer.len() > s1 + s2 + segment_length as usize {
let _ = this.buffer.split_to(s1 + s2 + segment_length as usize);
}
} else {
this.remaining = segment_length as usize - 1;
let _ = this.buffer.split_to(s1 + s2 + 1);
if this.buffer.len() > segment_length as usize {
let len = std::cmp::min(this.remaining, this.buffer.len());
this.remaining -= len;
return Poll::Ready(Some(Ok(this.buffer.split_to(len).into())));
} else {
this.remaining -= this.buffer.len();
this.found_stream = true;
return Poll::Ready(Some(Ok(this.buffer.to_vec().into())));
}
}
}
Poll::Pending
}
}
fn finalize_url(path: &str, query: BTreeMap<String, String>) -> String {
#[cfg(feature = "qhash")]
{
use std::collections::BTreeSet;
let qhash = {
let secret = env::var("HASH_SECRET");
if let Ok(secret) = secret {
let set = query
.iter()
.filter(|(key, _)| !matches!(key.as_str(), "qhash" | "range" | "rewrite"))
.map(|(key, value)| (key.as_bytes().to_owned(), value.as_bytes().to_owned()))
.collect::<BTreeSet<_>>();
let mut hasher = blake3::Hasher::new();
for (key, value) in set {
hasher.update(&key);
hasher.update(&value);
}
hasher.update(path.as_bytes());
hasher.update(secret.as_bytes());
let hash = hasher.finalize().to_hex();
Some(hash[..8].to_owned())
} else {
None
}
};
if qhash.is_some() {
let mut query = QString::new(query.into_iter().collect::<Vec<_>>());
query.add_pair(("qhash", qhash.unwrap()));
return format!("{}?{}", path, query);
}
}
let query = QString::new(query.into_iter().collect::<Vec<_>>());
format!("{}?{}", path, query)
}
pub fn escape_xml(raw: &str) -> Cow<'_, str> {
if !raw.contains(&['<', '>', '&', '\'', '"'][..]) {
// If there are no characters to escape, return the original string.
Cow::Borrowed(raw)
} else {
// If there are characters to escape, build a new string with the replacements.
let mut escaped = String::with_capacity(raw.len());
for c in raw.chars() {
match c {
'<' => escaped.push_str("&lt;"),
'>' => escaped.push_str("&gt;"),
'&' => escaped.push_str("&amp;"),
'\'' => escaped.push_str("&apos;"),
'"' => escaped.push_str("&quot;"),
_ => escaped.push(c),
}
}
Cow::Owned(escaped)
}
}
fn localize_url(url: &str, host: &str) -> String {
if url.starts_with("https://") {
let url = Url::parse(url).unwrap();
let host = url.host().unwrap().to_string();
let mut query = url.query_pairs().into_owned().collect::<BTreeMap<_, _>>();
query.insert("host".to_string(), host.clone());
return finalize_url(url.path(), query);
} else if url.ends_with(".m3u8") || url.ends_with(".ts") {
let mut query = BTreeMap::new();
query.insert("host".to_string(), host.to_string());
return finalize_url(url, query);
}
url.to_string()
}

154
src/ump_stream.rs Normal file
View file

@ -0,0 +1,154 @@
use crate::utils;
use bytes::{Bytes, BytesMut};
use futures_util::Stream;
use std::io;
use std::io::ErrorKind;
use std::pin::Pin;
use std::task::{Context, Poll};
fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> {
let mut pos = offset;
let prefix = utils::read_buf(buf, &mut pos);
let mut size = 0;
for shift in 1..=5 {
if prefix & (128 >> (shift - 1)) == 0 {
size = shift;
break;
}
}
if !(1..=5).contains(&size) {
return Err(io::Error::new(
ErrorKind::InvalidData,
format!("Invalid integer size {} at position {}", size, offset),
));
}
match size {
1 => Ok((prefix as i32, size)),
2 => {
let value = ((utils::read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111);
Ok((value, size))
}
3 => {
let value = (((utils::read_buf(buf, &mut pos) as i32)
| ((utils::read_buf(buf, &mut pos) as i32) << 8))
<< 5)
| (prefix as i32 & 0b11111);
Ok((value, size))
}
4 => {
let value = (((utils::read_buf(buf, &mut pos) as i32)
| ((utils::read_buf(buf, &mut pos) as i32) << 8)
| ((utils::read_buf(buf, &mut pos) as i32) << 16))
<< 4)
| (prefix as i32 & 0b1111);
Ok((value, size))
}
_ => {
let value = (utils::read_buf(buf, &mut pos) as i32)
| ((utils::read_buf(buf, &mut pos) as i32) << 8)
| ((utils::read_buf(buf, &mut pos) as i32) << 16)
| ((utils::read_buf(buf, &mut pos) as i32) << 24);
Ok((value, size))
}
}
}
pub struct UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
inner: S,
buffer: BytesMut,
found_stream: bool,
remaining: usize,
}
impl<S> UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
pub fn new(stream: S) -> Self {
UmpTransformStream {
inner: stream,
buffer: BytesMut::new(),
found_stream: false,
remaining: 0,
}
}
}
impl<S> Stream for UmpTransformStream<S>
where
S: Stream<Item = Result<Bytes, io::Error>> + Unpin,
{
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) {
match item {
Some(Ok(bytes)) => {
if this.found_stream {
if this.remaining > 0 {
let len = std::cmp::min(this.remaining, bytes.len());
this.remaining -= len;
if this.remaining == 0 {
this.buffer.clear();
this.buffer.extend_from_slice(&bytes[len..]);
this.found_stream = false;
}
return Poll::Ready(Some(Ok(bytes.slice(0..len))));
} else {
this.found_stream = false;
this.buffer.clear();
this.buffer.extend_from_slice(&bytes);
};
} else {
this.buffer.extend_from_slice(&bytes);
}
}
Some(Err(e)) => return Poll::Ready(Some(Err(e))),
None => {
return Poll::Ready(None);
}
}
}
if !this.found_stream && !this.buffer.is_empty() {
let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) {
Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) {
Ok(result) => result,
Err(e) => return Poll::Ready(Some(Err(e))),
};
if segment_type != 21 {
// Not the stream
if this.buffer.len() > s1 + s2 + segment_length as usize {
let _ = this.buffer.split_to(s1 + s2 + segment_length as usize);
}
} else {
this.remaining = segment_length as usize - 1;
let _ = this.buffer.split_to(s1 + s2 + 1);
if this.buffer.len() > segment_length as usize {
let len = std::cmp::min(this.remaining, this.buffer.len());
this.remaining -= len;
return Poll::Ready(Some(Ok(this.buffer.split_to(len).into())));
} else {
this.remaining -= this.buffer.len();
this.found_stream = true;
return Poll::Ready(Some(Ok(this.buffer.to_vec().into())));
}
}
}
Poll::Pending
}
}

96
src/utils.rs Normal file
View file

@ -0,0 +1,96 @@
use qstring::QString;
use reqwest::Url;
use std::borrow::Cow;
use std::collections::BTreeMap;
pub fn read_buf(buf: &[u8], pos: &mut usize) -> u8 {
let byte = buf[*pos];
*pos += 1;
byte
}
fn finalize_url(path: &str, query: BTreeMap<String, String>) -> String {
#[cfg(feature = "qhash")]
{
use std::collections::BTreeSet;
use std::env;
let qhash = {
let secret = env::var("HASH_SECRET");
if let Ok(secret) = secret {
let set = query
.iter()
.filter(|(key, _)| !matches!(key.as_str(), "qhash" | "range" | "rewrite"))
.map(|(key, value)| (key.as_bytes().to_owned(), value.as_bytes().to_owned()))
.collect::<BTreeSet<_>>();
let mut hasher = blake3::Hasher::new();
for (key, value) in set {
hasher.update(&key);
hasher.update(&value);
}
hasher.update(path.as_bytes());
hasher.update(secret.as_bytes());
let hash = hasher.finalize().to_hex();
Some(hash[..8].to_owned())
} else {
None
}
};
if qhash.is_some() {
let mut query = QString::new(query.into_iter().collect::<Vec<_>>());
query.add_pair(("qhash", qhash.unwrap()));
return format!("{}?{}", path, query);
}
}
let query = QString::new(query.into_iter().collect::<Vec<_>>());
format!("{}?{}", path, query)
}
pub fn localize_url(url: &str, host: &str) -> String {
if url.starts_with("https://") {
let url = Url::parse(url).unwrap();
let host = url.host().unwrap().to_string();
let mut query = url.query_pairs().into_owned().collect::<BTreeMap<_, _>>();
query.insert("host".to_string(), host.clone());
return finalize_url(url.path(), query);
} else if url.ends_with(".m3u8") || url.ends_with(".ts") {
let mut query = BTreeMap::new();
query.insert("host".to_string(), host.to_string());
return finalize_url(url, query);
}
url.to_string()
}
pub fn escape_xml(raw: &str) -> Cow<'_, str> {
if !raw.contains(&['<', '>', '&', '\'', '"'][..]) {
// If there are no characters to escape, return the original string.
Cow::Borrowed(raw)
} else {
// If there are characters to escape, build a new string with the replacements.
let mut escaped = String::with_capacity(raw.len());
for c in raw.chars() {
match c {
'<' => escaped.push_str("&lt;"),
'>' => escaped.push_str("&gt;"),
'&' => escaped.push_str("&amp;"),
'\'' => escaped.push_str("&apos;"),
'"' => escaped.push_str("&quot;"),
_ => escaped.push(c),
}
}
Cow::Owned(escaped)
}
}