diff --git a/src/comp_ctx.zig b/src/comp_ctx.zig index 3549762..e04b18a 100644 --- a/src/comp_ctx.zig +++ b/src/comp_ctx.zig @@ -128,7 +128,7 @@ pub const CompilationContext = struct { allocator: *std.mem.Allocator, symbol_table: SymbolTable, - current_function: ?*FunctionSymbol = null, + cur_function: ?*FunctionSymbol = null, current_scope: ?*Scope = null, pub fn init(allocator: *std.mem.Allocator) CompilationContext { @@ -172,6 +172,10 @@ pub const CompilationContext = struct { self.current_scope = self.current_scope.?.parent; } + pub fn setCurrentFunction(self: *@This(), func_ctx: ?FunctionAnalysisContext) void { + self.cur_function = func_ctx; + } + /// Solve a given type as a string into a SymbolUnderlyingTypeEnum /// This does not help if you want a full SymbolUnderlyingType, use /// solveType() for that. @@ -227,7 +231,9 @@ pub const CompilationContext = struct { _ = try type_map.put(param.name.lexeme, param_types.at(idx)); } - _ = try self.symbol_table.put(decl.func_name.lexeme, SymbolData{ + const lex = decl.func_name.lexeme; + + _ = try self.symbol_table.put(lex, SymbolData{ .Function = FunctionSymbol{ .decl = decl, .return_type = ret_type, @@ -235,6 +241,9 @@ pub const CompilationContext = struct { .scope = scope, }, }); + + var kv = self.symbol_table.get(lex); + self.cur_function = &kv.?.value.Function; } pub fn insertEnum(self: *@This(), enu: ast.Enum) !void { diff --git a/src/types.zig b/src/types.zig index ab61922..d1d7a3b 100644 --- a/src/types.zig +++ b/src/types.zig @@ -238,7 +238,7 @@ pub const TypeSolver = struct { // pull a hack with err contexts, lol) .Return => |ret| { var ret_stmt_type = try self.resolveExprType(ctx, ret.value); - // TODO check if ret_stmt_type == ctx.cur_function.return_type + try self.expectSymUnType(ret_stmt_type, ctx.cur_function.?.return_type); }, // If create two scopes for each branch of the if @@ -292,11 +292,15 @@ pub const TypeSolver = struct { self.setErrToken(null); self.setErrContext(null); + // always reset the contexts' current function + ctx.cur_function = null; + switch (node.*) { .Root => unreachable, .FnDecl => |decl| { self.setErrToken(decl.return_type); - self.setErrContext("function {}", decl.func_name.lexeme); + const name = decl.func_name.lexeme; + self.setErrContext("function {}", name); var ret_type = self.resolveGlobalType(ctx, decl.return_type.lexeme); std.debug.warn("start analysis of fn {} ret_type: {}\n", decl.func_name.lexeme, ret_type); @@ -313,6 +317,21 @@ pub const TypeSolver = struct { var scope = try comp.Scope.create(self.allocator, null, "function"); errdefer scope.deinit(); + // we intentionally insert the function so that: + // - we can do return statement validation + // - we have parameter types fully analyzed + if (ret_type != null and parameters.len == decl.params.len) { + try ctx.insertFn(decl, ret_type.?, parameters, scope); + } else { + if (ret_type != null) + self.doError("Return type was not fully resolved"); + + if (parameters.len != decl.params.len) + self.doError("Fully analyzed {} parameters, wanted {}", parameters.len, decl.params.len); + + return CompileError.TypeError; + } + // we must always start from a null current scope, // functions inside functions are not allowed std.debug.assert(ctx.current_scope == null); @@ -325,10 +344,6 @@ pub const TypeSolver = struct { // it should be null when we dump from a function. always ctx.dumpScope(); std.debug.assert(ctx.current_scope == null); - - if (ret_type != null and parameters.len == decl.params.len) { - try ctx.insertFn(decl, ret_type.?, parameters, scope); - } }, .Struct => |struc| {