416 lines
14 KiB
Rust
416 lines
14 KiB
Rust
use axum::extract::{Multipart, Query};
|
|
use axum::http::StatusCode;
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::routing::post;
|
|
use axum::{Json, Router};
|
|
use axum_macros::debug_handler;
|
|
use base64::{engine::general_purpose, Engine as _};
|
|
use core::panic;
|
|
use ffmpeg_cli::Parameter;
|
|
use ffprobe::{Format, Stream};
|
|
use futures_util::{future::ready, StreamExt};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::io::Read;
|
|
use std::io::Write;
|
|
use std::process::Stdio;
|
|
|
|
#[derive(Debug)]
|
|
enum Runner {
|
|
DeepDanbooru(String),
|
|
StableDiffusionWebUI(String),
|
|
}
|
|
|
|
fn get_upstream_runner() -> Runner {
|
|
match std::env::var("DD_ADDRESS") {
|
|
Ok(value) => Runner::DeepDanbooru(value),
|
|
Err(_) => match std::env::var("SD_ADDRESS") {
|
|
Ok(value) => Runner::StableDiffusionWebUI(value),
|
|
Err(_) => panic!("shit no addr"),
|
|
},
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
pretty_env_logger::init();
|
|
// build our application with a single route
|
|
let app = Router::new()
|
|
.route("/", post(upload_file))
|
|
.layer(axum::extract::DefaultBodyLimit::max(5 * 1024 * 1024 * 1024));
|
|
|
|
let upstream_runner = get_upstream_runner();
|
|
|
|
// run it with hyper on localhost:3000
|
|
log::info!("running on 0.0.0.0:6679 to {:?}", upstream_runner);
|
|
axum::Server::bind(&"0.0.0.0:6679".parse().unwrap())
|
|
.serve(app.into_make_service())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct Options {
|
|
threshold: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
#[serde(untagged)]
|
|
enum WrappedResponse {
|
|
Tags(Vec<String>),
|
|
Error(String),
|
|
}
|
|
|
|
// Make our own error that wraps `anyhow::Error`.
|
|
struct AppError(anyhow::Error);
|
|
|
|
// Tell axum how to convert `AppError` into a response.
|
|
impl IntoResponse for AppError {
|
|
fn into_response(self) -> Response {
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
format!("Something went wrong: {}", self.0),
|
|
)
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
|
|
// `Result<_, AppError>`. That way you don't need to do that manually.
|
|
impl<E> From<E> for AppError
|
|
where
|
|
E: Into<anyhow::Error>,
|
|
{
|
|
fn from(err: E) -> Self {
|
|
Self(err.into())
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize)]
|
|
struct WD14Response {
|
|
caption: HashMap<String, f32>,
|
|
}
|
|
|
|
async fn send_image_to_dd(
|
|
file_contents: Vec<u8>,
|
|
file_name: String,
|
|
file_mime_type: &str,
|
|
options: &Options,
|
|
) -> anyhow::Result<WrappedResponse> {
|
|
let part = reqwest::multipart::Part::bytes(file_contents.clone())
|
|
.file_name(file_name)
|
|
.mime_str(file_mime_type)
|
|
.unwrap();
|
|
let form = reqwest::multipart::Form::new().part("file", part);
|
|
|
|
let upstream_runner = get_upstream_runner();
|
|
|
|
match upstream_runner {
|
|
Runner::DeepDanbooru(url) => {
|
|
log::debug!("calling dd");
|
|
|
|
let resp = reqwest::Client::new()
|
|
.post(url)
|
|
.multipart(form)
|
|
.header("authorization", "Bearer 123")
|
|
.query(&[("threshold", options.threshold.clone())])
|
|
.send()
|
|
.await?;
|
|
|
|
let body = resp.text().await?;
|
|
log::info!("sd body: {}", &body);
|
|
let json_response: WrappedResponse = serde_json::from_str(&body)?;
|
|
|
|
log::debug!("called!");
|
|
Ok(json_response)
|
|
}
|
|
Runner::StableDiffusionWebUI(url) => {
|
|
log::debug!("calling sd");
|
|
|
|
let mut map: HashMap<&str, &str> = HashMap::new();
|
|
let file_base64 = general_purpose::STANDARD.encode(file_contents.clone());
|
|
|
|
map.insert("model", "wd14-vit-v2-git");
|
|
map.insert("threshold", options.threshold.as_ref());
|
|
map.insert("image", &file_base64);
|
|
|
|
let serialized_map = serde_json::to_vec(&map).unwrap();
|
|
let len = serialized_map.len();
|
|
log::info!("wd14 request length {} bytes", len);
|
|
|
|
let resp = reqwest::Client::new()
|
|
.post(format!("{}/tagger/v1/interrogate", url))
|
|
.body(serialized_map)
|
|
.send()
|
|
.await?;
|
|
|
|
let body = resp.text().await?;
|
|
log::info!("wd14 body: {}", &body);
|
|
let json_response: WD14Response = serde_json::from_str(&body)?;
|
|
|
|
// turn WD14Response into WrappedResponse
|
|
let mut tags = Vec::<String>::new();
|
|
for ele in json_response.caption {
|
|
tags.push(ele.0.clone());
|
|
}
|
|
|
|
log::debug!("called!");
|
|
Ok(WrappedResponse::Tags(tags))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn fetch_frame_as_image(
|
|
input_path: &str,
|
|
output_path: &str,
|
|
frame_index: usize,
|
|
frame_rate: f64, // X/1sec
|
|
) -> anyhow::Result<()> {
|
|
let timeline_index: f64 = frame_index as f64 / frame_rate;
|
|
let timeline_index_param = format!("{:.5}", timeline_index);
|
|
let builder = ffmpeg_cli::FfmpegBuilder::new()
|
|
.stderr(Stdio::piped())
|
|
.option(Parameter::Single("nostdin"))
|
|
.option(Parameter::Single("y"))
|
|
// fast seeking with -ss instead of select filter
|
|
.input(
|
|
ffmpeg_cli::File::new(input_path)
|
|
.option(Parameter::KeyValue("ss", &timeline_index_param)),
|
|
)
|
|
.output(ffmpeg_cli::File::new(output_path).option(Parameter::KeyValue("vframes", "1")));
|
|
|
|
log::debug!("running {:?}", builder);
|
|
let ffmpeg = builder.run().await.unwrap();
|
|
|
|
// TODO ffmpeg doesnt really provide a progress bar for this, its just
|
|
// a single frame being extracted. should ignoring x.unwrap() be the
|
|
// actual thing we gotta do?
|
|
|
|
log::debug!("run");
|
|
ffmpeg
|
|
.progress
|
|
.for_each(|x| {
|
|
log::debug!("progress x value = {:?}", &x);
|
|
// lmao
|
|
// x.unwrap();
|
|
ready(())
|
|
})
|
|
.await;
|
|
|
|
log::debug!("run!");
|
|
let output = ffmpeg.process.wait_with_output().unwrap();
|
|
|
|
log::debug!("out");
|
|
log::debug!(
|
|
"{}\nstderr:\n{}",
|
|
output.status,
|
|
std::str::from_utf8(&output.stderr).unwrap()
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn fetch_frame_count_full_decode(path: &std::path::Path) -> anyhow::Result<u64> {
|
|
let config = ffprobe::ConfigBuilder::new().count_frames(true).build();
|
|
let new_info = ffprobe::ffprobe_config(config, path)?;
|
|
let res = new_info
|
|
.streams
|
|
.get(0)
|
|
.unwrap()
|
|
.nb_read_frames
|
|
.clone()
|
|
.unwrap()
|
|
.parse::<u64>()?;
|
|
Ok(res)
|
|
}
|
|
|
|
fn calculate_frame_count(
|
|
path: &std::path::Path,
|
|
stream: &Stream,
|
|
format: &Format,
|
|
frame_rate: f64,
|
|
) -> anyhow::Result<u64> {
|
|
Ok(if let Some(parseable_data) = stream.nb_frames.clone() {
|
|
// if we can get it from the stream metadata, use it
|
|
parseable_data.parse::<u64>()?
|
|
} else if let Some(parseable_data) = format.try_get_duration() {
|
|
// this is a std::time::duration
|
|
// multiply that by frame rate and we get total frame count (approximate)
|
|
log::warn!("fetching duration from format metadata...");
|
|
let seconds = parseable_data?.as_secs_f64();
|
|
(seconds * frame_rate) as u64
|
|
} else {
|
|
log::warn!("file didn't provide frame metadata, calculating it ourselves...");
|
|
fetch_frame_count_full_decode(path)?
|
|
})
|
|
}
|
|
|
|
fn calculate_frame_rate(
|
|
temp_path: &std::path::Path,
|
|
frame_rate_str: String,
|
|
first_run: bool,
|
|
) -> anyhow::Result<f64> {
|
|
let parts = frame_rate_str.split("/").into_iter().collect::<Vec<_>>();
|
|
let frames_per = parts.get(0).unwrap().parse::<f64>()?;
|
|
let per_second = parts.get(1).unwrap().parse::<f64>()?;
|
|
if frames_per == 0.0f64 && per_second == 0.0f64 {
|
|
log::warn!("got incorrect frame rate, calling ffmpeg again...");
|
|
if !first_run {
|
|
std::panic!("couldnt get frame rate shit");
|
|
}
|
|
|
|
// call ffprobe directly
|
|
let mut cmd = std::process::Command::new("ffprobe");
|
|
let cmd = cmd.args(&[
|
|
"-v",
|
|
"error",
|
|
"-select_streams",
|
|
"v",
|
|
"-of",
|
|
"default=noprint_wrappers=1:nokey=1",
|
|
"-show_entries",
|
|
"stream=r_frame_rate",
|
|
]);
|
|
let cmd = cmd.arg(temp_path.to_str().unwrap());
|
|
|
|
cmd.get_args().for_each(|arg| {
|
|
log::debug!("arg {:?}", arg);
|
|
});
|
|
|
|
let output = cmd.output()?;
|
|
let possibly_new_frame_rate = String::from_utf8(output.stdout)?
|
|
.strip_suffix("\n")
|
|
.unwrap()
|
|
.to_string();
|
|
|
|
log::debug!("raw ffprobe gave {:?}", possibly_new_frame_rate);
|
|
|
|
calculate_frame_rate(temp_path.clone(), possibly_new_frame_rate, false)
|
|
} else {
|
|
Ok(frames_per / per_second)
|
|
}
|
|
}
|
|
|
|
#[debug_handler]
|
|
async fn upload_file(
|
|
options: Query<Options>,
|
|
mut multipart: Multipart,
|
|
) -> Result<(StatusCode, Json<WrappedResponse>), AppError> {
|
|
let mut maybe_file_contents: Option<axum::body::Bytes> = None;
|
|
let mut maybe_file_type: Option<String> = None;
|
|
let mut maybe_file_name: Option<String> = None;
|
|
|
|
while let Some(field) = multipart.next_field().await.unwrap() {
|
|
let name = field.name().unwrap().to_string();
|
|
let content_type = field.content_type().unwrap().to_string();
|
|
let filename = field.file_name().unwrap().to_string();
|
|
let data = field.bytes().await.unwrap();
|
|
|
|
log::info!("Length of `{}` is {} bytes", name, data.len());
|
|
if name == "file" {
|
|
maybe_file_contents = Some(data);
|
|
maybe_file_type = Some(content_type);
|
|
maybe_file_name = Some(filename);
|
|
}
|
|
}
|
|
|
|
if let Some(file_contents) = maybe_file_contents {
|
|
let file_type = maybe_file_type.unwrap();
|
|
let file_name = maybe_file_name.unwrap();
|
|
log::info!("file {} {}", file_type, file_name);
|
|
let is_video = file_type.starts_with("video/")
|
|
|| file_name.ends_with(".mp4")
|
|
|| file_name.ends_with(".gif")
|
|
|| file_name.ends_with(".webm");
|
|
if is_video {
|
|
let mut final_tag_set = HashSet::new();
|
|
|
|
let mut temp_file = tempfile::NamedTempFile::new()?;
|
|
temp_file.write_all(&file_contents.to_vec())?;
|
|
|
|
log::debug!("tmp path: {:?}", temp_file.path());
|
|
|
|
let info = ffprobe::ffprobe(temp_file.path())?;
|
|
let stream = info.streams.get(0).unwrap();
|
|
|
|
log::debug!("stream = {:?}", stream);
|
|
log::debug!("format = {:?}", info.format);
|
|
|
|
let frame_rate: f64 =
|
|
calculate_frame_rate(temp_file.path(), stream.r_frame_rate.clone(), true)?;
|
|
|
|
let total_frame_count =
|
|
calculate_frame_count(temp_file.path(), &stream, &info.format, frame_rate)?;
|
|
|
|
log::debug!("total frame count = {}", total_frame_count);
|
|
log::debug!("frame rate = {}", frame_rate);
|
|
|
|
let total_length_in_seconds = total_frame_count as f64 / frame_rate;
|
|
let wanted_frame_skip_seconds = match total_length_in_seconds as usize {
|
|
0..=10 => 2,
|
|
11..=60 => 10,
|
|
61..=120 => 15,
|
|
121..=300 => 20,
|
|
301..=1000 => 30,
|
|
1001..=1200 => 40,
|
|
1201.. => 60,
|
|
_ => 63,
|
|
} as f64;
|
|
let wanted_frame_skip = wanted_frame_skip_seconds * frame_rate;
|
|
|
|
let temporary_frame_dir = tempfile::tempdir()?;
|
|
let temporary_frame_path =
|
|
format!("{}/frame.png", temporary_frame_dir.path().to_string_lossy());
|
|
log::info!("frame path: '{}'", &temporary_frame_path);
|
|
log::info!("wanted_frame_skip: {}", &wanted_frame_skip_seconds);
|
|
|
|
for frame_number in (0..total_frame_count).step_by(wanted_frame_skip as usize) {
|
|
log::info!("extracting frame {:?}", frame_number);
|
|
fetch_frame_as_image(
|
|
temp_file.path().to_str().unwrap(),
|
|
&temporary_frame_path,
|
|
frame_number.try_into().unwrap(),
|
|
frame_rate,
|
|
)
|
|
.await?;
|
|
log::info!("extracted frame {:?}", frame_number);
|
|
|
|
let mut actual_frame_file = std::fs::File::open(&temporary_frame_path)?;
|
|
let mut frame_data = vec![];
|
|
actual_frame_file.read_to_end(&mut frame_data)?;
|
|
log::info!("sending frame {:?}", frame_number);
|
|
let tags_from_frame = if let WrappedResponse::Tags(tags_from_frame) =
|
|
send_image_to_dd(frame_data, "amongus.png".to_string(), "image/png", &options)
|
|
.await?
|
|
{
|
|
tags_from_frame
|
|
} else {
|
|
todo!()
|
|
};
|
|
|
|
for tag in tags_from_frame {
|
|
final_tag_set.insert(tag);
|
|
}
|
|
}
|
|
|
|
let response = WrappedResponse::Tags(final_tag_set.into_iter().collect::<Vec<_>>());
|
|
Ok((StatusCode::OK, Json(response)))
|
|
} else {
|
|
if !file_type.starts_with("image/") {
|
|
log::warn!("warning: mimetype {} is not image/", file_type);
|
|
}
|
|
|
|
let json_response =
|
|
send_image_to_dd(file_contents.to_vec(), file_name, &file_type, &options).await?;
|
|
Ok((StatusCode::OK, Json(json_response)))
|
|
}
|
|
} else {
|
|
Ok((
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(WrappedResponse::Error(
|
|
"no file found in request".to_string(),
|
|
)),
|
|
))
|
|
}
|
|
}
|