forked from TripwireTeam/tripwire
use queryrow, check if player exists before making
This commit is contained in:
parent
e7cc5f4ba7
commit
a9ea786279
3 changed files with 103 additions and 89 deletions
182
db.go
182
db.go
|
@ -57,34 +57,31 @@ func getAuthToken(username string, password string) (string, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT auth_token FROM users WHERE username = ? AND password = ?;
|
SELECT auth_token FROM users WHERE username = ? AND password = ?;
|
||||||
`
|
`
|
||||||
rows, err := DB.Query(sqlStatement, username, password)
|
row := DB.QueryRow(sqlStatement, username, password)
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// check amount of rows
|
// get auth token
|
||||||
if rows.Next() {
|
var authToken string
|
||||||
// get auth token
|
err := row.Scan(&authToken)
|
||||||
var authToken string
|
if err != nil {
|
||||||
rows.Scan(&authToken)
|
if err != sql.ErrNoRows {
|
||||||
rows.Close()
|
return "", err
|
||||||
if authToken == "" {
|
} else {
|
||||||
// generate new authToken
|
return "", &NotFoundError{}
|
||||||
authToken = uuid.New().String()
|
|
||||||
// update authToken
|
|
||||||
sqlStatement := `
|
|
||||||
UPDATE users SET auth_token = ? WHERE username = ? AND password = ?;
|
|
||||||
`
|
|
||||||
_, err := DB.Exec(sqlStatement, authToken, username, password)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return authToken, nil
|
|
||||||
} else {
|
|
||||||
return "", &NotFoundError{}
|
|
||||||
}
|
}
|
||||||
|
if authToken == "" {
|
||||||
|
// generate new authToken
|
||||||
|
authToken = uuid.New().String()
|
||||||
|
// update authToken
|
||||||
|
sqlStatement := `
|
||||||
|
UPDATE users SET auth_token = ? WHERE username = ? AND password = ?;
|
||||||
|
`
|
||||||
|
_, err := DB.Exec(sqlStatement, authToken, username, password)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return authToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkClientToken(clientToken string, userName string) (string, error) {
|
func checkClientToken(clientToken string, userName string) (string, error) {
|
||||||
|
@ -92,14 +89,14 @@ func checkClientToken(clientToken string, userName string) (string, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT id FROM users WHERE client_token = ? AND username = ?;
|
SELECT id FROM users WHERE client_token = ? AND username = ?;
|
||||||
`
|
`
|
||||||
rows, err := DB.Query(sqlStatement, clientToken, userName)
|
var x string
|
||||||
if err != nil {
|
err := DB.QueryRow(sqlStatement, clientToken, userName).Scan(&x)
|
||||||
|
|
||||||
|
// check if row exists
|
||||||
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
if err == nil {
|
||||||
|
|
||||||
// check amount of rows
|
|
||||||
if rows.Next() {
|
|
||||||
return clientToken, nil
|
return clientToken, nil
|
||||||
} else {
|
} else {
|
||||||
clientToken = uuid.New().String()
|
clientToken = uuid.New().String()
|
||||||
|
@ -141,95 +138,102 @@ func clearAuthToken(username string) error {
|
||||||
func createUser(username string, adminToken string) (string, error) {
|
func createUser(username string, adminToken string) (string, error) {
|
||||||
// check if adminToken is valid
|
// check if adminToken is valid
|
||||||
if validateAdminToken(adminToken) {
|
if validateAdminToken(adminToken) {
|
||||||
password := uuid.New().String()
|
exists, err := playerExistsByUsername(username)
|
||||||
insertUser(username, password)
|
if err != nil {
|
||||||
return password, nil
|
return "", err
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
password := uuid.New().String()
|
||||||
|
insertUser(username, password)
|
||||||
|
return password, nil
|
||||||
|
} else {
|
||||||
|
return "", &AlreadyExistsError{}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return "", &InvalidCredentialsError{}
|
return "", &InvalidCredentialsError{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func playerExistsByUsername(username string) (bool, error) {
|
||||||
|
sqlStatement := `
|
||||||
|
SELECT username FROM users WHERE username = ?;
|
||||||
|
`
|
||||||
|
var x string
|
||||||
|
err := DB.QueryRow(sqlStatement, username).Scan(&x)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getPlayerUUID(username string) (string, error) {
|
func getPlayerUUID(username string) (string, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT uuid FROM users WHERE username = ?;
|
SELECT uuid FROM users WHERE username = ?;
|
||||||
`
|
`
|
||||||
rows, err := DB.Query(sqlStatement, username)
|
row := DB.QueryRow(sqlStatement, username)
|
||||||
|
|
||||||
|
var uuid string
|
||||||
|
err := row.Scan(&uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
return uuid, nil
|
||||||
|
|
||||||
// check amount of rows
|
|
||||||
if rows.Next() {
|
|
||||||
// get uuid
|
|
||||||
var uuid string
|
|
||||||
err = rows.Scan(&uuid)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return uuid, nil
|
|
||||||
} else {
|
|
||||||
return "", &NotFoundError{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func refreshTokens(refresh RefreshPayload) (RefreshPayload, error) {
|
func refreshTokens(refresh RefreshPayload) (RefreshPayload, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT id FROM users WHERE auth_token = ? and client_token = ?;
|
SELECT id FROM users WHERE auth_token = ? and client_token = ?;
|
||||||
`
|
`
|
||||||
rows, err := DB.Query(sqlStatement, refresh.AccessToken, refresh.ClientToken)
|
row := DB.QueryRow(sqlStatement, refresh.AccessToken, refresh.ClientToken)
|
||||||
|
|
||||||
|
// get id
|
||||||
|
var id int
|
||||||
|
err := row.Scan(&id)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return RefreshPayload{}, err
|
return RefreshPayload{}, err
|
||||||
}
|
}
|
||||||
if rows.Next() {
|
// generate new authToken
|
||||||
// get id
|
authToken := uuid.New().String()
|
||||||
var id int
|
// update authToken
|
||||||
err = rows.Scan(&id)
|
sqlStatement = `
|
||||||
rows.Close()
|
UPDATE users SET auth_token = ? WHERE id = ?;
|
||||||
if err != nil {
|
`
|
||||||
return RefreshPayload{}, err
|
_, err = DB.Exec(sqlStatement, authToken, id)
|
||||||
}
|
if err != nil {
|
||||||
// generate new authToken
|
return RefreshPayload{}, err
|
||||||
authToken := uuid.New().String()
|
|
||||||
// update authToken
|
|
||||||
sqlStatement := `
|
|
||||||
UPDATE users SET auth_token = ? WHERE id = ?;
|
|
||||||
`
|
|
||||||
_, err := DB.Exec(sqlStatement, authToken, id)
|
|
||||||
if err != nil {
|
|
||||||
return RefreshPayload{}, err
|
|
||||||
}
|
|
||||||
// generate new clientToken
|
|
||||||
clientToken := uuid.New().String()
|
|
||||||
// update clientToken
|
|
||||||
sqlStatement = `
|
|
||||||
UPDATE users SET client_token = ? WHERE id = ?;
|
|
||||||
`
|
|
||||||
_, err = DB.Exec(sqlStatement, clientToken, id)
|
|
||||||
if err != nil {
|
|
||||||
return RefreshPayload{}, err
|
|
||||||
}
|
|
||||||
refresh.AccessToken = authToken
|
|
||||||
refresh.ClientToken = clientToken
|
|
||||||
return refresh, nil
|
|
||||||
} else {
|
|
||||||
return refresh, nil
|
|
||||||
}
|
}
|
||||||
|
// generate new clientToken
|
||||||
|
clientToken := uuid.New().String()
|
||||||
|
// update clientToken
|
||||||
|
sqlStatement = `
|
||||||
|
UPDATE users SET client_token = ? WHERE id = ?;
|
||||||
|
`
|
||||||
|
_, err = DB.Exec(sqlStatement, clientToken, id)
|
||||||
|
if err != nil {
|
||||||
|
return RefreshPayload{}, err
|
||||||
|
}
|
||||||
|
refresh.AccessToken = authToken
|
||||||
|
refresh.ClientToken = clientToken
|
||||||
|
return refresh, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateTokens(authToken string, clientToken string) (bool, error) {
|
func validateTokens(authToken string, clientToken string) (bool, error) {
|
||||||
sqlStatement := `
|
sqlStatement := `
|
||||||
SELECT id FROM users WHERE auth_token = ? and client_token = ?;
|
SELECT id FROM users WHERE auth_token = ? and client_token = ?;
|
||||||
`
|
`
|
||||||
rows, err := DB.Query(sqlStatement, authToken, clientToken)
|
var x string
|
||||||
|
err := DB.QueryRow(sqlStatement, authToken, clientToken).Scan(&x)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
if rows.Next() {
|
return true, nil
|
||||||
return true, nil
|
|
||||||
} else {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func invalidateTokens(authToken string, clientToken string) error {
|
func invalidateTokens(authToken string, clientToken string) error {
|
||||||
|
|
6
types.go
6
types.go
|
@ -63,3 +63,9 @@ type InvalidCredentialsError struct{}
|
||||||
func (m *InvalidCredentialsError) Error() string {
|
func (m *InvalidCredentialsError) Error() string {
|
||||||
return "Invalid credentials"
|
return "Invalid credentials"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AlreadyExistsError struct{}
|
||||||
|
|
||||||
|
func (m *AlreadyExistsError) Error() string {
|
||||||
|
return "The specified item already exists"
|
||||||
|
}
|
||||||
|
|
4
util.go
4
util.go
|
@ -33,6 +33,10 @@ func handleError(w http.ResponseWriter, err error) {
|
||||||
switch err.Error() {
|
switch err.Error() {
|
||||||
case "unexpected end of JSON input":
|
case "unexpected end of JSON input":
|
||||||
sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "The request data is malformed."})
|
sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "The request data is malformed."})
|
||||||
|
case "The specified item already exists":
|
||||||
|
sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "The specified item already exists."})
|
||||||
|
case "Invalid credentials":
|
||||||
|
sendError(w, YggError{Code: 400, Error: "Bad Request", ErrorMessage: "Invalid credentials"})
|
||||||
default:
|
default:
|
||||||
sendError(w, YggError{Code: 500, Error: "Unspecified error", ErrorMessage: "An error has occured handling your request."})
|
sendError(w, YggError{Code: 500, Error: "Unspecified error", ErrorMessage: "An error has occured handling your request."})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue