Compare commits

...

4 commits

Author SHA1 Message Date
61250ffaa2 use faster seeking in ffmpeg 2023-06-11 23:00:12 -03:00
1875d3fd48 add gif to allowed extensions 2023-06-11 23:00:03 -03:00
49fb8f4763 add support for stable-diffusion-webui-wd14-tagger extension 2023-06-11 22:59:44 -03:00
7516e47bb7 add log call 2023-06-11 18:05:48 -03:00
2 changed files with 108 additions and 33 deletions

View file

@ -9,6 +9,7 @@ edition = "2021"
anyhow = "1.0.70" anyhow = "1.0.70"
axum = { version = "0.6.12", features = ["tokio", "multipart"] } axum = { version = "0.6.12", features = ["tokio", "multipart"] }
axum-macros = "0.3.7" axum-macros = "0.3.7"
base64 = "0.21.2"
ffmpeg-cli = "0.1.0" ffmpeg-cli = "0.1.0"
ffprobe = "0.3.3" ffprobe = "0.3.3"
filepath = "0.1.2" filepath = "0.1.2"

View file

@ -4,13 +4,32 @@ use axum::response::{IntoResponse, Response};
use axum::routing::post; use axum::routing::post;
use axum::{Json, Router}; use axum::{Json, Router};
use axum_macros::debug_handler; use axum_macros::debug_handler;
use base64::{engine::general_purpose, Engine as _};
use core::panic;
use ffmpeg_cli::Parameter; use ffmpeg_cli::Parameter;
use futures_util::{future::ready, StreamExt}; use futures_util::{future::ready, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashSet; use std::collections::{HashMap, HashSet};
use std::io::Read; use std::io::Read;
use std::io::Write; use std::io::Write;
use std::process::Stdio; use std::process::Stdio;
use std::rc::Rc;
#[derive(Debug)]
enum Runner {
DeepDanbooru(String),
StableDiffusionWebUI(String),
}
fn get_upstream_runner() -> Runner {
match std::env::var("DD_ADDRESS") {
Ok(value) => Runner::DeepDanbooru(value),
Err(_) => match std::env::var("SD_ADDRESS") {
Ok(value) => Runner::StableDiffusionWebUI(value),
Err(_) => panic!("shit no addr"),
},
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -18,9 +37,12 @@ async fn main() {
// build our application with a single route // build our application with a single route
let app = Router::new() let app = Router::new()
.route("/", post(upload_file)) .route("/", post(upload_file))
.layer(axum::extract::DefaultBodyLimit::max(100 * 1024 * 1024)); .layer(axum::extract::DefaultBodyLimit::max(300 * 1024 * 1024));
let upstream_runner = get_upstream_runner();
// run it with hyper on localhost:3000 // run it with hyper on localhost:3000
log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner);
axum::Server::bind(&"0.0.0.0:6679".parse().unwrap()) axum::Server::bind(&"0.0.0.0:6679".parse().unwrap())
.serve(app.into_make_service()) .serve(app.into_make_service())
.await .await
@ -73,52 +95,97 @@ async fn test_handler() -> Result<(), AppError> {
Ok(()) Ok(())
} }
#[derive(Serialize, Deserialize)]
struct WD14Response {
caption: HashMap<String, f32>,
}
async fn send_image_to_dd( async fn send_image_to_dd(
file_contents: Vec<u8>, file_contents: Vec<u8>,
file_name: String, file_name: String,
file_mime_type: &str, file_mime_type: &str,
options: &Options, options: &Options,
) -> anyhow::Result<WrappedResponse> { ) -> anyhow::Result<WrappedResponse> {
let part = reqwest::multipart::Part::bytes(file_contents) let part = reqwest::multipart::Part::bytes(file_contents.clone())
.file_name(file_name) .file_name(file_name)
.mime_str(file_mime_type) .mime_str(file_mime_type)
.unwrap(); .unwrap();
let form = reqwest::multipart::Form::new().part("file", part); let form = reqwest::multipart::Form::new().part("file", part);
log::debug!("calling dd"); let upstream_runner = get_upstream_runner();
let resp = reqwest::Client::new() match upstream_runner {
.post("http://localhost:4443") Runner::DeepDanbooru(url) => {
.multipart(form) log::debug!("calling dd");
.header("authorization", "Bearer 123")
.query(&[("threshold", options.threshold.clone())])
.send()
.await?;
let body = resp.text().await?; let resp = reqwest::Client::new()
log::info!("body: {}", &body); .post(url)
let json_response: WrappedResponse = serde_json::from_str(&body)?; .multipart(form)
.header("authorization", "Bearer 123")
.query(&[("threshold", options.threshold.clone())])
.send()
.await?;
log::debug!("called!"); let body = resp.text().await?;
Ok(json_response) log::info!("body: {}", &body);
let json_response: WrappedResponse = serde_json::from_str(&body)?;
log::debug!("called!");
Ok(json_response)
}
Runner::StableDiffusionWebUI(url) => {
log::debug!("calling sd");
let mut map: HashMap<&str, &str> = HashMap::new();
let file_base64 = general_purpose::STANDARD.encode(file_contents.clone());
map.insert("model", "wd14-vit-v2-git");
map.insert("threshold", options.threshold.as_ref());
map.insert("image", &file_base64);
let serialized_map = serde_json::to_vec(&map).unwrap();
let resp = reqwest::Client::new()
.post(format!("{}/tagger/v1/interrogate", url))
.body(serialized_map)
.send()
.await?;
let body = resp.text().await?;
log::info!("body: {}", &body);
let json_response: WD14Response = serde_json::from_str(&body)?;
// turn WD14Response into WrappedResponse
let mut tags = Vec::<String>::new();
for ele in json_response.caption {
tags.push(ele.0.clone());
}
log::debug!("called!");
Ok(WrappedResponse::Tags(tags))
}
}
} }
async fn fetch_frame_as_image( async fn fetch_frame_as_image(
input_path: &str, input_path: &str,
output_path: &str, output_path: &str,
frame_index: usize, frame_index: usize,
frame_rate: f64, // X/1sec
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let frame_index_param = format!("select=eq(n\\,{})", frame_index); let frame_index_param = format!("select=eq(n\\,{})", frame_index);
let timeline_index: f64 = frame_index as f64 / frame_rate;
let timeline_index_param = format!("{:.5}", timeline_index);
let builder = ffmpeg_cli::FfmpegBuilder::new() let builder = ffmpeg_cli::FfmpegBuilder::new()
.stderr(Stdio::piped()) .stderr(Stdio::piped())
.option(Parameter::Single("nostdin")) .option(Parameter::Single("nostdin"))
.option(Parameter::Single("y")) .option(Parameter::Single("y"))
.input(ffmpeg_cli::File::new(input_path)) // fast seeking with -ss instead of select filter
.output( .input(
ffmpeg_cli::File::new(output_path) ffmpeg_cli::File::new(input_path)
.option(Parameter::KeyValue("vf", &frame_index_param)) .option(Parameter::KeyValue("ss", &timeline_index_param)),
.option(Parameter::KeyValue("vframes", "1")), )
); .output(ffmpeg_cli::File::new(output_path).option(Parameter::KeyValue("vframes", "1")));
log::debug!("running {:?}", builder); log::debug!("running {:?}", builder);
let ffmpeg = builder.run().await.unwrap(); let ffmpeg = builder.run().await.unwrap();
@ -138,7 +205,7 @@ async fn fetch_frame_as_image(
}) })
.await; .await;
log::debug!("run"); log::debug!("run!");
let output = ffmpeg.process.wait_with_output().unwrap(); let output = ffmpeg.process.wait_with_output().unwrap();
log::debug!("out"); log::debug!("out");
@ -177,7 +244,9 @@ async fn upload_file(
if let Some(file_contents) = maybe_file_contents { if let Some(file_contents) = maybe_file_contents {
let file_type = maybe_file_type.unwrap(); let file_type = maybe_file_type.unwrap();
let file_name = maybe_file_name.unwrap(); let file_name = maybe_file_name.unwrap();
let is_video = file_type.starts_with("video/") || file_name.ends_with(".mp4"); let is_video = file_type.starts_with("video/")
|| file_name.ends_with(".mp4")
|| file_name.ends_with(".gif");
if is_video { if is_video {
let mut final_tag_set = HashSet::new(); let mut final_tag_set = HashSet::new();
@ -193,36 +262,41 @@ async fn upload_file(
let frame_rate_str = stream.r_frame_rate.clone(); let frame_rate_str = stream.r_frame_rate.clone();
let parts = frame_rate_str.split("/").into_iter().collect::<Vec<_>>(); let parts = frame_rate_str.split("/").into_iter().collect::<Vec<_>>();
let frame_rate = let frame_rate: f64 =
parts.get(0).unwrap().parse::<u32>()? / parts.get(1).unwrap().parse::<u32>()?; parts.get(0).unwrap().parse::<f64>()? / parts.get(1).unwrap().parse::<f64>()?;
let total_length_in_seconds = total_frame_count / frame_rate; let total_length_in_seconds = total_frame_count as f64 / frame_rate;
let wanted_frame_skip_seconds = match total_length_in_seconds { let wanted_frame_skip_seconds = match total_length_in_seconds as usize {
0..=10 => 2, 0..=10 => 2,
11..=60 => 10, 11..=60 => 10,
61..=120 => 15, 61..=120 => 15,
121.. => 20, 121..=300 => 20,
}; 301.. => 30,
let wanted_frame_skip = (wanted_frame_skip_seconds * frame_rate).try_into().unwrap(); _ => 33,
} as f64;
let wanted_frame_skip = wanted_frame_skip_seconds * frame_rate;
let temporary_frame_dir = tempfile::tempdir()?; let temporary_frame_dir = tempfile::tempdir()?;
let temporary_frame_path = let temporary_frame_path =
format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy()); format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy());
log::info!("path: '{}'", &temporary_frame_path); log::info!("frame path: '{}'", &temporary_frame_path);
log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds); log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds);
for frame_number in (0..total_frame_count).step_by(wanted_frame_skip) { for frame_number in (0..total_frame_count).step_by(wanted_frame_skip as usize) {
log::info!("extracting frame {:?}", frame_number); log::info!("extracting frame {:?}", frame_number);
fetch_frame_as_image( fetch_frame_as_image(
temp_file.path().to_str().unwrap(), temp_file.path().to_str().unwrap(),
&temporary_frame_path, &temporary_frame_path,
frame_number.try_into().unwrap(), frame_number.try_into().unwrap(),
frame_rate,
) )
.await?; .await?;
log::info!("extracted frame {:?}", frame_number);
let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?; let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?;
let mut frame_data = vec![]; let mut frame_data = vec![];
actual_frame_file.read_to_end(&mut frame_data)?; actual_frame_file.read_to_end(&mut frame_data)?;
log::info!("sending frame {:?}", frame_number);
let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) = let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) =
send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options) send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options)
.await? .await?