Implemented requested changes and added 'auth_enforce_source' option

This commit is contained in:
soraefir 2023-07-13 12:48:18 +02:00
parent 4d14789e7b
commit 8953c105be
No known key found for this signature in database
GPG key ID: A362EA0491E2EEA0
5 changed files with 68 additions and 82 deletions

View file

@ -142,6 +142,7 @@ class Config
property use_quic : Bool = false property use_quic : Bool = false
property auth_type : Array(String) = ["invidious", "oauth"] property auth_type : Array(String) = ["invidious", "oauth"]
property auth_enforce_source : Bool = true
property oauth = {} of String => OAuthConfig property oauth = {} of String => OAuthConfig
# Saved cookies in "name1=value1; name2=value2..." format # Saved cookies in "name1=value1; name2=value2..." format
@ -170,6 +171,14 @@ class Config
end end
end end
def auth_oauth_enabled?
return (@auth_type.find(&.== "oauth") && @oauth.size > 0)
end
def auth_internal_enabled?
return (@auth_type.find(&.== "invidious"))
end
def self.load def self.load
# Load config from file or YAML string env var # Load config from file or YAML string env var
env_config_file = "INVIDIOUS_CONFIG_FILE" env_config_file = "INVIDIOUS_CONFIG_FILE"

View file

