458 lines
12 KiB
Go
458 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
)
|
|
|
|
// Runner types
|
|
type RunnerType int
|
|
|
|
const (
|
|
RunnerDeepDanbooru RunnerType = iota
|
|
RunnerSDWebUI
|
|
)
|
|
|
|
// Config holds the application configuration
|
|
type Config struct {
|
|
RunnerType RunnerType
|
|
Address string
|
|
SDModel string
|
|
}
|
|
|
|
// WrappedResponse matches the Rust untagged enum - returns either []string or "error"
|
|
type WrappedResponse struct {
|
|
Tags []string
|
|
Error string
|
|
}
|
|
|
|
func (w WrappedResponse) MarshalJSON() ([]byte, error) {
|
|
if w.Error != "" {
|
|
return json.Marshal(w.Error)
|
|
}
|
|
return json.Marshal(w.Tags)
|
|
}
|
|
|
|
func (w *WrappedResponse) UnmarshalJSON(data []byte) error {
|
|
// Try array first
|
|
var tags []string
|
|
if err := json.Unmarshal(data, &tags); err == nil {
|
|
w.Tags = tags
|
|
return nil
|
|
}
|
|
// Try string
|
|
var errStr string
|
|
if err := json.Unmarshal(data, &errStr); err == nil {
|
|
w.Error = errStr
|
|
return nil
|
|
}
|
|
return fmt.Errorf("cannot unmarshal response")
|
|
}
|
|
|
|
// WD14Response from SD WebUI tagger
|
|
type WD14Response struct {
|
|
Caption WD14Caption `json:"caption"`
|
|
}
|
|
|
|
type WD14Caption struct {
|
|
Tag map[string]float32 `json:"tag"`
|
|
Rating map[string]float32 `json:"rating"`
|
|
}
|
|
|
|
var (
|
|
config *Config
|
|
httpClient = &http.Client{}
|
|
)
|
|
|
|
func loadConfig() (*Config, error) {
|
|
cfg := &Config{
|
|
SDModel: "wd14-vit-v2-git",
|
|
}
|
|
|
|
if addr := os.Getenv("DD_ADDRESS"); addr != "" {
|
|
cfg.RunnerType = RunnerDeepDanbooru
|
|
cfg.Address = addr
|
|
return cfg, nil
|
|
}
|
|
|
|
if addr := os.Getenv("SD_ADDRESS"); addr != "" {
|
|
cfg.RunnerType = RunnerSDWebUI
|
|
cfg.Address = addr
|
|
if model := os.Getenv("SD_MODEL"); model != "" {
|
|
cfg.SDModel = model
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("neither DD_ADDRESS nor SD_ADDRESS is set")
|
|
}
|
|
|
|
func sendToDeepDanbooru(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
|
|
slog.Debug("calling dd")
|
|
|
|
// Create multipart form
|
|
body := &bytes.Buffer{}
|
|
writer := multipart.NewWriter(body)
|
|
|
|
part, err := writer.CreateFormFile("file", fileName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create form file: %w", err)
|
|
}
|
|
|
|
if _, err := part.Write(fileContents); err != nil {
|
|
return nil, fmt.Errorf("write file contents: %w", err)
|
|
}
|
|
|
|
if err := writer.Close(); err != nil {
|
|
return nil, fmt.Errorf("close writer: %w", err)
|
|
}
|
|
|
|
// Build request with threshold query param
|
|
url := config.Address + "?threshold=" + threshold
|
|
req, err := http.NewRequest("POST", url, body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
req.Header.Set("Authorization", "Bearer 123")
|
|
|
|
resp, err := httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
slog.Debug("dd body", "body", string(respBody))
|
|
|
|
var result WrappedResponse
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return nil, fmt.Errorf("parse response: %w", err)
|
|
}
|
|
|
|
slog.Debug("called!")
|
|
return &result, nil
|
|
}
|
|
|
|
func sendToSDWebUI(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
|
|
slog.Debug("calling sd")
|
|
|
|
// Encode image as base64
|
|
imageB64 := base64.StdEncoding.EncodeToString(fileContents)
|
|
|
|
reqData := map[string]string{
|
|
"model": config.SDModel,
|
|
"threshold": threshold,
|
|
"image": imageB64,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
slog.Info("wd14 request length", "bytes", len(jsonBody))
|
|
|
|
url := config.Address + "/tagger/v1/interrogate"
|
|
resp, err := httpClient.Post(url, "application/json", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("send request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
slog.Debug("wd14 body", "body", string(respBody))
|
|
|
|
var wd14Resp WD14Response
|
|
if err := json.Unmarshal(respBody, &wd14Resp); err != nil {
|
|
return nil, fmt.Errorf("parse response: %w", err)
|
|
}
|
|
|
|
// Convert to WrappedResponse
|
|
tags := make([]string, 0, len(wd14Resp.Caption.Tag)+len(wd14Resp.Caption.Rating))
|
|
for tag := range wd14Resp.Caption.Tag {
|
|
tags = append(tags, tag)
|
|
}
|
|
for rating := range wd14Resp.Caption.Rating {
|
|
tags = append(tags, rating)
|
|
}
|
|
|
|
slog.Debug("called!")
|
|
return &WrappedResponse{Tags: tags}, nil
|
|
}
|
|
|
|
func sendImageToTagger(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
|
|
switch config.RunnerType {
|
|
case RunnerDeepDanbooru:
|
|
return sendToDeepDanbooru(fileContents, fileName, mimeType, threshold)
|
|
case RunnerSDWebUI:
|
|
return sendToSDWebUI(fileContents, fileName, mimeType, threshold)
|
|
default:
|
|
return nil, fmt.Errorf("invalid runner type")
|
|
}
|
|
}
|
|
|
|
func isVideo(contentType, fileName string) bool {
|
|
if strings.HasPrefix(contentType, "video/") {
|
|
return true
|
|
}
|
|
|
|
ext := strings.ToLower(filepath.Ext(fileName))
|
|
switch ext {
|
|
case ".mp4", ".gif", ".mkv", ".webm":
|
|
return true
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, v any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(v)
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, msg string, status int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(WrappedResponse{Error: msg})
|
|
}
|
|
|
|
func uploadHandler(w http.ResponseWriter, r *http.Request) {
|
|
threshold := r.URL.Query().Get("threshold")
|
|
if threshold == "" {
|
|
threshold = "0.5"
|
|
}
|
|
|
|
// Parse multipart form (8GB limit set at middleware level)
|
|
if err := r.ParseMultipartForm(8 << 30); err != nil {
|
|
slog.Error("failed to parse multipart form", "error", err)
|
|
writeError(w, "failed to parse multipart form", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
file, header, err := r.FormFile("file")
|
|
if err != nil {
|
|
slog.Error("no file found in request", "error", err)
|
|
writeError(w, "no file found in request", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer file.Close()
|
|
|
|
contentType := header.Header.Get("Content-Type")
|
|
fileName := header.Filename
|
|
|
|
slog.Info("received file", "name", fileName, "type", contentType, "size", header.Size)
|
|
|
|
fileContents, err := io.ReadAll(file)
|
|
if err != nil {
|
|
slog.Error("failed to read file", "error", err)
|
|
writeError(w, "failed to read file", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if isVideo(contentType, fileName) {
|
|
handleVideo(w, fileContents, fileName, threshold)
|
|
} else {
|
|
handleImage(w, fileContents, fileName, contentType, threshold)
|
|
}
|
|
}
|
|
|
|
func handleImage(w http.ResponseWriter, data []byte, fileName, contentType, threshold string) {
|
|
if !strings.HasPrefix(contentType, "image/") {
|
|
slog.Warn("mimetype is not image/", "type", contentType)
|
|
}
|
|
|
|
resp, err := sendImageToTagger(data, fileName, contentType, threshold)
|
|
if err != nil {
|
|
slog.Error("upstream tagger failed", "error", err)
|
|
writeError(w, "upstream tagger failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
writeJSON(w, resp)
|
|
}
|
|
|
|
func handleVideo(w http.ResponseWriter, data []byte, fileName, threshold string) {
|
|
// Write video to temp file
|
|
ext := filepath.Ext(fileName)
|
|
tmpFile, err := os.CreateTemp("", "glimbus-video-*"+ext)
|
|
if err != nil {
|
|
slog.Error("failed to create temp file", "error", err)
|
|
writeError(w, "failed to create temp file", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
if _, err := tmpFile.Write(data); err != nil {
|
|
tmpFile.Close()
|
|
slog.Error("failed to write temp file", "error", err)
|
|
writeError(w, "failed to write temp file", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
tmpFile.Close()
|
|
|
|
slog.Debug("tmp path", "path", tmpFile.Name())
|
|
|
|
// Get video info
|
|
probeResult, err := ffprobe(tmpFile.Name())
|
|
if err != nil {
|
|
slog.Error("ffprobe failed", "error", err)
|
|
writeError(w, "ffprobe failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if len(probeResult.Streams) == 0 {
|
|
slog.Error("no streams found")
|
|
writeError(w, "no streams found in video", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
stream := probeResult.Streams[0]
|
|
slog.Debug("stream", "data", stream)
|
|
slog.Debug("format", "data", probeResult.Format)
|
|
|
|
frameRate, err := calculateFrameRate(tmpFile.Name(), stream.RFrameRate, true)
|
|
if err != nil {
|
|
slog.Error("failed to get frame rate", "error", err)
|
|
writeError(w, "failed to get frame rate", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
frameCount, err := calculateFrameCount(tmpFile.Name(), &stream, &probeResult.Format, frameRate)
|
|
if err != nil {
|
|
slog.Error("failed to get frame count", "error", err)
|
|
writeError(w, "failed to get frame count", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
slog.Debug("video info", "frameCount", frameCount, "frameRate", frameRate)
|
|
|
|
totalSeconds := float64(frameCount) / frameRate
|
|
skipSeconds := getFrameSkipSeconds(int(totalSeconds))
|
|
frameSkip := uint64(float64(skipSeconds) * frameRate)
|
|
|
|
slog.Info("frame extraction plan", "totalSeconds", totalSeconds, "skipSeconds", skipSeconds)
|
|
|
|
// Create temp directory for frames
|
|
tmpDir, err := os.MkdirTemp("", "glimbus-frames-*")
|
|
if err != nil {
|
|
slog.Error("failed to create temp dir", "error", err)
|
|
writeError(w, "failed to create temp dir", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
defer os.RemoveAll(tmpDir)
|
|
|
|
framePath := filepath.Join(tmpDir, "frame.png")
|
|
slog.Info("frame path", "path", framePath)
|
|
|
|
// Process frames and collect tags
|
|
tagSet := make(map[string]struct{})
|
|
|
|
for frameNum := uint64(0); frameNum < frameCount; frameNum += frameSkip {
|
|
slog.Info("extracting frame", "frameNum", frameNum)
|
|
|
|
if err := extractFrame(tmpFile.Name(), framePath, int(frameNum), frameRate); err != nil {
|
|
slog.Error("frame extraction failed", "frameNum", frameNum, "error", err)
|
|
writeError(w, "frame extraction failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
slog.Info("extracted frame", "frameNum", frameNum)
|
|
|
|
frameData, err := os.ReadFile(framePath)
|
|
if err != nil {
|
|
slog.Error("failed to read frame", "error", err)
|
|
writeError(w, "failed to read frame", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
slog.Info("sending frame", "frameNum", frameNum)
|
|
|
|
resp, err := sendImageToTagger(frameData, "amongus.png", "image/png", threshold)
|
|
if err != nil {
|
|
slog.Error("upstream tagger failed for frame", "frameNum", frameNum, "error", err)
|
|
writeError(w, "upstream tagger failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if resp.Tags != nil {
|
|
slog.Info("tags from frame", "count", len(resp.Tags))
|
|
for _, tag := range resp.Tags {
|
|
tagSet[tag] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert set to slice
|
|
tags := make([]string, 0, len(tagSet))
|
|
for tag := range tagSet {
|
|
tags = append(tags, tag)
|
|
}
|
|
|
|
writeJSON(w, &WrappedResponse{Tags: tags})
|
|
}
|
|
|
|
func main() {
|
|
// Setup logging
|
|
logLevel := slog.LevelInfo
|
|
if os.Getenv("LOG_LEVEL") == "debug" || os.Getenv("RUST_LOG") == "debug" {
|
|
logLevel = slog.LevelDebug
|
|
}
|
|
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
|
Level: logLevel,
|
|
}))
|
|
slog.SetDefault(logger)
|
|
|
|
// Load configuration
|
|
var err error
|
|
config, err = loadConfig()
|
|
if err != nil {
|
|
slog.Error("configuration error", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
runnerName := "DeepDanbooru"
|
|
if config.RunnerType == RunnerSDWebUI {
|
|
runnerName = "StableDiffusionWebUI"
|
|
}
|
|
slog.Info("configuration loaded", "runner", runnerName, "address", config.Address)
|
|
|
|
// Setup router
|
|
r := chi.NewRouter()
|
|
r.Use(middleware.Logger)
|
|
r.Use(middleware.Recoverer)
|
|
|
|
r.Post("/", uploadHandler)
|
|
r.Post("/tagger/v1/interrogate", uploadHandler)
|
|
|
|
// Start server
|
|
addr := "0.0.0.0:6679"
|
|
slog.Info("starting server", "address", addr)
|
|
|
|
if err := http.ListenAndServe(addr, r); err != nil {
|
|
slog.Error("server error", "error", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|