diff --git a/src/codegen.zig b/src/codegen.zig index 06ea48e..b1542ff 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -23,12 +23,20 @@ fn mkLLVMBool(val: bool) llvm.LLVMValueRef { } } +pub const LLVMTable = std.StringHashMap(llvm.LLVMValueRef); +pub const LLVMValueList = std.ArrayList(llvm.LLVMValueRef); + pub const Codegen = struct { allocator: *std.mem.Allocator, ctx: *comp.CompilationContext, + llvm_table: LLVMTable, pub fn init(allocator: *std.mem.Allocator, ctx: *comp.CompilationContext) Codegen { - return Codegen{ .allocator = allocator, .ctx = ctx }; + return Codegen{ + .allocator = allocator, + .ctx = ctx, + .llvm_table = LLVMTable.init(allocator), + }; } fn typeToLLVM(self: *@This(), typ: comp.SymbolUnderlyingType) !llvm.LLVMTypeRef { @@ -61,7 +69,11 @@ pub const Codegen = struct { }; } - fn emitExpr(self: *Codegen, builder: var, expr: *const ast.Expr) anyerror!llvm.LLVMValueRef { + fn emitExpr( + self: *Codegen, + builder: var, + expr: *const ast.Expr, + ) anyerror!llvm.LLVMValueRef { // TODO if expr is Variable, we should do a variable lookup // in a symbol table, going up in scope, etc. @@ -154,6 +166,36 @@ pub const Codegen = struct { } }, + .Call => |call| { + const name = call.callee.*.Variable.lexeme; + //var sym = try self.ctx.fetchGlobalSymbol(func_name, .Function); + + var llvm_func = self.llvm_table.get(name); + if (llvm_func == null) { + std.debug.warn("Function '{}' not found\n", name); + return CompileError.EmitError; + } + + // TODO args + var args = LLVMValueList.init(self.allocator); + errdefer args.deinit(); + + for (call.arguments.toSlice()) |arg_expr| { + var arg_val = try self.emitExpr(builder, &arg_expr); + try args.append(arg_val); + } + + var args_slice = args.toSlice(); + + return llvm.LLVMBuildCall( + builder, + llvm_func.?.value, + args_slice.ptr, + @intCast(c_uint, args_slice.len), + c"call", + ); + }, + else => { std.debug.warn("Got unexpected expr {}\n", ast.ExprType(expr.*)); return CompileError.EmitError; @@ -289,6 +331,7 @@ pub const Codegen = struct { ); var func = llvm.LLVMAddFunction(mod, name_cstr.ptr, llvm_ret_type); + _ = try self.llvm_table.put(name, func); var buf = try self.allocator.alloc(u8, 512); var entry_lbl = try std.fmt.bufPrint(buf, "fn_{}_entry", name); @@ -362,10 +405,10 @@ pub const Codegen = struct { return CompileError.LLVMError; } - if (llvm.LLVMWriteBitcodeToFile(mod, c"awoo.bc") != 0) { - std.debug.warn("error writing bitcode to file: {}\n", sliceify(err)); - return CompileError.LLVMError; - } + //if (llvm.LLVMWriteBitcodeToFile(mod, c"awoo.bc") != 0) { + // std.debug.warn("error writing bitcode to file: {}\n", sliceify(err)); + // return CompileError.LLVMError; + //} std.debug.warn("cgen: verify llvm module\n"); _ = llvm.LLVMVerifyModule( diff --git a/src/comp_ctx.zig b/src/comp_ctx.zig index 2cacc07..827c460 100644 --- a/src/comp_ctx.zig +++ b/src/comp_ctx.zig @@ -181,4 +181,26 @@ pub const CompilationContext = struct { pub fn insertConst(self: *@This(), constdecl: ast.SingleConst, typ: SymbolUnderlyingType) !void { _ = try self.symbol_table.put(constdecl.name.lexeme, SymbolData{ .Const = typ }); } + + pub fn fetchGlobalSymbol( + self: *@This(), + identifier: []const u8, + typ: SymbolType, + ) !SymbolData { + var sym_kv = self.symbol_table.get(identifier); + if (sym_kv == null) { + std.debug.warn("Unknown {} '{}'\n", typ, identifier); + return CompilationError.TypeError; + } + + var value = sym_kv.?.value; + var sym_typ = SymbolType(value); + + if (sym_typ != typ) { + std.debug.warn("Expected {}, got {}\n", sym_typ, typ); + return CompilationError.TypeError; + } + + return sym_kv.?.value; + } }; diff --git a/src/types.zig b/src/types.zig index f170223..8afa4c3 100644 --- a/src/types.zig +++ b/src/types.zig @@ -150,14 +150,8 @@ pub const TypeSolver = struct { std.debug.assert(ast.ExprType(call.callee.*) == .Variable); const func_name = call.callee.*.Variable.lexeme; - var sym_kv = ctx.symbol_table.get(func_name); - if (sym_kv == null) { - self.doError("Unknown function '{}'\n", func_name); - return CompileError.TypeError; - } - - std.debug.assert(comp.SymbolType(sym_kv.?.value) == .Function); - var func_sym = sym_kv.?.value.Function; + var symbol = try ctx.fetchGlobalSymbol(func_name, .Function); + var func_sym = symbol.Function; // TODO check parameter type mismatches between // call.arguments and func_sym.parameters