diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..f065479 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,94 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Glimbus is a Go-based web service that processes images and videos using DeepDanbooru or Stable Diffusion WebUI for image tagging. The service accepts uploaded media files and returns tags by analyzing frames. + +**Core Architecture:** +- Two-file Go web server (`main.go`, `video.go`) +- Uses chi router with standard library net/http +- Listens on `0.0.0.0:6679` +- Accepts multipart file uploads via POST to `/` +- For videos: extracts frames at variable intervals, sends each frame to upstream tagger, aggregates tags +- For images: sends directly to upstream tagger +- Supports two upstream runners (configured via env vars): + - `DeepDanbooru` (via `DD_ADDRESS`) + - `StableDiffusionWebUI` (via `SD_ADDRESS`) + +## Build and Run Commands + +**Build:** +```bash +go build -o glimbus . +``` + +**Run (development with logging):** +```bash +# Set environment variables first (see .envrc for example) +export LOG_LEVEL=debug # or RUST_LOG=debug for compatibility +export SD_ADDRESS=http://your-server:port +export SD_MODEL=vitv3 # or wd14-vit-v2-git (default) + +./glimbus +``` + +**Run tests:** +```bash +go test ./... +``` + +## Environment Variables + +Required (one of): +- `DD_ADDRESS`: DeepDanbooru server URL (checked first) +- `SD_ADDRESS`: Stable Diffusion WebUI server URL + +Optional: +- `SD_MODEL`: Model name for SD WebUI (default: `wd14-vit-v2-git`) +- `LOG_LEVEL`: Log level (`debug` for verbose logging) +- `RUST_LOG`: Also supported for compatibility (`debug` enables debug logging) + +## Video Processing Logic + +The service adaptively samples video frames based on duration: look at `getFrameSkipSeconds` to find the mapping, but its kind of like this (as an example): +- 0-10s: every 2 seconds +- 11-60s: every 10 seconds +- 61-120s: every 15 seconds +- 121-300s: every 20 seconds +- 301-1000s: every 30 seconds +- 1001-1200s: every 40 seconds +- 1201s+: every 60 seconds + +Frame extraction uses ffmpeg with fast seeking (`-ss` before input), extracting one frame per interval. All tags from all frames are collected into a deduplicated set. + +## Key Implementation Details + +**Frame Rate Calculation** (video.go): +The service attempts to get frame rate from stream metadata, falls back to calling ffprobe directly if the metadata returns 0/0. + +**Frame Count Calculation** (video.go): +Tries stream metadata first, then format duration x frame rate, finally falls back to full decode with `ffprobe -count_frames`. + +**Upstream Runner Detection** (main.go): +Checks `DD_ADDRESS` first, then `SD_ADDRESS`. Exits with error if neither is set. + +**Response Format:** +- Returns JSON array of tags: `["tag1", "tag2", ...]` +- On error: returns JSON string: `"error message"` + +## Dependencies + +Core (Go standard library + chi): +- `github.com/go-chi/chi/v5`: HTTP router +- `net/http`: HTTP server +- `os/exec`: ffmpeg/ffprobe execution +- `encoding/json`: JSON serialization +- `log/slog`: Structured logging + +External tools required: +- `ffmpeg`: Frame extraction +- `ffprobe`: Video metadata + +The service has an 8GB body size limit for file uploads. diff --git a/Cargo.toml b/Cargo.toml deleted file mode 100644 index de1d817..0000000 --- a/Cargo.toml +++ /dev/null @@ -1,23 +0,0 @@ -[package] -name = "glimbus" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -anyhow = "1.0.70" -axum = { version = "0.6.12", features = ["tokio", "multipart"] } -axum-macros = "0.3.7" -base64 = "0.21.2" -ffmpeg-cli = "0.1.0" -ffprobe = "0.3.3" -filepath = "0.1.2" -futures-util = "0.3.28" -log = "0.4.17" -pretty_env_logger = "0.4.0" -reqwest = { version = "0.11.16", features = ["json", "multipart"] } -serde = { version = "1.0.159", features = ["derive"] } -serde_json = "1.0.95" -tempfile = "3.5.0" -tokio = { version = "1.27.0", features = ["full"] } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f19f5ec --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module glimbus + +go 1.25.5 + +require github.com/go-chi/chi/v5 v5.2.3 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5bd7be3 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= diff --git a/main.go b/main.go new file mode 100644 index 0000000..8a54de7 --- /dev/null +++ b/main.go @@ -0,0 +1,458 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" +) + +// Runner types +type RunnerType int + +const ( + RunnerDeepDanbooru RunnerType = iota + RunnerSDWebUI +) + +// Config holds the application configuration +type Config struct { + RunnerType RunnerType + Address string + SDModel string +} + +// WrappedResponse matches the Rust untagged enum - returns either []string or "error" +type WrappedResponse struct { + Tags []string + Error string +} + +func (w WrappedResponse) MarshalJSON() ([]byte, error) { + if w.Error != "" { + return json.Marshal(w.Error) + } + return json.Marshal(w.Tags) +} + +func (w *WrappedResponse) UnmarshalJSON(data []byte) error { + // Try array first + var tags []string + if err := json.Unmarshal(data, &tags); err == nil { + w.Tags = tags + return nil + } + // Try string + var errStr string + if err := json.Unmarshal(data, &errStr); err == nil { + w.Error = errStr + return nil + } + return fmt.Errorf("cannot unmarshal response") +} + +// WD14Response from SD WebUI tagger +type WD14Response struct { + Caption WD14Caption `json:"caption"` +} + +type WD14Caption struct { + Tag map[string]float32 `json:"tag"` + Rating map[string]float32 `json:"rating"` +} + +var ( + config *Config + httpClient = &http.Client{} +) + +func loadConfig() (*Config, error) { + cfg := &Config{ + SDModel: "wd14-vit-v2-git", + } + + if addr := os.Getenv("DD_ADDRESS"); addr != "" { + cfg.RunnerType = RunnerDeepDanbooru + cfg.Address = addr + return cfg, nil + } + + if addr := os.Getenv("SD_ADDRESS"); addr != "" { + cfg.RunnerType = RunnerSDWebUI + cfg.Address = addr + if model := os.Getenv("SD_MODEL"); model != "" { + cfg.SDModel = model + } + return cfg, nil + } + + return nil, fmt.Errorf("neither DD_ADDRESS nor SD_ADDRESS is set") +} + +func sendToDeepDanbooru(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { + slog.Debug("calling dd") + + // Create multipart form + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + part, err := writer.CreateFormFile("file", fileName) + if err != nil { + return nil, fmt.Errorf("create form file: %w", err) + } + + if _, err := part.Write(fileContents); err != nil { + return nil, fmt.Errorf("write file contents: %w", err) + } + + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("close writer: %w", err) + } + + // Build request with threshold query param + url := config.Address + "?threshold=" + threshold + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", "Bearer 123") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + slog.Debug("dd body", "body", string(respBody)) + + var result WrappedResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + slog.Debug("called!") + return &result, nil +} + +func sendToSDWebUI(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { + slog.Debug("calling sd") + + // Encode image as base64 + imageB64 := base64.StdEncoding.EncodeToString(fileContents) + + reqData := map[string]string{ + "model": config.SDModel, + "threshold": threshold, + "image": imageB64, + } + + jsonBody, err := json.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + slog.Info("wd14 request length", "bytes", len(jsonBody)) + + url := config.Address + "/tagger/v1/interrogate" + resp, err := httpClient.Post(url, "application/json", bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + slog.Debug("wd14 body", "body", string(respBody)) + + var wd14Resp WD14Response + if err := json.Unmarshal(respBody, &wd14Resp); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + // Convert to WrappedResponse + tags := make([]string, 0, len(wd14Resp.Caption.Tag)+len(wd14Resp.Caption.Rating)) + for tag := range wd14Resp.Caption.Tag { + tags = append(tags, tag) + } + for rating := range wd14Resp.Caption.Rating { + tags = append(tags, rating) + } + + slog.Debug("called!") + return &WrappedResponse{Tags: tags}, nil +} + +func sendImageToTagger(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { + switch config.RunnerType { + case RunnerDeepDanbooru: + return sendToDeepDanbooru(fileContents, fileName, mimeType, threshold) + case RunnerSDWebUI: + return sendToSDWebUI(fileContents, fileName, mimeType, threshold) + default: + return nil, fmt.Errorf("invalid runner type") + } +} + +func isVideo(contentType, fileName string) bool { + if strings.HasPrefix(contentType, "video/") { + return true + } + + ext := strings.ToLower(filepath.Ext(fileName)) + switch ext { + case ".mp4", ".gif", ".mkv", ".webm": + return true + } + + return false +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, msg string, status int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + json.NewEncoder(w).Encode(WrappedResponse{Error: msg}) +} + +func uploadHandler(w http.ResponseWriter, r *http.Request) { + threshold := r.URL.Query().Get("threshold") + if threshold == "" { + threshold = "0.5" + } + + // Parse multipart form (8GB limit set at middleware level) + if err := r.ParseMultipartForm(8 << 30); err != nil { + slog.Error("failed to parse multipart form", "error", err) + writeError(w, "failed to parse multipart form", http.StatusBadRequest) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + slog.Error("no file found in request", "error", err) + writeError(w, "no file found in request", http.StatusInternalServerError) + return + } + defer file.Close() + + contentType := header.Header.Get("Content-Type") + fileName := header.Filename + + slog.Info("received file", "name", fileName, "type", contentType, "size", header.Size) + + fileContents, err := io.ReadAll(file) + if err != nil { + slog.Error("failed to read file", "error", err) + writeError(w, "failed to read file", http.StatusInternalServerError) + return + } + + if isVideo(contentType, fileName) { + handleVideo(w, fileContents, fileName, threshold) + } else { + handleImage(w, fileContents, fileName, contentType, threshold) + } +} + +func handleImage(w http.ResponseWriter, data []byte, fileName, contentType, threshold string) { + if !strings.HasPrefix(contentType, "image/") { + slog.Warn("mimetype is not image/", "type", contentType) + } + + resp, err := sendImageToTagger(data, fileName, contentType, threshold) + if err != nil { + slog.Error("upstream tagger failed", "error", err) + writeError(w, "upstream tagger failed", http.StatusInternalServerError) + return + } + + writeJSON(w, resp) +} + +func handleVideo(w http.ResponseWriter, data []byte, fileName, threshold string) { + // Write video to temp file + ext := filepath.Ext(fileName) + tmpFile, err := os.CreateTemp("", "glimbus-video-*"+ext) + if err != nil { + slog.Error("failed to create temp file", "error", err) + writeError(w, "failed to create temp file", http.StatusInternalServerError) + return + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write(data); err != nil { + tmpFile.Close() + slog.Error("failed to write temp file", "error", err) + writeError(w, "failed to write temp file", http.StatusInternalServerError) + return + } + tmpFile.Close() + + slog.Debug("tmp path", "path", tmpFile.Name()) + + // Get video info + probeResult, err := ffprobe(tmpFile.Name()) + if err != nil { + slog.Error("ffprobe failed", "error", err) + writeError(w, "ffprobe failed", http.StatusInternalServerError) + return + } + + if len(probeResult.Streams) == 0 { + slog.Error("no streams found") + writeError(w, "no streams found in video", http.StatusInternalServerError) + return + } + + stream := probeResult.Streams[0] + slog.Debug("stream", "data", stream) + slog.Debug("format", "data", probeResult.Format) + + frameRate, err := calculateFrameRate(tmpFile.Name(), stream.RFrameRate, true) + if err != nil { + slog.Error("failed to get frame rate", "error", err) + writeError(w, "failed to get frame rate", http.StatusInternalServerError) + return + } + + frameCount, err := calculateFrameCount(tmpFile.Name(), &stream, &probeResult.Format, frameRate) + if err != nil { + slog.Error("failed to get frame count", "error", err) + writeError(w, "failed to get frame count", http.StatusInternalServerError) + return + } + + slog.Debug("video info", "frameCount", frameCount, "frameRate", frameRate) + + totalSeconds := float64(frameCount) / frameRate + skipSeconds := getFrameSkipSeconds(int(totalSeconds)) + frameSkip := uint64(float64(skipSeconds) * frameRate) + + slog.Info("frame extraction plan", "totalSeconds", totalSeconds, "skipSeconds", skipSeconds) + + // Create temp directory for frames + tmpDir, err := os.MkdirTemp("", "glimbus-frames-*") + if err != nil { + slog.Error("failed to create temp dir", "error", err) + writeError(w, "failed to create temp dir", http.StatusInternalServerError) + return + } + defer os.RemoveAll(tmpDir) + + framePath := filepath.Join(tmpDir, "frame.png") + slog.Info("frame path", "path", framePath) + + // Process frames and collect tags + tagSet := make(map[string]struct{}) + + for frameNum := uint64(0); frameNum < frameCount; frameNum += frameSkip { + slog.Info("extracting frame", "frameNum", frameNum) + + if err := extractFrame(tmpFile.Name(), framePath, int(frameNum), frameRate); err != nil { + slog.Error("frame extraction failed", "frameNum", frameNum, "error", err) + writeError(w, "frame extraction failed", http.StatusInternalServerError) + return + } + + slog.Info("extracted frame", "frameNum", frameNum) + + frameData, err := os.ReadFile(framePath) + if err != nil { + slog.Error("failed to read frame", "error", err) + writeError(w, "failed to read frame", http.StatusInternalServerError) + return + } + + slog.Info("sending frame", "frameNum", frameNum) + + resp, err := sendImageToTagger(frameData, "amongus.png", "image/png", threshold) + if err != nil { + slog.Error("upstream tagger failed for frame", "frameNum", frameNum, "error", err) + writeError(w, "upstream tagger failed", http.StatusInternalServerError) + return + } + + if resp.Tags != nil { + slog.Info("tags from frame", "count", len(resp.Tags)) + for _, tag := range resp.Tags { + tagSet[tag] = struct{}{} + } + } + } + + // Convert set to slice + tags := make([]string, 0, len(tagSet)) + for tag := range tagSet { + tags = append(tags, tag) + } + + writeJSON(w, &WrappedResponse{Tags: tags}) +} + +func main() { + // Setup logging + logLevel := slog.LevelInfo + if os.Getenv("LOG_LEVEL") == "debug" || os.Getenv("RUST_LOG") == "debug" { + logLevel = slog.LevelDebug + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: logLevel, + })) + slog.SetDefault(logger) + + // Load configuration + var err error + config, err = loadConfig() + if err != nil { + slog.Error("configuration error", "error", err) + os.Exit(1) + } + + runnerName := "DeepDanbooru" + if config.RunnerType == RunnerSDWebUI { + runnerName = "StableDiffusionWebUI" + } + slog.Info("configuration loaded", "runner", runnerName, "address", config.Address) + + // Setup router + r := chi.NewRouter() + r.Use(middleware.Logger) + r.Use(middleware.Recoverer) + + r.Post("/", uploadHandler) + r.Post("/tagger/v1/interrogate", uploadHandler) + + // Start server + addr := "0.0.0.0:6679" + slog.Info("starting server", "address", addr) + + if err := http.ListenAndServe(addr, r); err != nil { + slog.Error("server error", "error", err) + os.Exit(1) + } +} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 6773db7..0000000 --- a/src/main.rs +++ /dev/null @@ -1,434 +0,0 @@ -use axum::extract::{Multipart, Query}; -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::routing::post; -use axum::{Json, Router}; -use axum_macros::debug_handler; -use base64::{engine::general_purpose, Engine as _}; -use core::panic; -use ffmpeg_cli::Parameter; -use ffprobe::{Format, Stream}; -use futures_util::{future::ready, StreamExt}; -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; -use std::io::Read; -use std::io::Write; -use std::process::Stdio; - -#[derive(Debug)] -enum Runner { - DeepDanbooru(String), - StableDiffusionWebUI(String), -} - -fn get_upstream_runner() -> Runner { - match std::env::var("DD_ADDRESS") { - Ok(value) => Runner::DeepDanbooru(value), - Err(_) => match std::env::var("SD_ADDRESS") { - Ok(value) => Runner::StableDiffusionWebUI(value), - Err(_) => panic!("shit no addr"), - }, - } -} - -fn get_sd_model() -> String { - match std::env::var("SD_MODEL") { - Ok(value) => value, - Err(_) => "wd14-vit-v2-git".to_string(), - } -} - -#[tokio::main] -async fn main() { - pretty_env_logger::init(); - // build our application with a single route - let app = Router::new() - .route("/", post(upload_file)) - .layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024 * 1024)); - - let upstream_runner = get_upstream_runner(); - - // run it with hyper on localhost:3000 - log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner); - axum::Server::bind(&"0.0.0.0:6679".parse().unwrap()) - .serve(app.into_make_service()) - .await - .unwrap(); -} - -#[derive(Deserialize)] -struct Options { - threshold: String, -} - -#[derive(Serialize, Deserialize)] -#[serde(untagged)] -enum WrappedResponse { - Tags(Vec), - Error(String), -} - -// Make our own error that wraps `anyhow::Error`. -struct AppError(anyhow::Error); - -// Tell axum how to convert `AppError` into a response. -impl IntoResponse for AppError { - fn into_response(self) -> Response { - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Something went wrong: {}", self.0), - ) - .into_response() - } -} - -// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into -// `Result<_, AppError>`. That way you don't need to do that manually. -impl From for AppError -where - E: Into, -{ - fn from(err: E) -> Self { - Self(err.into()) - } -} - -#[derive(Serialize, Deserialize)] -struct WD14Response { - caption: WD14ResponseTagHolder, -} - -#[derive(Serialize, Deserialize)] -struct WD14ResponseTagHolder { - tag: HashMap, - rating: HashMap, -} - -async fn send_image_to_dd( - file_contents: Vec, - file_name: String, - file_mime_type: &str, - options: &Options, -) -> anyhow::Result { - let part = reqwest::multipart::Part::bytes(file_contents.clone()) - .file_name(file_name) - .mime_str(file_mime_type) - .unwrap(); - let form = reqwest::multipart::Form::new().part("file", part); - - let upstream_runner = get_upstream_runner(); - - match upstream_runner { - Runner::DeepDanbooru(url) => { - log::debug!("calling dd"); - - let resp = reqwest::Client::new() - .post(url) - .multipart(form) - .header("authorization", "Bearer 123") - .query(&[("threshold", options.threshold.clone())]) - .send() - .await?; - - let body = resp.text().await?; - log::info!("sd body: {}", &body); - let json_response: WrappedResponse = serde_json::from_str(&body)?; - - log::debug!("called!"); - Ok(json_response) - } - Runner::StableDiffusionWebUI(url) => { - log::debug!("calling sd"); - - let mut map: HashMap<&str, &str> = HashMap::new(); - let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); - - let sd_model = get_sd_model(); - - map.insert("model", &sd_model); - map.insert("threshold", options.threshold.as_ref()); - map.insert("image", &file_base64); - - let serialized_map = serde_json::to_vec(&map).unwrap(); - let len = serialized_map.len(); - log::info!("wd14 request length {} bytes", len); - - let resp = reqwest::Client::new() - .post(format!("{}/tagger/v1/interrogate", url)) - .body(serialized_map) - .send() - .await?; - - let body = resp.text().await?; - log::info!("wd14 body: {}", &body); - let json_response: WD14Response = serde_json::from_str(&body)?; - - // turn WD14Response into WrappedResponse - let mut tags = Vec::::new(); - for ele in json_response.caption.tag { - tags.push(ele.0.clone()); - } - for ele in json_response.caption.rating { - tags.push(ele.0.clone()); - } - - log::debug!("called!"); - Ok(WrappedResponse::Tags(tags)) - } - } -} - -async fn fetch_frame_as_image( - input_path: &str, - output_path: &str, - frame_index: usize, - frame_rate: f64, // X/1sec -) -> anyhow::Result<()> { - let timeline_index: f64 = frame_index as f64 / frame_rate; - let timeline_index_param = format!("{:.5}", timeline_index); - log::debug!( - "construct command with {:?} {:?}", - timeline_index, - timeline_index_param - ); - log::info!( - "ffmpeg command: -nostdin -y -ss {} -i {} -vframes 1 {}", - timeline_index_param, - input_path, - output_path - ); - let builder = ffmpeg_cli::FfmpegBuilder::new() - .stderr(Stdio::piped()) - .option(Parameter::Single("nostdin")) - .option(Parameter::Single("y")) - // fast seeking with -ss instead of select filter - .input( - ffmpeg_cli::File::new(input_path) - .option(Parameter::KeyValue("ss", &timeline_index_param)), - ) - .output(ffmpeg_cli::File::new(output_path).option(Parameter::KeyValue("vframes", "1"))); - - log::debug!("running {:?}", builder); - log::debug!("calling builder.run()..."); - let mut ffmpeg = builder.run().await.unwrap(); - log::debug!("builder.run() completed"); - - // For single frame extraction, we don't need to monitor progress. - // Just wait for the process to complete. - log::debug!("waiting for ffmpeg process to complete..."); - let output = ffmpeg.process.wait_with_output().unwrap(); - - log::debug!( - "output:\n{}\nstderr:\n{}", - output.status, - std::str::from_utf8(&output.stderr).unwrap() - ); - - Ok(()) -} - -fn fetch_frame_count_full_decode(path: &std::path::Path) -> anyhow::Result { - let config = ffprobe::ConfigBuilder::new().count_frames(true).build(); - let new_info = ffprobe::ffprobe_config(config, path)?; - let res = new_info - .streams - .get(0) - .unwrap() - .nb_read_frames - .clone() - .unwrap() - .parse::()?; - Ok(res) -} - -fn calculate_frame_count( - path: &std::path::Path, - stream: &Stream, - format: &Format, - frame_rate: f64, -) -> anyhow::Result { - Ok(if let Some(parseable_data) = stream.nb_frames.clone() { - // if we can get it from the stream metadata, use it - parseable_data.parse::()? - } else if let Some(parseable_data) = format.try_get_duration() { - // this is a std::time::duration - // multiply that by frame rate and we get total frame count (approximate) - log::warn!("fetching duration from format metadata..."); - let seconds = parseable_data?.as_secs_f64(); - (seconds * frame_rate) as u64 - } else { - log::warn!("file didn't provide frame metadata, calculating it ourselves..."); - fetch_frame_count_full_decode(path)? - }) -} - -fn calculate_frame_rate( - temp_path: &std::path::Path, - frame_rate_str: String, - first_run: bool, -) -> anyhow::Result { - let parts = frame_rate_str.split("/").into_iter().collect::>(); - let frames_per = parts.get(0).unwrap().parse::()?; - let per_second = parts.get(1).unwrap().parse::()?; - if frames_per == 0.0f64 && per_second == 0.0f64 { - log::warn!("got incorrect frame rate, calling ffmpeg again..."); - if !first_run { - std::panic!("couldnt get frame rate shit"); - } - - // call ffprobe directly - let mut cmd = std::process::Command::new("ffprobe"); - let cmd = cmd.args(&[ - "-v", - "error", - "-select_streams", - "v", - "-of", - "default=noprint_wrappers=1:nokey=1", - "-show_entries", - "stream=r_frame_rate", - ]); - let cmd = cmd.arg(temp_path.to_str().unwrap()); - - cmd.get_args().for_each(|arg| { - log::debug!("arg {:?}", arg); - }); - - let output = cmd.output()?; - let possibly_new_frame_rate = String::from_utf8(output.stdout)? - .strip_suffix("\n") - .unwrap() - .to_string(); - - log::debug!("raw ffprobe gave {:?}", possibly_new_frame_rate); - - calculate_frame_rate(temp_path, possibly_new_frame_rate, false) - } else { - Ok(frames_per / per_second) - } -} - -#[debug_handler] -async fn upload_file( - options: Query, - mut multipart: Multipart, -) -> Result<(StatusCode, Json), AppError> { - let mut maybe_file_contents: Option = None; - let mut maybe_file_type: Option = None; - let mut maybe_file_name: Option = None; - - while let Some(field) = multipart.next_field().await.unwrap() { - let name = field.name().unwrap().to_string(); - let content_type = field.content_type().unwrap().to_string(); - let filename = field.file_name().unwrap().to_string(); - let data = field.bytes().await?; - - log::info!("Length of `{}` is {} bytes", name, data.len()); - if name == "file" { - maybe_file_contents = Some(data); - maybe_file_type = Some(content_type); - maybe_file_name = Some(filename); - } - } - - if let Some(file_contents) = maybe_file_contents { - let file_type = maybe_file_type.unwrap(); - let file_name = maybe_file_name.unwrap(); - log::info!("file {} {}", file_type, file_name); - let is_video = file_type.starts_with("video/") - || file_name.ends_with(".mp4") - || file_name.ends_with(".gif") - || file_name.ends_with(".mkv") - || file_name.ends_with(".webm"); - if is_video { - let mut final_tag_set = HashSet::new(); - - let mut temp_file = tempfile::NamedTempFile::new()?; - temp_file.write_all(&file_contents.to_vec())?; - - log::debug!("tmp path: {:?}", temp_file.path()); - - let info = ffprobe::ffprobe(temp_file.path())?; - let stream = info.streams.get(0).unwrap(); - - log::debug!("stream = {:?}", stream); - log::debug!("format = {:?}", info.format); - - let frame_rate: f64 = - calculate_frame_rate(temp_file.path(), stream.r_frame_rate.clone(), true)?; - - let total_frame_count = - calculate_frame_count(temp_file.path(), &stream, &info.format, frame_rate)?; - - log::debug!("total frame count = {}", total_frame_count); - log::debug!("frame rate = {}", frame_rate); - - let total_length_in_seconds = total_frame_count as f64 / frame_rate; - let wanted_frame_skip_seconds = match total_length_in_seconds as usize { - 0..=10 => 2, - 11..=60 => 10, - 61..=120 => 15, - 121..=300 => 20, - 301..=1000 => 30, - 1001..=1200 => 40, - 1201.. => 60, - _ => 63, - } as f64; - let wanted_frame_skip = wanted_frame_skip_seconds * frame_rate; - - let temporary_frame_dir = tempfile::tempdir()?; - let temporary_frame_path = - format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy()); - log::info!("frame path: '{}'", &temporary_frame_path); - log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds); - - for frame_number in (0..total_frame_count).step_by(wanted_frame_skip as usize) { - log::info!("extracting frame {:?}", frame_number); - fetch_frame_as_image( - temp_file.path().to_str().unwrap(), - &temporary_frame_path, - frame_number.try_into().unwrap(), - frame_rate, - ) - .await?; - log::info!("extracted frame {:?}", frame_number); - - let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?; - let mut frame_data = vec![]; - actual_frame_file.read_to_end(&mut frame_data)?; - log::info!("sending frame {:?}", frame_number); - let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) = - send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options) - .await? - { - tags_from_frame - } else { - todo!() - }; - - log::info!("{} tags from frame", tags_from_frame.len()); - for tag in tags_from_frame { - final_tag_set.insert(tag); - } - } - - let response = WrappedResponse::Tags(final_tag_set.into_iter().collect::>()); - Ok((StatusCode::OK, Json(response))) - } else { - if !file_type.starts_with("image/") { - log::warn!("warning: mimetype {} is not image/", file_type); - } - - let json_response = - send_image_to_dd(file_contents.to_vec(), file_name, &file_type, &options).await?; - Ok((StatusCode::OK, Json(json_response))) - } - } else { - Ok(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(WrappedResponse::Error( - "no file found in request".to_string(), - )), - )) - } -} diff --git a/video.go b/video.go new file mode 100644 index 0000000..92fc49e --- /dev/null +++ b/video.go @@ -0,0 +1,211 @@ +package main + +import ( + "encoding/json" + "fmt" + "log/slog" + "os/exec" + "strconv" + "strings" +) + +// FFprobe result structures +type ProbeResult struct { + Streams []ProbeStream `json:"streams"` + Format ProbeFormat `json:"format"` +} + +type ProbeStream struct { + RFrameRate string `json:"r_frame_rate"` + NbFrames string `json:"nb_frames,omitempty"` + NbReadFrames string `json:"nb_read_frames,omitempty"` +} + +type ProbeFormat struct { + Duration string `json:"duration,omitempty"` +} + +// ffprobe runs ffprobe and returns structured metadata +func ffprobe(path string) (*ProbeResult, error) { + cmd := exec.Command("ffprobe", + "-v", "error", + "-print_format", "json", + "-show_format", + "-show_streams", + path, + ) + + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("ffprobe failed: %w", err) + } + + var result ProbeResult + if err := json.Unmarshal(output, &result); err != nil { + return nil, fmt.Errorf("parse ffprobe output: %w", err) + } + + return &result, nil +} + +// ffprobeFrameRate gets frame rate via direct ffprobe call (fallback) +func ffprobeFrameRate(path string) (string, error) { + cmd := exec.Command("ffprobe", + "-v", "error", + "-select_streams", "v", + "-of", "default=noprint_wrappers=1:nokey=1", + "-show_entries", "stream=r_frame_rate", + path, + ) + + output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("ffprobe frame rate failed: %w", err) + } + + return strings.TrimSpace(string(output)), nil +} + +// ffprobeFrameCountFullDecode counts frames via full decode (slow fallback) +func ffprobeFrameCountFullDecode(path string) (uint64, error) { + cmd := exec.Command("ffprobe", + "-v", "error", + "-count_frames", + "-select_streams", "v:0", + "-show_entries", "stream=nb_read_frames", + "-of", "default=noprint_wrappers=1:nokey=1", + path, + ) + + output, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("ffprobe count frames failed: %w", err) + } + + count, err := strconv.ParseUint(strings.TrimSpace(string(output)), 10, 64) + if err != nil { + return 0, fmt.Errorf("parse frame count: %w", err) + } + + return count, nil +} + +// parseFrameRate parses "num/denom" format frame rate +func parseFrameRate(s string) (float64, error) { + parts := strings.Split(s, "/") + if len(parts) != 2 { + return 0, fmt.Errorf("invalid frame rate format: %s", s) + } + + num, err := strconv.ParseFloat(parts[0], 64) + if err != nil { + return 0, fmt.Errorf("parse numerator: %w", err) + } + + denom, err := strconv.ParseFloat(parts[1], 64) + if err != nil { + return 0, fmt.Errorf("parse denominator: %w", err) + } + + if num == 0 && denom == 0 { + return 0, fmt.Errorf("zero frame rate") + } + + return num / denom, nil +} + +// calculateFrameRate gets frame rate with ffprobe fallback +func calculateFrameRate(path string, frameRateStr string, firstRun bool) (float64, error) { + frameRate, err := parseFrameRate(frameRateStr) + if err != nil || frameRate == 0 { + slog.Warn("got incorrect frame rate, calling ffprobe again...") + + if !firstRun { + return 0, fmt.Errorf("couldn't get frame rate") + } + + // Call ffprobe directly + newFrameRateStr, err := ffprobeFrameRate(path) + if err != nil { + return 0, err + } + + slog.Debug("raw ffprobe gave", "frameRate", newFrameRateStr) + return calculateFrameRate(path, newFrameRateStr, false) + } + + return frameRate, nil +} + +// calculateFrameCount gets total frame count with fallbacks +func calculateFrameCount(path string, stream *ProbeStream, format *ProbeFormat, frameRate float64) (uint64, error) { + // Try stream metadata first + if stream.NbFrames != "" { + count, err := strconv.ParseUint(stream.NbFrames, 10, 64) + if err == nil { + return count, nil + } + } + + // Try format duration + if format.Duration != "" { + duration, err := strconv.ParseFloat(format.Duration, 64) + if err == nil { + slog.Warn("fetching duration from format metadata...") + return uint64(duration * frameRate), nil + } + } + + // Fallback to full decode + slog.Warn("file didn't provide frame metadata, calculating it ourselves...") + return ffprobeFrameCountFullDecode(path) +} + +// extractFrame extracts a single frame at the given index using ffmpeg +func extractFrame(inputPath, outputPath string, frameIndex int, frameRate float64) error { + // Calculate timeline position + timelineIndex := float64(frameIndex) / frameRate + timelineStr := fmt.Sprintf("%.5f", timelineIndex) + + slog.Debug("construct command", "timelineIndex", timelineIndex, "timelineStr", timelineStr) + slog.Info("ffmpeg command", "args", fmt.Sprintf("-nostdin -y -ss %s -i %s -vframes 1 %s", timelineStr, inputPath, outputPath)) + + // Use fast seeking with -ss before -i + cmd := exec.Command("ffmpeg", + "-nostdin", + "-y", + "-ss", timelineStr, + "-i", inputPath, + "-vframes", "1", + outputPath, + ) + + output, err := cmd.CombinedOutput() + if err != nil { + slog.Error("ffmpeg failed", "output", string(output), "error", err) + return fmt.Errorf("ffmpeg failed: %w", err) + } + + slog.Debug("ffmpeg output", "output", string(output)) + return nil +} + +// getFrameSkipSeconds returns the frame sampling interval based on video duration +func getFrameSkipSeconds(totalSeconds int) int { + switch { + case totalSeconds <= 10: + return 1 + case totalSeconds <= 60: + return 5 + case totalSeconds <= 120: + return 10 + case totalSeconds <= 300: + return 12 + case totalSeconds <= 1000: + return 15 + case totalSeconds <= 1200: + return 40 + default: + return 60 + } +}