558 lines
20 KiB
Zig
558 lines
20 KiB
Zig
const std = @import("std");
|
|
const ast = @import("ast.zig");
|
|
|
|
const comp = @import("comp_ctx.zig");
|
|
|
|
const CompileError = @import("codegen.zig").CompileError;
|
|
const Token = @import("tokens.zig").Token;
|
|
|
|
const SymbolUnderlyingType = comp.SymbolUnderlyingType;
|
|
|
|
pub const TypeSolver = struct {
|
|
allocator: *std.mem.Allocator,
|
|
|
|
// error handling
|
|
err_ctx: ?[]const u8 = null,
|
|
err_tok: ?Token = null,
|
|
hadError: bool = false,
|
|
|
|
err_ctx_buffer: []u8,
|
|
|
|
pub fn init(allocator: *std.mem.Allocator) !TypeSolver {
|
|
return TypeSolver{
|
|
.allocator = allocator,
|
|
.err_ctx_buffer = try allocator.alloc(u8, 512),
|
|
};
|
|
}
|
|
|
|
fn setErrContext(self: *@This(), comptime fmt: ?[]const u8, args: ...) void {
|
|
if (fmt == null) {
|
|
self.err_ctx = null;
|
|
return;
|
|
}
|
|
|
|
self.err_ctx = std.fmt.bufPrint(
|
|
self.err_ctx_buffer,
|
|
fmt.?,
|
|
args,
|
|
) catch unreachable;
|
|
}
|
|
|
|
fn setErrToken(self: *@This(), tok: ?Token) void {
|
|
self.err_tok = tok;
|
|
}
|
|
|
|
fn doError(self: *@This(), comptime fmt: []const u8, args: ...) void {
|
|
self.hadError = true;
|
|
|
|
std.debug.warn("analysis error");
|
|
if (self.err_tok) |tok| {
|
|
std.debug.warn(" at line {}", tok.line);
|
|
}
|
|
|
|
if (self.err_ctx) |ctx| {
|
|
std.debug.warn(" on {}", ctx);
|
|
}
|
|
|
|
std.debug.warn("\n\t");
|
|
std.debug.warn(fmt, args);
|
|
std.debug.warn("\n");
|
|
}
|
|
|
|
/// Resolve a type in global scope
|
|
/// Properly resolves composite (currently opaque) types to structs/enums.
|
|
fn resolveGlobalType(
|
|
self: *@This(),
|
|
ctx: *comp.CompilationContext,
|
|
identifier: []const u8,
|
|
) ?SymbolUnderlyingType {
|
|
// first, we assume the identifier is for a simple type
|
|
// if we fail (and this always returns OpaqueType as a fallback),
|
|
// we take it and find something in global scope
|
|
var typ = ctx.solveType(identifier);
|
|
|
|
switch (typ) {
|
|
.OpaqueType => |val| {
|
|
var sym = ctx.symbol_table.get(val);
|
|
|
|
if (sym == null) {
|
|
self.doError("Unknown type: '{}'", val);
|
|
return null;
|
|
}
|
|
|
|
return switch (sym.?.value.*) {
|
|
.Struct => SymbolUnderlyingType{ .Struct = val },
|
|
.Enum => SymbolUnderlyingType{ .Enum = val },
|
|
|
|
else => blk: {
|
|
self.doError(
|
|
"expected struct or enum for '{}', got {}",
|
|
val,
|
|
@tagName(comp.SymbolType(sym.?.value.*)),
|
|
);
|
|
break :blk null;
|
|
},
|
|
};
|
|
},
|
|
|
|
else => return typ,
|
|
}
|
|
}
|
|
|
|
/// Check if the given symbol type matches a given category.
|
|
/// Does not validate equality of Structs and Enums.
|
|
pub fn expectSymUnTypeEnum(
|
|
self: *@This(),
|
|
symbol_type: comp.SymbolUnderlyingType,
|
|
wanted_type_enum: comp.SymbolUnderlyingTypeEnum,
|
|
) !void {
|
|
var actual_enum = comp.SymbolUnderlyingTypeEnum(symbol_type);
|
|
if (actual_enum != wanted_type_enum) {
|
|
std.debug.warn("Expected {}, got {}\n", wanted_type_enum, actual_enum);
|
|
return CompileError.TypeError;
|
|
}
|
|
}
|
|
|
|
/// Check if the given symbol is of a numeric type.
|
|
pub fn expectSymUnTypeNumeric(
|
|
self: *@This(),
|
|
symbol_type: comp.SymbolUnderlyingType,
|
|
) !void {
|
|
switch (symbol_type) {
|
|
.Integer32, .Integer64, .Double => {},
|
|
else => {
|
|
var actual_enum = comp.SymbolUnderlyingTypeEnum(symbol_type);
|
|
std.debug.warn("Expected numeric, got {}\n", actual_enum);
|
|
return CompileError.TypeError;
|
|
},
|
|
}
|
|
}
|
|
|
|
/// Compare if the given type names are equal.
|
|
fn compositeIdentifierEqual(
|
|
self: *@This(),
|
|
typ_enum: comp.SymbolUnderlyingTypeEnum,
|
|
sym_ident: []const u8,
|
|
expected_ident: []const u8,
|
|
) !void {
|
|
if (!std.mem.eql(u8, sym_ident, expected_ident)) {
|
|
self.doError(
|
|
"Expected {} {}, got {}",
|
|
@tagName(typ_enum),
|
|
expected_ident,
|
|
sym_ident,
|
|
);
|
|
|
|
return CompileError.TypeError;
|
|
}
|
|
}
|
|
|
|
/// Check if the given type matches the given expected type.
|
|
/// This does proper validation of the types if they're structs or enums.
|
|
pub fn expectSymUnTypeEqual(
|
|
self: *@This(),
|
|
symbol_type: comp.SymbolUnderlyingType,
|
|
expected_type: comp.SymbolUnderlyingType,
|
|
) !void {
|
|
const symbol_enum = comp.SymbolUnderlyingTypeEnum(symbol_type);
|
|
const expected_enum = comp.SymbolUnderlyingTypeEnum(expected_type);
|
|
|
|
if (symbol_enum != expected_enum) {
|
|
std.debug.warn("Expected {}, got {}\n", expected_enum, symbol_enum);
|
|
return CompileError.TypeError;
|
|
}
|
|
|
|
// for most cases, this is already enough, however, for
|
|
// composite/abstraction types (structs & enums) we must check
|
|
// if they're actually equal types inside
|
|
|
|
switch (expected_type) {
|
|
.Struct => |expected_identifier| try self.compositeIdentifierEqual(
|
|
.Struct,
|
|
symbol_type.Struct,
|
|
expected_identifier,
|
|
),
|
|
|
|
.Enum => |expected_identifier| try self.compositeIdentifierEqual(
|
|
.Enum,
|
|
symbol_type.Enum,
|
|
expected_identifier,
|
|
),
|
|
|
|
// for everything else, an enum equality test is enough
|
|
else => {},
|
|
}
|
|
}
|
|
|
|
// TODO make return type optional and so, skip exprs that
|
|
// fail to be fully resolved, instead of returning CompileError
|
|
|
|
// TODO make the expr ptr a const since we want to implicit cast things
|
|
pub fn resolveExprType(
|
|
self: *@This(),
|
|
ctx: *comp.CompilationContext,
|
|
expr: *const ast.Expr,
|
|
) anyerror!SymbolUnderlyingType {
|
|
switch (expr.*) {
|
|
.Binary => |binary| {
|
|
var left_type = try self.resolveExprType(ctx, binary.left);
|
|
var right_type = try self.resolveExprType(ctx, binary.right);
|
|
|
|
return switch (binary.op) {
|
|
// all numeric operations return numeric types
|
|
.Add, .Sub, .Mul, .Div, .Mod => left_type,
|
|
|
|
.Greater, .GreaterEqual, .Less, .LessEqual => blk: {
|
|
try self.expectSymUnTypeNumeric(left_type);
|
|
try self.expectSymUnTypeNumeric(right_type);
|
|
|
|
break :blk SymbolUnderlyingType{ .Bool = {} };
|
|
},
|
|
|
|
// all boolean ops return bools
|
|
.Equal, .And, .Or => SymbolUnderlyingType{ .Bool = {} },
|
|
};
|
|
},
|
|
|
|
.Unary => |unary| {
|
|
var right_type = try self.resolveExprType(ctx, unary.right);
|
|
|
|
return switch (unary.op) {
|
|
.Negate => right_type,
|
|
.Not => right_type,
|
|
};
|
|
},
|
|
|
|
.Literal => |literal| {
|
|
return switch (literal) {
|
|
.Bool => SymbolUnderlyingType{ .Bool = {} },
|
|
|
|
// TODO recast Integer32 as Integer64 if the type we're
|
|
// checking into is Integer64, but not the other way.
|
|
.Integer32 => SymbolUnderlyingType{ .Integer32 = {} },
|
|
.Integer64 => SymbolUnderlyingType{ .Integer64 = {} },
|
|
.Float => SymbolUnderlyingType{ .Double = {} },
|
|
|
|
else => unreachable,
|
|
};
|
|
},
|
|
|
|
.Grouping => |group_expr| return try self.resolveExprType(ctx, group_expr),
|
|
|
|
.Struct => |struc| blk: {
|
|
const name = struc.name.lexeme;
|
|
var typ = self.resolveGlobalType(ctx, name);
|
|
if (typ == null) {
|
|
self.doError("Unknown struct name '{}'\n", name);
|
|
return CompileError.TypeError;
|
|
}
|
|
|
|
return typ.?;
|
|
},
|
|
|
|
.Call => |call| {
|
|
self.setErrToken(call.paren);
|
|
std.debug.assert(ast.ExprType(call.callee.*) == .Variable);
|
|
const func_name = call.callee.*.Variable.lexeme;
|
|
|
|
var symbol = try ctx.fetchGlobalSymbol(func_name, .Function);
|
|
var func_sym = symbol.Function;
|
|
|
|
for (call.arguments.toSlice()) |arg_expr, idx| {
|
|
var param_type = func_sym.parameter_list.at(idx);
|
|
var arg_type = try self.resolveExprType(ctx, &arg_expr);
|
|
|
|
self.expectSymUnTypeEqual(arg_type, param_type) catch {
|
|
self.doError(
|
|
"Expected parameter {} to be {}, got {}",
|
|
idx,
|
|
@tagName(comp.SymbolUnderlyingTypeEnum(param_type)),
|
|
@tagName(comp.SymbolUnderlyingTypeEnum(arg_type)),
|
|
);
|
|
|
|
return CompileError.TypeError;
|
|
};
|
|
}
|
|
|
|
return func_sym.return_type;
|
|
},
|
|
|
|
.Variable => |vari| {
|
|
self.setErrToken(vari);
|
|
var metadata = try ctx.resolveVarType(vari.lexeme, true);
|
|
try ctx.insertMetadata(vari.lexeme, metadata.?);
|
|
return metadata.?.typ;
|
|
},
|
|
|
|
.Get => |get| {
|
|
var target = get.target.*;
|
|
if (ast.ExprType(target) != .Variable) {
|
|
std.debug.warn("Expected Variable as get target, got {}\n", ast.ExprType(target));
|
|
return CompileError.TypeError;
|
|
}
|
|
|
|
const lexeme = target.Variable.lexeme;
|
|
var global_typ_opt = self.resolveGlobalType(ctx, lexeme);
|
|
|
|
// TODO:
|
|
// - name resolution for when global_typ is null + analysis of
|
|
// the name's type
|
|
// - analysis for structs
|
|
|
|
if (global_typ_opt == null) @panic("TODO name resolution");
|
|
|
|
var global_typ = global_typ_opt.?;
|
|
|
|
switch (global_typ) {
|
|
|
|
// TODO we need to fetch the given
|
|
// struct field (on get.name) type and return it
|
|
.Struct => @panic("TODO analysis of struct"),
|
|
|
|
.Enum => |enum_identifier| {
|
|
// fetch an enum off symbol table, then we use the
|
|
// identifier map to ensure get.name exists in the enum
|
|
var map = ctx.symbol_table.get(enum_identifier).?.value.Enum;
|
|
const name = get.name.lexeme;
|
|
|
|
var kv = map.get(name);
|
|
if (kv == null) {
|
|
self.doError(
|
|
"Field {} not found in enum {}",
|
|
name,
|
|
lexeme,
|
|
);
|
|
return CompileError.TypeError;
|
|
}
|
|
|
|
return global_typ;
|
|
},
|
|
|
|
else => {
|
|
std.debug.warn(
|
|
"Expected Struct/Enum as get target, got {}\n",
|
|
comp.SymbolUnderlyingTypeEnum(global_typ),
|
|
);
|
|
|
|
return CompileError.TypeError;
|
|
},
|
|
}
|
|
},
|
|
|
|
.Assign => |assign| {
|
|
var var_type = ctx.current_scope.?.env.get(
|
|
assign.name.lexeme,
|
|
).?.value;
|
|
|
|
var value_type = try self.resolveExprType(ctx, assign.value);
|
|
try self.expectSymUnTypeEqual(var_type, value_type);
|
|
return var_type;
|
|
},
|
|
|
|
.Set => @panic("TODO analysis of Set exprs"),
|
|
}
|
|
}
|
|
|
|
pub fn stmtPass(
|
|
self: *@This(),
|
|
ctx: *comp.CompilationContext,
|
|
stmt: ast.Stmt,
|
|
) anyerror!void {
|
|
switch (stmt) {
|
|
|
|
// There are no side-effects to the type system when the statement
|
|
// is just an expression or a println. we just resolve it
|
|
// to ensure we dont have type errors.
|
|
.Expr, .Println => |expr_ptr| {
|
|
_ = try self.resolveExprType(ctx, expr_ptr);
|
|
},
|
|
|
|
// VarDecl means we check the type of the expression and
|
|
// insert it into the context, however we need to know a pointer
|
|
// to where we are, scope-wise, we don't have that info here,
|
|
// so it should be implicit into the context.
|
|
.VarDecl => |vardecl| {
|
|
self.setErrToken(vardecl.name);
|
|
const name = vardecl.name.lexeme;
|
|
|
|
var var_type = try self.resolveExprType(ctx, vardecl.value);
|
|
|
|
// TODO check current_scope being null
|
|
|
|
_ = try ctx.current_scope.?.env.put(name, var_type);
|
|
},
|
|
|
|
// Returns dont cause any type system things as they deal with
|
|
// values, however, we must ensure that the expression type
|
|
// matches the function type (must fetch from context, or we could
|
|
// pull a hack with err contexts, lol)
|
|
.Return => |ret| {
|
|
var ret_stmt_type = try self.resolveExprType(ctx, ret.value);
|
|
try self.expectSymUnTypeEqual(ret_stmt_type, ctx.cur_function.?.return_type);
|
|
},
|
|
|
|
// If create two scopes for each branch of the if
|
|
.If => |ifstmt| {
|
|
var cond_type = try self.resolveExprType(ctx, ifstmt.condition);
|
|
try self.expectSymUnTypeEnum(cond_type, .Bool);
|
|
|
|
try ctx.bumpScope("if_then");
|
|
|
|
for (ifstmt.then_branch.toSlice()) |then_stmt| {
|
|
try self.stmtPass(ctx, then_stmt);
|
|
}
|
|
|
|
ctx.dumpScope();
|
|
|
|
if (ifstmt.else_branch) |else_branch| {
|
|
try ctx.bumpScope("if_else");
|
|
defer ctx.dumpScope();
|
|
|
|
for (else_branch.toSlice()) |else_stmt| {
|
|
try self.stmtPass(ctx, else_stmt);
|
|
}
|
|
}
|
|
},
|
|
|
|
// Loop (creates 1 scope) asserts that the expression
|
|
// type is a bool
|
|
.Loop => |loop| {
|
|
if (loop.condition) |cond| {
|
|
var expr = try self.resolveExprType(ctx, cond);
|
|
try self.expectSymUnTypeEnum(expr, .Bool);
|
|
}
|
|
|
|
// TODO bump-dump scope
|
|
for (loop.then_branch.toSlice()) |then_stmt| {
|
|
try self.stmtPass(ctx, then_stmt);
|
|
}
|
|
},
|
|
|
|
// For (creates 1 scope) receives arrays, which we dont have yet
|
|
.For => @panic("TODO for"),
|
|
|
|
else => unreachable,
|
|
}
|
|
}
|
|
|
|
pub fn nodePass(
|
|
self: *@This(),
|
|
ctx: *comp.CompilationContext,
|
|
node: *ast.Node,
|
|
) !void {
|
|
self.setErrToken(null);
|
|
self.setErrContext(null);
|
|
|
|
// always reset the contexts' current function
|
|
ctx.cur_function = null;
|
|
|
|
switch (node.*) {
|
|
.Root => unreachable,
|
|
.FnDecl => |decl| {
|
|
self.setErrToken(decl.return_type);
|
|
const name = decl.func_name.lexeme;
|
|
self.setErrContext("function {}", name);
|
|
var ret_type = self.resolveGlobalType(ctx, decl.return_type.lexeme);
|
|
|
|
std.debug.warn("start analysis of fn {}, ret type: {}\n", decl.func_name.lexeme, ret_type);
|
|
|
|
var parameters = comp.TypeList.init(self.allocator);
|
|
for (decl.params.toSlice()) |param| {
|
|
var param_type = self.resolveGlobalType(ctx, param.typ.lexeme);
|
|
if (param_type == null) continue;
|
|
try parameters.append(param_type.?);
|
|
}
|
|
|
|
// for a function, we always create a new root scope for it
|
|
// and force-set it into the current context
|
|
var scope = try comp.Scope.create(self.allocator, null, name);
|
|
errdefer scope.deinit();
|
|
|
|
// we intentionally insert the function so that:
|
|
// - we can do return statement validation
|
|
// - we have parameter types fully analyzed
|
|
if (ret_type != null and parameters.len == decl.params.len) {
|
|
try ctx.insertFn(decl, ret_type.?, parameters, scope);
|
|
} else {
|
|
if (ret_type != null)
|
|
self.doError("Return type was not fully resolved");
|
|
|
|
if (parameters.len != decl.params.len)
|
|
self.doError("Fully analyzed {} parameters, wanted {}", parameters.len, decl.params.len);
|
|
|
|
return CompileError.TypeError;
|
|
}
|
|
|
|
// we must always start from a null current scope,
|
|
// functions inside functions are not allowed
|
|
std.debug.assert(ctx.current_scope == null);
|
|
ctx.setScope(scope);
|
|
|
|
for (decl.body.toSlice()) |stmt| {
|
|
try self.stmtPass(ctx, stmt);
|
|
}
|
|
|
|
// it should be null when we dump from a function. always
|
|
ctx.dumpScope();
|
|
std.debug.assert(ctx.current_scope == null);
|
|
},
|
|
|
|
.Struct => |struc| {
|
|
self.setErrToken(struc.name);
|
|
self.setErrContext("struct {}", struc.name.lexeme);
|
|
|
|
var types = comp.TypeList.init(self.allocator);
|
|
|
|
for (struc.fields.toSlice()) |field| {
|
|
self.setErrToken(field.name);
|
|
var field_type = self.resolveGlobalType(ctx, field.typ.lexeme);
|
|
if (field_type == null) continue;
|
|
try types.append(field_type.?);
|
|
}
|
|
|
|
// only determine struct as fully resolved
|
|
// when length of declared types == length of resolved types
|
|
|
|
// we don't return type errors from the main loop so we can
|
|
// keep going and find more type errors
|
|
if (types.len == struc.fields.len)
|
|
try ctx.insertStruct(struc, types);
|
|
},
|
|
|
|
// TODO change enums to u32
|
|
.Enum => |enu| {
|
|
self.setErrToken(enu.name);
|
|
self.setErrContext("enum {}", enu.name.lexeme);
|
|
|
|
try ctx.insertEnum(enu);
|
|
},
|
|
|
|
.ConstDecl => |constlist| {
|
|
for (constlist.toSlice()) |constdecl| {
|
|
self.setErrToken(constdecl.name);
|
|
self.setErrContext("const {}", constdecl.name.lexeme);
|
|
|
|
var expr_type = try self.resolveExprType(ctx, constdecl.expr);
|
|
try ctx.insertConst(constdecl, expr_type);
|
|
}
|
|
},
|
|
|
|
else => {
|
|
std.debug.warn("TODO type analysis of {}\n", node.*);
|
|
return CompileError.TypeError;
|
|
},
|
|
}
|
|
}
|
|
|
|
pub fn pass(self: *@This(), root: *ast.Node) !comp.CompilationContext {
|
|
var ctx = comp.CompilationContext.init(self.allocator);
|
|
|
|
var slice = root.Root.toSlice();
|
|
for (slice) |_, idx| {
|
|
try self.nodePass(&ctx, &slice[idx]);
|
|
}
|
|
|
|
return ctx;
|
|
}
|
|
};
|