glimbus/main.go
2026-01-11 19:29:20 -03:00

458 lines
12 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log/slog"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
)
// Runner types
type RunnerType int
const (
RunnerDeepDanbooru RunnerType = iota
RunnerSDWebUI
)
// Config holds the application configuration
type Config struct {
RunnerType RunnerType
Address string
SDModel string
}
// WrappedResponse matches the Rust untagged enum - returns either []string or "error"
type WrappedResponse struct {
Tags []string
Error string
}
func (w WrappedResponse) MarshalJSON() ([]byte, error) {
if w.Error != "" {
return json.Marshal(w.Error)
}
return json.Marshal(w.Tags)
}
func (w *WrappedResponse) UnmarshalJSON(data []byte) error {
// Try array first
var tags []string
if err := json.Unmarshal(data, &tags); err == nil {
w.Tags = tags
return nil
}
// Try string
var errStr string
if err := json.Unmarshal(data, &errStr); err == nil {
w.Error = errStr
return nil
}
return fmt.Errorf("cannot unmarshal response")
}
// WD14Response from SD WebUI tagger
type WD14Response struct {
Caption WD14Caption `json:"caption"`
}
type WD14Caption struct {
Tag map[string]float32 `json:"tag"`
Rating map[string]float32 `json:"rating"`
}
var (
config *Config
httpClient = &http.Client{}
)
func loadConfig() (*Config, error) {
cfg := &Config{
SDModel: "wd14-vit-v2-git",
}
if addr := os.Getenv("DD_ADDRESS"); addr != "" {
cfg.RunnerType = RunnerDeepDanbooru
cfg.Address = addr
return cfg, nil
}
if addr := os.Getenv("SD_ADDRESS"); addr != "" {
cfg.RunnerType = RunnerSDWebUI
cfg.Address = addr
if model := os.Getenv("SD_MODEL"); model != "" {
cfg.SDModel = model
}
return cfg, nil
}
return nil, fmt.Errorf("neither DD_ADDRESS nor SD_ADDRESS is set")
}
func sendToDeepDanbooru(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
slog.Debug("calling dd")
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile("file", fileName)
if err != nil {
return nil, fmt.Errorf("create form file: %w", err)
}
if _, err := part.Write(fileContents); err != nil {
return nil, fmt.Errorf("write file contents: %w", err)
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("close writer: %w", err)
}
// Build request with threshold query param
url := config.Address + "?threshold=" + threshold
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer 123")
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
slog.Debug("dd body", "body", string(respBody))
var result WrappedResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
slog.Debug("called!")
return &result, nil
}
func sendToSDWebUI(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
slog.Debug("calling sd")
// Encode image as base64
imageB64 := base64.StdEncoding.EncodeToString(fileContents)
reqData := map[string]string{
"model": config.SDModel,
"threshold": threshold,
"image": imageB64,
}
jsonBody, err := json.Marshal(reqData)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
slog.Info("wd14 request length", "bytes", len(jsonBody))
url := config.Address + "/tagger/v1/interrogate"
resp, err := httpClient.Post(url, "application/json", bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
slog.Debug("wd14 body", "body", string(respBody))
var wd14Resp WD14Response
if err := json.Unmarshal(respBody, &wd14Resp); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
// Convert to WrappedResponse
tags := make([]string, 0, len(wd14Resp.Caption.Tag)+len(wd14Resp.Caption.Rating))
for tag := range wd14Resp.Caption.Tag {
tags = append(tags, tag)
}
for rating := range wd14Resp.Caption.Rating {
tags = append(tags, rating)
}
slog.Debug("called!")
return &WrappedResponse{Tags: tags}, nil
}
func sendImageToTagger(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) {
switch config.RunnerType {
case RunnerDeepDanbooru:
return sendToDeepDanbooru(fileContents, fileName, mimeType, threshold)
case RunnerSDWebUI:
return sendToSDWebUI(fileContents, fileName, mimeType, threshold)
default:
return nil, fmt.Errorf("invalid runner type")
}
}
func isVideo(contentType, fileName string) bool {
if strings.HasPrefix(contentType, "video/") {
return true
}
ext := strings.ToLower(filepath.Ext(fileName))
switch ext {
case ".mp4", ".gif", ".mkv", ".webm":
return true
}
return false
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, msg string, status int) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(WrappedResponse{Error: msg})
}
func uploadHandler(w http.ResponseWriter, r *http.Request) {
threshold := r.URL.Query().Get("threshold")
if threshold == "" {
threshold = "0.5"
}
// Parse multipart form (8GB limit set at middleware level)
if err := r.ParseMultipartForm(8 << 30); err != nil {
slog.Error("failed to parse multipart form", "error", err)
writeError(w, "failed to parse multipart form", http.StatusBadRequest)
return
}
file, header, err := r.FormFile("file")
if err != nil {
slog.Error("no file found in request", "error", err)
writeError(w, "no file found in request", http.StatusInternalServerError)
return
}
defer file.Close()
contentType := header.Header.Get("Content-Type")
fileName := header.Filename
slog.Info("received file", "name", fileName, "type", contentType, "size", header.Size)
fileContents, err := io.ReadAll(file)
if err != nil {
slog.Error("failed to read file", "error", err)
writeError(w, "failed to read file", http.StatusInternalServerError)
return
}
if isVideo(contentType, fileName) {
handleVideo(w, fileContents, fileName, threshold)
} else {
handleImage(w, fileContents, fileName, contentType, threshold)
}
}
func handleImage(w http.ResponseWriter, data []byte, fileName, contentType, threshold string) {
if !strings.HasPrefix(contentType, "image/") {
slog.Warn("mimetype is not image/", "type", contentType)
}
resp, err := sendImageToTagger(data, fileName, contentType, threshold)
if err != nil {
slog.Error("upstream tagger failed", "error", err)
writeError(w, "upstream tagger failed", http.StatusInternalServerError)
return
}
writeJSON(w, resp)
}
func handleVideo(w http.ResponseWriter, data []byte, fileName, threshold string) {
// Write video to temp file
ext := filepath.Ext(fileName)
tmpFile, err := os.CreateTemp("", "glimbus-video-*"+ext)
if err != nil {
slog.Error("failed to create temp file", "error", err)
writeError(w, "failed to create temp file", http.StatusInternalServerError)
return
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
slog.Error("failed to write temp file", "error", err)
writeError(w, "failed to write temp file", http.StatusInternalServerError)
return
}
tmpFile.Close()
slog.Debug("tmp path", "path", tmpFile.Name())
// Get video info
probeResult, err := ffprobe(tmpFile.Name())
if err != nil {
slog.Error("ffprobe failed", "error", err)
writeError(w, "ffprobe failed", http.StatusInternalServerError)
return
}
if len(probeResult.Streams) == 0 {
slog.Error("no streams found")
writeError(w, "no streams found in video", http.StatusInternalServerError)
return
}
stream := probeResult.Streams[0]
slog.Debug("stream", "data", stream)
slog.Debug("format", "data", probeResult.Format)
frameRate, err := calculateFrameRate(tmpFile.Name(), stream.RFrameRate, true)
if err != nil {
slog.Error("failed to get frame rate", "error", err)
writeError(w, "failed to get frame rate", http.StatusInternalServerError)
return
}
frameCount, err := calculateFrameCount(tmpFile.Name(), &stream, &probeResult.Format, frameRate)
if err != nil {
slog.Error("failed to get frame count", "error", err)
writeError(w, "failed to get frame count", http.StatusInternalServerError)
return
}
slog.Debug("video info", "frameCount", frameCount, "frameRate", frameRate)
totalSeconds := float64(frameCount) / frameRate
skipSeconds := getFrameSkipSeconds(int(totalSeconds))
frameSkip := uint64(float64(skipSeconds) * frameRate)
slog.Info("frame extraction plan", "totalSeconds", totalSeconds, "skipSeconds", skipSeconds)
// Create temp directory for frames
tmpDir, err := os.MkdirTemp("", "glimbus-frames-*")
if err != nil {
slog.Error("failed to create temp dir", "error", err)
writeError(w, "failed to create temp dir", http.StatusInternalServerError)
return
}
defer os.RemoveAll(tmpDir)
framePath := filepath.Join(tmpDir, "frame.png")
slog.Info("frame path", "path", framePath)
// Process frames and collect tags
tagSet := make(map[string]struct{})
for frameNum := uint64(0); frameNum < frameCount; frameNum += frameSkip {
slog.Info("extracting frame", "frameNum", frameNum)
if err := extractFrame(tmpFile.Name(), framePath, int(frameNum), frameRate); err != nil {
slog.Error("frame extraction failed", "frameNum", frameNum, "error", err)
writeError(w, "frame extraction failed", http.StatusInternalServerError)
return
}
slog.Info("extracted frame", "frameNum", frameNum)
frameData, err := os.ReadFile(framePath)
if err != nil {
slog.Error("failed to read frame", "error", err)
writeError(w, "failed to read frame", http.StatusInternalServerError)
return
}
slog.Info("sending frame", "frameNum", frameNum)
resp, err := sendImageToTagger(frameData, "amongus.png", "image/png", threshold)
if err != nil {
slog.Error("upstream tagger failed for frame", "frameNum", frameNum, "error", err)
writeError(w, "upstream tagger failed", http.StatusInternalServerError)
return
}
if resp.Tags != nil {
slog.Info("tags from frame", "count", len(resp.Tags))
for _, tag := range resp.Tags {
tagSet[tag] = struct{}{}
}
}
}
// Convert set to slice
tags := make([]string, 0, len(tagSet))
for tag := range tagSet {
tags = append(tags, tag)
}
writeJSON(w, &WrappedResponse{Tags: tags})
}
func main() {
// Setup logging
logLevel := slog.LevelInfo
if os.Getenv("LOG_LEVEL") == "debug" || os.Getenv("RUST_LOG") == "debug" {
logLevel = slog.LevelDebug
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: logLevel,
}))
slog.SetDefault(logger)
// Load configuration
var err error
config, err = loadConfig()
if err != nil {
slog.Error("configuration error", "error", err)
os.Exit(1)
}
runnerName := "DeepDanbooru"
if config.RunnerType == RunnerSDWebUI {
runnerName = "StableDiffusionWebUI"
}
slog.Info("configuration loaded", "runner", runnerName, "address", config.Address)
// Setup router
r := chi.NewRouter()
r.Use(middleware.Logger)
r.Use(middleware.Recoverer)
r.Post("/", uploadHandler)
r.Post("/tagger/v1/interrogate", uploadHandler)
// Start server
addr := "0.0.0.0:6679"
slog.Info("starting server", "address", addr)
if err := http.ListenAndServe(addr, r); err != nil {
slog.Error("server error", "error", err)
os.Exit(1)
}
}