120 lines
3.7 KiB
Python
120 lines
3.7 KiB
Python
import pytest
|
|
import random
|
|
import csv
|
|
from tempfile import mkstemp
|
|
from pathlib import Path
|
|
|
|
|
|
def get_mult(star_type):
|
|
if star_type.startswith("Neutron"):
|
|
return 4.0
|
|
if star_type.startswith("White Dwarf"):
|
|
return 1.5
|
|
return 1.0
|
|
|
|
|
|
def gen_pos(p_distrib):
|
|
p = []
|
|
for v in p_distrib:
|
|
v = random.triangular(-v, v)
|
|
p.append(v)
|
|
return p
|
|
|
|
|
|
def make_stars(num, p_distrib):
|
|
star_types = [
|
|
"A (Blue-White) Star",
|
|
"A (Blue-White super giant) Star",
|
|
"B (Blue-White) Star",
|
|
"B (Blue-White super giant) Star",
|
|
"Black Hole",
|
|
"CJ Star",
|
|
"CN Star",
|
|
"C Star",
|
|
"F (White) Star",
|
|
"F (White super giant) Star",
|
|
"G (White-Yellow) Star",
|
|
"G (White-Yellow super giant) Star",
|
|
"Herbig Ae/Be Star",
|
|
"K (Yellow-Orange giant) Star",
|
|
"K (Yellow-Orange) Star",
|
|
"L (Brown dwarf) Star",
|
|
"M (Red dwarf) Star",
|
|
"M (Red giant) Star",
|
|
"M (Red super giant) Star",
|
|
"MS-type Star",
|
|
"Neutron Star",
|
|
"O (Blue-White) Star",
|
|
"star_type",
|
|
"S-type Star",
|
|
"Supermassive Black Hole",
|
|
"T (Brown dwarf) Star",
|
|
"T Tauri Star",
|
|
"White Dwarf (DAB) Star",
|
|
"White Dwarf (DA) Star",
|
|
"White Dwarf (DAV) Star",
|
|
"White Dwarf (DAZ) Star",
|
|
"White Dwarf (DB) Star",
|
|
"White Dwarf (DBV) Star",
|
|
"White Dwarf (DBZ) Star",
|
|
"White Dwarf (DC) Star",
|
|
"White Dwarf (DCV) Star",
|
|
"White Dwarf (DQ) Star",
|
|
"White Dwarf (D) Star",
|
|
"Wolf-Rayet C Star",
|
|
"Wolf-Rayet NC Star",
|
|
"Wolf-Rayet N Star",
|
|
"Wolf-Rayet O Star",
|
|
"Wolf-Rayet Star",
|
|
"Y (Brown dwarf) Star",
|
|
]
|
|
id_n = 0
|
|
while id_n < num:
|
|
name = "System {}".format(id_n)
|
|
body = "System {} Star {}".format(id_n, 0)
|
|
distance = 0
|
|
star_type = random.choice(star_types)
|
|
mult = get_mult(star_type)
|
|
x, y, z = gen_pos(p_distrib)
|
|
s_type = random.choice(star_types)
|
|
record = [id_n, s_type, name, body, mult, distance]
|
|
record.extend((x, y, z))
|
|
yield record
|
|
id_n += 1
|
|
for sub_id in range(random.randint(0, 4)):
|
|
star_type = random.choice(star_types)
|
|
mult = get_mult(star_type)
|
|
distance = random.randint(100, 10000)
|
|
body = "System {} Star {}".format(id_n, sub_id + 1)
|
|
s_type = random.choice(star_types)
|
|
record = [id_n, s_type, name, body, mult, distance]
|
|
record.extend((x, y, z))
|
|
yield record
|
|
id_n += 1
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def stars_path():
|
|
num_stars = int(1e7)
|
|
p_distrib = [5000, 5000, 500]
|
|
tmpfile, filename = mkstemp(suffix=".csv", prefix="stars_", text=True)
|
|
filename = Path(filename)
|
|
tmpfile = open(tmpfile, "w", encoding="utf-8")
|
|
fields = ["id", "star_type", "system", "body", "mult"]
|
|
fields += ["distance", "x", "y", "z"]
|
|
csv_writer = csv.DictWriter(tmpfile, fields)
|
|
rows = (dict(zip(fields, row)) for row in make_stars(num_stars, p_distrib))
|
|
csv_writer.writeheader()
|
|
csv_writer.writerows(rows)
|
|
tmpfile.close()
|
|
while True:
|
|
sys_ids = random.choices(range(num_stars), k=10)
|
|
if len(set(sys_ids)) == len(sys_ids):
|
|
break
|
|
rand_sys = list(map("System {}".format, sys_ids))
|
|
yield str(filename.resolve()), rand_sys
|
|
if filename.exists():
|
|
filename.unlink()
|
|
idx = filename.with_suffix(".idx")
|
|
if idx.exists():
|
|
idx.unlink()
|