use axum::extract::{Multipart, Query}; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use axum::routing::post; use axum::{Json, Router}; use axum_macros::debug_handler; use base64::{engine::general_purpose, Engine as _}; use core::panic; use ffmpeg_cli::Parameter; use ffprobe::{Format, Stream}; use futures_util::{future::ready, StreamExt}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::io::Read; use std::io::Write; use std::process::Stdio; #[derive(Debug)] enum Runner { DeepDanbooru(String), StableDiffusionWebUI(String), } fn get_upstream_runner() -> Runner { match std::env::var("DD_ADDRESS") { Ok(value) => Runner::DeepDanbooru(value), Err(_) => match std::env::var("SD_ADDRESS") { Ok(value) => Runner::StableDiffusionWebUI(value), Err(_) => panic!("shit no addr"), }, } } #[tokio::main] async fn main() { pretty_env_logger::init(); // build our application with a single route let app = Router::new() .route("/", post(upload_file)) .layer(axum::extract::DefaultBodyLimit::max(512 * 1024 * 1024)); let upstream_runner = get_upstream_runner(); // run it with hyper on localhost:3000 log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner); axum::Server::bind(&"0.0.0.0:6679".parse().unwrap()) .serve(app.into_make_service()) .await .unwrap(); } #[derive(Deserialize)] struct Options { threshold: String, } #[derive(Serialize, Deserialize)] #[serde(untagged)] enum WrappedResponse { Tags(Vec), Error(String), } // Make our own error that wraps `anyhow::Error`. struct AppError(anyhow::Error); // Tell axum how to convert `AppError` into a response. impl IntoResponse for AppError { fn into_response(self) -> Response { ( StatusCode::INTERNAL_SERVER_ERROR, format!("Something went wrong: {}", self.0), ) .into_response() } } // This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into // `Result<_, AppError>`. That way you don't need to do that manually. impl From for AppError where E: Into, { fn from(err: E) -> Self { Self(err.into()) } } #[derive(Serialize, Deserialize)] struct WD14Response { caption: HashMap, } async fn send_image_to_dd( file_contents: Vec, file_name: String, file_mime_type: &str, options: &Options, ) -> anyhow::Result { let part = reqwest::multipart::Part::bytes(file_contents.clone()) .file_name(file_name) .mime_str(file_mime_type) .unwrap(); let form = reqwest::multipart::Form::new().part("file", part); let upstream_runner = get_upstream_runner(); match upstream_runner { Runner::DeepDanbooru(url) => { log::debug!("calling dd"); let resp = reqwest::Client::new() .post(url) .multipart(form) .header("authorization", "Bearer 123") .query(&[("threshold", options.threshold.clone())]) .send() .await?; let body = resp.text().await?; log::info!("sd body: {}", &body); let json_response: WrappedResponse = serde_json::from_str(&body)?; log::debug!("called!"); Ok(json_response) } Runner::StableDiffusionWebUI(url) => { log::debug!("calling sd"); let mut map: HashMap<&str, &str> = HashMap::new(); let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); map.insert("model", "wd14-vit-v2-git"); map.insert("threshold", options.threshold.as_ref()); map.insert("image", &file_base64); let serialized_map = serde_json::to_vec(&map).unwrap(); let len = serialized_map.len(); log::info!("wd14 request length {} bytes", len); let resp = reqwest::Client::new() .post(format!("{}/tagger/v1/interrogate", url)) .body(serialized_map) .send() .await?; let body = resp.text().await?; log::info!("wd14 body: {}", &body); let json_response: WD14Response = serde_json::from_str(&body)?; // turn WD14Response into WrappedResponse let mut tags = Vec::::new(); for ele in json_response.caption { tags.push(ele.0.clone()); } log::debug!("called!"); Ok(WrappedResponse::Tags(tags)) } } } async fn fetch_frame_as_image( input_path: &str, output_path: &str, frame_index: usize, frame_rate: f64, // X/1sec ) -> anyhow::Result<()> { let timeline_index: f64 = frame_index as f64 / frame_rate; let timeline_index_param = format!("{:.5}", timeline_index); let builder = ffmpeg_cli::FfmpegBuilder::new() .stderr(Stdio::piped()) .option(Parameter::Single("nostdin")) .option(Parameter::Single("y")) // fast seeking with -ss instead of select filter .input( ffmpeg_cli::File::new(input_path) .option(Parameter::KeyValue("ss", &timeline_index_param)), ) .output(ffmpeg_cli::File::new(output_path).option(Parameter::KeyValue("vframes", "1"))); log::debug!("running {:?}", builder); let ffmpeg = builder.run().await.unwrap(); // TODO ffmpeg doesnt really provide a progress bar for this, its just // a single frame being extracted. should ignoring x.unwrap() be the // actual thing we gotta do? log::debug!("run"); ffmpeg .progress .for_each(|x| { log::debug!("progress x value = {:?}", &x); // lmao // x.unwrap(); ready(()) }) .await; log::debug!("run!"); let output = ffmpeg.process.wait_with_output().unwrap(); log::debug!("out"); log::debug!( "{}\nstderr:\n{}", output.status, std::str::from_utf8(&output.stderr).unwrap() ); Ok(()) } fn fetch_frame_count_full_decode(path: &std::path::Path) -> anyhow::Result { let config = ffprobe::ConfigBuilder::new().count_frames(true).build(); let new_info = ffprobe::ffprobe_config(config, path)?; let res = new_info .streams .get(0) .unwrap() .nb_read_frames .clone() .unwrap() .parse::()?; Ok(res) } fn calculate_frame_count( path: &std::path::Path, stream: &Stream, format: &Format, frame_rate: f64, ) -> anyhow::Result { Ok(if let Some(parseable_data) = stream.nb_frames.clone() { // if we can get it from the stream metadata, use it parseable_data.parse::()? } else if let Some(parseable_data) = format.try_get_duration() { // this is a std::time::duration // multiply that by frame rate and we get total frame count (approximate) log::warn!("fetching duration from format metadata..."); let seconds = parseable_data?.as_secs_f64(); (seconds * frame_rate) as u64 } else { log::warn!("file didn't provide frame metadata, calculating it ourselves..."); fetch_frame_count_full_decode(path)? }) } #[debug_handler] async fn upload_file( options: Query, mut multipart: Multipart, ) -> Result<(StatusCode, Json), AppError> { let mut maybe_file_contents: Option = None; let mut maybe_file_type: Option = None; let mut maybe_file_name: Option = None; while let Some(field) = multipart.next_field().await.unwrap() { let name = field.name().unwrap().to_string(); let content_type = field.content_type().unwrap().to_string(); let filename = field.file_name().unwrap().to_string(); let data = field.bytes().await.unwrap(); log::info!("Length of `{}` is {} bytes", name, data.len()); if name == "file" { maybe_file_contents = Some(data); maybe_file_type = Some(content_type); maybe_file_name = Some(filename); } } if let Some(file_contents) = maybe_file_contents { let file_type = maybe_file_type.unwrap(); let file_name = maybe_file_name.unwrap(); let is_video = file_type.starts_with("video/") || file_name.ends_with(".mp4") || file_name.ends_with(".gif") || file_name.ends_with(".webm"); if is_video { let mut final_tag_set = HashSet::new(); let mut temp_file = tempfile::NamedTempFile::new()?; temp_file.write_all(&file_contents.to_vec())?; log::debug!("tmp path: {:?}", temp_file.path()); let info = ffprobe::ffprobe(temp_file.path())?; let stream = info.streams.get(0).unwrap(); log::debug!("stream = {:?}", stream); log::debug!("format = {:?}", info.format); let frame_rate_str = stream.r_frame_rate.clone(); let parts = frame_rate_str.split("/").into_iter().collect::>(); let frame_rate: f64 = parts.get(0).unwrap().parse::()? / parts.get(1).unwrap().parse::()?; let total_frame_count = calculate_frame_count(temp_file.path(), &stream, &info.format, frame_rate)?; let total_length_in_seconds = total_frame_count as f64 / frame_rate; let wanted_frame_skip_seconds = match total_length_in_seconds as usize { 0..=10 => 2, 11..=60 => 10, 61..=120 => 15, 121..=300 => 20, 301..=1000 => 30, 1001..=1200 => 40, 1201.. => 60, _ => 63, } as f64; let wanted_frame_skip = wanted_frame_skip_seconds * frame_rate; let temporary_frame_dir = tempfile::tempdir()?; let temporary_frame_path = format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy()); log::info!("frame path: '{}'", &temporary_frame_path); log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds); for frame_number in (0..total_frame_count).step_by(wanted_frame_skip as usize) { log::info!("extracting frame {:?}", frame_number); fetch_frame_as_image( temp_file.path().to_str().unwrap(), &temporary_frame_path, frame_number.try_into().unwrap(), frame_rate, ) .await?; log::info!("extracted frame {:?}", frame_number); let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?; let mut frame_data = vec![]; actual_frame_file.read_to_end(&mut frame_data)?; log::info!("sending frame {:?}", frame_number); let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) = send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options) .await? { tags_from_frame } else { todo!() }; for tag in tags_from_frame { final_tag_set.insert(tag); } } let response = WrappedResponse::Tags(final_tag_set.into_iter().collect::>()); Ok((StatusCode::OK, Json(response))) } else { let json_response = send_image_to_dd(file_contents.to_vec(), file_name, &file_type, &options).await?; Ok((StatusCode::OK, Json(json_response))) } } else { Ok(( StatusCode::INTERNAL_SERVER_ERROR, Json(WrappedResponse::Error( "no file found in request".to_string(), )), )) } }