ED_LRR/rust/src/lib.rs

383 lines
12 KiB
Rust
Raw Normal View History

2020-03-28 13:53:52 +00:00
// #![deny(warnings)]
mod common;
mod preprocess;
mod route;
#[macro_use]
extern crate derivative;
use crate::common::SystemSerde;
use crate::common::{find_matches, SysEntry};
use crate::route::{Router, SearchState};
use pyo3::exceptions::*;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3::PyObjectProtocol;
use std::path::PathBuf;
/*
pub id: u32,
pub star_type: String,
pub system: String,
pub body: String,
pub mult: f32,
pub distance: u32,
pub x: f32,
pub y: f32,
pub z: f32,
*/
impl SystemSerde {
fn fill_dict(&self, dict: &PyDict) -> PyResult<()> {
dict.clear();
dict.set_item("id", self.id)?;
dict.set_item("star_type", self.star_type.clone())?;
dict.set_item("system", self.system.clone())?;
dict.set_item("body", self.body.clone())?;
dict.set_item("mult", self.mult)?;
dict.set_item("distance", self.distance)?;
dict.set_item("x", self.x)?;
dict.set_item("y", self.y)?;
dict.set_item("z", self.z)?;
return Ok(());
}
}
#[pyclass(dict)]
#[derive(Derivative)]
#[derivative(Debug)]
#[text_signature = "(callback, workers, /)"]
struct PyRouter {
router: Router,
stars_path: String,
}
#[pymethods]
impl PyRouter {
#[new]
#[args(callback = "None")]
fn new(callback: Option<PyObject>, py: Python<'static>) -> PyResult<Self> {
let cb_func = move |state: &SearchState| {
return match callback.as_ref() {
Some(cb) => cb.call(py, (state.clone(),), None),
None => Ok(py.None()),
};
};
let router = match Router::new(Box::new(cb_func)) {
Ok(router) => router,
Err(err_msg) => {
return Err(PyErr::new::<ValueError, _>(err_msg));
}
};
let ret = PyRouter {
router,
stars_path: String::from(""),
};
Ok(ret)
}
#[args(filter_func = "None")]
#[text_signature = "(path, /)"]
fn load(&mut self, path: String, py: Python<'static>) -> PyResult<PyObject> {
self.stars_path = path;
return Ok(py.None());
}
#[args(greedyness = "0.5", num_workers = "0", beam_width = "0")]
#[text_signature = "(hops, range, greedyness, beam_width, num_workers, /)"]
fn route(
&mut self,
hops: &PyList,
range: f32,
greedyness: f32,
beam_width: usize,
num_workers: usize,
py: Python,
) -> PyResult<PyObject> {
let route_res = self.router.load(&PathBuf::from(self.stars_path.clone()));
if let Err(err_msg) = route_res {
return Err(PyErr::new::<ValueError, _>(err_msg));
};
let mut sys_entries: Vec<SysEntry> = Vec::new();
for hop in hops {
if let Ok(id) = hop.extract() {
sys_entries.push(SysEntry::ID(id));
} else {
sys_entries.push(SysEntry::parse(hop.extract()?));
}
}
println!("Resolving systems...");
let ids: Vec<u32> = match resolve(&sys_entries, &self.router.path) {
Ok(ids) => ids,
Err(err_msg) => {
return Err(PyErr::new::<ValueError, _>(err_msg));
}
};
match self
.router
.computer_route(&ids, range, greedyness, beam_width, num_workers)
{
// TODO: return list of dicts (or objects)
Ok(route) => Ok(route.len().to_object(py)),
Err(err_msg) => Err(PyErr::new::<RuntimeError, _>(err_msg)),
}
}
#[args(hops = "*")]
#[text_signature = "(sys_1, sys_2, ..., /)"]
fn resolve_systems(&self, hops: &PyTuple, py: Python) -> PyResult<PyObject> {
let mut sys_entries: Vec<SysEntry> = Vec::new();
for hop in hops {
if let Ok(id) = hop.extract() {
sys_entries.push(SysEntry::ID(id));
} else {
sys_entries.push(SysEntry::parse(hop.extract()?));
}
}
println!("Resolving systems...");
let ids: Vec<u32> = match resolve(&sys_entries, &PathBuf::from(self.stars_path.clone())) {
Ok(ids) => ids,
Err(err_msg) => {
return Err(PyErr::new::<ValueError, _>(err_msg));
}
};
let ret: Vec<(_, u32)> = hops.into_iter().zip(ids.into_iter()).collect();
Ok(PyDict::from_sequence(py, ret.to_object(py))?.to_object(py))
}
#[staticmethod]
fn preprocess_edsm() -> PyResult<()> {
unimplemented!()
}
#[staticmethod]
fn preprocess_galaxy() -> PyResult<()> {
unimplemented!()
}
}
#[pyproto]
impl PyObjectProtocol for PyRouter {
fn __str__(&self) -> PyResult<String> {
Ok(format!("{:?}", &self))
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("{:?}", &self))
}
}
fn resolve(entries: &Vec<SysEntry>, path: &PathBuf) -> Result<Vec<u32>, String> {
let mut names: Vec<String> = Vec::new();
let mut ids: Vec<u32> = Vec::new();
let mut ret: Vec<u32> = Vec::new();
for ent in entries {
match ent {
SysEntry::Name(name) => names.push(name.to_owned()),
SysEntry::ID(id) => ids.push(*id),
}
}
if !path.exists() {
return Err(format!(
"Source file \"{:?}\" does not exist!",
path.display()
));
}
let name_ids = find_matches(path, names, false)?;
for ent in entries {
match ent {
SysEntry::Name(name) => {
let ent_res = name_ids
.get(&name.to_owned())
.ok_or(format!("System {} not found", name))?;
let sys = ent_res
.1
.as_ref()
.ok_or(format!("System {} not found", name))?;
if ent_res.0 < 0.75 {
println!(
"WARNING: {} match to {} with low confidence ({:.2}%)",
name,
sys.system,
ent_res.0 * 100.0
);
}
ret.push(sys.id);
}
SysEntry::ID(id) => ret.push(*id),
}
}
return Ok(ret);
}
#[pymodule]
pub fn _ed_lrr(_py: Python, m: &PyModule) -> PyResult<()> {
better_panic::install();
m.add_class::<PyRouter>()?;
Ok(())
}
/*
/// Preprocess bodies.json and systemsWithCoordinates.json into stars.csv
#[pyfn(m, "preprocess")]
#[text_signature = "(infile_systems, infile_bodies, outfile, callback, /)"]
fn ed_lrr_preprocess(
py: Python<'static>,
infile_systems: String,
infile_bodies: String,
outfile: String,
callback: PyObject,
) -> PyResult<PyObject> {
use preprocess::*;
let state = PyDict::new(py);
let state_dict = PyDict::new(py);
callback.call(py, (state_dict,), None).unwrap();
let callback_wrapped = move |state: &PreprocessState| {
// println!("SEND: {:?}",state);
state_dict.set_item("file", state.file.clone())?;
state_dict.set_item("total", state.total)?;
state_dict.set_item("count", state.count)?;
state_dict.set_item("done", state.done)?;
state_dict.set_item("message", state.message.clone())?;
callback.call(py, (state_dict,), None)
};
preprocess_files(
&PathBuf::from(infile_bodies),
&PathBuf::from(infile_systems),
&PathBuf::from(outfile),
&callback_wrapped,
)
.unwrap();
Ok(state.to_object(py))
}
/// Find system by name
#[pyfn(m, "find_sys")]
#[text_signature = "(sys_names, sys_list_path, /)"]
fn find_sys(py: Python, sys_names: Vec<String>, sys_list: String) -> PyResult<PyObject> {
let path = PathBuf::from(sys_list);
match find_matches(&path, sys_names, false) {
Ok(vals) => {
let ret = PyDict::new(py);
for (key, (diff, sys)) in vals {
let ret_dict = PyDict::new(py);
if let Some(val) = sys {
let pos = PyList::new(py, val.pos.iter());
ret_dict.set_item("star_type", val.star_type.clone())?;
ret_dict.set_item("system", val.system.clone())?;
ret_dict.set_item("body", val.body.clone())?;
ret_dict.set_item("distance", val.distance)?;
ret_dict.set_item("pos", pos)?;
ret_dict.set_item("id", val.id)?;
ret.set_item(key, (diff, ret_dict).to_object(py))?;
}
}
Ok(ret.to_object(py))
}
Err(e) => Err(PyErr::new::<ValueError, _>(e)),
}
}
/// Compute a Route using the suplied parameters
#[pyfn(m, "route")]
#[text_signature = "(hops, range, mode, primary, permute, keep_first, keep_last, greedyness, precomp, path, num_workers, callback, /)"]
#[allow(clippy::too_many_arguments)]
fn py_route(
py: Python<'static>,
hops: Vec<&str>,
range: f32,
mode: String,
primary: bool,
permute: bool,
keep_first: bool,
keep_last: bool,
greedyness: Option<f32>,
precomp: Option<String>,
path: String,
num_workers: Option<usize>,
callback: PyObject,
) -> PyResult<PyObject> {
use route::*;
let num_workers = num_workers.unwrap_or(1);
let mode = match Mode::parse(&mode) {
Ok(val) => val,
Err(e) => {
return Err(PyErr::new::<ValueError, _>(e));
}
};
let state_dict = PyDict::new(py);
{
let cb_res = callback.call(py, (state_dict,), None);
if cb_res.is_err() {
println!("Error: {:?}", cb_res);
}
}
let callback_wrapped = move |state: &SearchState| {
state_dict.set_item("mode", state.mode.clone())?;
state_dict.set_item("system", state.system.clone())?;
state_dict.set_item("body", state.body.clone())?;
state_dict.set_item("depth", state.depth)?;
state_dict.set_item("queue_size", state.queue_size)?;
state_dict.set_item("d_rem", state.d_rem)?;
state_dict.set_item("d_total", state.d_total)?;
state_dict.set_item("prc_done", state.prc_done)?;
state_dict.set_item("n_seen", state.n_seen)?;
state_dict.set_item("prc_seen", state.prc_seen)?;
state_dict.set_item("from", state.from.clone())?;
state_dict.set_item("to", state.to.clone())?;
let cb_res = callback.call(py, (state_dict,), None);
if cb_res.is_err() {
println!("Error: {:?}", cb_res);
}
cb_res
};
let hops: Vec<SysEntry> = (hops.iter().map(|v| SysEntry::from_str(&v)).collect::<Result<Vec<SysEntry>,_>>())?;
println!("Resolving systems...");
let hops: Vec<u32> = match resolve(&hops, &PathBuf::from(&path)) {
Ok(ids) => ids,
Err(err_msg) => {
return Err(PyErr::new::<ValueError, _>(err_msg));
}
};
let opts = RouteOpts {
systems: hops,
range: Some(range),
file_path: PathBuf::from(path),
precomp_file: precomp.map(PathBuf::from),
callback: Box::new(callback_wrapped),
mode,
factor: greedyness,
precompute: false,
permute,
keep_first,
keep_last,
primary,
workers: num_workers,
};
match route(opts) {
Ok(Some(route)) => {
let hops = route.iter().map(|hop| {
let pos = PyList::new(py, hop.pos.iter());
let elem = PyDict::new(py);
elem.set_item("star_type", hop.star_type.clone()).unwrap();
elem.set_item("system", hop.system.clone()).unwrap();
elem.set_item("body", hop.body.clone()).unwrap();
elem.set_item("distance", hop.distance).unwrap();
elem.set_item("pos", pos).unwrap();
elem
});
let lst = PyList::new(py, hops);
Ok(lst.to_object(py))
}
Ok(None) => Ok(py.None()),
Err(e) => Err(PyErr::new::<ValueError, _>(e)),
}
}
*/