diff --git a/config.lua b/config.lua index f83449b..278ee9b 100644 --- a/config.lua +++ b/config.lua @@ -29,10 +29,107 @@ local function findConfigFile() return nil end -local function loadConfigFile() - local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path) - local config_file_function = assert(loadstring(config_file_data)) - return config_file_function() +local function insertError(errors, key, message) + errors[key] = errors[key] or {} + table.insert(errors[key], message) + return errors end -return loadConfigFile() +local function fakeTableSchema(value_schema) + return setmetatable({ __list = true }, { + __index = function (self, key) + return value_schema + end + }) +end + +local SPECIAL_KEYS = {__list = true} + +local function validateSchema(schema, input, errors) + local errors = errors or {} + if schema.__list then + -- generate full schema for lists that are same size as input + for k, _ in pairs(input) do schema[k] = schema.__schema_value end + end + + for key, field_schema in pairs(schema) do + if not SPECIAL_KEYS[key] then + assert(field_schema, 'schema not provided') + local input_value = input[key] + local input_type = type(input_value) + local wanted_type = field_schema.type + + local actual_wanted_type = wanted_type + if wanted_type == 'list' then + actual_wanted_type = 'table' + end + + if input_type == actual_wanted_type then + if wanted_type == 'table' then + -- recursive schema validation for generic tables + errors[key] = errors[key] or {} + validateSchema(field_schema.schema, input_value, errors[key]) + if next(errors[key]) == nil then + errors[key] = nil + end + elseif wanted_type == 'list' then + -- for lists (which only have schemas for values), interpret + -- it differently + errors[key] = errors[key] or {} + validateSchema(fakeTableSchema(field_schema.schema), input_value, errors[key]) + if next(errors[key]) == nil then + errors[key] = nil + end + end + else + insertError( + errors, key, + string.format('wanted %s but got %s', tostring(wanted_type), tostring(input_type)) + ) + end + end + end + return errors +end + +local function validateConfigFile(config_object) + local all_schema_errors = {} + for module_name, module_config in pairs(config_object.wantedScripts) do + local module_manifest = require('scripts.' .. module_name) + local config_schema = module_manifest.config + local schema_errors = validateSchema(config_schema, module_config) + if schema_errors then + all_schema_errors[module_name] = schema_errors + end + end + return all_schema_errors +end + +local function writeSchemaErrors(errors, out) + out('sex') +end + +local function loadConfigFile(options) + local options = options or {} + + local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path) + local config_file_function = assert(loadstring(config_file_data)) + local config_object = config_file_function() + + if options.validate then + local schema_errors = validateConfigFile(config_object) + + local total_count = table.pprint(schema_errors, {call=function() end}) + if total_count > 0 then + log('CONFIG ERROR') + table.pprint(schema_errors, {call=log}) + end + end + + return config_object +end + +return { + loadConfigFile=loadConfigFile, + validateSchema=validateSchema, +} diff --git a/main.lua b/main.lua index c2e8e40..b0cf8ff 100644 --- a/main.lua +++ b/main.lua @@ -7,10 +7,17 @@ -- local config = loadConfig() local ctx = require('ctx') -local conf = require('config') +local config = require('config') +require('util') -ctx:loadFromConfig(conf) +ctx:loadFromConfig(config.loadConfigFile()) -return function() - ctx:onRequest() -end +return { + init=function () + -- validate config and print out errors + config.loadConfigFile({validate = true}) + end, + access=function() + ctx:onRequest() + end +} diff --git a/nginx.conf b/nginx.conf index 81904be..4b6e8bf 100644 --- a/nginx.conf +++ b/nginx.conf @@ -1,4 +1,8 @@ - server { +init_by_lua_block { + require("aproxy.main").init() +} + +server { listen 80; lua_code_cache off; @@ -7,9 +11,9 @@ # must happen before proxy_pass access_by_lua_block { - require("aproxy.main")() + require("aproxy.main").access() } proxy_pass http://localhost:9999; } - } +} diff --git a/scripts/webfinger_allowlist.lua b/scripts/webfinger_allowlist.lua index b5f5c7e..41d5c80 100644 --- a/scripts/webfinger_allowlist.lua +++ b/scripts/webfinger_allowlist.lua @@ -37,8 +37,8 @@ return { }, config={ ['accounts'] = { - type='table', - value={ + type='list', + schema={ type='string', description='ap id' }, diff --git a/test.lua b/test.lua index 63d3a08..968117b 100644 --- a/test.lua +++ b/test.lua @@ -1,5 +1,6 @@ lu = require('luaunit') local rex = require('rex_pcre2') +require('util') function createNgx() local ngx = { @@ -54,12 +55,18 @@ function setupFakeRequest(path, options) end local ctx = require('ctx') -function setupTest(module_require_path, config) +local config = require('config') +function setupTest(module_require_path, input_config) resetNgx() local module = require(module_require_path) - state = module.init(config) + + local schema_errors = config.validateSchema(module.config, input_config) + local count = table.pprint(schema_errors) + lu.assertIs(count, 0) + + state = module.init(input_config) ctx.compiled_chain = { - {module, config, state} + {module, input_config, state} } return module end @@ -74,4 +81,5 @@ function onRequest() end require('tests.webfinger_allowlist') +require('tests.schema_validation') os.exit(lu.LuaUnit.run()) diff --git a/tests/schema_validation.lua b/tests/schema_validation.lua new file mode 100644 index 0000000..7834930 --- /dev/null +++ b/tests/schema_validation.lua @@ -0,0 +1,49 @@ +TestSchemaValidator = {} + +local config = require('config') + +function TestSchemaValidator:testBasicFields() + local errors = config.validateSchema({a={type='string'}}, {a='test'}) + lu.assertIs(table.len(errors), 0) + local errors = config.validateSchema({a={type='number'}}, {a=123}) + lu.assertIs(table.len(errors), 0) + local errors = config.validateSchema({a={type='string'}}, {a=123}) + lu.assertIs(table.len(errors), 1) +end + +function TestSchemaValidator:testList() + local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3}}) + lu.assertIs(table.len(errors), 0) + + local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3,'asd'}}) + lu.assertIs(table.len(errors), 1) +end + +function TestSchemaValidator:testTable() + local TEST_SCHEMA = { + a={ + type='table', + schema={ + b={ + type='number' + } + } + } + } + + local errors = config.validateSchema( + TEST_SCHEMA, + {a= + {b=2} + } + ) + lu.assertIs(table.len(errors), 0) + + local errors = config.validateSchema( + TEST_SCHEMA, + {a= + {b='sex'} + } + ) + lu.assertIs(table.len(errors), 1) +end diff --git a/util.lua b/util.lua new file mode 100644 index 0000000..4db2229 --- /dev/null +++ b/util.lua @@ -0,0 +1,28 @@ +function table.len(t) + local count = 0 + for _ in pairs(t) do count = count + 1 end + return count +end + +function table.pprint(t, options, ident, total_count) + local ident = ident or 0 + local total_count = total_count or 0 + + local options = options or {} + local print_function = options.call or print + if type(t) == 'table' then + local count = 0 + for k, v in pairs(t) do + print_function(string.rep('\t', ident) .. k) + count = count + 1 + total_count = table.pprint(v, options, ident + 1, total_count) + end + if count == 0 then + --print('') + end + else + print_function(string.rep('\t', ident) .. tostring(t)) + total_count = total_count + 1 + end + return total_count +end