const std = @import("std"); const ast = @import("ast.zig"); const comp = @import("comp_ctx.zig"); const CompileError = @import("codegen.zig").CompileError; const Token = @import("tokens.zig").Token; const SymbolUnderlyingType = comp.SymbolUnderlyingType; pub const TypeSolver = struct { allocator: *std.mem.Allocator, // error handling err_ctx: ?[]const u8 = null, err_tok: ?Token = null, hadError: bool = false, pub fn init(allocator: *std.mem.Allocator) TypeSolver { return TypeSolver{ .allocator = allocator }; } fn setErrContext(self: *@This(), comptime fmt: ?[]const u8, args: ...) void { if (fmt == null) { self.err_ctx = null; return; } // TODO allocate buffer on init() and use it var buf = self.allocator.alloc(u8, 256) catch unreachable; self.err_ctx = std.fmt.bufPrint(buf, fmt.?, args) catch unreachable; } fn setErrToken(self: *@This(), tok: ?Token) void { self.err_tok = tok; } fn doError(self: *@This(), comptime fmt: []const u8, args: ...) void { self.hadError = true; std.debug.warn("type error"); if (self.err_tok) |tok| { std.debug.warn(" at line {}", tok.line); } if (self.err_ctx) |ctx| { std.debug.warn(" on {}", ctx); } std.debug.warn("\n\t"); std.debug.warn(fmt, args); std.debug.warn("\n"); } /// Resolve a type in global scope fn resolveGlobalType( self: *@This(), ctx: *comp.CompilationContext, identifier: []const u8, ) ?SymbolUnderlyingType { // assume the identifier references a builtin var typ = ctx.solveType(identifier); switch (typ) { .OpaqueType => |val| { // solve for opaque so it isnt opaque var sym = ctx.symbol_table.get(val); if (sym != null) return switch (sym.?.value) { .Struct => SymbolUnderlyingType{ .Struct = val }, .Enum => SymbolUnderlyingType{ .Enum = val }, else => blk: { self.doError( "expected struct or enum for type '{}', got {}", val, sym, ); break :blk null; }, }; self.doError("Unknown type: '{}'", val); return null; }, else => return typ, } } pub fn resolveExprType( self: *@This(), ctx: *comp.CompilationContext, expr: *const ast.Expr, ) anyerror!SymbolUnderlyingType { switch (expr.*) { .Binary => |binary| { var left_type = self.resolveExprType(ctx, binary.left); var right_type = self.resolveExprType(ctx, binary.right); return switch (binary.op) { // all numeric operations return numeric types .Add, .Sub, .Mul, .Div, .Mod => left_type, // TODO check left and right as numeric .Greater, .GreaterEqual, .Less, .LessEqual => SymbolUnderlyingType{ .Bool = {} }, // all boolean ops return bools .Equal, .And, .Or => SymbolUnderlyingType{ .Bool = {} }, }; }, // for now, unary operators only have .Not .Unary => |unary| { var right_type = self.resolveExprType(ctx, unary.right); return switch (unary.op) { .Negate => right_type, .Not => right_type, }; }, .Literal => |literal| { return switch (literal) { .Bool => SymbolUnderlyingType{ .Bool = {} }, // TODO determine its i64 depending of parseInt results .Integer => SymbolUnderlyingType{ .Integer32 = {} }, else => unreachable, }; }, .Grouping => |group_expr| return try self.resolveExprType(ctx, group_expr), .Struct => |struc| blk: { const name = struc.name.lexeme; var typ = self.resolveGlobalType(ctx, name); if (typ == null) { self.doError("Unknown struct name '{}'\n", name); return CompileError.TypeError; } return typ.?; }, // TODO variable resolution // TODO Get (for structs and enums) else => { std.debug.warn("TODO resolve expr {}\n", ast.ExprType(expr.*)); unreachable; }, } } pub fn nodePass( self: *@This(), ctx: *comp.CompilationContext, node: *ast.Node, ) !void { self.setErrToken(null); self.setErrContext(null); switch (node.*) { .Root => unreachable, .FnDecl => |decl| { self.setErrToken(decl.return_type); self.setErrContext("function {}", decl.func_name.lexeme); var ret_type = self.resolveGlobalType(ctx, decl.return_type.lexeme); std.debug.warn("resolved fn {} type: {}\n", decl.func_name.lexeme, ret_type); var parameters = comp.TypeList.init(self.allocator); for (decl.params.toSlice()) |param| { var param_type = self.resolveGlobalType(ctx, param.typ.lexeme); if (param_type == null) continue; try parameters.append(param_type.?); } // TODO symbols and scope resolution, that's // its own can of worms var symbols = comp.SymbolTable.init(self.allocator); // TODO go through body, resolve statements, expressions // and everything else if (ret_type != null and parameters.len == decl.params.len) { try ctx.insertFn(decl, ret_type.?, parameters, symbols); } }, .Struct => |struc| { self.setErrToken(struc.name); self.setErrContext("struct {}", struc.name.lexeme); var types = comp.TypeList.init(self.allocator); for (struc.fields.toSlice()) |field| { self.setErrToken(field.name); var field_type = self.resolveGlobalType(ctx, field.typ.lexeme); if (field_type == null) continue; try types.append(field_type.?); } // only determine struct as fully resolved // when length of declared types == length of resolved types // we don't return type errors from the main loop so we can // keep going and find more type errors if (types.len == struc.fields.len) try ctx.insertStruct(struc, types); }, // TODO change enums to u32 .Enum => |enu| { self.setErrToken(enu.name); self.setErrContext("enum {}", enu.name.lexeme); try ctx.insertEnum(enu); }, .ConstDecl => |constlist| { for (constlist.toSlice()) |constdecl| { self.setErrToken(constdecl.name); self.setErrContext("const {}", constdecl.name.lexeme); var expr_type = try self.resolveExprType(ctx, constdecl.expr); try ctx.insertConst(constdecl, expr_type); } }, else => { std.debug.warn("TODO type analysis of {}\n", node.*); return CompileError.TypeError; }, } } pub fn pass(self: *@This(), root: *ast.Node) !comp.CompilationContext { var ctx = comp.CompilationContext.init(self.allocator); var slice = root.Root.toSlice(); for (slice) |_, idx| { try self.nodePass(&ctx, &slice[idx]); } return ctx; } };