diff --git a/Cargo.toml b/Cargo.toml index 74f6016..de1d817 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ edition = "2021" anyhow = "1.0.70" axum = { version = "0.6.12", features = ["tokio", "multipart"] } axum-macros = "0.3.7" +base64 = "0.21.2" ffmpeg-cli = "0.1.0" ffprobe = "0.3.3" filepath = "0.1.2" diff --git a/src/main.rs b/src/main.rs index db459ac..54828bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,13 +4,32 @@ use axum::response::{IntoResponse, Response}; use axum::routing::post; use axum::{Json, Router}; use axum_macros::debug_handler; +use base64::{engine::general_purpose, Engine as _}; +use core::panic; use ffmpeg_cli::Parameter; use futures_util::{future::ready, StreamExt}; use serde::{Deserialize, Serialize}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::io::Read; use std::io::Write; use std::process::Stdio; +use std::rc::Rc; + +#[derive(Debug)] +enum Runner { + DeepDanbooru(String), + StableDiffusionWebUI(String), +} + +fn get_upstream_runner() -> Runner { + match std::env::var("DD_ADDRESS") { + Ok(value) => Runner::DeepDanbooru(value), + Err(_) => match std::env::var("SD_ADDRESS") { + Ok(value) => Runner::StableDiffusionWebUI(value), + Err(_) => panic!("shit no addr"), + }, + } +} #[tokio::main] async fn main() { @@ -18,10 +37,12 @@ async fn main() { // build our application with a single route let app = Router::new() .route("/", post(upload_file)) - .layer(axum::extract::DefaultBodyLimit::max(100 * 1024 * 1024)); + .layer(axum::extract::DefaultBodyLimit::max(300 * 1024 * 1024)); + + let upstream_runner = get_upstream_runner(); // run it with hyper on localhost:3000 - log::info!("running on 0.0.0.0:6679 to localhost:4443"); + log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner); axum::Server::bind(&"0.0.0.0:6679".parse().unwrap()) .serve(app.into_make_service()) .await @@ -74,34 +95,76 @@ async fn test_handler() -> Result<(), AppError> { Ok(()) } +#[derive(Serialize, Deserialize)] +struct WD14Response { + caption: HashMap, +} + async fn send_image_to_dd( file_contents: Vec, file_name: String, file_mime_type: &str, options: &Options, ) -> anyhow::Result { - let part = reqwest::multipart::Part::bytes(file_contents) + let part = reqwest::multipart::Part::bytes(file_contents.clone()) .file_name(file_name) .mime_str(file_mime_type) .unwrap(); let form = reqwest::multipart::Form::new().part("file", part); - log::debug!("calling dd"); + let upstream_runner = get_upstream_runner(); - let resp = reqwest::Client::new() - .post("http://localhost:4443") - .multipart(form) - .header("authorization", "Bearer 123") - .query(&[("threshold", options.threshold.clone())]) - .send() - .await?; + match upstream_runner { + Runner::DeepDanbooru(url) => { + log::debug!("calling dd"); - let body = resp.text().await?; - log::info!("body: {}", &body); - let json_response: WrappedResponse = serde_json::from_str(&body)?; + let resp = reqwest::Client::new() + .post(url) + .multipart(form) + .header("authorization", "Bearer 123") + .query(&[("threshold", options.threshold.clone())]) + .send() + .await?; - log::debug!("called!"); - Ok(json_response) + let body = resp.text().await?; + log::info!("body: {}", &body); + let json_response: WrappedResponse = serde_json::from_str(&body)?; + + log::debug!("called!"); + Ok(json_response) + } + Runner::StableDiffusionWebUI(url) => { + log::debug!("calling sd"); + + let mut map: HashMap<&str, &str> = HashMap::new(); + let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); + + map.insert("model", "wd14-vit-v2-git"); + map.insert("threshold", options.threshold.as_ref()); + map.insert("image", &file_base64); + + let serialized_map = serde_json::to_vec(&map).unwrap(); + + let resp = reqwest::Client::new() + .post(format!("{}/tagger/v1/interrogate", url)) + .body(serialized_map) + .send() + .await?; + + let body = resp.text().await?; + log::info!("body: {}", &body); + let json_response: WD14Response = serde_json::from_str(&body)?; + + // turn WD14Response into WrappedResponse + let mut tags = Vec::::new(); + for ele in json_response.caption { + tags.push(ele.0.clone()); + } + + log::debug!("called!"); + Ok(WrappedResponse::Tags(tags)) + } + } } async fn fetch_frame_as_image( @@ -139,7 +202,7 @@ async fn fetch_frame_as_image( }) .await; - log::debug!("run"); + log::debug!("run!"); let output = ffmpeg.process.wait_with_output().unwrap(); log::debug!("out");