package main import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "log/slog" "mime/multipart" "net/http" "os" "path/filepath" "strings" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" ) // Runner types type RunnerType int const ( RunnerDeepDanbooru RunnerType = iota RunnerSDWebUI ) // Config holds the application configuration type Config struct { RunnerType RunnerType Address string SDModel string } // WrappedResponse matches the Rust untagged enum - returns either []string or "error" type WrappedResponse struct { Tags []string Error string } func (w WrappedResponse) MarshalJSON() ([]byte, error) { if w.Error != "" { return json.Marshal(w.Error) } return json.Marshal(w.Tags) } func (w *WrappedResponse) UnmarshalJSON(data []byte) error { // Try array first var tags []string if err := json.Unmarshal(data, &tags); err == nil { w.Tags = tags return nil } // Try string var errStr string if err := json.Unmarshal(data, &errStr); err == nil { w.Error = errStr return nil } return fmt.Errorf("cannot unmarshal response") } // WD14Response from SD WebUI tagger type WD14Response struct { Caption WD14Caption `json:"caption"` } type WD14Caption struct { Tag map[string]float32 `json:"tag"` Rating map[string]float32 `json:"rating"` } var ( config *Config httpClient = &http.Client{} ) func loadConfig() (*Config, error) { cfg := &Config{ SDModel: "wd14-vit-v2-git", } if addr := os.Getenv("DD_ADDRESS"); addr != "" { cfg.RunnerType = RunnerDeepDanbooru cfg.Address = addr return cfg, nil } if addr := os.Getenv("SD_ADDRESS"); addr != "" { cfg.RunnerType = RunnerSDWebUI cfg.Address = addr if model := os.Getenv("SD_MODEL"); model != "" { cfg.SDModel = model } return cfg, nil } return nil, fmt.Errorf("neither DD_ADDRESS nor SD_ADDRESS is set") } func sendToDeepDanbooru(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { slog.Debug("calling dd") // Create multipart form body := &bytes.Buffer{} writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("file", fileName) if err != nil { return nil, fmt.Errorf("create form file: %w", err) } if _, err := part.Write(fileContents); err != nil { return nil, fmt.Errorf("write file contents: %w", err) } if err := writer.Close(); err != nil { return nil, fmt.Errorf("close writer: %w", err) } // Build request with threshold query param url := config.Address + "?threshold=" + threshold req, err := http.NewRequest("POST", url, body) if err != nil { return nil, fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", writer.FormDataContentType()) req.Header.Set("Authorization", "Bearer 123") resp, err := httpClient.Do(req) if err != nil { return nil, fmt.Errorf("send request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read response: %w", err) } slog.Debug("dd body", "body", string(respBody)) var result WrappedResponse if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("parse response: %w", err) } slog.Debug("called!") return &result, nil } func sendToSDWebUI(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { slog.Debug("calling sd") // Encode image as base64 imageB64 := base64.StdEncoding.EncodeToString(fileContents) reqData := map[string]string{ "model": config.SDModel, "threshold": threshold, "image": imageB64, } jsonBody, err := json.Marshal(reqData) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } slog.Info("wd14 request length", "bytes", len(jsonBody)) url := config.Address + "/tagger/v1/interrogate" resp, err := httpClient.Post(url, "application/json", bytes.NewReader(jsonBody)) if err != nil { return nil, fmt.Errorf("send request: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read response: %w", err) } slog.Debug("wd14 body", "body", string(respBody)) var wd14Resp WD14Response if err := json.Unmarshal(respBody, &wd14Resp); err != nil { return nil, fmt.Errorf("parse response: %w", err) } // Convert to WrappedResponse tags := make([]string, 0, len(wd14Resp.Caption.Tag)+len(wd14Resp.Caption.Rating)) for tag := range wd14Resp.Caption.Tag { tags = append(tags, tag) } for rating := range wd14Resp.Caption.Rating { tags = append(tags, rating) } slog.Debug("called!") return &WrappedResponse{Tags: tags}, nil } func sendImageToTagger(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { switch config.RunnerType { case RunnerDeepDanbooru: return sendToDeepDanbooru(fileContents, fileName, mimeType, threshold) case RunnerSDWebUI: return sendToSDWebUI(fileContents, fileName, mimeType, threshold) default: return nil, fmt.Errorf("invalid runner type") } } func isVideo(contentType, fileName string) bool { if strings.HasPrefix(contentType, "video/") { return true } ext := strings.ToLower(filepath.Ext(fileName)) switch ext { case ".mp4", ".gif", ".mkv", ".webm": return true } return false } func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(v) } func writeError(w http.ResponseWriter, msg string, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(WrappedResponse{Error: msg}) } func uploadHandler(w http.ResponseWriter, r *http.Request) { threshold := r.URL.Query().Get("threshold") if threshold == "" { threshold = "0.5" } // Parse multipart form (8GB limit set at middleware level) if err := r.ParseMultipartForm(8 << 30); err != nil { slog.Error("failed to parse multipart form", "error", err) writeError(w, "failed to parse multipart form", http.StatusBadRequest) return } file, header, err := r.FormFile("file") if err != nil { slog.Error("no file found in request", "error", err) writeError(w, "no file found in request", http.StatusInternalServerError) return } defer file.Close() contentType := header.Header.Get("Content-Type") fileName := header.Filename slog.Info("received file", "name", fileName, "type", contentType, "size", header.Size) fileContents, err := io.ReadAll(file) if err != nil { slog.Error("failed to read file", "error", err) writeError(w, "failed to read file", http.StatusInternalServerError) return } if isVideo(contentType, fileName) { handleVideo(w, fileContents, fileName, threshold) } else { handleImage(w, fileContents, fileName, contentType, threshold) } } func handleImage(w http.ResponseWriter, data []byte, fileName, contentType, threshold string) { if !strings.HasPrefix(contentType, "image/") { slog.Warn("mimetype is not image/", "type", contentType) } resp, err := sendImageToTagger(data, fileName, contentType, threshold) if err != nil { slog.Error("upstream tagger failed", "error", err) writeError(w, "upstream tagger failed", http.StatusInternalServerError) return } writeJSON(w, resp) } func handleVideo(w http.ResponseWriter, data []byte, fileName, threshold string) { // Write video to temp file ext := filepath.Ext(fileName) tmpFile, err := os.CreateTemp("", "glimbus-video-*"+ext) if err != nil { slog.Error("failed to create temp file", "error", err) writeError(w, "failed to create temp file", http.StatusInternalServerError) return } defer os.Remove(tmpFile.Name()) if _, err := tmpFile.Write(data); err != nil { tmpFile.Close() slog.Error("failed to write temp file", "error", err) writeError(w, "failed to write temp file", http.StatusInternalServerError) return } tmpFile.Close() slog.Debug("tmp path", "path", tmpFile.Name()) // Get video info probeResult, err := ffprobe(tmpFile.Name()) if err != nil { slog.Error("ffprobe failed", "error", err) writeError(w, "ffprobe failed", http.StatusInternalServerError) return } if len(probeResult.Streams) == 0 { slog.Error("no streams found") writeError(w, "no streams found in video", http.StatusInternalServerError) return } stream := probeResult.Streams[0] slog.Debug("stream", "data", stream) slog.Debug("format", "data", probeResult.Format) frameRate, err := calculateFrameRate(tmpFile.Name(), stream.RFrameRate, true) if err != nil { slog.Error("failed to get frame rate", "error", err) writeError(w, "failed to get frame rate", http.StatusInternalServerError) return } frameCount, err := calculateFrameCount(tmpFile.Name(), &stream, &probeResult.Format, frameRate) if err != nil { slog.Error("failed to get frame count", "error", err) writeError(w, "failed to get frame count", http.StatusInternalServerError) return } slog.Debug("video info", "frameCount", frameCount, "frameRate", frameRate) totalSeconds := float64(frameCount) / frameRate skipSeconds := getFrameSkipSeconds(int(totalSeconds)) frameSkip := uint64(float64(skipSeconds) * frameRate) slog.Info("frame extraction plan", "totalSeconds", totalSeconds, "skipSeconds", skipSeconds) // Create temp directory for frames tmpDir, err := os.MkdirTemp("", "glimbus-frames-*") if err != nil { slog.Error("failed to create temp dir", "error", err) writeError(w, "failed to create temp dir", http.StatusInternalServerError) return } defer os.RemoveAll(tmpDir) framePath := filepath.Join(tmpDir, "frame.png") slog.Info("frame path", "path", framePath) // Process frames and collect tags tagSet := make(map[string]struct{}) for frameNum := uint64(0); frameNum < frameCount; frameNum += frameSkip { slog.Info("extracting frame", "frameNum", frameNum) if err := extractFrame(tmpFile.Name(), framePath, int(frameNum), frameRate); err != nil { slog.Error("frame extraction failed", "frameNum", frameNum, "error", err) writeError(w, "frame extraction failed", http.StatusInternalServerError) return } slog.Info("extracted frame", "frameNum", frameNum) frameData, err := os.ReadFile(framePath) if err != nil { slog.Error("failed to read frame", "error", err) writeError(w, "failed to read frame", http.StatusInternalServerError) return } slog.Info("sending frame", "frameNum", frameNum) resp, err := sendImageToTagger(frameData, "amongus.png", "image/png", threshold) if err != nil { slog.Error("upstream tagger failed for frame", "frameNum", frameNum, "error", err) writeError(w, "upstream tagger failed", http.StatusInternalServerError) return } if resp.Tags != nil { slog.Info("tags from frame", "count", len(resp.Tags)) for _, tag := range resp.Tags { tagSet[tag] = struct{}{} } } } // Convert set to slice tags := make([]string, 0, len(tagSet)) for tag := range tagSet { tags = append(tags, tag) } writeJSON(w, &WrappedResponse{Tags: tags}) } func main() { // Setup logging logLevel := slog.LevelInfo if os.Getenv("LOG_LEVEL") == "debug" || os.Getenv("RUST_LOG") == "debug" { logLevel = slog.LevelDebug } logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: logLevel, })) slog.SetDefault(logger) // Load configuration var err error config, err = loadConfig() if err != nil { slog.Error("configuration error", "error", err) os.Exit(1) } runnerName := "DeepDanbooru" if config.RunnerType == RunnerSDWebUI { runnerName = "StableDiffusionWebUI" } slog.Info("configuration loaded", "runner", runnerName, "address", config.Address) // Setup router r := chi.NewRouter() r.Use(middleware.Logger) r.Use(middleware.Recoverer) r.Post("/", uploadHandler) r.Post("/tagger/v1/interrogate", uploadHandler) // Start server addr := "0.0.0.0:6679" slog.Info("starting server", "address", addr) if err := http.ListenAndServe(addr, r); err != nil { slog.Error("server error", "error", err) os.Exit(1) } }