diff --git a/src/main.rs b/src/main.rs index fa0e5f5..2d8649c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,13 +31,20 @@ fn get_upstream_runner() -> Runner { } } +fn get_sd_model() -> String { + match std::env::var("SD_MODEL") { + Ok(value) => value, + Err(_) => "wd14-vit-v2-git".to_string(), + } +} + #[tokio::main] async fn main() { pretty_env_logger::init(); // build our application with a single route let app = Router::new() .route("/", post(upload_file)) - .layer(axum::extract::DefaultBodyLimit::max(5 * 1024 * 1024 * 1024)); + .layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024 * 1024)); let upstream_runner = get_upstream_runner(); @@ -88,7 +95,13 @@ where #[derive(Serialize, Deserialize)] struct WD14Response { - caption: HashMap, + caption: WD14ResponseTagHolder, +} + +#[derive(Serialize, Deserialize)] +struct WD14ResponseTagHolder { + tag: HashMap, + rating: HashMap, } async fn send_image_to_dd( @@ -130,7 +143,9 @@ async fn send_image_to_dd( let mut map: HashMap<&str, &str> = HashMap::new(); let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); - map.insert("model", "wd14-vit-v2-git"); + let sd_model = get_sd_model(); + + map.insert("model", &sd_model); map.insert("threshold", options.threshold.as_ref()); map.insert("image", &file_base64); @@ -150,7 +165,10 @@ async fn send_image_to_dd( // turn WD14Response into WrappedResponse let mut tags = Vec::::new(); - for ele in json_response.caption { + for ele in json_response.caption.tag { + tags.push(ele.0.clone()); + } + for ele in json_response.caption.rating { tags.push(ele.0.clone()); } @@ -321,6 +339,7 @@ async fn upload_file( let is_video = file_type.starts_with("video/") || file_name.ends_with(".mp4") || file_name.ends_with(".gif") + || file_name.ends_with(".mkv") || file_name.ends_with(".webm"); if is_video { let mut final_tag_set = HashSet::new();