101 lines
3.1 KiB
Python
101 lines
3.1 KiB
Python
from flask import Flask, jsonify
|
|
import uuid
|
|
import json
|
|
from webargs import fields, validate
|
|
from webargs.flaskparser import use_args
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
from sqlalchemy_utils import Timestamp, generic_repr
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.orm import scoped_session, sessionmaker, relationship, backref
|
|
from sqlalchemy.types import Float, String, Boolean
|
|
|
|
app = Flask(__name__)
|
|
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///jobs.db"
|
|
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
|
|
|
db = SQLAlchemy(app)
|
|
|
|
|
|
@generic_repr
|
|
class Job(db.Model, Timestamp):
|
|
id = db.Column(db.String, default=lambda: str(uuid.uuid4()), primary_key=True)
|
|
jump_range = db.Column(db.Float, nullable=False)
|
|
mode = db.Column(db.String, default="bfs")
|
|
systems = db.Column(db.String)
|
|
permute = db.Column(db.String, default=None, nullable=True)
|
|
primary = db.Column(db.Boolean, default=False)
|
|
factor = db.Column(db.Float, default=0.5)
|
|
done = db.Column(db.DateTime, nullable=True, default=None)
|
|
started = db.Column(db.DateTime, nullable=True, default=None)
|
|
progress = db.Column(db.Float, default=0.0)
|
|
|
|
# ============================================================
|
|
|
|
@classmethod
|
|
def new(cls, **kwargs):
|
|
obj = cls(**kwargs)
|
|
db.session.add(obj)
|
|
db.session.commit()
|
|
print(obj)
|
|
return obj
|
|
|
|
@property
|
|
def dict(self):
|
|
ret = {}
|
|
for col in self.__table__.columns:
|
|
ret[col.name] = getattr(self, col.name)
|
|
ret["systems"] = json.loads(ret["systems"])
|
|
return ret
|
|
|
|
@dict.setter
|
|
def set_dict(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
db.create_all()
|
|
db.session.commit()
|
|
|
|
|
|
@app.errorhandler(422)
|
|
@app.errorhandler(400)
|
|
def handle_error(err):
|
|
headers = err.data.get("headers", None)
|
|
messages = err.data.get("messages", ["Invalid request."])
|
|
if headers:
|
|
return jsonify({"errors": messages}), err.code, headers
|
|
else:
|
|
return jsonify({"errors": messages}), err.code
|
|
|
|
|
|
@app.route("/route", methods=["GET", "POST"])
|
|
@use_args(
|
|
{
|
|
"jump_range": fields.Float(required=True),
|
|
"mode": fields.String(
|
|
missing="bfs", validate=validate.OneOf(["bfs", "greedy", "a-star"])
|
|
),
|
|
"systems": fields.DelimitedList(fields.String, required=True),
|
|
"permute": fields.String(
|
|
missing=None,
|
|
validate=validate.OneOf(["all", "keep_first", "keep_last", "keep_both"]),
|
|
),
|
|
"primary": fields.Boolean(missing=False),
|
|
"factor": fields.Float(missing=0.5),
|
|
}
|
|
)
|
|
def route(args):
|
|
args["systems"] = json.dumps(args["systems"])
|
|
for k, v in args.items():
|
|
print(k, v)
|
|
return jsonify({"id": Job.new(**args).id})
|
|
|
|
|
|
@app.route("/status/<uuid:job_id>")
|
|
def status(job_id):
|
|
job = db.session.query(Job).get_or_404(str(job_id))
|
|
return jsonify(job.dict)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=3777, debug=True)
|