@ -11,7 +11,7 @@ module Invidious::OAuthHelper
end end
end end
def get(key) def make_client(key)
if HOST_URL == "" if HOST_URL == ""
raise Exception.new("Missing domain and port configuration") raise Exception.new("Missing domain and port configuration")
end end
@ -27,24 +27,22 @@ module Invidious::OAuthHelper
) )
end end
def get_uri_host_pair(host, uri) def get_uri_host_pair(host, url)
if (uri.starts_with?("https://")) if (url.starts_with?(/https*\:\/\//))
res = uri.gsub(/https*\:\/\//, "").split('/', 2) uri = URI.parse url
[res[0], "/" + res[1]] [uri.host || host, uri.path || "/"]
else else
[host, uri] [host, url]
end end
end end
def get_info(key, token) def get_info(key, token)
provider = self.get_provider(key) provider = self.get_provider(key)
uri_host_pair = self.get_uri_host_pair(provider.host, provider.info_uri) uri_host_pair = self.get_uri_host_pair(provider.host, provider.info_uri)
LOGGER.info(uri_host_pair[0] + " " + uri_host_pair[1])
client = HTTP::Client.new(uri_host_pair[0], tls: true) client = HTTP::Client.new(uri_host_pair[0], tls: true)
token.authenticate(client) token.authenticate(client)
response = client.get uri_host_pair[1] response = client.get uri_host_pair[1]
client.close client.close
LOGGER.info(response.body)
response.body response.body
end end

View file

@ -3,11 +3,10 @@
module Invidious::Routes::Login module Invidious::Routes::Login
def self.login_page(env) def self.login_page(env)
locale = env.get("preferences").as(Preferences).locale locale = env.get("preferences").as(Preferences).locale
referer = get_referer(env, "/feed/subscriptions")
user = env.get? "user" user = env.get? "user"
referer = get_referer(env, "/feed/subscriptions")
return env.redirect referer if user return env.redirect referer if user
if !CONFIG.login_enabled if !CONFIG.login_enabled
@ -19,18 +18,14 @@ module Invidious::Routes::Login
captcha = nil captcha = nil
account_type = env.params.query["type"]? account_type = env.params.query["type"]?
account_type ||= "invidious" account_type ||= ""
if CONFIG.auth_type.find(&.== account_type).nil? if CONFIG.auth_type.size == 0
if CONFIG.auth_type.size == 0 return error_template(401, "No authentication backend enabled.")
account_type = "invidious" elsif CONFIG.auth_type.find(&.== account_type).nil? && CONFIG.auth_type.size == 1
else account_type = CONFIG.auth_type[0]
account_type = CONFIG.auth_type[0]
end
end end
oauth = CONFIG.auth_type.find(&.== "oauth") && (CONFIG.oauth.size > 0)
captcha_type = env.params.query["captcha"]? captcha_type = env.params.query["captcha"]?
captcha_type ||= "image" captcha_type ||= "image"
@ -39,36 +34,36 @@ module Invidious::Routes::Login
def self.login_oauth(env) def self.login_oauth(env)
locale = env.get("preferences").as(Preferences).locale locale = env.get("preferences").as(Preferences).locale
referer = get_referer(env, "/feed/subscriptions") referer = get_referer(env, "/feed/subscriptions")
authorization_code = env.params.query["code"]? authorization_code = env.params.query["code"]?
provider_k = env.params.url["provider"] provider_k = env.params.url["provider"]
if authorization_code
begin
token = OAuthHelper.get(provider_k).get_access_token_using_authorization_code(authorization_code)
email = OAuthHelper.info_field(provider_k, token)
if email if authorization_code.nil?
user = Invidious::Database::Users.select(email: email)
if user
user_flow_existing(env, email)
else
user_flow_new(env, email, nil)
end
end
rescue ex
return error_template(500, "Internal Error" + (ex.message || ""))
end
else
return error_template(403, "Missing Authorization Code") return error_template(403, "Missing Authorization Code")
end end
begin
token = OAuthHelper.make_client(provider_k).get_access_token_using_authorization_code(authorization_code)
if email = OAuthHelper.info_field(provider_k, token)
if user = Invidious::Database::Users.select(email: email)
if CONFIG.auth_enforce_source && user.password != ("oauth:" + provider_k)
return error_template(401, "Wrong provider")
else
user_flow_existing(env, email)
end
else
user_flow_new(env, email, nil, "oauth:" + provider_k)
end
end
rescue ex
return error_template(500, "Internal Error")
end
env.redirect referer env.redirect referer
end end
def self.login(env) def self.login(env)
locale = env.get("preferences").as(Preferences).locale locale = env.get("preferences").as(Preferences).locale
referer = get_referer(env, "/feed/subscriptions") referer = get_referer(env, "/feed/subscriptions")
if !CONFIG.login_enabled if !CONFIG.login_enabled
@ -78,23 +73,20 @@ module Invidious::Routes::Login
# https://stackoverflow.com/a/574698 # https://stackoverflow.com/a/574698
email = env.params.body["email"]?.try &.downcase.byte_slice(0, 254) email = env.params.body["email"]?.try &.downcase.byte_slice(0, 254)
password = env.params.body["password"]? password = env.params.body["password"]?
oauth = CONFIG.auth_type.find(&.== "oauth") && (CONFIG.oauth.size > 0)
account_type = env.params.query["type"]? account_type = env.params.query["type"]?
account_type ||= "invidious" account_type ||= ""
if CONFIG.auth_type.size == 0 if CONFIG.auth_type.size == 0
return error_template(401, "No authentication backend enabled.") return error_template(401, "No authentication backend enabled.")
end elsif CONFIG.auth_type.find(&.== account_type).nil? && CONFIG.auth_type.size == 1
if CONFIG.auth_type.find(&.== account_type).nil?
account_type = CONFIG.auth_type[0] account_type = CONFIG.auth_type[0]
end end
case account_type case account_type
when "oauth" when "oauth"
provider_k = env.params.body["provider"] provider_k = env.params.body["provider"]
env.redirect OAuthHelper.get(provider_k).get_authorize_uri("openid email profile") env.redirect OAuthHelper.make_client(provider_k).get_authorize_uri("openid email profile")
when "saml" when "saml"
return error_template(501, "Not implemented") return error_template(501, "Not implemented")
when "ldap" when "ldap"
@ -108,18 +100,14 @@ module Invidious::Routes::Login
return error_template(401, "Password is a required field") return error_template(401, "Password is a required field")
end end
user = Invidious::Database::Users.select(email: email) if user = Invidious::Database::Users.select(email: email)
if user.password.not_nil!.starts_with? "oauth"
if user return error_template(401, "Wrong provider")
if Crypto::Bcrypt::Password.new(user.password.not_nil!).verify(password.byte_slice(0, 55)) elsif Crypto::Bcrypt::Password.new(user.password.not_nil!).verify(password.byte_slice(0, 55))
sid = Base64.urlsafe_encode(Random::Secure.random_bytes(32)) user_flow_existing(env, email)
Invidious::Database::SessionIDs.insert(sid, email)
env.response.cookies["SID"] = Invidious::User::Cookies.sid(CONFIG.domain, sid)
else else
return error_template(401, "Wrong username or password") return error_template(401, "Wrong username or password")
end end
user_flow_existing(env, email)
else else
if !CONFIG.registration_enabled if !CONFIG.registration_enabled
return error_template(400, "Registration has been disabled by administrator.") return error_template(400, "Registration has been disabled by administrator.")
@ -196,7 +184,7 @@ module Invidious::Routes::Login
end end
end end
end end
user_flow_new(env, email, password) user_flow_new(env, email, password, "internal")
end end
env.redirect referer env.redirect referer
@ -249,12 +237,12 @@ module Invidious::Routes::Login
end end
end end
def self.user_flow_new(env, email, password) def self.user_flow_new(env, email, password, provider)
sid = Base64.urlsafe_encode(Random::Secure.random_bytes(32)) sid = Base64.urlsafe_encode(Random::Secure.random_bytes(32))
if password if provider == "internal"
user, sid = create_user(sid, email, password) user, sid = create_internal_user(sid, email, password)
else else
user, sid = create_user(sid, email) user, sid = create_user(sid, email, provider)
end end
if language_header = env.request.headers["Accept-Language"]? if language_header = env.request.headers["Accept-Language"]?

View file

@ -3,7 +3,7 @@ require "crypto/bcrypt/password"
# Materialized views may not be defined using bound parameters (`$1` as used elsewhere) # Materialized views may not be defined using bound parameters (`$1` as used elsewhere)
MATERIALIZED_VIEW_SQL = ->(email : String) { "SELECT cv.* FROM channel_videos cv WHERE EXISTS (SELECT subscriptions FROM users u WHERE cv.ucid = ANY (u.subscriptions) AND u.email = E'#{email.gsub({'\'' => "\\'", '\\' => "\\\\"})}') ORDER BY published DESC" } MATERIALIZED_VIEW_SQL = ->(email : String) { "SELECT cv.* FROM channel_videos cv WHERE EXISTS (SELECT subscriptions FROM users u WHERE cv.ucid = ANY (u.subscriptions) AND u.email = E'#{email.gsub({'\'' => "\\'", '\\' => "\\\\"})}') ORDER BY published DESC" }
def create_user(sid, email) def create_user(sid, email, password)
token = Base64.urlsafe_encode(Random::Secure.random_bytes(32)) token = Base64.urlsafe_encode(Random::Secure.random_bytes(32))
user = Invidious::User.new({ user = Invidious::User.new({
@ -12,7 +12,7 @@ def create_user(sid, email)
subscriptions: [] of String, subscriptions: [] of String,
email: email, email: email,
preferences: Preferences.new(CONFIG.default_user_preferences.to_tuple), preferences: Preferences.new(CONFIG.default_user_preferences.to_tuple),
password: nil, password: password,
token: token, token: token,
watched: [] of String, watched: [] of String,
feed_needs_update: true, feed_needs_update: true,
@ -21,23 +21,9 @@ def create_user(sid, email)
return user, sid return user, sid
end end
def create_user(sid, email, password) def create_internal_user(sid, email, password)
password = Crypto::Bcrypt::Password.create(password, cost: 10) password = Crypto::Bcrypt::Password.create(password.not_nil!, cost: 10)
token = Base64.urlsafe_encode(Random::Secure.random_bytes(32)) create_user(sid, email, password.to_s)
user = Invidious::User.new({
updated: Time.utc,
notifications: [] of String,
subscriptions: [] of String,
email: email,
preferences: Preferences.new(CONFIG.default_user_preferences.to_tuple),
password: password.to_s,
token: token,
watched: [] of String,
feed_needs_update: true,
})
return user, sid
end end
def get_subscription_feed(user, max_results = 40, page = 1) def get_subscription_feed(user, max_results = 40, page = 1)

View file

@ -11,14 +11,14 @@
<form class="pure-form pure-form-stacked" action="/login?referer=<%= URI.encode_www_form(referer) %>&type=oauth" method="post"> <form class="pure-form pure-form-stacked" action="/login?referer=<%= URI.encode_www_form(referer) %>&type=oauth" method="post">
<fieldset> <fieldset>
<select name="provider" id="provider"> <select name="provider" id="provider">
<% CONFIG.oauth.each_key do |k| %> <% CONFIG.oauth.each_key do |key| %>
<option value="<%= k %>"><%= k %></option> <option value="<%= key %>"><%= key %></option>
<% end %> <% end %>
</select> </select>
<button type="submit" class="pure-button pure-button-primary"><%= translate(locale, "Sign In via OAuth") %></button> <button type="submit" class="pure-button pure-button-primary"><%= translate(locale, "Sign In via OAuth") %></button>
</fieldset> </fieldset>
</form> </form>
<% else # "invidious" %> <% when "invidious" %>
<form class="pure-form pure-form-stacked" action="/login?referer=<%= URI.encode_www_form(referer) %>&type=invidious" method="post"> <form class="pure-form pure-form-stacked" action="/login?referer=<%= URI.encode_www_form(referer) %>&type=invidious" method="post">
<fieldset> <fieldset>
<% if email %> <% if email %>
@ -79,11 +79,16 @@
<%= translate(locale, "Sign In") %>/<%= translate(locale, "Register") %> <%= translate(locale, "Sign In") %>/<%= translate(locale, "Register") %>
</button> </button>
<% end %> <% end %>
<% if oauth %>
<a class="pure-button pure-button-secondary" href="/login?referer=<%= URI.encode_www_form(referer) %>&type=oauth">OAuth</a>
<% end %>
</fieldset> </fieldset>
</form> </form>
<% else %>
<% if CONFIG.auth_internal_enabled? %>
<a class="pure-button pure-button-secondary" href="/login?referer=<%= URI.encode_www_form(referer) %>&type=invidious">Internal</a>
<% end %>
<% if CONFIG.auth_oauth_enabled? %>
<a class="pure-button pure-button-secondary" href="/login?referer=<%= URI.encode_www_form(referer) %>&type=oauth">OAuth</a>
<% end %>
<label></label>
<% end %> <% end %>
</div> </div>
</div> </div>