diff --git a/build.gradle b/build.gradle index 8ceed09..330ea06 100644 --- a/build.gradle +++ b/build.gradle @@ -18,6 +18,7 @@ dependencies { implementation 'org.bouncycastle:bcprov-jdk15on:1.70' implementation 'com.github.FireMasterK.NewPipeExtractor:NewPipeExtractor:8cf9a4aef0919df2ef1baafd30ab5bfefefc0844' implementation 'com.github.FireMasterK:nanojson:9f4af3b739cc13f3d0d9d4b758bbe2b2ae7119d7' + implementation 'com.nimbusds:oauth2-oidc-sdk:11.5' implementation 'com.fasterxml.jackson.core:jackson-core:2.15.2' implementation 'com.fasterxml.jackson.core:jackson-annotations:2.15.2' implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.2' diff --git a/config.properties b/config.properties index 9b5fed0..f538750 100644 --- a/config.properties +++ b/config.properties @@ -82,3 +82,11 @@ hibernate.connection.password:changeme # Frontend configuration #frontend.statusPageUrl:https://kavin.rocks #frontend.donationUrl:https://kavin.rocks + +# Oidc configuration +#oidc.provider.INSERT_HERE.name:INSERT_HERE +#oidc.provider.INSERT_HERE.clientId:INSERT_HERE +#oidc.provider.INSERT_HERE.clientSecret:INSERT_HERE +#oidc.provider.INSERT_HERE.authUri:INSERT_HERE +#oidc.provider.INSERT_HERE.tokenUri:INSERT_HERE +#oidc.provider.INSERT_HERE.userinfoUri:INSERT_HERE diff --git a/src/main/java/me/kavin/piped/consts/Constants.java b/src/main/java/me/kavin/piped/consts/Constants.java index 516ece6..07c5b81 100644 --- a/src/main/java/me/kavin/piped/consts/Constants.java +++ b/src/main/java/me/kavin/piped/consts/Constants.java @@ -3,12 +3,15 @@ package me.kavin.piped.consts; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.fasterxml.jackson.databind.node.ObjectNode; import io.minio.MinioClient; import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectArrayList; import me.kavin.piped.utils.PageMixin; import me.kavin.piped.utils.RequestUtils; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.resp.ListLinkHandlerMixin; import okhttp3.OkHttpClient; import okhttp3.brotli.BrotliInterceptor; @@ -23,9 +26,8 @@ import rocks.kavin.reqwest4j.ReqwestUtils; import java.io.File; import java.io.FileReader; -import java.net.InetSocketAddress; -import java.net.ProxySelector; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.regex.Pattern; @@ -102,6 +104,7 @@ public class Constants { public static final String YOUTUBE_COUNTRY; public static final String VERSION; + public static final List OIDC_PROVIDERS; public static final ObjectMapper mapper = JsonMapper.builder() .addMixIn(Page.class, PageMixin.class) @@ -168,12 +171,37 @@ public class Constants { MATRIX_SERVER = getProperty(prop, "MATRIX_SERVER", "https://matrix-client.matrix.org"); MATRIX_TOKEN = getProperty(prop, "MATRIX_TOKEN"); GEO_RESTRICTION_CHECKER_URL = getProperty(prop, "GEO_RESTRICTION_CHECKER_URL"); + + OIDC_PROVIDERS = new ObjectArrayList<>(); + + Map> oidcProviderConfig = new Object2ObjectOpenHashMap<>(); + ArrayNode providerNames = frontendProperties.putArray("oidcProviders"); prop.forEach((_key, _value) -> { String key = String.valueOf(_key), value = String.valueOf(_value); if (key.startsWith("hibernate")) hibernateProperties.put(key, value); else if (key.startsWith("frontend.")) frontendProperties.put(StringUtils.substringAfter(key, "frontend."), value); + else if (key.startsWith("oidc.provider")) { + String[] split = key.split("\\."); + if (split.length != 4) return; + oidcProviderConfig + .computeIfAbsent(split[2], k -> new Object2ObjectOpenHashMap<>()) + .put(split[3], value); + } + }); + oidcProviderConfig.forEach((provider, config) -> { + ObjectNode providerNode = frontendProperties.putObject(provider); + OIDC_PROVIDERS.add(new OidcProvider( + getRequiredMapValue(config, "name"), + getRequiredMapValue(config, "clientId"), + getRequiredMapValue(config, "clientSecret"), + getRequiredMapValue(config, "authUri"), + getRequiredMapValue(config, "tokenUri"), + getRequiredMapValue(config, "userinfoUri") + )); + providerNames.add(provider); + config.forEach(providerNode::put); }); frontendProperties.put("imageProxyUrl", IMAGE_PROXY_PART); frontendProperties.putArray("countries").addAll( @@ -230,4 +258,13 @@ public class Constants { return prop.getProperty(key, def); } + + private static String getRequiredMapValue(final Map map, Object key) { + String value = map.get(key); + if (StringUtils.isBlank(value)) { + System.err.println("Missing '" + key + "' in sub-configuration"); + System.exit(1); + } + return value; + } } diff --git a/src/main/java/me/kavin/piped/server/ServerLauncher.java b/src/main/java/me/kavin/piped/server/ServerLauncher.java index 21b2e2b..3098bfd 100644 --- a/src/main/java/me/kavin/piped/server/ServerLauncher.java +++ b/src/main/java/me/kavin/piped/server/ServerLauncher.java @@ -18,8 +18,11 @@ import me.kavin.piped.server.handlers.auth.AuthPlaylistHandlers; import me.kavin.piped.server.handlers.auth.FeedHandlers; import me.kavin.piped.server.handlers.auth.StorageHandlers; import me.kavin.piped.server.handlers.auth.UserHandlers; +import me.kavin.piped.utils.ErrorResponse; import me.kavin.piped.utils.*; import me.kavin.piped.utils.obj.MatrixHelper; +import me.kavin.piped.utils.obj.OidcData; +import me.kavin.piped.utils.obj.OidcProvider; import me.kavin.piped.utils.obj.federation.FederatedVideoInfo; import me.kavin.piped.utils.resp.*; import org.apache.commons.lang3.StringUtils; @@ -33,10 +36,9 @@ import org.xml.sax.InputSource; import java.io.ByteArrayInputStream; import java.net.InetSocketAddress; -import java.util.List; +import java.net.URI; import java.util.Objects; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; import static io.activej.config.converter.ConfigConverters.ofInetSocketAddress; import static io.activej.http.HttpHeaders.*; @@ -271,6 +273,22 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { } catch (Exception e) { return getErrorResponse(e, request.getPath()); } + })).map(GET, "/oidc/:provider/:function", AsyncServlet.ofBlocking(executor, request -> { + try { + String function = request.getPathParameter("function"); + OidcProvider provider = getOidcProvider(request.getPathParameter("provider")); + if (provider == null) + return HttpResponse.ofCode(500).withHtml("Can't find the provider on the server"); + + return switch (function) { + case "login" -> UserHandlers.oidcLoginResponse(provider, request.getQueryParameter("redirect")); + case "callback" -> UserHandlers.oidcCallbackResponse(provider, URI.create(request.getFullUrl())); + case "delete" -> UserHandlers.oidcDeleteResponse(provider, URI.create(request.getFullUrl())); + default -> HttpResponse.ofCode(500).withHtml("Invalid function `" + function + "`"); + }; + } catch (Exception e) { + return getErrorResponse(e, request.getPath()); + } })).map(POST, "/login", AsyncServlet.ofBlocking(executor, request -> { try { LoginRequest body = mapper.readValue(request.loadBody().getResult().asArray(), @@ -517,6 +535,15 @@ public class ServerLauncher extends MultithreadedHttpServerLauncher { return new CustomServletDecorator(router); } + private static OidcProvider getOidcProvider(String provider) { + for (int i = 0; i < Constants.OIDC_PROVIDERS.size(); i++) { + OidcProvider curr = Constants.OIDC_PROVIDERS.get(i); + if (curr == null || !curr.name.equals(provider)) continue; + return curr; + } + return null; + } + private static String[] getArray(String s) { if (s == null) { diff --git a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java index 3e0bfe5..7fb56ca 100644 --- a/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java +++ b/src/main/java/me/kavin/piped/server/handlers/auth/UserHandlers.java @@ -1,6 +1,14 @@ package me.kavin.piped.server.handlers.auth; import com.fasterxml.jackson.core.JsonProcessingException; +import com.nimbusds.jwt.JWTClaimsSet; +import com.nimbusds.oauth2.sdk.*; +import com.nimbusds.oauth2.sdk.auth.ClientAuthentication; +import com.nimbusds.oauth2.sdk.auth.ClientSecretBasic; +import com.nimbusds.oauth2.sdk.id.State; +import com.nimbusds.openid.connect.sdk.*; +import com.nimbusds.openid.connect.sdk.claims.UserInfo; +import io.activej.http.HttpResponse; import jakarta.persistence.criteria.CriteriaBuilder; import jakarta.persistence.criteria.CriteriaQuery; import jakarta.persistence.criteria.Root; @@ -9,6 +17,9 @@ import me.kavin.piped.utils.DatabaseHelper; import me.kavin.piped.utils.DatabaseSessionFactory; import me.kavin.piped.utils.ExceptionHandler; import me.kavin.piped.utils.RequestUtils; +import me.kavin.piped.utils.obj.OidcData; +import me.kavin.piped.utils.obj.OidcProvider; +import me.kavin.piped.utils.obj.db.OidcUserData; import me.kavin.piped.utils.obj.db.User; import me.kavin.piped.utils.resp.*; import org.apache.commons.codec.digest.DigestUtils; @@ -19,6 +30,10 @@ import org.springframework.security.crypto.argon2.Argon2PasswordEncoder; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import java.io.IOException; +import java.net.URI; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.UUID; @@ -27,6 +42,7 @@ import static me.kavin.piped.consts.Constants.mapper; public class UserHandlers { private static final Argon2PasswordEncoder argon2PasswordEncoder = Argon2PasswordEncoder.defaultsForSpringSecurity_v5_8(); private static final BCryptPasswordEncoder bcryptPasswordEncoder = new BCryptPasswordEncoder(); + public static final Map PENDING_OIDC = new HashMap<>(); public static byte[] registerResponse(String user, String pass) throws Exception { @@ -109,10 +125,169 @@ public class UserHandlers { } } - public static byte[] deleteUserResponse(String session, String pass) throws IOException { + public static HttpResponse oidcLoginResponse(OidcProvider provider, String redirectUri) throws Exception{ + if (StringUtils.isBlank(redirectUri)) { + return HttpResponse.ofCode(400).withHtml("redirect is a required parameter"); + } - if (StringUtils.isBlank(session) || StringUtils.isBlank(pass)) - ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session and password are required parameters")); + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); + OidcData data = new OidcData(redirectUri); + String state = data.getState(); + + PENDING_OIDC.put(state, data); + + AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder( + new ResponseType("code"), + new Scope("openid"), + provider.clientID, callback).endpointURI(provider.authUri) + .state(new State(state)).nonce(data.nonce).build(); + + if (redirectUri.equals(Constants.FRONTEND_URL + "/login")) { + return HttpResponse.redirect302(oidcRequest.toURI().toString()); + } + return HttpResponse.ok200().withHtml( + "" + + "

Warning:

You are trying to give
" +
+                        redirectUri +
+                        "
access to your Piped account. If you wish to continue click " + + "here"); + } + public static HttpResponse oidcCallbackResponse(OidcProvider provider, URI requestUri) throws Exception { + ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); + + AuthenticationSuccessResponse sr = parseOidcUri(requestUri); + + OidcData data = PENDING_OIDC.get(sr.getState().toString()); + if (data == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent invalid state data. Try again or contact your oidc admin" + ); + } + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/callback"); + AuthorizationCode code = sr.getAuthorizationCode(); + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback); + + + TokenRequest tokenReq = new TokenRequest(provider.tokenUri, clientAuth, codeGrant); + OIDCTokenResponse tokenResponse = (OIDCTokenResponse) OIDCTokenResponseParser.parse(tokenReq.toHTTPRequest().send()); + + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); + } + + OIDCTokenResponse successResponse = tokenResponse.toSuccessResponse(); + + if (data.isInvalidNonce((String) successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet().getClaim("nonce"))) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid nonce. Try again or contact your oidc admin" + ); + } + + UserInfoRequest ur = new UserInfoRequest(provider.userinfoUri, successResponse.getOIDCTokens().getBearerAccessToken()); + UserInfoResponse userInfoResponse = UserInfoResponse.parse(ur.toHTTPRequest().send()); + + if (!userInfoResponse.indicatesSuccess()) { + System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getCode()); + System.out.println(userInfoResponse.toErrorResponse().getErrorObject().getDescription()); + return HttpResponse.ofCode(500).withHtml( + "The userinfo endpoint returned an error. Please try again or contact your oidc admin\n\n" + + userInfoResponse.toErrorResponse().getErrorObject().getDescription()); + } + + UserInfo userInfo = userInfoResponse.toSuccessResponse().getUserInfo(); + + + String uid = userInfo.getSubject().toString(); + String sessionId; + try (Session s = DatabaseSessionFactory.createSession()) { + // TODO: Add oidc provider to database + String dbName = provider + "-" + uid; + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(User.class); + Root root = cr.from(User.class); + cr.select(root).where(root.get("username").in( + dbName + )); + + User dbuser = s.createQuery(cr).uniqueResult(); + + if (dbuser == null) { + User newuser = new User(dbName, "", Set.of()); + + var tr = s.beginTransaction(); + s.persist(newuser); + tr.commit(); + + + sessionId = newuser.getSessionId(); + } else sessionId = dbuser.getSessionId(); + } + return HttpResponse.redirect302(data.data + "?session=" + sessionId); + + } + + public static HttpResponse oidcDeleteResponse(OidcProvider provider, URI requestUri) throws Exception { + ClientAuthentication clientAuth = new ClientSecretBasic(provider.clientID, provider.clientSecret); + + AuthenticationSuccessResponse sr = parseOidcUri(requestUri); + + OidcData data = UserHandlers.PENDING_OIDC.get(sr.getState().toString()); + if (data == null) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent invalid state data. Try again or contact your oidc admin" + ); + } + + long start = Long.parseLong(data.data.split("\\|")[1]); + String session = data.data.split("\\|")[0]; + + URI callback = new URI(Constants.PUBLIC_URL + "/oidc/" + provider.name + "/delete"); + AuthorizationCode code = sr.getAuthorizationCode(); + AuthorizationGrant codeGrant = new AuthorizationCodeGrant(code, callback); + + + TokenRequest tokenRequest = new TokenRequest(provider.tokenUri, clientAuth, codeGrant); + TokenResponse tokenResponse = OIDCTokenResponseParser.parse(tokenRequest.toHTTPRequest().send()); + + if (!tokenResponse.indicatesSuccess()) { + TokenErrorResponse errorResponse = tokenResponse.toErrorResponse(); + return HttpResponse.ofCode(500).withHtml("Failure while trying to request token:\n\n" + errorResponse.getErrorObject().getDescription()); + } + + OIDCTokenResponse successResponse = (OIDCTokenResponse) tokenResponse.toSuccessResponse(); + + JWTClaimsSet claims = successResponse.getOIDCTokens().getIDToken().getJWTClaimsSet(); + + if (data.isInvalidNonce((String) claims.getClaim("nonce"))) { + return HttpResponse.ofCode(400).withHtml( + "Your oidc provider sent an invalid nonce. Please try again or contact your oidc admin." + ); + } + + long authTime = (long) claims.getClaim("auth_time"); + + if (authTime < start) { + return HttpResponse.ofCode(500).withHtml( + "Your oidc provider didn't verify your identity. Please try again or contact your oidc admin." + ); + } + + try (Session s = DatabaseSessionFactory.createSession()) { + + var tr = s.beginTransaction(); + s.remove(DatabaseHelper.getUserFromSession(session)); + tr.commit(); + } + return HttpResponse.redirect302(Constants.FRONTEND_URL + "/preferences?deleted=" + session); + } + + public static byte[] deleteUserResponse(String session, String pass) throws IOException { + if (StringUtils.isBlank(session)) + ExceptionHandler.throwErrorResponse(new InvalidRequestResponse("session is a required parameter")); try (Session s = DatabaseSessionFactory.createSession()) { User user = DatabaseHelper.getUserFromSession(session); @@ -122,6 +297,31 @@ public class UserHandlers { String hash = user.getPassword(); + if (hash.isEmpty()) { + + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(OidcUserData.class); + Root root = cr.from(OidcUserData.class); + cr.select(root).where(cb.equal(root.get("user"), user.getId())); + + OidcUserData oidcUserData = s.createQuery(cr).uniqueResult(); + + //TODO: Get user from oidc table and lookup provider + OidcProvider provider = Constants.OIDC_PROVIDERS.get(0); + URI callback = URI.create(String.format("%s/oidc/%s/delete", Constants.PUBLIC_URL, provider.name)); + OidcData data = new OidcData(session + "|" + Instant.now().getEpochSecond()); + String state = data.getState(); + PENDING_OIDC.put(state, data); + + AuthenticationRequest oidcRequest = new AuthenticationRequest.Builder( + new ResponseType("code"), + new Scope("openid"), provider.clientID, callback).endpointURI(provider.authUri) + .state(new State(state)).nonce(data.nonce).maxAge(0).build(); + + + return mapper.writeValueAsBytes(mapper.createObjectNode() + .put("redirect", oidcRequest.toURI().toString())); + } if (!hashMatch(hash, pass)) ExceptionHandler.throwErrorResponse(new IncorrectCredentialsResponse()); @@ -133,6 +333,7 @@ public class UserHandlers { } } + public static byte[] logoutResponse(String session) throws JsonProcessingException { if (StringUtils.isBlank(session)) @@ -151,4 +352,14 @@ public class UserHandlers { return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); } + + private static AuthenticationSuccessResponse parseOidcUri(URI uri) throws Exception { + AuthenticationResponse response = AuthenticationResponseParser.parse(uri); + + if (response instanceof AuthenticationErrorResponse) { + System.err.println(response.toErrorResponse().getErrorObject()); + throw new Exception(response.toErrorResponse().getErrorObject().toString()); + } + return response.toSuccessResponse(); + } } diff --git a/src/main/java/me/kavin/piped/utils/obj/OidcData.java b/src/main/java/me/kavin/piped/utils/obj/OidcData.java new file mode 100644 index 0000000..f1bfd6c --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/OidcData.java @@ -0,0 +1,36 @@ +package me.kavin.piped.utils.obj; + +import com.nimbusds.openid.connect.sdk.Nonce; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; + +public class OidcData { + public final Nonce nonce; + + public String data; + + public OidcData(String data) { + this.nonce = new Nonce(); + this.data = data; + } + + public boolean isInvalidNonce(String nonce) { + return !nonce.equals(this.nonce.toString()); + } + + public String getState() { + String value = nonce + data; + + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] hash = md.digest(value.getBytes(StandardCharsets.UTF_8)); + return Base64.getEncoder().encodeToString(hash); + + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException("SHA-256 not supported", e); + } + } +} diff --git a/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java new file mode 100644 index 0000000..aedce63 --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/OidcProvider.java @@ -0,0 +1,30 @@ +package me.kavin.piped.utils.obj; + +import com.nimbusds.oauth2.sdk.auth.Secret; +import com.nimbusds.oauth2.sdk.id.ClientID; + +import java.net.URI; +import java.net.URISyntaxException; + +public class OidcProvider { + public String name; + public ClientID clientID; + public Secret clientSecret; + public URI authUri; + public URI tokenUri; + public URI userinfoUri; + + public OidcProvider(String name, String clientID, String clientSecret, String authUri, String tokenUri, String userinfoUri) { + this.name = name; + this.clientID = new ClientID(clientID); + this.clientSecret = new Secret(clientSecret); + try { + this.authUri = new URI(authUri); + this.tokenUri = new URI(tokenUri); + this.userinfoUri = new URI(userinfoUri); + } catch (URISyntaxException e) { + System.err.println("Malformed URI for oidc provider '" + name + "' found."); + System.exit(1); + } + } +} diff --git a/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java b/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java new file mode 100644 index 0000000..e23d621 --- /dev/null +++ b/src/main/java/me/kavin/piped/utils/obj/db/OidcUserData.java @@ -0,0 +1,19 @@ +package me.kavin.piped.utils.obj.db; + +import jakarta.persistence.*; + +@Entity +@Table(name = "oidc_user_data") +public class OidcUserData { + + @Column(unique = true) + @Id + private String sub; + + @OneToOne + private User user; + + private String provider; + + +} diff --git a/src/main/java/me/kavin/piped/utils/obj/db/User.java b/src/main/java/me/kavin/piped/utils/obj/db/User.java index 2b8f603..4a3e343 100644 --- a/src/main/java/me/kavin/piped/utils/obj/db/User.java +++ b/src/main/java/me/kavin/piped/utils/obj/db/User.java @@ -21,7 +21,7 @@ public class User implements Serializable { @Column(name = "id") private long id; - @Column(name = "username", unique = true, length = 24) + @Column(name = "username", unique = true, length = 32) private String username; @Column(name = "password", columnDefinition = "text")