rayoko/src/codegen.zig

249 lines
8.9 KiB
Zig

const std = @import("std");
const ast = @import("ast.zig");
const llvm = @import("llvm.zig");
fn sliceify(non_slice: ?[*]const u8) []const u8 {
if (non_slice == null) return "";
return non_slice.?[0..std.mem.len(u8, non_slice.?)];
}
pub const CompileError = error{
LLVMError,
EmitError,
TypeError,
};
/// Does not account for custom types e.g structs, better type resolution
/// should be found
fn basicTypeToLLVM(ret_type: []const u8) !llvm.LLVMTypeRef {
if (std.mem.eql(u8, ret_type, "i32")) {
return llvm.LLVMInt32Type();
} else if (std.mem.eql(u8, ret_type, "bool")) {
return llvm.LLVMInt1Type();
} else {
std.debug.warn("Invalid return type: {}\n", ret_type);
return CompileError.TypeError;
}
}
pub const Codegen = struct {
allocator: *std.mem.Allocator,
pub fn init(allocator: *std.mem.Allocator) Codegen {
return Codegen{ .allocator = allocator };
}
fn emitExpr(self: *Codegen, builder: var, expr: *const ast.Expr) anyerror!llvm.LLVMValueRef {
// TODO if expr is Variable, we should do a variable lookup
// in a symbol table, going up in scope, etc.
// TODO VarDecl add things to the symbol table
// TODO Assign modify symbol table
// TODO Calls fetch symbol table, check arity of it at codegen level
return switch (expr.*) {
// TODO handle all literals, construct llvm values for them
.Literal => |literal| blk: {
break :blk switch (literal) {
// TODO other literals
.Integer => |val| blk2: {
var val_cstr = try std.cstr.addNullByte(self.allocator, val);
break :blk2 llvm.LLVMConstIntOfString(llvm.LLVMInt32Type(), val_cstr.ptr, 10);
},
.Float => |val| blk2: {
var val_cstr = try std.cstr.addNullByte(self.allocator, val);
break :blk2 llvm.LLVMConstRealOfString(llvm.LLVMDoubleType(), val_cstr.ptr);
},
.Bool => |val| blk2: {
if (val) {
break :blk2 llvm.LLVMConstInt(llvm.LLVMInt1Type(), 1, 1);
} else {
break :blk2 llvm.LLVMConstInt(llvm.LLVMInt1Type(), 0, 1);
}
},
else => unreachable,
};
},
.Binary => |binary| {
var left = try self.emitExpr(builder, binary.left);
var right = try self.emitExpr(builder, binary.right);
return switch (binary.op) {
.Add => llvm.LLVMBuildAdd(builder, left, right, c"addtmp"),
.Sub => llvm.LLVMBuildSub(builder, left, right, c"subtmp"),
.Mul => llvm.LLVMBuildMul(builder, left, right, c"multmp"),
//.Div => llvm.LLVMBuildDiv(builder, left, right, c"divtmp"),
.And => llvm.LLVMBuildAnd(builder, left, right, c"andtmp"),
.Or => llvm.LLVMBuildOr(builder, left, right, c"ortmp"),
else => {
std.debug.warn("Unexpected binary operator: '{}'\n", binary.op);
return CompileError.EmitError;
},
};
},
// TODO codegen errors
else => {
std.debug.warn("Got unexpected expr {}\n", expr.*);
return CompileError.EmitError;
},
};
}
fn emitStmt(self: *Codegen, builder: var, stmt: *const ast.Stmt) !void {
switch (stmt.*) {
.Expr => |expr| _ = try self.emitExpr(builder, expr),
.Return => |ret| {
var ret_expr = try self.emitExpr(builder, ret.value);
_ = llvm.LLVMBuildRet(builder, ret_expr);
},
else => {
std.debug.warn("Got unexpected statement {}\n", stmt.*);
return CompileError.EmitError;
},
}
}
fn genNode(
self: *Codegen,
mod: llvm.LLVMModuleRef,
node: *const ast.Node,
) !void {
switch (node.*) {
.Root => @panic("Should not have gotten Root"),
.FnDecl => |decl| {
const name = decl.func_name.lexeme;
const name_cstr = try std.cstr.addNullByte(self.allocator, name);
errdefer self.allocator.free(name_cstr);
const fn_ret_type = decl.return_type.lexeme;
var param_types = llvm.LLVMTypeList.init(self.allocator);
errdefer param_types.deinit();
for (decl.params.toSlice()) |param| {
// TODO better type resolution
try param_types.append(try basicTypeToLLVM(param.typ.lexeme));
}
var llvm_ret_type = llvm.LLVMFunctionType(
try basicTypeToLLVM(fn_ret_type),
param_types.toSlice().ptr,
@intCast(c_uint, param_types.len),
0,
);
var func = llvm.LLVMAddFunction(mod, name_cstr.ptr, llvm_ret_type);
var entry = llvm.LLVMAppendBasicBlock(func, c"entry");
var builder = llvm.LLVMCreateBuilder();
llvm.LLVMPositionBuilderAtEnd(builder, entry);
for (decl.body.toSlice()) |stmt| {
// TODO custom function context for us
try self.emitStmt(builder, &stmt);
}
//var tmp = llvm.LLVMBuildAdd(
// builder,
// llvm.LLVMGetParam(func, 0),
// llvm.LLVMGetParam(func, 1),
// c"tmp",
//);
//_ = llvm.LLVMBuildRet(builder, tmp);
std.debug.warn("cgen: fn decl done\n");
},
else => {
std.debug.warn("got unhandled Node {}\n", node.*);
unreachable;
},
}
}
pub fn gen(self: *Codegen, root: *ast.Node) !void {
std.debug.warn("cgen: start gen\n");
_ = llvm.LLVMInitializeNativeTarget();
var mod = llvm.LLVMModuleCreateWithName(c"awoo").?;
defer llvm.LLVMDisposeModule(mod);
for (root.Root.toSlice()) |child| {
std.debug.warn("cgen: gen child {}\n", child);
try self.genNode(mod, &child);
}
var err: ?[*]u8 = null;
defer llvm.LLVMDisposeMessage(err);
_ = llvm.LLVMVerifyModule(
mod,
llvm.LLVMVerifierFailureAction.LLVMAbortProcessAction,
&err,
);
if (llvm.LLVMWriteBitcodeToFile(mod, c"awoo.bc") != 0) {
std.debug.warn("error writing bitcode to file: {}\n", sliceify(err));
return CompileError.LLVMError;
}
llvm.LLVMInitializeAllTargetInfos();
llvm.LLVMInitializeAllTargets();
llvm.LLVMInitializeAllTargetMCs();
llvm.LLVMInitializeAllAsmParsers();
llvm.LLVMInitializeAllAsmPrinters();
var engine: llvm.LLVMExecutionEngineRef = undefined;
if (llvm.LLVMCreateExecutionEngineForModule(&engine, mod, &err) != 0) {
std.debug.warn("failed to create execution engine: {}\n", sliceify(err));
return CompileError.LLVMError;
}
var machine = llvm.LLVMGetExecutionEngineTargetMachine(engine);
defer llvm.LLVMDisposeTargetMachine(machine);
var target = llvm.LLVMGetTargetMachineTarget(machine);
var target_data = llvm.LLVMCreateTargetDataLayout(machine);
var data_layout = llvm.LLVMCopyStringRepOfTargetData(target_data);
llvm.LLVMSetDataLayout(mod, data_layout);
var outpath_cstr = try std.cstr.addNullByte(self.allocator, "outpath.o");
//var asmpath_cstr = try std.cstr.addNullByte(self.allocator, "output.S");
var desc = llvm.LLVMGetTargetDescription(target);
var features = llvm.LLVMGetTargetMachineFeatureString(machine);
var triple = llvm.LLVMGetTargetMachineTriple(machine);
std.debug.warn("target: {}\n", sliceify(desc));
std.debug.warn("triple: {}\n", sliceify(triple));
std.debug.warn("features: {}\n", sliceify(features));
//if (llvm.LLVMTargetMachineEmitToFile(
// machine,
// mod,
// asmpath_cstr.ptr,
// llvm.LLVMCodeGenFileType.LLVMAssemblyFile,
// &err,
//) != 0) {
// std.debug.warn("failed to emit to assembly file: {}\n", sliceify(err));
// return CompileError.LLVMError;
//}
if (llvm.LLVMTargetMachineEmitToFile(
machine,
mod,
outpath_cstr.ptr,
llvm.LLVMCodeGenFileType.LLVMObjectFile,
&err,
) != 0) {
std.debug.warn("failed to emit to file: {}\n", sliceify(err));
return CompileError.LLVMError;
}
}
};