Compare commits

...

4 commits

Author SHA1 Message Date
6d917bdf1f allow mkv 2024-08-04 18:37:00 -03:00
657436f80d support new api 2024-08-04 18:36:56 -03:00
ee275f0fa5 bump body limit lol 2024-08-04 18:36:49 -03:00
22f79bbdf1 let model be configurable 2024-08-04 18:35:21 -03:00

View file

@ -31,13 +31,20 @@ fn get_upstream_runner() -> Runner {
}
}
fn get_sd_model() -> String {
match std::env::var("SD_MODEL") {
Ok(value) => value,
Err(_) => "wd14-vit-v2-git".to_string(),
}
}
#[tokio::main]
async fn main() {
pretty_env_logger::init();
// build our application with a single route
let app = Router::new()
.route("/", post(upload_file))
.layer(axum::extract::DefaultBodyLimit::max(5 * 1024 * 1024 * 1024));
.layer(axum::extract::DefaultBodyLimit::max(8 * 1024 * 1024 * 1024));
let upstream_runner = get_upstream_runner();
@ -88,7 +95,13 @@ where
#[derive(Serialize, Deserialize)]
struct WD14Response {
caption: HashMap<String, f32>,
caption: WD14ResponseTagHolder,
}
#[derive(Serialize, Deserialize)]
struct WD14ResponseTagHolder {
tag: HashMap<String, f32>,
rating: HashMap<String, f32>,
}
async fn send_image_to_dd(
@ -130,7 +143,9 @@ async fn send_image_to_dd(
let mut map: HashMap<&str, &str> = HashMap::new();
let file_base64 = general_purpose::STANDARD.encode(file_contents.clone());
map.insert("model", "wd14-vit-v2-git");
let sd_model = get_sd_model();
map.insert("model", &sd_model);
map.insert("threshold", options.threshold.as_ref());
map.insert("image", &file_base64);
@ -150,7 +165,10 @@ async fn send_image_to_dd(
// turn WD14Response into WrappedResponse
let mut tags = Vec::<String>::new();
for ele in json_response.caption {
for ele in json_response.caption.tag {
tags.push(ele.0.clone());
}
for ele in json_response.caption.rating {
tags.push(ele.0.clone());
}
@ -321,6 +339,7 @@ async fn upload_file(
let is_video = file_type.starts_with("video/")
|| file_name.ends_with(".mp4")
|| file_name.ends_with(".gif")
|| file_name.ends_with(".mkv")
|| file_name.ends_with(".webm");
if is_video {
let mut final_tag_set = HashSet::new();