add support for stable-diffusion-webui-wd14-tagger extension
This commit is contained in:
parent
7516e47bb7
commit
49fb8f4763
2 changed files with 82 additions and 18 deletions
|
@ -9,6 +9,7 @@ edition = "2021"
|
|||
anyhow = "1.0.70"
|
||||
axum = { version = "0.6.12", features = ["tokio", "multipart"] }
|
||||
axum-macros = "0.3.7"
|
||||
base64 = "0.21.2"
|
||||
ffmpeg-cli = "0.1.0"
|
||||
ffprobe = "0.3.3"
|
||||
filepath = "0.1.2"
|
||||
|
|
99
src/main.rs
99
src/main.rs
|
@ -4,13 +4,32 @@ use axum::response::{IntoResponse, Response};
|
|||
use axum::routing::post;
|
||||
use axum::{Json, Router};
|
||||
use axum_macros::debug_handler;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use core::panic;
|
||||
use ffmpeg_cli::Parameter;
|
||||
use futures_util::{future::ready, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
use std::process::Stdio;
|
||||
use std::rc::Rc;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Runner {
|
||||
DeepDanbooru(String),
|
||||
StableDiffusionWebUI(String),
|
||||
}
|
||||
|
||||
fn get_upstream_runner() -> Runner {
|
||||
match std::env::var("DD_ADDRESS") {
|
||||
Ok(value) => Runner::DeepDanbooru(value),
|
||||
Err(_) => match std::env::var("SD_ADDRESS") {
|
||||
Ok(value) => Runner::StableDiffusionWebUI(value),
|
||||
Err(_) => panic!("shit no addr"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
|
@ -18,10 +37,12 @@ async fn main() {
|
|||
// build our application with a single route
|
||||
let app = Router::new()
|
||||
.route("/", post(upload_file))
|
||||
.layer(axum::extract::DefaultBodyLimit::max(100 * 1024 * 1024));
|
||||
.layer(axum::extract::DefaultBodyLimit::max(300 * 1024 * 1024));
|
||||
|
||||
let upstream_runner = get_upstream_runner();
|
||||
|
||||
// run it with hyper on localhost:3000
|
||||
log::info!("running on 0.0.0.0:6679 to localhost:4443");
|
||||
log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner);
|
||||
axum::Server::bind(&"0.0.0.0:6679".parse().unwrap())
|
||||
.serve(app.into_make_service())
|
||||
.await
|
||||
|
@ -74,34 +95,76 @@ async fn test_handler() -> Result<(), AppError> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct WD14Response {
|
||||
caption: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
async fn send_image_to_dd(
|
||||
file_contents: Vec<u8>,
|
||||
file_name: String,
|
||||
file_mime_type: &str,
|
||||
options: &Options,
|
||||
) -> anyhow::Result<WrappedResponse> {
|
||||
let part = reqwest::multipart::Part::bytes(file_contents)
|
||||
let part = reqwest::multipart::Part::bytes(file_contents.clone())
|
||||
.file_name(file_name)
|
||||
.mime_str(file_mime_type)
|
||||
.unwrap();
|
||||
let form = reqwest::multipart::Form::new().part("file", part);
|
||||
|
||||
log::debug!("calling dd");
|
||||
let upstream_runner = get_upstream_runner();
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post("http://localhost:4443")
|
||||
.multipart(form)
|
||||
.header("authorization", "Bearer 123")
|
||||
.query(&[("threshold", options.threshold.clone())])
|
||||
.send()
|
||||
.await?;
|
||||
match upstream_runner {
|
||||
Runner::DeepDanbooru(url) => {
|
||||
log::debug!("calling dd");
|
||||
|
||||
let body = resp.text().await?;
|
||||
log::info!("body: {}", &body);
|
||||
let json_response: WrappedResponse = serde_json::from_str(&body)?;
|
||||
let resp = reqwest::Client::new()
|
||||
.post(url)
|
||||
.multipart(form)
|
||||
.header("authorization", "Bearer 123")
|
||||
.query(&[("threshold", options.threshold.clone())])
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
log::debug!("called!");
|
||||
Ok(json_response)
|
||||
let body = resp.text().await?;
|
||||
log::info!("body: {}", &body);
|
||||
let json_response: WrappedResponse = serde_json::from_str(&body)?;
|
||||
|
||||
log::debug!("called!");
|
||||
Ok(json_response)
|
||||
}
|
||||
Runner::StableDiffusionWebUI(url) => {
|
||||
log::debug!("calling sd");
|
||||
|
||||
let mut map: HashMap<&str, &str> = HashMap::new();
|
||||
let file_base64 = general_purpose::STANDARD.encode(file_contents.clone());
|
||||
|
||||
map.insert("model", "wd14-vit-v2-git");
|
||||
map.insert("threshold", options.threshold.as_ref());
|
||||
map.insert("image", &file_base64);
|
||||
|
||||
let serialized_map = serde_json::to_vec(&map).unwrap();
|
||||
|
||||
let resp = reqwest::Client::new()
|
||||
.post(format!("{}/tagger/v1/interrogate", url))
|
||||
.body(serialized_map)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body = resp.text().await?;
|
||||
log::info!("body: {}", &body);
|
||||
let json_response: WD14Response = serde_json::from_str(&body)?;
|
||||
|
||||
// turn WD14Response into WrappedResponse
|
||||
let mut tags = Vec::<String>::new();
|
||||
for ele in json_response.caption {
|
||||
tags.push(ele.0.clone());
|
||||
}
|
||||
|
||||
log::debug!("called!");
|
||||
Ok(WrappedResponse::Tags(tags))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_frame_as_image(
|
||||
|
@ -139,7 +202,7 @@ async fn fetch_frame_as_image(
|
|||
})
|
||||
.await;
|
||||
|
||||
log::debug!("run");
|
||||
log::debug!("run!");
|
||||
let output = ffmpeg.process.wait_with_output().unwrap();
|
||||
|
||||
log::debug!("out");
|
||||
|
|
Loading…
Reference in a new issue