diff --git a/Cargo.lock b/Cargo.lock index 9b7015c..8b9f810 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1406,6 +1406,8 @@ version = "0.1.0" dependencies = [ "actix-web", "blake3", + "bytes", + "futures-util", "image", "libwebp-sys", "mimalloc", diff --git a/Cargo.toml b/Cargo.toml index 1c14ffe..a275c58 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ rgb = { version = "0.8.37", optional = true } once_cell = "1.18.0" regex = "1.10.2" blake3 = { version = "1.5.0", optional = true } +bytes = "1.5.0" +futures-util = "0.3.29" [features] default = ["webp", "mimalloc", "reqwest-rustls", "qhash"] diff --git a/src/main.rs b/src/main.rs index c3c0451..0c5c523 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,15 +3,21 @@ use actix_web::{web, App, HttpRequest, HttpResponse, HttpResponseBuilder, HttpSe use once_cell::sync::Lazy; use qstring::QString; use regex::Regex; +use reqwest::Error as ReqwestError; use reqwest::{Body, Client, Request, Url}; use std::collections::BTreeMap; -use std::env; use std::error::Error; +use std::io::ErrorKind; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{env, io}; #[cfg(not(any(feature = "reqwest-native-tls", feature = "reqwest-rustls")))] compile_error!("feature \"reqwest-native-tls\" or \"reqwest-rustls\" must be set for proxy to have TLS support"); -#[cfg(any(feature = "webp", feature = "avif"))] +use bytes::{Bytes, BytesMut}; +use futures_util::Stream; +#[cfg(any(feature = "webp", feature = "avif", feature = "qhash"))] use tokio::task::spawn_blocking; #[cfg(feature = "mimalloc")] @@ -226,6 +232,10 @@ async fn index(req: HttpRequest) -> Result> { let video_playback = req.path().eq("/videoplayback"); let is_android = video_playback && query.get("c").unwrap_or("").eq("ANDROID"); + let is_ump = video_playback && query.get("ump").is_some(); + + let mime_type = query.get("mime").map(|s| s.to_string()); + let qs = { let collected = query .into_pairs() @@ -403,8 +413,167 @@ async fn index(req: HttpRequest) -> Result> { response.append_header(("content-length", content_length)); } + let resp = resp.bytes_stream(); + + if is_ump { + if let Some(mime_type) = mime_type { + response.content_type(mime_type); + } + let transformed_stream = UmpTransformStream::new(resp); + return Ok(response.streaming(transformed_stream)); + } + // Stream response - Ok(response.streaming(resp.bytes_stream())) + Ok(response.streaming(resp)) +} + +fn read_buf(buf: &[u8], pos: &mut usize) -> u8 { + let byte = buf[*pos]; + *pos += 1; + byte +} + +fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> { + let mut pos = offset; + let prefix = read_buf(buf, &mut pos); + let mut size = 0; + for shift in 1..=5 { + if prefix & (128 >> (shift - 1)) == 0 { + size = shift; + break; + } + } + if !(1..=5).contains(&size) { + return Err(io::Error::new( + ErrorKind::InvalidData, + format!("Invalid integer size {} at position {}", size, offset), + )); + } + + match size { + 1 => Ok((prefix as i32, size)), + 2 => { + let value = ((read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111); + Ok((value, size)) + } + 3 => { + let value = + (((read_buf(buf, &mut pos) as i32) | ((read_buf(buf, &mut pos) as i32) << 8)) << 5) + | (prefix as i32 & 0b11111); + Ok((value, size)) + } + 4 => { + let value = (((read_buf(buf, &mut pos) as i32) + | ((read_buf(buf, &mut pos) as i32) << 8) + | ((read_buf(buf, &mut pos) as i32) << 16)) + << 4) + | (prefix as i32 & 0b1111); + Ok((value, size)) + } + _ => { + let value = (read_buf(buf, &mut pos) as i32) + | ((read_buf(buf, &mut pos) as i32) << 8) + | ((read_buf(buf, &mut pos) as i32) << 16) + | ((read_buf(buf, &mut pos) as i32) << 24); + Ok((value, size)) + } + } +} + +struct UmpTransformStream +where + S: Stream> + Unpin, +{ + inner: S, + buffer: BytesMut, + found_stream: bool, + remaining: usize, + finished: bool, +} + +impl UmpTransformStream +where + S: Stream> + Unpin, +{ + pub fn new(stream: S) -> Self { + UmpTransformStream { + inner: stream, + buffer: BytesMut::new(), + found_stream: false, + remaining: 0, + finished: false, + } + } +} + +impl Stream for UmpTransformStream +where + S: Stream> + Unpin, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) { + match item { + Some(Ok(bytes)) => { + if this.found_stream { + if this.remaining > 0 { + let len = std::cmp::min(this.remaining, bytes.len()); + this.remaining -= len; + if this.remaining == 0 { + this.buffer.clear(); + this.buffer.extend_from_slice(&bytes[len..]); + this.found_stream = false; + } + return Poll::Ready(Some(Ok(bytes.slice(0..len)))); + } else { + this.finished = true; + + return Poll::Ready(None); + } + } else { + this.buffer.extend_from_slice(&bytes); + } + } + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => { + this.finished = true; + return Poll::Ready(None); + } + } + } + + if !this.found_stream && !this.buffer.is_empty() { + let (segment_type, s1) = read_variable_integer(&this.buffer, 0).unwrap(); + let (segment_length, s2) = read_variable_integer(&this.buffer, s1).unwrap(); + if segment_type != 21 { + // Not the stream + if this.buffer.len() > s1 + s2 + segment_length as usize { + let _ = this.buffer.split_to(s1 + s2 + segment_length as usize); + } + } else { + this.found_stream = true; + this.remaining = segment_length as usize - 1; + + let _ = this.buffer.split_to(s1 + s2 + 1); + + if this.buffer.len() > segment_length as usize { + let len = std::cmp::min(this.remaining, this.buffer.len()); + this.remaining -= len; + + return Poll::Ready(Some(Ok(this.buffer.split_to(len).into()))); + } else { + this.remaining -= this.buffer.len(); + + return Poll::Ready(Some(Ok(this.buffer.to_vec().into()))); + } + } + } + + Poll::Pending + } } fn finalize_url(path: &str, query: BTreeMap) -> String { @@ -443,12 +612,12 @@ fn finalize_url(path: &str, query: BTreeMap) -> String { if qhash.is_some() { let mut query = QString::new(query.into_iter().collect::>()); query.add_pair(("qhash", qhash.unwrap())); - return format!("{}?{}", path, query.to_string()); + return format!("{}?{}", path, query); } } let query = QString::new(query.into_iter().collect::>()); - format!("{}?{}", path, query.to_string()) + format!("{}?{}", path, query) } fn localize_url(url: &str, host: &str) -> String {