Compare commits

...

4 commits

Author SHA1 Message Date
6d917bdf1f allow mkv 2024-08-04 18:37:00 -03:00
657436f80d support new api 2024-08-04 18:36:56 -03:00
ee275f0fa5 bump body limit lol 2024-08-04 18:36:49 -03:00
22f79bbdf1 let model be configurable 2024-08-04 18:35:21 -03:00

View file

@ -31,13 +31,20 @@ fn get_upstream_runner() -> Runner {
} }
} }
fn get_sd_model() -> String {
match std::env::var("SD_MODEL") {
Ok(value) => value,
Err(_) => "wd14-vit-v2-git".to_string(),
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
pretty_env_logger::init(); pretty_env_logger::init();
// build our application with a single route // build our application with a single route
let app = Router::new() let app = Router::new()
.route("/", post(upload_file)) .route("/", post(upload_file))
.layer(axum::extract::DefaultBodyLimit::max(5 * 1024 * 1024 * 1024)); .layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024 * 1024));
let upstream_runner = get_upstream_runner(); let upstream_runner = get_upstream_runner();
@ -88,7 +95,13 @@ where
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct WD14Response { struct WD14Response {
caption: HashMap<String, f32>, caption: WD14ResponseTagHolder,
}
#[derive(Serialize, Deserialize)]
struct WD14ResponseTagHolder {
tag: HashMap<String, f32>,
rating: HashMap<String, f32>,
} }
async fn send_image_to_dd( async fn send_image_to_dd(
@ -130,7 +143,9 @@ async fn send_image_to_dd(
let mut map: HashMap<&str, &str> = HashMap::new(); let mut map: HashMap<&str, &str> = HashMap::new();
let file_base64 = general_purpose::STANDARD.encode(file_contents.clone()); let file_base64 = general_purpose::STANDARD.encode(file_contents.clone());
map.insert("model", "wd14-vit-v2-git"); let sd_model = get_sd_model();
map.insert("model", &sd_model);
map.insert("threshold", options.threshold.as_ref()); map.insert("threshold", options.threshold.as_ref());
map.insert("image", &file_base64); map.insert("image", &file_base64);
@ -150,7 +165,10 @@ async fn send_image_to_dd(
// turn WD14Response into WrappedResponse // turn WD14Response into WrappedResponse
let mut tags = Vec::<String>::new(); let mut tags = Vec::<String>::new();
for ele in json_response.caption { for ele in json_response.caption.tag {
tags.push(ele.0.clone());
}
for ele in json_response.caption.rating {
tags.push(ele.0.clone()); tags.push(ele.0.clone());
} }
@ -321,6 +339,7 @@ async fn upload_file(
let is_video = file_type.starts_with("video/") let is_video = file_type.starts_with("video/")
|| file_name.ends_with(".mp4") || file_name.ends_with(".mp4")
|| file_name.ends_with(".gif") || file_name.ends_with(".gif")
|| file_name.ends_with(".mkv")
|| file_name.ends_with(".webm"); || file_name.ends_with(".webm");
if is_video { if is_video {
let mut final_tag_set = HashSet::new(); let mut final_tag_set = HashSet::new();