mirror of
https://gitea.invidious.io/iv-org/Invidious-quic-proxy.git
synced 2024-08-15 00:43:22 +00:00
12206c14c6
The code for reestablishing conns begins after we retrieve a request from the queue. However, since the request was removed from the queue we no longer have access to it, resulting in a dropped request. This commit adds another queue to aggregate any requests during client restarts in order to preserve and handle them after the connection has been reestablished. This queue would have higher priority than the normal requests queue meaning that normal requests would only get processed if and only if there aren't any more "dropped" requests to handle.
225 lines
7.7 KiB
Python
225 lines
7.7 KiB
Python
"""
|
|
This module contains the QUIC request stuff.
|
|
|
|
Adapted from https://github.com/aiortc/aioquic/blob/3d2708013aa959d70b7e4dbcd6cfb173c5ac4359/examples/http3_client.py
|
|
"""
|
|
|
|
import asyncio
|
|
import collections
|
|
import time
|
|
import logging
|
|
from urllib.parse import urlparse
|
|
|
|
from aioquic.asyncio.protocol import QuicConnectionProtocol
|
|
from aioquic.h3.connection import H3_ALPN, H3Connection
|
|
from aioquic.h3.events import (
|
|
DataReceived,
|
|
HeadersReceived,
|
|
)
|
|
from aioquic.quic.events import ConnectionTerminated
|
|
from aioquic.quic.configuration import QuicConfiguration
|
|
from aioquic.asyncio.client import connect
|
|
|
|
logger = logging.getLogger("client")
|
|
logger.setLevel(logging.INFO)
|
|
USER_AGENT = "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.97 Safari/537.36"
|
|
HTTPRequest = collections.namedtuple("HTTPRequest", ["content", "headers", "method", "url"])
|
|
|
|
|
|
class InvidiousRequest:
|
|
def __init__(self, url, method, headers, content):
|
|
self.url = urlparse(url)
|
|
self.method = method
|
|
self.headers = headers
|
|
self.content = content
|
|
|
|
self.completed = asyncio.Event()
|
|
|
|
|
|
class URL:
|
|
def __init__(self, url):
|
|
self.authority = url.netloc
|
|
self.full_path = url.path
|
|
if url.query:
|
|
self.full_path += "?" + url.query
|
|
self.scheme = url.scheme
|
|
|
|
|
|
class HttpClient(QuicConnectionProtocol):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self._http = H3Connection(self._quic)
|
|
self._request_events = {}
|
|
self._request_waiter = {}
|
|
|
|
self.terminated = False
|
|
|
|
async def get(self, url, headers=None):
|
|
"""
|
|
Perform a GET request.
|
|
"""
|
|
headers = {} if not headers else headers
|
|
|
|
return await self._request(
|
|
HTTPRequest(method="GET", url=URL(url), headers=headers, content=b"")
|
|
)
|
|
|
|
async def post(self, url: str, data: bytes, headers=None):
|
|
"""
|
|
Perform a POST request.
|
|
"""
|
|
headers = {} if not headers else headers
|
|
|
|
return await self._request(
|
|
HTTPRequest(method="POST", url=URL(url), content=data, headers=headers)
|
|
)
|
|
|
|
def http_event_received(self, event):
|
|
if isinstance(event, (HeadersReceived, DataReceived)):
|
|
stream_id = event.stream_id
|
|
if stream_id in self._request_events:
|
|
# http
|
|
self._request_events[event.stream_id].append(event)
|
|
if event.stream_ended:
|
|
request_waiter = self._request_waiter.pop(stream_id)
|
|
request_waiter.set_result(self._request_events.pop(stream_id))
|
|
|
|
def quic_event_received(self, event):
|
|
if isinstance(event, ConnectionTerminated):
|
|
self.terminated = True
|
|
|
|
# pass event to the HTTP layer
|
|
for http_event in self._http.handle_event(event):
|
|
self.http_event_received(http_event)
|
|
|
|
async def _request(self, http_request):
|
|
stream_id = self._quic.get_next_available_stream_id()
|
|
headers = await self.fetch_default_headers(http_request)
|
|
|
|
# Custom headers
|
|
for k, v in http_request.headers.items():
|
|
headers[k.encode()] = v.encode()
|
|
|
|
self._http.send_headers(
|
|
stream_id=stream_id,
|
|
headers=[(k, v) for k, v in headers.items()],
|
|
)
|
|
self._http.send_data(stream_id=stream_id, data=http_request.content, end_stream=True)
|
|
|
|
waiter = self._loop.create_future()
|
|
self._request_events[stream_id] = collections.deque()
|
|
self._request_waiter[stream_id] = waiter
|
|
self.transmit()
|
|
|
|
return await asyncio.shield(waiter)
|
|
|
|
async def fetch_default_headers(self, http_request):
|
|
return {b":method": http_request.method.encode(),
|
|
b":scheme": http_request.url.scheme.encode(),
|
|
b":authority": http_request.url.authority.encode(),
|
|
b":path": http_request.url.full_path.encode() or b"/",
|
|
b"user-agent": USER_AGENT.encode()}
|
|
|
|
|
|
async def perform_http_request(client, url, method, headers, data, store_at):
|
|
# perform request
|
|
start = time.time()
|
|
|
|
if method == "POST":
|
|
http_events = await client.post(url, data.encode(), headers=headers)
|
|
else:
|
|
http_events = await client.get(url, headers=headers)
|
|
elapsed = time.time() - start
|
|
|
|
# log speed
|
|
octets = 0
|
|
for http_event in http_events:
|
|
if isinstance(http_event, DataReceived):
|
|
octets += len(http_event.data)
|
|
logger.info(
|
|
"Response received for %s %s : %d bytes in %.1f s (%.3f Mbps)"
|
|
% (method, url.path, octets, elapsed, octets * 8 / elapsed / 1000000)
|
|
)
|
|
s = time.time()
|
|
await handle_response(http_events, store_at)
|
|
e = time.time() - s
|
|
|
|
logger.info(
|
|
"Response has been stored in %.1f s)"
|
|
% elapsed
|
|
)
|
|
|
|
|
|
async def handle_response(http_events, store_at=None):
|
|
resulting_data = {}
|
|
_accumulator = b""
|
|
|
|
for http_event in http_events:
|
|
if isinstance(http_event, HeadersReceived):
|
|
headers = {}
|
|
for k, v in http_event.headers:
|
|
headers[k.decode()] = v.decode()
|
|
resulting_data["headers"] = headers
|
|
|
|
elif isinstance(http_event, DataReceived):
|
|
if not http_event.stream_ended:
|
|
_accumulator += http_event.data
|
|
continue
|
|
_accumulator += http_event.data
|
|
|
|
# If a content type header isn't available then we don't have a response to get.
|
|
if not resulting_data["headers"].get("content-type"):
|
|
return
|
|
|
|
if resulting_data["headers"]["content-type"].startswith("image"):
|
|
resulting_data["response"] = _accumulator
|
|
else:
|
|
resulting_data["response"] = _accumulator.decode()
|
|
|
|
if store_at is not None:
|
|
store_at.update(resulting_data)
|
|
return store_at
|
|
|
|
|
|
class RequestProcessor:
|
|
def __init__(self):
|
|
# {InvidiousRequest, storage_dict}
|
|
self.requests_to_do = asyncio.Queue(0)
|
|
self.paused_requests = asyncio.Queue(0)
|
|
self.processors = []
|
|
|
|
async def request_worker(self):
|
|
configuration = QuicConfiguration(is_client=True, alpn_protocols=H3_ALPN)
|
|
|
|
while self.recreate_connection_check():
|
|
async with connect("youtube.com", 443, configuration=configuration, create_protocol=HttpClient) as client:
|
|
while True:
|
|
status = await self._handle_request(client)
|
|
# Connection has been terminated.
|
|
if status is False:
|
|
break
|
|
|
|
async def _handle_request(self, client):
|
|
# Paused requests has a higher priority than the normal staged requests.
|
|
try:
|
|
request, storage = self.paused_requests.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
request, storage = await self.requests_to_do.get()
|
|
|
|
if client.terminated:
|
|
# Client has to be recreated so we'll put the current request on pause.
|
|
await self.paused_requests.put([request, storage])
|
|
return False
|
|
|
|
await perform_http_request(client=client, url=request.url, method=request.method,
|
|
headers=request.headers, data=request.content,
|
|
store_at=storage)
|
|
|
|
request.completed.set()
|
|
|
|
def recreate_connection_check(self):
|
|
# TODO in the future this code should calculate whether or not to recreate a connection based on the amount
|
|
# of connections currently available and the amount of traffic we're currently receiving. For now we'll
|
|
# just have it recreate a connection anytime it's broken.
|
|
return True
|