ED_LRR/rust/src/lib.rs

222 lines
8.0 KiB
Rust

#![deny(warnings)]
mod common;
mod preprocess;
mod route;
use std::path::PathBuf;
use common::find_matches;
use pyo3::exceptions::*;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};
enum SysEntry {
ID(u32),
Name(String),
}
impl From<&str> for SysEntry {
fn from(s: &str) -> SysEntry {
if let Ok(n) = s.parse() {
SysEntry::ID(n)
} else {
SysEntry::Name(s.to_owned())
}
}
}
fn resolve(entries: &Vec<SysEntry>, path: &PathBuf) -> Result<Vec<u32>,String> {
let mut names: Vec<String> = Vec::new();
let mut ids: Vec<u32> = Vec::new();
let mut ret: Vec<u32> = Vec::new();
for ent in entries {
match ent {
SysEntry::Name(name) => names.push(name.to_owned()),
SysEntry::ID(id) => ids.push(*id),
}
};
let name_ids = find_matches(path, names, false)?;
for ent in entries {
match ent {
SysEntry::Name(name) => {
let ent_res=name_ids.get(&name.to_owned()).ok_or(format!("System {} not found",name))?;
let sys=ent_res.1.as_ref().ok_or(format!("System {} not found",name))?;
if ent_res.0<0.75 {
println!("WARNING: {} match to {} with low confidence ({}%)",name,sys.system,ent_res.0*100.0);
}
ret.push(sys.id);
}
SysEntry::ID(id) => ret.push(*id),
}
}
return Ok(ret);
}
#[pymodule]
pub fn _ed_lrr(_py: Python, m: &PyModule) -> PyResult<()> {
better_panic::install();
/// preprocess(infile_systems, infile_bodies, outfile, callback)
/// --
///
/// Preprocess bodies.json and systemsWithCoordinates.json into stars.csv
#[pyfn(m, "preprocess")]
fn ed_lrr_preprocess(
py: Python<'static>,
infile_systems: String,
infile_bodies: String,
outfile: String,
callback: PyObject,
) -> PyResult<PyObject> {
use preprocess::*;
let state = PyDict::new(py);
let state_dict = PyDict::new(py);
callback.call(py, (state_dict,), None).unwrap();
let callback_wrapped = move |state: &PreprocessState| {
// println!("SEND: {:?}",state);
state_dict.set_item("file", state.file.clone())?;
state_dict.set_item("total", state.total)?;
state_dict.set_item("count", state.count)?;
state_dict.set_item("done", state.done)?;
state_dict.set_item("message", state.message.clone())?;
callback.call(py, (state_dict,), None)
};
preprocess_files(
&PathBuf::from(infile_bodies),
&PathBuf::from(infile_systems),
&PathBuf::from(outfile),
&callback_wrapped,
)
.unwrap();
Ok(state.to_object(py))
}
///find_sys(sys_names, sys_list_path)
/// --
///
/// Find system by name
#[pyfn(m, "find_sys")]
fn find_sys(py: Python, sys_names: Vec<String>, sys_list: String) -> PyResult<PyObject> {
let path = PathBuf::from(sys_list);
match find_matches(&path, sys_names,false) {
Ok(vals) => {
let ret = PyDict::new(py);
for (key, (diff, sys)) in vals {
let ret_dict = PyDict::new(py);
if let Some(val) = sys {
let pos = PyList::new(py, val.pos.iter());
ret_dict.set_item("star_type", val.star_type.clone())?;
ret_dict.set_item("system", val.system.clone())?;
ret_dict.set_item("body", val.body.clone())?;
ret_dict.set_item("distance", val.distance)?;
ret_dict.set_item("pos", pos)?;
ret_dict.set_item("id", val.id)?;
ret.set_item(key, (diff, ret_dict).to_object(py))?;
}
}
Ok(ret.to_object(py))
}
Err(e) => Err(PyErr::new::<ValueError, _>(e)),
}
}
/// route(hops, range, prune, mode, primary, permute, keep_first, keep_last, greedyness, precomp, path, num_workers, callback)
/// --
///
/// Compute a Route using the suplied parameters
#[pyfn(m, "route")]
#[allow(clippy::too_many_arguments)]
fn py_route(
py: Python<'static>,
hops: Vec<&str>,
range: f32,
prune: Option<(usize, f32)>,
mode: String,
primary: bool,
permute: bool,
keep_first: bool,
keep_last: bool,
greedyness: Option<f32>,
precomp: Option<String>,
path: String,
num_workers: Option<usize>,
callback: PyObject,
) -> PyResult<PyObject> {
use route::*;
let num_workers=num_workers.unwrap_or(1);
let mode = match Mode::parse(&mode) {
Ok(val) => val,
Err(e) => {
return Err(PyErr::new::<ValueError, _>(e));
}
};
let state_dict = PyDict::new(py);
{
let cb_res = callback.call(py, (state_dict,), None);
if cb_res.is_err() {
println!("Error: {:?}",cb_res);
}
}
let callback_wrapped = move |state: &SearchState| {
state_dict.set_item("mode", state.mode.clone())?;
state_dict.set_item("system", state.system.clone())?;
state_dict.set_item("body", state.body.clone())?;
state_dict.set_item("depth", state.depth)?;
state_dict.set_item("queue_size", state.queue_size)?;
state_dict.set_item("d_rem", state.d_rem)?;
state_dict.set_item("d_total", state.d_total)?;
state_dict.set_item("prc_done", state.prc_done)?;
state_dict.set_item("n_seen", state.n_seen)?;
state_dict.set_item("prc_seen", state.prc_seen)?;
state_dict.set_item("from", state.from.clone())?;
state_dict.set_item("to", state.to.clone())?;
let cb_res=callback.call(py, (state_dict,), None);
if cb_res.is_err() {
println!("Error: {:?}",cb_res);
}
cb_res
};
let hops: Vec<SysEntry> = hops.iter().cloned().map(SysEntry::from).collect();
println!("Resolving systems...");
let hops: Vec<u32> = match resolve(&hops,&PathBuf::from(&path)) {
Ok(ids) => ids,
Err(err_msg) => {
return Err(PyErr::new::<ValueError, _>(err_msg));
}
};
let opts = RouteOpts {
systems: hops,
range: Some(range),
file_path: PathBuf::from(path),
precomp_file: precomp.map(PathBuf::from),
callback: Box::new(callback_wrapped),
mode,
factor: greedyness,
precompute: false,
permute,
keep_first,
keep_last,
primary,
prune,
workers: num_workers
};
let none = ().to_object(py);
match route(opts) {
Ok(Some(route)) => {
let hops = route.iter().map(|hop| {
let pos = PyList::new(py, hop.pos.iter());
let elem = PyDict::new(py);
elem.set_item("star_type", hop.star_type.clone()).unwrap();
elem.set_item("system", hop.system.clone()).unwrap();
elem.set_item("body", hop.body.clone()).unwrap();
elem.set_item("distance", hop.distance).unwrap();
elem.set_item("pos", pos).unwrap();
elem
});
let lst = PyList::new(py, hops);
Ok(lst.to_object(py))
}
Ok(None) => Ok(none),
Err(e) => Err(PyErr::new::<ValueError, _>(e)),
}
}
Ok(())
}