From c1d6939c4348d5e3e30cdd62b848eae0982fcebb Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 24 Sep 2019 17:47:17 -0300 Subject: [PATCH] add the rest of (currently nonfunctional) emitting of if stmts - add compilation context and basics of type solver --- src/codegen.zig | 57 ++++++++++++++++++++++++++++++- src/comp_ctx.zig | 87 ++++++++++++++++++++++++++++++++++++++++++++++++ src/main.zig | 18 ++++++---- src/types.zig | 48 ++++++++++++++++++++++++++ 4 files changed, 202 insertions(+), 8 deletions(-) create mode 100644 src/comp_ctx.zig create mode 100644 src/types.zig diff --git a/src/codegen.zig b/src/codegen.zig index 485ef4e..f28943a 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -98,7 +98,7 @@ pub const Codegen = struct { }; } - fn emitStmt(self: *Codegen, builder: var, stmt: *const ast.Stmt) !void { + fn emitStmt(self: *Codegen, builder: var, stmt: *const ast.Stmt) anyerror!void { switch (stmt.*) { .Expr => |expr| _ = try self.emitExpr(builder, expr), @@ -111,6 +111,50 @@ pub const Codegen = struct { var cond = try self.emitExpr(builder, ifstmt.condition); var zero = mkLLVMBool(false); var icmp = llvm.LLVMBuildICmp(builder, llvm.LLVMIntPredicate.LLVMIntNE, cond, zero, c"ifcond"); + + var insert = llvm.LLVMGetInsertBlock(builder); + var function = llvm.LLVMGetBasicBlockParent(insert); + var global_ctx = llvm.LLVMGetGlobalContext(); + + var then_bb = llvm.LLVMAppendBasicBlock(function, c"then"); + var else_bb = llvm.LLVMCreateBasicBlockInContext(global_ctx, c"else"); + var merge_bb = llvm.LLVMCreateBasicBlockInContext(global_ctx, c"ifcont"); + + var condbr = llvm.LLVMBuildCondBr(builder, icmp, then_bb, else_bb); + + llvm.LLVMPositionBuilderAtEnd(builder, then_bb); + + // roughly translating to kaleidoscope's + // 'Value *ThenV = Then->codegen();' + for (ifstmt.then_branch.toSlice()) |then_stmt| { + try self.emitStmt(builder, &then_stmt); + } + + _ = llvm.LLVMBuildBr(builder, merge_bb); + then_bb = llvm.LLVMGetInsertBlock(builder); + + llvm.LLVMPositionBuilderAtEnd(builder, else_bb); + + // roughly translating to kaleidoscope's + // 'Else *ElseV = Else->codegen();' + if (ifstmt.else_branch) |else_block| { + for (else_block.toSlice()) |else_stmt| { + try self.emitStmt(builder, &else_stmt); + } + } + + _ = llvm.LLVMBuildBr(builder, merge_bb); + else_bb = llvm.LLVMGetInsertBlock(builder); + + llvm.LLVMPositionBuilderAtEnd(builder, merge_bb); + + var phi = llvm.LLVMBuildPhi(builder, llvm.LLVMVoidType(), c"iftmp"); + + var then_bb_val = llvm.LLVMBasicBlockAsValue(then_bb); + var else_bb_val = llvm.LLVMBasicBlockAsValue(else_bb); + + llvm.LLVMAddIncoming(phi, &then_bb_val, &then_bb, 1); + llvm.LLVMAddIncoming(phi, &else_bb_val, &else_bb, 1); }, else => { @@ -120,6 +164,16 @@ pub const Codegen = struct { } } + fn emitBlock(self: *Codegen, builder: var, block: ast.Block) !llvm.LLVMValueRef { + var entry = llvm.LLVMAppendBasicBlock(func, entry_lbl_cstr.ptr); + + var builder = llvm.LLVMCreateBuilder(); + llvm.LLVMPositionBuilderAtEnd(builder, entry); + for (block.toSlice()) |stmt| { + try self.emitStmt(builder, &stmt); + } + } + fn genNode( self: *Codegen, mod: llvm.LLVMModuleRef, @@ -173,6 +227,7 @@ pub const Codegen = struct { std.debug.warn("cgen: generated function '{}'\n", name); }, + else => { std.debug.warn("got unhandled Node {}\n", node.*); unreachable; diff --git a/src/comp_ctx.zig b/src/comp_ctx.zig new file mode 100644 index 0000000..79e5116 --- /dev/null +++ b/src/comp_ctx.zig @@ -0,0 +1,87 @@ +const std = @import("std"); + +pub const CompilationError = error{TypeError}; + +pub const SymbolTable = std.hash_map.StringHashMap(SymbolData); + +pub const SymbolUnderlyingTypeEnum = enum { + Integer32, + Integer64, + Bool, + CustomType, +}; + +pub const SymbolUnderlyingType = union(SymbolUnderlyingTypeEnum) { + Integer32: void, + Integer64: void, + Bool: void, + CustomType: []const u8, +}; + +// functions, for our purposes, other than symbols, have: +// - a return type +// - TODO parameters +pub const FunctionSymbol = struct { + return_type: SymbolUnderlyingType, + symbols: SymbolTable, +}; + +// structs are hashmaps pointing to SymbolUnderlyingType +pub const UnderlyingTypeMap = std.hash_map.StringHashMap(SymbolUnderlyingType); + +// enums have lists of identifiers +pub const IdentifierList = std.ArrayList([]const u8); + +// TODO const +pub const SymbolType = enum { + Function, + Struct, + Enum, +}; + +pub const SymbolData = union(SymbolType) { + Function: FunctionSymbol, + Struct: UnderlyingTypeMap, + Enum: IdentifierList, +}; + +const builtin_type_identifiers = [_][]const u8{ "i32", "i64", "bool" }; + +const builtin_types = [_]SymbolUnderlyingTypeEnum{ .Integer32, .Integer64, .Bool }; + +pub const CompilationContext = struct { + allocator: *std.mem.Allocator, + symbol_table: SymbolTable, + + pub fn init(allocator: *std.mem.Allocator) CompilationContext { + return CompilationContext{ + .allocator = allocator, + .symbol_table = SymbolTable.init(allocator), + }; + } + + pub fn solveTypeEnum( + self: *@This(), + typ_ident: []const u8, + ) SymbolUnderlyingTypeEnum { + inline for (builtin_type_identifiers) |typ, idx| { + if (std.mem.eql(u8, typ, typ_ident)) return builtin_types[idx]; + } + + return .CustomType; + } + + pub fn solveType( + self: *@This(), + typ_ident: []const u8, + ) SymbolUnderlyingType { + const typ_enum_val = self.solveTypeEnum(typ_ident); + + return switch (typ_enum_val) { + .Integer32 => SymbolUnderlyingType{ .Integer32 = {} }, + .Integer64 => SymbolUnderlyingType{ .Integer64 = {} }, + .Bool => SymbolUnderlyingType{ .Bool = {} }, + .CustomType => SymbolUnderlyingType{ .CustomType = typ_ident }, + }; + } +}; diff --git a/src/main.zig b/src/main.zig index 1000164..a8dd326 100644 --- a/src/main.zig +++ b/src/main.zig @@ -4,6 +4,7 @@ const scanners = @import("scanners.zig"); const parsers = @import("parsers.zig"); const printer = @import("ast_printer.zig"); const codegen = @import("codegen.zig"); +const types = @import("types.zig"); pub const Result = enum { Ok, @@ -38,19 +39,22 @@ pub fn run(allocator: *std.mem.Allocator, slice: []const u8) !Result { scan.reset(); var parser = parsers.Parser.init(allocator, &scan); - var root = try parser.parse(); - if (root == null) { + var root_opt = try parser.parse(); + + if (root_opt == null) { return Result.ParseError; } - std.debug.warn("parse tree\n"); - printer.printNode(root.?, 0); + var root = root_opt.?; - // TODO type pass - // TODO variable pass + std.debug.warn("parse tree\n"); + printer.printNode(root, 0); + + var solver = types.TypeSolver.init(allocator); + var ctx = solver.pass(root); var cgen = codegen.Codegen.init(allocator); - try cgen.gen(root.?); + try cgen.gen(root); return Result.Ok; } diff --git a/src/types.zig b/src/types.zig new file mode 100644 index 0000000..b4084f1 --- /dev/null +++ b/src/types.zig @@ -0,0 +1,48 @@ +const std = @import("std"); +const ast = @import("ast.zig"); + +const comp = @import("comp_ctx.zig"); + +pub const TypeSolver = struct { + allocator: *std.mem.Allocator, + + pub fn init(allocator: *std.mem.Allocator) TypeSolver { + return TypeSolver{ .allocator = allocator }; + } + + pub fn nodePass( + self: *@This(), + ctx: *comp.CompilationContext, + node: *ast.Node, + ) void { + switch (node.*) { + .Root => unreachable, + .FnDecl => |decl| { + var ret_type = ctx.solveType(decl.return_type.lexeme); + + // TODO maybe solve when custom? + + std.debug.warn("fn {} type: {}\n", decl.func_name.lexeme, ret_type); + + // ctx.insertFn(decl.name.lexeme, ret_type); + }, + + // TODO infer type of expr in const + //.ConstDecl => {}, + //.Struct => {}, + //.Enum => {}, + else => unreachable, + } + } + + pub fn pass(self: *@This(), root: *ast.Node) comp.CompilationContext { + var ctx = comp.CompilationContext.init(self.allocator); + + var slice = root.Root.toSlice(); + for (slice) |_, idx| { + self.nodePass(&ctx, &slice[idx]); + } + + return ctx; + } +};