From a9ea786279eea2e590b93f5e064d06615872e1be Mon Sep 17 00:00:00 2001 From: TaiAurori <31465218+TaiAurori@users.noreply.github.com> Date: Wed, 22 Jun 2022 21:20:15 -0400 Subject: [PATCH] use queryrow, check if player exists before making --- db.go | 182 ++++++++++++++++++++++++++++--------------------------- types.go | 6 ++ util.go | 4 ++ 3 files changed, 103 insertions(+), 89 deletions(-) diff --git a/db.go b/db.go index 78f46f8..4fd369c 100644 --- a/db.go +++ b/db.go @@ -57,34 +57,31 @@ func getAuthToken(username string, password string) (string, error) { sqlStatement := ` SELECT auth_token FROM users WHERE username = ? AND password = ?; ` - rows, err := DB.Query(sqlStatement, username, password) - if err != nil { - return "", err - } - defer rows.Close() + row := DB.QueryRow(sqlStatement, username, password) - // check amount of rows - if rows.Next() { - // get auth token - var authToken string - rows.Scan(&authToken) - rows.Close() - if authToken == "" { - // generate new authToken - authToken = uuid.New().String() - // update authToken - sqlStatement := ` - UPDATE users SET auth_token = ? WHERE username = ? AND password = ?; - ` - _, err := DB.Exec(sqlStatement, authToken, username, password) - if err != nil { - return "", err - } + // get auth token + var authToken string + err := row.Scan(&authToken) + if err != nil { + if err != sql.ErrNoRows { + return "", err + } else { + return "", &NotFoundError{} } - return authToken, nil - } else { - return "", &NotFoundError{} } + if authToken == "" { + // generate new authToken + authToken = uuid.New().String() + // update authToken + sqlStatement := ` + UPDATE users SET auth_token = ? WHERE username = ? AND password = ?; + ` + _, err := DB.Exec(sqlStatement, authToken, username, password) + if err != nil { + return "", err + } + } + return authToken, nil } func checkClientToken(clientToken string, userName string) (string, error) { @@ -92,14 +89,14 @@ func checkClientToken(clientToken string, userName string) (string, error) { sqlStatement := ` SELECT id FROM users WHERE client_token = ? AND username = ?; ` - rows, err := DB.Query(sqlStatement, clientToken, userName) - if err != nil { + var x string + err := DB.QueryRow(sqlStatement, clientToken, userName).Scan(&x) + + // check if row exists + if err != nil && err != sql.ErrNoRows { return "", err } - defer rows.Close() - - // check amount of rows - if rows.Next() { + if err == nil { return clientToken, nil } else { clientToken = uuid.New().String() @@ -141,95 +138,102 @@ func clearAuthToken(username string) error { func createUser(username string, adminToken string) (string, error) { // check if adminToken is valid if validateAdminToken(adminToken) { - password := uuid.New().String() - insertUser(username, password) - return password, nil + exists, err := playerExistsByUsername(username) + if err != nil { + return "", err + } + if !exists { + password := uuid.New().String() + insertUser(username, password) + return password, nil + } else { + return "", &AlreadyExistsError{} + } } else { return "", &InvalidCredentialsError{} } } +func playerExistsByUsername(username string) (bool, error) { + sqlStatement := ` + SELECT username FROM users WHERE username = ?; + ` + var x string + err := DB.QueryRow(sqlStatement, username).Scan(&x) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + return true, nil +} + func getPlayerUUID(username string) (string, error) { sqlStatement := ` SELECT uuid FROM users WHERE username = ?; ` - rows, err := DB.Query(sqlStatement, username) + row := DB.QueryRow(sqlStatement, username) + + var uuid string + err := row.Scan(&uuid) if err != nil { return "", err } - defer rows.Close() - - // check amount of rows - if rows.Next() { - // get uuid - var uuid string - err = rows.Scan(&uuid) - if err != nil { - return "", err - } - return uuid, nil - } else { - return "", &NotFoundError{} - } + return uuid, nil } func refreshTokens(refresh RefreshPayload) (RefreshPayload, error) { sqlStatement := ` SELECT id FROM users WHERE auth_token = ? and client_token = ?; ` - rows, err := DB.Query(sqlStatement, refresh.AccessToken, refresh.ClientToken) + row := DB.QueryRow(sqlStatement, refresh.AccessToken, refresh.ClientToken) + + // get id + var id int + err := row.Scan(&id) + if err != nil { return RefreshPayload{}, err } - if rows.Next() { - // get id - var id int - err = rows.Scan(&id) - rows.Close() - if err != nil { - return RefreshPayload{}, err - } - // generate new authToken - authToken := uuid.New().String() - // update authToken - sqlStatement := ` - UPDATE users SET auth_token = ? WHERE id = ?; - ` - _, err := DB.Exec(sqlStatement, authToken, id) - if err != nil { - return RefreshPayload{}, err - } - // generate new clientToken - clientToken := uuid.New().String() - // update clientToken - sqlStatement = ` - UPDATE users SET client_token = ? WHERE id = ?; - ` - _, err = DB.Exec(sqlStatement, clientToken, id) - if err != nil { - return RefreshPayload{}, err - } - refresh.AccessToken = authToken - refresh.ClientToken = clientToken - return refresh, nil - } else { - return refresh, nil + // generate new authToken + authToken := uuid.New().String() + // update authToken + sqlStatement = ` + UPDATE users SET auth_token = ? WHERE id = ?; + ` + _, err = DB.Exec(sqlStatement, authToken, id) + if err != nil { + return RefreshPayload{}, err } + // generate new clientToken + clientToken := uuid.New().String() + // update clientToken + sqlStatement = ` + UPDATE users SET client_token = ? WHERE id = ?; + ` + _, err = DB.Exec(sqlStatement, clientToken, id) + if err != nil { + return RefreshPayload{}, err + } + refresh.AccessToken = authToken + refresh.ClientToken = clientToken + return refresh, nil } func validateTokens(authToken string, clientToken string) (bool, error) { sqlStatement := ` SELECT id FROM users WHERE auth_token = ? and client_token = ?; ` - rows, err := DB.Query(sqlStatement, authToken, clientToken) + var x string + err := DB.QueryRow(sqlStatement, authToken, clientToken).Scan(&x) if err != nil { + if err == sql.ErrNoRows { + return false, nil + } return false, err } - if rows.Next() { - return true, nil - } else { - return false, nil - } + return true, nil } func invalidateTokens(authToken string, clientToken string) error { diff --git a/types.go b/types.go index 5725387..3320731 100644 --- a/types.go +++ b/types.go @@ -63,3 +63,9 @@ type InvalidCredentialsError struct{} func (m *InvalidCredentialsError) Error() string { return "Invalid credentials" } + +type AlreadyExistsError struct{} + +func (m *AlreadyExistsError) Error() string { + return "The specified item already exists" +} diff --git a/util.go b/util.go index d6f9e14..52db0e3 100644 --- a/util.go +++ b/util.go @@ -33,6 +33,10 @@ func handleError(w http.ResponseWriter, err error) { switch err.Error() { case "unexpected end of JSON input": sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "The request data is malformed."}) + case "The specified item already exists": + sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "The specified item already exists."}) + case "Invalid credentials": + sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "Invalid credentials"}) default: sendError(w, YggError{Code: 500, Error: "Unspecified error", ErrorMessage: "An error has occured handling your request."}) }