diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index f065479..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,94 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Glimbus is a Go-based web service that processes images and videos using DeepDanbooru or Stable Diffusion WebUI for image tagging. The service accepts uploaded media files and returns tags by analyzing frames. - -**Core Architecture:** -- Two-file Go web server (`main.go`, `video.go`) -- Uses chi router with standard library net/http -- Listens on `0.0.0.0:6679` -- Accepts multipart file uploads via POST to `/` -- For videos: extracts frames at variable intervals, sends each frame to upstream tagger, aggregates tags -- For images: sends directly to upstream tagger -- Supports two upstream runners (configured via env vars): - - `DeepDanbooru` (via `DD_ADDRESS`) - - `StableDiffusionWebUI` (via `SD_ADDRESS`) - -## Build and Run Commands - -**Build:** -```bash -go build -o glimbus . -``` - -**Run (development with logging):** -```bash -# Set environment variables first (see .envrc for example) -export LOG_LEVEL=debug # or RUST_LOG=debug for compatibility -export SD_ADDRESS=http://your-server:port -export SD_MODEL=vitv3 # or wd14-vit-v2-git (default) - -./glimbus -``` - -**Run tests:** -```bash -go test ./... -``` - -## Environment Variables - -Required (one of): -- `DD_ADDRESS`: DeepDanbooru server URL (checked first) -- `SD_ADDRESS`: Stable Diffusion WebUI server URL - -Optional: -- `SD_MODEL`: Model name for SD WebUI (default: `wd14-vit-v2-git`) -- `LOG_LEVEL`: Log level (`debug` for verbose logging) -- `RUST_LOG`: Also supported for compatibility (`debug` enables debug logging) - -## Video Processing Logic - -The service adaptively samples video frames based on duration: look at `getFrameSkipSeconds` to find the mapping, but its kind of like this (as an example): -- 0-10s: every 2 seconds -- 11-60s: every 10 seconds -- 61-120s: every 15 seconds -- 121-300s: every 20 seconds -- 301-1000s: every 30 seconds -- 1001-1200s: every 40 seconds -- 1201s+: every 60 seconds - -Frame extraction uses ffmpeg with fast seeking (`-ss` before input), extracting one frame per interval. All tags from all frames are collected into a deduplicated set. - -## Key Implementation Details - -**Frame Rate Calculation** (video.go): -The service attempts to get frame rate from stream metadata, falls back to calling ffprobe directly if the metadata returns 0/0. - -**Frame Count Calculation** (video.go): -Tries stream metadata first, then format duration x frame rate, finally falls back to full decode with `ffprobe -count_frames`. - -**Upstream Runner Detection** (main.go): -Checks `DD_ADDRESS` first, then `SD_ADDRESS`. Exits with error if neither is set. - -**Response Format:** -- Returns JSON array of tags: `["tag1", "tag2", ...]` -- On error: returns JSON string: `"error message"` - -## Dependencies - -Core (Go standard library + chi): -- `github.com/go-chi/chi/v5`: HTTP router -- `net/http`: HTTP server -- `os/exec`: ffmpeg/ffprobe execution -- `encoding/json`: JSON serialization -- `log/slog`: Structured logging - -External tools required: -- `ffmpeg`: Frame extraction -- `ffprobe`: Video metadata - -The service has an 8GB body size limit for file uploads. diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..de1d817 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "glimbus" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.70" +axum = { version = "0.6.12", features = ["tokio", "multipart"] } +axum-macros = "0.3.7" +base64 = "0.21.2" +ffmpeg-cli = "0.1.0" +ffprobe = "0.3.3" +filepath = "0.1.2" +futures-util = "0.3.28" +log = "0.4.17" +pretty_env_logger = "0.4.0" +reqwest = { version = "0.11.16", features = ["json", "multipart"] } +serde = { version = "1.0.159", features = ["derive"] } +serde_json = "1.0.95" +tempfile = "3.5.0" +tokio = { version = "1.27.0", features = ["full"] } diff --git a/go.mod b/go.mod deleted file mode 100644 index f19f5ec..0000000 --- a/go.mod +++ /dev/null @@ -1,5 +0,0 @@ -module glimbus - -go 1.25.5 - -require github.com/go-chi/chi/v5 v5.2.3 diff --git a/go.sum b/go.sum deleted file mode 100644 index 5bd7be3..0000000 --- a/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= -github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= diff --git a/main.go b/main.go deleted file mode 100644 index 8a54de7..0000000 --- a/main.go +++ /dev/null @@ -1,458 +0,0 @@ -package main - -import ( - "bytes" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log/slog" - "mime/multipart" - "net/http" - "os" - "path/filepath" - "strings" - - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" -) - -// Runner types -type RunnerType int - -const ( - RunnerDeepDanbooru RunnerType = iota - RunnerSDWebUI -) - -// Config holds the application configuration -type Config struct { - RunnerType RunnerType - Address string - SDModel string -} - -// WrappedResponse matches the Rust untagged enum - returns either []string or "error" -type WrappedResponse struct { - Tags []string - Error string -} - -func (w WrappedResponse) MarshalJSON() ([]byte, error) { - if w.Error != "" { - return json.Marshal(w.Error) - } - return json.Marshal(w.Tags) -} - -func (w *WrappedResponse) UnmarshalJSON(data []byte) error { - // Try array first - var tags []string - if err := json.Unmarshal(data, &tags); err == nil { - w.Tags = tags - return nil - } - // Try string - var errStr string - if err := json.Unmarshal(data, &errStr); err == nil { - w.Error = errStr - return nil - } - return fmt.Errorf("cannot unmarshal response") -} - -// WD14Response from SD WebUI tagger -type WD14Response struct { - Caption WD14Caption `json:"caption"` -} - -type WD14Caption struct { - Tag map[string]float32 `json:"tag"` - Rating map[string]float32 `json:"rating"` -} - -var ( - config *Config - httpClient = &http.Client{} -) - -func loadConfig() (*Config, error) { - cfg := &Config{ - SDModel: "wd14-vit-v2-git", - } - - if addr := os.Getenv("DD_ADDRESS"); addr != "" { - cfg.RunnerType = RunnerDeepDanbooru - cfg.Address = addr - return cfg, nil - } - - if addr := os.Getenv("SD_ADDRESS"); addr != "" { - cfg.RunnerType = RunnerSDWebUI - cfg.Address = addr - if model := os.Getenv("SD_MODEL"); model != "" { - cfg.SDModel = model - } - return cfg, nil - } - - return nil, fmt.Errorf("neither DD_ADDRESS nor SD_ADDRESS is set") -} - -func sendToDeepDanbooru(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { - slog.Debug("calling dd") - - // Create multipart form - body := &bytes.Buffer{} - writer := multipart.NewWriter(body) - - part, err := writer.CreateFormFile("file", fileName) - if err != nil { - return nil, fmt.Errorf("create form file: %w", err) - } - - if _, err := part.Write(fileContents); err != nil { - return nil, fmt.Errorf("write file contents: %w", err) - } - - if err := writer.Close(); err != nil { - return nil, fmt.Errorf("close writer: %w", err) - } - - // Build request with threshold query param - url := config.Address + "?threshold=" + threshold - req, err := http.NewRequest("POST", url, body) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Content-Type", writer.FormDataContentType()) - req.Header.Set("Authorization", "Bearer 123") - - resp, err := httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - slog.Debug("dd body", "body", string(respBody)) - - var result WrappedResponse - if err := json.Unmarshal(respBody, &result); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - slog.Debug("called!") - return &result, nil -} - -func sendToSDWebUI(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { - slog.Debug("calling sd") - - // Encode image as base64 - imageB64 := base64.StdEncoding.EncodeToString(fileContents) - - reqData := map[string]string{ - "model": config.SDModel, - "threshold": threshold, - "image": imageB64, - } - - jsonBody, err := json.Marshal(reqData) - if err != nil { - return nil, fmt.Errorf("marshal request: %w", err) - } - - slog.Info("wd14 request length", "bytes", len(jsonBody)) - - url := config.Address + "/tagger/v1/interrogate" - resp, err := httpClient.Post(url, "application/json", bytes.NewReader(jsonBody)) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - slog.Debug("wd14 body", "body", string(respBody)) - - var wd14Resp WD14Response - if err := json.Unmarshal(respBody, &wd14Resp); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - // Convert to WrappedResponse - tags := make([]string, 0, len(wd14Resp.Caption.Tag)+len(wd14Resp.Caption.Rating)) - for tag := range wd14Resp.Caption.Tag { - tags = append(tags, tag) - } - for rating := range wd14Resp.Caption.Rating { - tags = append(tags, rating) - } - - slog.Debug("called!") - return &WrappedResponse{Tags: tags}, nil -} - -func sendImageToTagger(fileContents []byte, fileName, mimeType, threshold string) (*WrappedResponse, error) { - switch config.RunnerType { - case RunnerDeepDanbooru: - return sendToDeepDanbooru(fileContents, fileName, mimeType, threshold) - case RunnerSDWebUI: - return sendToSDWebUI(fileContents, fileName, mimeType, threshold) - default: - return nil, fmt.Errorf("invalid runner type") - } -} - -func isVideo(contentType, fileName string) bool { - if strings.HasPrefix(contentType, "video/") { - return true - } - - ext := strings.ToLower(filepath.Ext(fileName)) - switch ext { - case ".mp4", ".gif", ".mkv", ".webm": - return true - } - - return false -} - -func writeJSON(w http.ResponseWriter, v any) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(v) -} - -func writeError(w http.ResponseWriter, msg string, status int) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - json.NewEncoder(w).Encode(WrappedResponse{Error: msg}) -} - -func uploadHandler(w http.ResponseWriter, r *http.Request) { - threshold := r.URL.Query().Get("threshold") - if threshold == "" { - threshold = "0.5" - } - - // Parse multipart form (8GB limit set at middleware level) - if err := r.ParseMultipartForm(8 << 30); err != nil { - slog.Error("failed to parse multipart form", "error", err) - writeError(w, "failed to parse multipart form", http.StatusBadRequest) - return - } - - file, header, err := r.FormFile("file") - if err != nil { - slog.Error("no file found in request", "error", err) - writeError(w, "no file found in request", http.StatusInternalServerError) - return - } - defer file.Close() - - contentType := header.Header.Get("Content-Type") - fileName := header.Filename - - slog.Info("received file", "name", fileName, "type", contentType, "size", header.Size) - - fileContents, err := io.ReadAll(file) - if err != nil { - slog.Error("failed to read file", "error", err) - writeError(w, "failed to read file", http.StatusInternalServerError) - return - } - - if isVideo(contentType, fileName) { - handleVideo(w, fileContents, fileName, threshold) - } else { - handleImage(w, fileContents, fileName, contentType, threshold) - } -} - -func handleImage(w http.ResponseWriter, data []byte, fileName, contentType, threshold string) { - if !strings.HasPrefix(contentType, "image/") { - slog.Warn("mimetype is not image/", "type", contentType) - } - - resp, err := sendImageToTagger(data, fileName, contentType, threshold) - if err != nil { - slog.Error("upstream tagger failed", "error", err) - writeError(w, "upstream tagger failed", http.StatusInternalServerError) - return - } - - writeJSON(w, resp) -} - -func handleVideo(w http.ResponseWriter, data []byte, fileName, threshold string) { - // Write video to temp file - ext := filepath.Ext(fileName) - tmpFile, err := os.CreateTemp("", "glimbus-video-*"+ext) - if err != nil { - slog.Error("failed to create temp file", "error", err) - writeError(w, "failed to create temp file", http.StatusInternalServerError) - return - } - defer os.Remove(tmpFile.Name()) - - if _, err := tmpFile.Write(data); err != nil { - tmpFile.Close() - slog.Error("failed to write temp file", "error", err) - writeError(w, "failed to write temp file", http.StatusInternalServerError) - return - } - tmpFile.Close() - - slog.Debug("tmp path", "path", tmpFile.Name()) - - // Get video info - probeResult, err := ffprobe(tmpFile.Name()) - if err != nil { - slog.Error("ffprobe failed", "error", err) - writeError(w, "ffprobe failed", http.StatusInternalServerError) - return - } - - if len(probeResult.Streams) == 0 { - slog.Error("no streams found") - writeError(w, "no streams found in video", http.StatusInternalServerError) - return - } - - stream := probeResult.Streams[0] - slog.Debug("stream", "data", stream) - slog.Debug("format", "data", probeResult.Format) - - frameRate, err := calculateFrameRate(tmpFile.Name(), stream.RFrameRate, true) - if err != nil { - slog.Error("failed to get frame rate", "error", err) - writeError(w, "failed to get frame rate", http.StatusInternalServerError) - return - } - - frameCount, err := calculateFrameCount(tmpFile.Name(), &stream, &probeResult.Format, frameRate) - if err != nil { - slog.Error("failed to get frame count", "error", err) - writeError(w, "failed to get frame count", http.StatusInternalServerError) - return - } - - slog.Debug("video info", "frameCount", frameCount, "frameRate", frameRate) - - totalSeconds := float64(frameCount) / frameRate - skipSeconds := getFrameSkipSeconds(int(totalSeconds)) - frameSkip := uint64(float64(skipSeconds) * frameRate) - - slog.Info("frame extraction plan", "totalSeconds", totalSeconds, "skipSeconds", skipSeconds) - - // Create temp directory for frames - tmpDir, err := os.MkdirTemp("", "glimbus-frames-*") - if err != nil { - slog.Error("failed to create temp dir", "error", err) - writeError(w, "failed to create temp dir", http.StatusInternalServerError) - return - } - defer os.RemoveAll(tmpDir) - - framePath := filepath.Join(tmpDir, "frame.png") - slog.Info("frame path", "path", framePath) - - // Process frames and collect tags - tagSet := make(map[string]struct{}) - - for frameNum := uint64(0); frameNum < frameCount; frameNum += frameSkip { - slog.Info("extracting frame", "frameNum", frameNum) - - if err := extractFrame(tmpFile.Name(), framePath, int(frameNum), frameRate); err != nil { - slog.Error("frame extraction failed", "frameNum", frameNum, "error", err) - writeError(w, "frame extraction failed", http.StatusInternalServerError) - return - } - - slog.Info("extracted frame", "frameNum", frameNum) - - frameData, err := os.ReadFile(framePath) - if err != nil { - slog.Error("failed to read frame", "error", err) - writeError(w, "failed to read frame", http.StatusInternalServerError) - return - } - - slog.Info("sending frame", "frameNum", frameNum) - - resp, err := sendImageToTagger(frameData, "amongus.png", "image/png", threshold) - if err != nil { - slog.Error("upstream tagger failed for frame", "frameNum", frameNum, "error", err) - writeError(w, "upstream tagger failed", http.StatusInternalServerError) - return - } - - if resp.Tags != nil { - slog.Info("tags from frame", "count", len(resp.Tags)) - for _, tag := range resp.Tags { - tagSet[tag] = struct{}{} - } - } - } - - // Convert set to slice - tags := make([]string, 0, len(tagSet)) - for tag := range tagSet { - tags = append(tags, tag) - } - - writeJSON(w, &WrappedResponse{Tags: tags}) -} - -func main() { - // Setup logging - logLevel := slog.LevelInfo - if os.Getenv("LOG_LEVEL") == "debug" || os.Getenv("RUST_LOG") == "debug" { - logLevel = slog.LevelDebug - } - - logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: logLevel, - })) - slog.SetDefault(logger) - - // Load configuration - var err error - config, err = loadConfig() - if err != nil { - slog.Error("configuration error", "error", err) - os.Exit(1) - } - - runnerName := "DeepDanbooru" - if config.RunnerType == RunnerSDWebUI { - runnerName = "StableDiffusionWebUI" - } - slog.Info("configuration loaded", "runner", runnerName, "address", config.Address) - - // Setup router - r := chi.NewRouter() - r.Use(middleware.Logger) - r.Use(middleware.Recoverer) - - r.Post("/", uploadHandler) - r.Post("/tagger/v1/interrogate", uploadHandler) - - // Start server - addr := "0.0.0.0:6679" - slog.Info("starting server", "address", addr) - - if err := http.ListenAndServe(addr, r); err != nil { - slog.Error("server error", "error", err) - os.Exit(1) - } -} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..6773db7 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,434 @@ +use axum::extract::{Multipart, Query}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::routing::post; +use axum::{Json, Router}; +use axum_macros::debug_handler; +use base64::{engine::general_purpose, Engine as _}; +use core::panic; +use ffmpeg_cli::Parameter; +use ffprobe::{Format, Stream}; +use futures_util::{future::ready, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::io::Read; +use std::io::Write; +use std::process::Stdio; + +#[derive(Debug)] +enum Runner { + DeepDanbooru(String), + StableDiffusionWebUI(String), +} + +fn get_upstream_runner() -> Runner { + match std::env::var("DD_ADDRESS") { + Ok(value) => Runner::DeepDanbooru(value), + Err(_) => match std::env::var("SD_ADDRESS") { + Ok(value) => Runner::StableDiffusionWebUI(value), + Err(_) => panic!("shit no addr"), + }, + } +} + +fn get_sd_model() -> String { + match std::env::var("SD_MODEL") { + Ok(value) => value, + Err(_) => "wd14-vit-v2-git".to_string(), + } +} + +#[tokio::main] +async fn main() { + pretty_env_logger::init(); + // build our application with a single route + let app = Router::new() + .route("/", post(upload_file)) + .layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024 * 1024)); + + let upstream_runner = get_upstream_runner(); + + // run it with hyper on localhost:3000 + log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner); + axum::Server::bind(&"0.0.0.0:6679".parse().unwrap()) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +#[derive(Deserialize)] +struct Options { + threshold: String, +} + +#[derive(Serialize, Deserialize)] +#[serde(untagged)] +enum WrappedResponse { + Tags(Vec), + Error(String), +} + +// Make our own error that wraps `anyhow::Error`. +struct AppError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {}", self.0), + ) + .into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for AppError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} + +#[derive(Serialize, Deserialize)] +struct WD14Response { + caption: WD14ResponseTagHolder, +} + +#[derive(Serialize, Deserialize)] +struct WD14ResponseTagHolder { + tag: HashMap, + rating: HashMap, +} + +async fn send_image_to_dd( + file_contents: Vec, + file_name: String, + file_mime_type: &str, + options: &Options, +) -> anyhow::Result { + let part = reqwest::multipart::Part::bytes(file_contents.clone()) + .file_name(file_name) + .mime_str(file_mime_type) + .unwrap(); + let form = reqwest::multipart::Form::new().part("file", part); + + let upstream_runner = get_upstream_runner(); + + match upstream_runner { + Runner::DeepDanbooru(url) => { + log::debug!("calling dd"); + + let resp = reqwest::Client::new() + .post(url) + .multipart(form) + .header("authorization", "Bearer 123") + .query(&[("threshold", options.threshold.clone())]) + .send() + .await?; + + let body = resp.text().await?; + log::info!("sd body: {}", &body); + let json_response: WrappedResponse = serde_json::from_str(&body)?; + + log::debug!("called!"); + Ok(json_response) + } + Runner::StableDiffusionWebUI(url) => { + log::debug!("calling sd"); + + let mut map: HashMap<&str, &str> = HashMap::new(); + let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); + + let sd_model = get_sd_model(); + + map.insert("model", &sd_model); + map.insert("threshold", options.threshold.as_ref()); + map.insert("image", &file_base64); + + let serialized_map = serde_json::to_vec(&map).unwrap(); + let len = serialized_map.len(); + log::info!("wd14 request length {} bytes", len); + + let resp = reqwest::Client::new() + .post(format!("{}/tagger/v1/interrogate", url)) + .body(serialized_map) + .send() + .await?; + + let body = resp.text().await?; + log::info!("wd14 body: {}", &body); + let json_response: WD14Response = serde_json::from_str(&body)?; + + // turn WD14Response into WrappedResponse + let mut tags = Vec::::new(); + for ele in json_response.caption.tag { + tags.push(ele.0.clone()); + } + for ele in json_response.caption.rating { + tags.push(ele.0.clone()); + } + + log::debug!("called!"); + Ok(WrappedResponse::Tags(tags)) + } + } +} + +async fn fetch_frame_as_image( + input_path: &str, + output_path: &str, + frame_index: usize, + frame_rate: f64, // X/1sec +) -> anyhow::Result<()> { + let timeline_index: f64 = frame_index as f64 / frame_rate; + let timeline_index_param = format!("{:.5}", timeline_index); + log::debug!( + "construct command with {:?} {:?}", + timeline_index, + timeline_index_param + ); + log::info!( + "ffmpeg command: -nostdin -y -ss {} -i {} -vframes 1 {}", + timeline_index_param, + input_path, + output_path + ); + let builder = ffmpeg_cli::FfmpegBuilder::new() + .stderr(Stdio::piped()) + .option(Parameter::Single("nostdin")) + .option(Parameter::Single("y")) + // fast seeking with -ss instead of select filter + .input( + ffmpeg_cli::File::new(input_path) + .option(Parameter::KeyValue("ss", &timeline_index_param)), + ) + .output(ffmpeg_cli::File::new(output_path).option(Parameter::KeyValue("vframes", "1"))); + + log::debug!("running {:?}", builder); + log::debug!("calling builder.run()..."); + let mut ffmpeg = builder.run().await.unwrap(); + log::debug!("builder.run() completed"); + + // For single frame extraction, we don't need to monitor progress. + // Just wait for the process to complete. + log::debug!("waiting for ffmpeg process to complete..."); + let output = ffmpeg.process.wait_with_output().unwrap(); + + log::debug!( + "output:\n{}\nstderr:\n{}", + output.status, + std::str::from_utf8(&output.stderr).unwrap() + ); + + Ok(()) +} + +fn fetch_frame_count_full_decode(path: &std::path::Path) -> anyhow::Result { + let config = ffprobe::ConfigBuilder::new().count_frames(true).build(); + let new_info = ffprobe::ffprobe_config(config, path)?; + let res = new_info + .streams + .get(0) + .unwrap() + .nb_read_frames + .clone() + .unwrap() + .parse::()?; + Ok(res) +} + +fn calculate_frame_count( + path: &std::path::Path, + stream: &Stream, + format: &Format, + frame_rate: f64, +) -> anyhow::Result { + Ok(if let Some(parseable_data) = stream.nb_frames.clone() { + // if we can get it from the stream metadata, use it + parseable_data.parse::()? + } else if let Some(parseable_data) = format.try_get_duration() { + // this is a std::time::duration + // multiply that by frame rate and we get total frame count (approximate) + log::warn!("fetching duration from format metadata..."); + let seconds = parseable_data?.as_secs_f64(); + (seconds * frame_rate) as u64 + } else { + log::warn!("file didn't provide frame metadata, calculating it ourselves..."); + fetch_frame_count_full_decode(path)? + }) +} + +fn calculate_frame_rate( + temp_path: &std::path::Path, + frame_rate_str: String, + first_run: bool, +) -> anyhow::Result { + let parts = frame_rate_str.split("/").into_iter().collect::>(); + let frames_per = parts.get(0).unwrap().parse::()?; + let per_second = parts.get(1).unwrap().parse::()?; + if frames_per == 0.0f64 && per_second == 0.0f64 { + log::warn!("got incorrect frame rate, calling ffmpeg again..."); + if !first_run { + std::panic!("couldnt get frame rate shit"); + } + + // call ffprobe directly + let mut cmd = std::process::Command::new("ffprobe"); + let cmd = cmd.args(&[ + "-v", + "error", + "-select_streams", + "v", + "-of", + "default=noprint_wrappers=1:nokey=1", + "-show_entries", + "stream=r_frame_rate", + ]); + let cmd = cmd.arg(temp_path.to_str().unwrap()); + + cmd.get_args().for_each(|arg| { + log::debug!("arg {:?}", arg); + }); + + let output = cmd.output()?; + let possibly_new_frame_rate = String::from_utf8(output.stdout)? + .strip_suffix("\n") + .unwrap() + .to_string(); + + log::debug!("raw ffprobe gave {:?}", possibly_new_frame_rate); + + calculate_frame_rate(temp_path, possibly_new_frame_rate, false) + } else { + Ok(frames_per / per_second) + } +} + +#[debug_handler] +async fn upload_file( + options: Query, + mut multipart: Multipart, +) -> Result<(StatusCode, Json), AppError> { + let mut maybe_file_contents: Option = None; + let mut maybe_file_type: Option = None; + let mut maybe_file_name: Option = None; + + while let Some(field) = multipart.next_field().await.unwrap() { + let name = field.name().unwrap().to_string(); + let content_type = field.content_type().unwrap().to_string(); + let filename = field.file_name().unwrap().to_string(); + let data = field.bytes().await?; + + log::info!("Length of `{}` is {} bytes", name, data.len()); + if name == "file" { + maybe_file_contents = Some(data); + maybe_file_type = Some(content_type); + maybe_file_name = Some(filename); + } + } + + if let Some(file_contents) = maybe_file_contents { + let file_type = maybe_file_type.unwrap(); + let file_name = maybe_file_name.unwrap(); + log::info!("file {} {}", file_type, file_name); + let is_video = file_type.starts_with("video/") + || file_name.ends_with(".mp4") + || file_name.ends_with(".gif") + || file_name.ends_with(".mkv") + || file_name.ends_with(".webm"); + if is_video { + let mut final_tag_set = HashSet::new(); + + let mut temp_file = tempfile::NamedTempFile::new()?; + temp_file.write_all(&file_contents.to_vec())?; + + log::debug!("tmp path: {:?}", temp_file.path()); + + let info = ffprobe::ffprobe(temp_file.path())?; + let stream = info.streams.get(0).unwrap(); + + log::debug!("stream = {:?}", stream); + log::debug!("format = {:?}", info.format); + + let frame_rate: f64 = + calculate_frame_rate(temp_file.path(), stream.r_frame_rate.clone(), true)?; + + let total_frame_count = + calculate_frame_count(temp_file.path(), &stream, &info.format, frame_rate)?; + + log::debug!("total frame count = {}", total_frame_count); + log::debug!("frame rate = {}", frame_rate); + + let total_length_in_seconds = total_frame_count as f64 / frame_rate; + let wanted_frame_skip_seconds = match total_length_in_seconds as usize { + 0..=10 => 2, + 11..=60 => 10, + 61..=120 => 15, + 121..=300 => 20, + 301..=1000 => 30, + 1001..=1200 => 40, + 1201.. => 60, + _ => 63, + } as f64; + let wanted_frame_skip = wanted_frame_skip_seconds * frame_rate; + + let temporary_frame_dir = tempfile::tempdir()?; + let temporary_frame_path = + format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy()); + log::info!("frame path: '{}'", &temporary_frame_path); + log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds); + + for frame_number in (0..total_frame_count).step_by(wanted_frame_skip as usize) { + log::info!("extracting frame {:?}", frame_number); + fetch_frame_as_image( + temp_file.path().to_str().unwrap(), + &temporary_frame_path, + frame_number.try_into().unwrap(), + frame_rate, + ) + .await?; + log::info!("extracted frame {:?}", frame_number); + + let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?; + let mut frame_data = vec![]; + actual_frame_file.read_to_end(&mut frame_data)?; + log::info!("sending frame {:?}", frame_number); + let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) = + send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options) + .await? + { + tags_from_frame + } else { + todo!() + }; + + log::info!("{} tags from frame", tags_from_frame.len()); + for tag in tags_from_frame { + final_tag_set.insert(tag); + } + } + + let response = WrappedResponse::Tags(final_tag_set.into_iter().collect::>()); + Ok((StatusCode::OK, Json(response))) + } else { + if !file_type.starts_with("image/") { + log::warn!("warning: mimetype {} is not image/", file_type); + } + + let json_response = + send_image_to_dd(file_contents.to_vec(), file_name, &file_type, &options).await?; + Ok((StatusCode::OK, Json(json_response))) + } + } else { + Ok(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(WrappedResponse::Error( + "no file found in request".to_string(), + )), + )) + } +} diff --git a/video.go b/video.go deleted file mode 100644 index 92fc49e..0000000 --- a/video.go +++ /dev/null @@ -1,211 +0,0 @@ -package main - -import ( - "encoding/json" - "fmt" - "log/slog" - "os/exec" - "strconv" - "strings" -) - -// FFprobe result structures -type ProbeResult struct { - Streams []ProbeStream `json:"streams"` - Format ProbeFormat `json:"format"` -} - -type ProbeStream struct { - RFrameRate string `json:"r_frame_rate"` - NbFrames string `json:"nb_frames,omitempty"` - NbReadFrames string `json:"nb_read_frames,omitempty"` -} - -type ProbeFormat struct { - Duration string `json:"duration,omitempty"` -} - -// ffprobe runs ffprobe and returns structured metadata -func ffprobe(path string) (*ProbeResult, error) { - cmd := exec.Command("ffprobe", - "-v", "error", - "-print_format", "json", - "-show_format", - "-show_streams", - path, - ) - - output, err := cmd.Output() - if err != nil { - return nil, fmt.Errorf("ffprobe failed: %w", err) - } - - var result ProbeResult - if err := json.Unmarshal(output, &result); err != nil { - return nil, fmt.Errorf("parse ffprobe output: %w", err) - } - - return &result, nil -} - -// ffprobeFrameRate gets frame rate via direct ffprobe call (fallback) -func ffprobeFrameRate(path string) (string, error) { - cmd := exec.Command("ffprobe", - "-v", "error", - "-select_streams", "v", - "-of", "default=noprint_wrappers=1:nokey=1", - "-show_entries", "stream=r_frame_rate", - path, - ) - - output, err := cmd.Output() - if err != nil { - return "", fmt.Errorf("ffprobe frame rate failed: %w", err) - } - - return strings.TrimSpace(string(output)), nil -} - -// ffprobeFrameCountFullDecode counts frames via full decode (slow fallback) -func ffprobeFrameCountFullDecode(path string) (uint64, error) { - cmd := exec.Command("ffprobe", - "-v", "error", - "-count_frames", - "-select_streams", "v:0", - "-show_entries", "stream=nb_read_frames", - "-of", "default=noprint_wrappers=1:nokey=1", - path, - ) - - output, err := cmd.Output() - if err != nil { - return 0, fmt.Errorf("ffprobe count frames failed: %w", err) - } - - count, err := strconv.ParseUint(strings.TrimSpace(string(output)), 10, 64) - if err != nil { - return 0, fmt.Errorf("parse frame count: %w", err) - } - - return count, nil -} - -// parseFrameRate parses "num/denom" format frame rate -func parseFrameRate(s string) (float64, error) { - parts := strings.Split(s, "/") - if len(parts) != 2 { - return 0, fmt.Errorf("invalid frame rate format: %s", s) - } - - num, err := strconv.ParseFloat(parts[0], 64) - if err != nil { - return 0, fmt.Errorf("parse numerator: %w", err) - } - - denom, err := strconv.ParseFloat(parts[1], 64) - if err != nil { - return 0, fmt.Errorf("parse denominator: %w", err) - } - - if num == 0 && denom == 0 { - return 0, fmt.Errorf("zero frame rate") - } - - return num / denom, nil -} - -// calculateFrameRate gets frame rate with ffprobe fallback -func calculateFrameRate(path string, frameRateStr string, firstRun bool) (float64, error) { - frameRate, err := parseFrameRate(frameRateStr) - if err != nil || frameRate == 0 { - slog.Warn("got incorrect frame rate, calling ffprobe again...") - - if !firstRun { - return 0, fmt.Errorf("couldn't get frame rate") - } - - // Call ffprobe directly - newFrameRateStr, err := ffprobeFrameRate(path) - if err != nil { - return 0, err - } - - slog.Debug("raw ffprobe gave", "frameRate", newFrameRateStr) - return calculateFrameRate(path, newFrameRateStr, false) - } - - return frameRate, nil -} - -// calculateFrameCount gets total frame count with fallbacks -func calculateFrameCount(path string, stream *ProbeStream, format *ProbeFormat, frameRate float64) (uint64, error) { - // Try stream metadata first - if stream.NbFrames != "" { - count, err := strconv.ParseUint(stream.NbFrames, 10, 64) - if err == nil { - return count, nil - } - } - - // Try format duration - if format.Duration != "" { - duration, err := strconv.ParseFloat(format.Duration, 64) - if err == nil { - slog.Warn("fetching duration from format metadata...") - return uint64(duration * frameRate), nil - } - } - - // Fallback to full decode - slog.Warn("file didn't provide frame metadata, calculating it ourselves...") - return ffprobeFrameCountFullDecode(path) -} - -// extractFrame extracts a single frame at the given index using ffmpeg -func extractFrame(inputPath, outputPath string, frameIndex int, frameRate float64) error { - // Calculate timeline position - timelineIndex := float64(frameIndex) / frameRate - timelineStr := fmt.Sprintf("%.5f", timelineIndex) - - slog.Debug("construct command", "timelineIndex", timelineIndex, "timelineStr", timelineStr) - slog.Info("ffmpeg command", "args", fmt.Sprintf("-nostdin -y -ss %s -i %s -vframes 1 %s", timelineStr, inputPath, outputPath)) - - // Use fast seeking with -ss before -i - cmd := exec.Command("ffmpeg", - "-nostdin", - "-y", - "-ss", timelineStr, - "-i", inputPath, - "-vframes", "1", - outputPath, - ) - - output, err := cmd.CombinedOutput() - if err != nil { - slog.Error("ffmpeg failed", "output", string(output), "error", err) - return fmt.Errorf("ffmpeg failed: %w", err) - } - - slog.Debug("ffmpeg output", "output", string(output)) - return nil -} - -// getFrameSkipSeconds returns the frame sampling interval based on video duration -func getFrameSkipSeconds(totalSeconds int) int { - switch { - case totalSeconds <= 10: - return 1 - case totalSeconds <= 60: - return 5 - case totalSeconds <= 120: - return 10 - case totalSeconds <= 300: - return 12 - case totalSeconds <= 1000: - return 15 - case totalSeconds <= 1200: - return 40 - default: - return 60 - } -}