use crate::utils; use bytes::{Bytes, BytesMut}; use futures_util::Stream; use std::io; use std::io::ErrorKind; use std::pin::Pin; use std::task::{Context, Poll}; fn read_variable_integer(buf: &[u8], offset: usize) -> io::Result<(i32, usize)> { let mut pos = offset; let prefix = utils::read_buf(buf, &mut pos); let mut size = 0; for shift in 1..=5 { if prefix & (128 >> (shift - 1)) == 0 { size = shift; break; } } if !(1..=5).contains(&size) { return Err(io::Error::new( ErrorKind::InvalidData, format!("Invalid integer size {} at position {}", size, offset), )); } match size { 1 => Ok((prefix as i32, size)), 2 => { let value = ((utils::read_buf(buf, &mut pos) as i32) << 6) | (prefix as i32 & 0b111111); Ok((value, size)) } 3 => { let value = (((utils::read_buf(buf, &mut pos) as i32) | ((utils::read_buf(buf, &mut pos) as i32) << 8)) << 5) | (prefix as i32 & 0b11111); Ok((value, size)) } 4 => { let value = (((utils::read_buf(buf, &mut pos) as i32) | ((utils::read_buf(buf, &mut pos) as i32) << 8) | ((utils::read_buf(buf, &mut pos) as i32) << 16)) << 4) | (prefix as i32 & 0b1111); Ok((value, size)) } _ => { let value = (utils::read_buf(buf, &mut pos) as i32) | ((utils::read_buf(buf, &mut pos) as i32) << 8) | ((utils::read_buf(buf, &mut pos) as i32) << 16) | ((utils::read_buf(buf, &mut pos) as i32) << 24); Ok((value, size)) } } } pub struct UmpTransformStream where S: Stream> + Unpin, { inner: S, buffer: BytesMut, found_stream: bool, remaining: usize, } impl UmpTransformStream where S: Stream> + Unpin, { pub fn new(stream: S) -> Self { UmpTransformStream { inner: stream, buffer: BytesMut::new(), found_stream: false, remaining: 0, } } } impl Stream for UmpTransformStream where S: Stream> + Unpin, { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); while let Poll::Ready(item) = Pin::new(&mut this.inner).poll_next(cx) { match item { Some(Ok(bytes)) => { if this.found_stream { if this.remaining > 0 { let len = std::cmp::min(this.remaining, bytes.len()); this.remaining -= len; if this.remaining == 0 { this.buffer.clear(); this.buffer.extend_from_slice(&bytes[len..]); this.found_stream = false; } return Poll::Ready(Some(Ok(bytes.slice(0..len)))); } else { this.found_stream = false; this.buffer.clear(); this.buffer.extend_from_slice(&bytes); }; } else { this.buffer.extend_from_slice(&bytes); } } Some(Err(e)) => return Poll::Ready(Some(Err(e))), None => { return Poll::Ready(None); } } } if !this.found_stream && !this.buffer.is_empty() { let (segment_type, s1) = match read_variable_integer(&this.buffer, 0) { Ok(result) => result, Err(_) => return Poll::Pending, }; let (segment_length, s2) = match read_variable_integer(&this.buffer, s1) { Ok(result) => result, Err(_) => return Poll::Pending, }; if segment_type != 21 { // Not the stream if this.buffer.len() > s1 + s2 + segment_length as usize { let _ = this.buffer.split_to(s1 + s2 + segment_length as usize); } } else { this.remaining = segment_length as usize - 1; let _ = this.buffer.split_to(s1 + s2 + 1); if this.buffer.len() > segment_length as usize { let len = std::cmp::min(this.remaining, this.buffer.len()); this.remaining -= len; return Poll::Ready(Some(Ok(this.buffer.split_to(len).into()))); } else { this.remaining -= this.buffer.len(); this.found_stream = true; return Poll::Ready(Some(Ok(this.buffer.to_vec().into()))); } } } Poll::Pending } }