diff --git a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java index d70bf05..c4eaa13 100644 --- a/src/main/java/me/kavin/piped/utils/DatabaseHelper.java +++ b/src/main/java/me/kavin/piped/utils/DatabaseHelper.java @@ -16,15 +16,19 @@ public class DatabaseHelper { public static User getUserFromSession(String session) { try (Session s = DatabaseSessionFactory.createSession()) { s.setHibernateFlushMode(FlushMode.MANUAL); - CriteriaBuilder cb = s.getCriteriaBuilder(); - CriteriaQuery cr = cb.createQuery(User.class); - Root root = cr.from(User.class); - cr.select(root).where(cb.equal(root.get("sessionId"), session)); - - return s.createQuery(cr).uniqueResult(); + return getUserFromSession(session, s); } } + public static User getUserFromSession(String session, Session s) { + CriteriaBuilder cb = s.getCriteriaBuilder(); + CriteriaQuery cr = cb.createQuery(User.class); + Root root = cr.from(User.class); + cr.select(root).where(cb.equal(root.get("sessionId"), session)); + + return s.createQuery(cr).uniqueResult(); + } + public static User getUserFromSessionWithSubscribed(String session) { try (Session s = DatabaseSessionFactory.createSession()) { s.setHibernateFlushMode(FlushMode.MANUAL); diff --git a/src/main/java/me/kavin/piped/utils/ResponseHelper.java b/src/main/java/me/kavin/piped/utils/ResponseHelper.java index 4f82343..94616e8 100644 --- a/src/main/java/me/kavin/piped/utils/ResponseHelper.java +++ b/src/main/java/me/kavin/piped/utils/ResponseHelper.java @@ -47,12 +47,9 @@ import org.springframework.security.crypto.argon2.Argon2PasswordEncoder; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import javax.persistence.criteria.CriteriaBuilder; -import javax.persistence.criteria.CriteriaDelete; import javax.persistence.criteria.CriteriaQuery; import javax.persistence.criteria.JoinType; import javax.persistence.criteria.Root; -import javax.persistence.criteria.Subquery; - import java.io.IOException; import java.util.*; import java.util.concurrent.TimeUnit; @@ -626,54 +623,6 @@ public class ResponseHelper { } - private static final Argon2PasswordEncoder argon2PasswordEncoder = new Argon2PasswordEncoder(); - - public static byte[] deleteUserResponse(String session, String pass) throws IOException { - - if (StringUtils.isBlank(pass)) - return Constants.mapper.writeValueAsBytes(new InvalidRequestResponse()); - - try (Session s = DatabaseSessionFactory.createSession()) { - User user = DatabaseHelper.getUserFromSession(session); - - if (user == null) - return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); - - String hash = user.getPassword(); - boolean passMatch = - (hash.startsWith("$argon2") && argon2PasswordEncoder.matches(pass, hash)) || bcryptPasswordEncoder.matches(pass, hash); - - if (!passMatch) - return Constants.mapper.writeValueAsBytes(new IncorrectCredentialsResponse()); - - try { - CriteriaBuilder plCriteria = s.getCriteriaBuilder(); - CriteriaQuery plQuery = - plCriteria.createQuery(me.kavin.piped.utils.obj.db.Playlist.class); - Root plRoot = - plQuery.from(me.kavin.piped.utils.obj.db.Playlist.class); - plQuery.select(plRoot).where(plCriteria.equal(plRoot.get("owner"), user.getId())); - List playlists = s.createQuery(plQuery).getResultList(); - - Iterator iter = playlists.iterator(); - - while (iter.hasNext()) - s.delete(iter.next()); - - s.delete(user); - - s.getTransaction().begin(); - s.getTransaction().commit(); - - Multithreading.runAsync(() -> pruneUnusedPlaylistVideos()); - } catch (Exception e) { - return Constants.mapper.writeValueAsBytes(new ErrorResponse(ExceptionUtils.getStackTrace(e), e.getMessage())); - } - - return Constants.mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername())); - } - } - public static byte[] registerResponse(String user, String pass) throws IOException { if (Constants.DISABLE_REGISTRATION) @@ -717,8 +666,16 @@ public class ResponseHelper { } } + private static final Argon2PasswordEncoder argon2PasswordEncoder = new Argon2PasswordEncoder(); + private static final BCryptPasswordEncoder bcryptPasswordEncoder = new BCryptPasswordEncoder(); + private static boolean hashMatch(String hash, String pass) { + return hash.startsWith("$argon2") ? + argon2PasswordEncoder.matches(pass, hash) : + bcryptPasswordEncoder.matches(pass, hash); + } + public static byte[] loginResponse(String user, String pass) throws IOException { @@ -737,11 +694,7 @@ public class ResponseHelper { if (dbuser != null) { String hash = dbuser.getPassword(); - if (hash.startsWith("$argon2")) { - if (argon2PasswordEncoder.matches(pass, hash)) { - return Constants.mapper.writeValueAsBytes(new LoginResponse(dbuser.getSessionId())); - } - } else if (bcryptPasswordEncoder.matches(pass, hash)) { + if (hashMatch(hash, pass)) { return Constants.mapper.writeValueAsBytes(new LoginResponse(dbuser.getSessionId())); } } @@ -750,6 +703,37 @@ public class ResponseHelper { } } + public static byte[] deleteUserResponse(String session, String pass) throws IOException { + + if (StringUtils.isBlank(pass)) + return Constants.mapper.writeValueAsBytes(new InvalidRequestResponse()); + + try (Session s = DatabaseSessionFactory.createSession()) { + User user = DatabaseHelper.getUserFromSession(session); + + if (user == null) + return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); + + String hash = user.getPassword(); + + if (!hashMatch(hash, pass)) + return Constants.mapper.writeValueAsBytes(new IncorrectCredentialsResponse()); + + try { + s.delete(user); + + s.getTransaction().begin(); + s.getTransaction().commit(); + + Multithreading.runAsync(() -> pruneUnusedPlaylistVideos()); + } catch (Exception e) { + return Constants.mapper.writeValueAsBytes(new ErrorResponse(ExceptionUtils.getStackTrace(e), e.getMessage())); + } + + return Constants.mapper.writeValueAsBytes(new DeleteUserResponse(user.getUsername())); + } + } + public static byte[] subscribeResponse(String session, String channelId) throws IOException { @@ -792,9 +776,7 @@ public class ResponseHelper { if (user != null) { try (Session s = DatabaseSessionFactory.createSession()) { if (user.getSubscribed().contains(channelId)) { - Set subscribed = user.getSubscribed(); - subscribed.removeIf(sub -> sub.equals(channelId)); - user.setSubscribed(subscribed); + user.getSubscribed().remove(channelId); s.update(user); s.getTransaction().begin(); @@ -957,7 +939,7 @@ public class ResponseHelper { Multithreading.runAsync(() -> { try (Session s = DatabaseSessionFactory.createSession()) { var channels = DatabaseHelper.getChannelsFromIds(s, Arrays.asList(channelIds)); - + outer: for (String channelId : channelIds) { for (var channel : channels) @@ -1057,7 +1039,7 @@ public class ResponseHelper { s.getTransaction().begin(); s.getTransaction().commit(); - + Multithreading.runAsync(() -> pruneUnusedPlaylistVideos()); } @@ -1066,21 +1048,16 @@ public class ResponseHelper { public static byte[] playlistsResponse(String session) throws IOException { - User user = DatabaseHelper.getUserFromSession(session); - - if (user == null) - return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); - try (Session s = DatabaseSessionFactory.createSession()) { - var cb = s.getCriteriaBuilder(); - var query = cb.createQuery(me.kavin.piped.utils.obj.db.Playlist.class); - var root = query.from(me.kavin.piped.utils.obj.db.Playlist.class); - query.select(root); - query.where(cb.equal(root.get("owner"), user)); + + User user = DatabaseHelper.getUserFromSession(session, s); + + if (user == null) + return Constants.mapper.writeValueAsBytes(new AuthenticationFailureResponse()); var playlists = new ObjectArrayList<>(); - for (var playlist : s.createQuery(query).list()) { + for (var playlist : user.getPlaylists()) { ObjectNode node = Constants.mapper.createObjectNode(); node.put("id", String.valueOf(playlist.getPlaylistId())); node.put("name", playlist.getName()); @@ -1221,15 +1198,15 @@ public class ResponseHelper { try (Session s = DatabaseSessionFactory.createSession()) { CriteriaBuilder cb = s.getCriteriaBuilder(); - - CriteriaDelete pvQuery = cb.createCriteriaDelete(PlaylistVideo.class); - Root pvRoot = pvQuery.from(PlaylistVideo.class); - Subquery subQuery = pvQuery.subquery(String.class); - Root subRoot = - subQuery.from(me.kavin.piped.utils.obj.db.Playlist.class); - + var pvQuery = cb.createCriteriaDelete(PlaylistVideo.class); + var pvRoot = pvQuery.from(PlaylistVideo.class); + + var subQuery = pvQuery.subquery(me.kavin.piped.utils.obj.db.Playlist.class); + var subRoot = subQuery.from(me.kavin.piped.utils.obj.db.Playlist.class); + subQuery.select(subRoot.join("videos").get("id")); + pvQuery.where(cb.not(pvRoot.get("id").in(subQuery))); s.getTransaction().begin(); diff --git a/src/main/java/me/kavin/piped/utils/obj/db/User.java b/src/main/java/me/kavin/piped/utils/obj/db/User.java index dc5a0f9..a7f496e 100644 --- a/src/main/java/me/kavin/piped/utils/obj/db/User.java +++ b/src/main/java/me/kavin/piped/utils/obj/db/User.java @@ -34,6 +34,9 @@ public class User implements Serializable { @Column(name = "channel", length = 30) private Set subscribed_ids; + @OneToMany(mappedBy = "owner", cascade = CascadeType.ALL) + private Set playlists; + public User() { } @@ -83,4 +86,12 @@ public class User implements Serializable { public void setSubscribed(Set subscribed_ids) { this.subscribed_ids = subscribed_ids; } + + public Set getPlaylists() { + return playlists; + } + + public void setPlaylists(Set playlists) { + this.playlists = playlists; + } }