diff --git a/src/invidious/helpers/helpers.cr b/src/invidious/helpers/helpers.cr index ace2a6f7..9cefcf14 100644 --- a/src/invidious/helpers/helpers.cr +++ b/src/invidious/helpers/helpers.cr @@ -87,12 +87,40 @@ end struct Config module ConfigPreferencesConverter + def self.to_yaml(value : Preferences, yaml : YAML::Nodes::Builder) + value.to_yaml(yaml) + end + def self.from_yaml(ctx : YAML::ParseContext, node : YAML::Nodes::Node) : Preferences Preferences.new(*ConfigPreferences.new(ctx, node).to_tuple) end + end - def self.to_yaml(value : Preferences, yaml : YAML::Nodes::Builder) - value.to_yaml(yaml) + module FamilyConverter + def self.to_yaml(value : Socket::Family, yaml : YAML::Nodes::Builder) + case value + when Socket::Family::UNSPEC + yaml.scalar nil + when Socket::Family::INET + yaml.scalar "ipv4" + when Socket::Family::INET6 + yaml.scalar "ipv6" + end + end + + def self.from_yaml(ctx : YAML::ParseContext, node : YAML::Nodes::Node) : Socket::Family + if node.is_a?(YAML::Nodes::Scalar) + case node.value.downcase + when "ipv4" + Socket::Family::INET + when "ipv6" + Socket::Family::INET6 + else + Socket::Family::UNSPEC + end + else + node.raise "Expected scalar, not #{node.class}" + end end end @@ -131,12 +159,13 @@ struct Config default: Preferences.new(*ConfigPreferences.from_yaml("").to_tuple), converter: ConfigPreferencesConverter, }, - dmca_content: {type: Array(String), default: [] of String}, # For compliance with DMCA, disables download widget using list of video IDs - check_tables: {type: Bool, default: false}, # Check table integrity, automatically try to add any missing columns, create tables, etc. - cache_annotations: {type: Bool, default: false}, # Cache annotations requested from IA, will not cache empty annotations or annotations that only contain cards - banner: {type: String?, default: nil}, # Optional banner to be displayed along top of page for announcements, etc. - hsts: {type: Bool?, default: true}, # Enables 'Strict-Transport-Security'. Ensure that `domain` and all subdomains are served securely - disable_proxy: {type: Bool? | Array(String)?, default: false}, # Disable proxying server-wide: options: 'dash', 'livestreams', 'downloads', 'local' + dmca_content: {type: Array(String), default: [] of String}, # For compliance with DMCA, disables download widget using list of video IDs + check_tables: {type: Bool, default: false}, # Check table integrity, automatically try to add any missing columns, create tables, etc. + cache_annotations: {type: Bool, default: false}, # Cache annotations requested from IA, will not cache empty annotations or annotations that only contain cards + banner: {type: String?, default: nil}, # Optional banner to be displayed along top of page for announcements, etc. + hsts: {type: Bool?, default: true}, # Enables 'Strict-Transport-Security'. Ensure that `domain` and all subdomains are served securely + disable_proxy: {type: Bool? | Array(String)?, default: false}, # Disable proxying server-wide: options: 'dash', 'livestreams', 'downloads', 'local' + force_resolve: {type: Socket::Family, default: Socket::Family::UNSPEC, converter: FamilyConverter}, # Connect to YouTube over 'ipv6', 'ipv4'. Will sometimes resolve fix issues with rate-limiting (see https://github.com/ytdl-org/youtube-dl/issues/21729) }) end @@ -650,48 +679,6 @@ def cache_annotation(db, id, annotations) end end -def proxy_file(response, env) - if response.headers.includes_word?("Content-Encoding", "gzip") - Gzip::Writer.open(env.response) do |deflate| - response.pipe(deflate) - end - elsif response.headers.includes_word?("Content-Encoding", "deflate") - Flate::Writer.open(env.response) do |deflate| - response.pipe(deflate) - end - else - response.pipe(env.response) - end -end - -class HTTP::Client::Response - def pipe(io) - HTTP.serialize_body(io, headers, @body, @body_io, @version) - end -end - -# Supports serialize_body without first writing headers -module HTTP - def self.serialize_body(io, headers, body, body_io, version) - if body - io << body - elsif body_io - content_length = content_length(headers) - if content_length - copied = IO.copy(body_io, io) - if copied != content_length - raise ArgumentError.new("Content-Length header is #{content_length} but body had #{copied} bytes") - end - elsif Client::Response.supports_chunked?(version) - headers["Transfer-Encoding"] = "chunked" - serialize_chunked_body(io, body_io) - else - io << body - end - end - end -end - def create_notification_stream(env, config, kemal_config, decrypt_function, topics, connection_channel) connection = Channel(PQ::Notification).new(8) connection_channel.send({true, connection}) @@ -834,3 +821,79 @@ def extract_initial_data(body) return JSON.parse(initial_data) end end + +def proxy_file(response, env) + if response.headers.includes_word?("Content-Encoding", "gzip") + Gzip::Writer.open(env.response) do |deflate| + response.pipe(deflate) + end + elsif response.headers.includes_word?("Content-Encoding", "deflate") + Flate::Writer.open(env.response) do |deflate| + response.pipe(deflate) + end + else + response.pipe(env.response) + end +end + +class HTTP::Client::Response + def pipe(io) + HTTP.serialize_body(io, headers, @body, @body_io, @version) + end +end + +# Supports serialize_body without first writing headers +module HTTP + def self.serialize_body(io, headers, body, body_io, version) + if body + io << body + elsif body_io + content_length = content_length(headers) + if content_length + copied = IO.copy(body_io, io) + if copied != content_length + raise ArgumentError.new("Content-Length header is #{content_length} but body had #{copied} bytes") + end + elsif Client::Response.supports_chunked?(version) + headers["Transfer-Encoding"] = "chunked" + serialize_chunked_body(io, body_io) + else + io << body + end + end + end +end + +class HTTP::Client + property family : Socket::Family = Socket::Family::UNSPEC + + private def socket + socket = @socket + return socket if socket + + hostname = @host.starts_with?('[') && @host.ends_with?(']') ? @host[1..-2] : @host + socket = TCPSocket.new hostname, @port, @dns_timeout, @connect_timeout, @family + socket.read_timeout = @read_timeout if @read_timeout + socket.sync = false + + {% if !flag?(:without_openssl) %} + if tls = @tls + socket = OpenSSL::SSL::Socket::Client.new(socket, context: tls, sync_close: true, hostname: @host) + end + {% end %} + + @socket = socket + end +end + +class TCPSocket + def initialize(host, port, dns_timeout = nil, connect_timeout = nil, family = Socket::Family::UNSPEC) + Addrinfo.tcp(host, port, timeout: dns_timeout, family: family) do |addrinfo| + super(addrinfo.family, addrinfo.type, addrinfo.protocol) + connect(addrinfo, timeout: connect_timeout) do |error| + close + error + end + end + end +end diff --git a/src/invidious/helpers/utils.cr b/src/invidious/helpers/utils.cr index 9ce8efdb..b7deae76 100644 --- a/src/invidious/helpers/utils.cr +++ b/src/invidious/helpers/utils.cr @@ -20,6 +20,7 @@ end def make_client(url : URI, region = nil) client = HTTPClient.new(url) + client.family = CONFIG.force_resolve client.read_timeout = 15.seconds client.connect_timeout = 15.seconds