diff --git a/config.lua b/config.lua index 846d6b4..f83449b 100644 --- a/config.lua +++ b/config.lua @@ -29,101 +29,10 @@ local function findConfigFile() return nil end -local function insertError(errors, key, message) - errors[key] = errors[key] or {} - table.insert(errors[key], message) - return errors -end - -local function fakeTableSchema(value_schema) - return setmetatable({ __list = true }, { - __index = function (self, key) - return value_schema - end - }) -end - -local SPECIAL_KEYS = {__list = true} - -local function validateSchema(schema, input, errors) - local errors = errors or {} - if schema.__list then - -- generate full schema for lists that are same size as input - for k, _ in pairs(input) do schema[k] = schema.__schema_value end - end - - for key, field_schema in pairs(schema) do - if not SPECIAL_KEYS[key] then - assert(field_schema, 'schema not provided') - local input_value = input[key] - local input_type = type(input_value) - local wanted_type = field_schema.type - - local actual_wanted_type = wanted_type - if wanted_type == 'list' then - actual_wanted_type = 'table' - end - - if input_type == actual_wanted_type then - if wanted_type == 'table' then - -- recursive schema validation for generic tables - errors[key] = errors[key] or {} - validateSchema(field_schema.schema, input_value, errors[key]) - if next(errors[key]) == nil then - errors[key] = nil - end - elseif wanted_type == 'list' then - -- for lists (which only have schemas for values), interpret - -- it differently - errors[key] = errors[key] or {} - validateSchema(fakeTableSchema(field_schema.schema), input_value, errors[key]) - if next(errors[key]) == nil then - errors[key] = nil - end - end - else - insertError( - errors, key, - string.format('wanted %s but got %s', tostring(wanted_type), tostring(input_type)) - ) - end - end - end - return errors -end - -local function validateConfigFile(config_object) - local all_schema_errors = {} - for module_name, module_config in pairs(config_object.wantedScripts) do - local module_manifest = require('scripts.' .. module_name) - local config_schema = module_manifest.config - local schema_errors = validateSchema(config_schema, module_config) - if schema_errors then - all_schema_errors[module_name] = schema_errors - end - end - return all_schema_errors -end - -local function writeSchemaErrors(errors, out) - out('sex') -end - local function loadConfigFile() local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path) local config_file_function = assert(loadstring(config_file_data)) - local config_object = config_file_function() - local schema_errors = validateConfigFile(config_object) - - local total_count = table.pprint(schema_errors, {call=function() end}) - if total_count > 0 then - log('CONFIG ERROR') - table.pprint(schema_errors, {call=log}) - end - return config_object + return config_file_function() end -return { - loadConfigFile=loadConfigFile, - validateSchema=validateSchema, -} +return loadConfigFile() diff --git a/main.lua b/main.lua index 60fbb8e..c2e8e40 100644 --- a/main.lua +++ b/main.lua @@ -7,10 +7,9 @@ -- local config = loadConfig() local ctx = require('ctx') -local config = require('config') -require('util') +local conf = require('config') -ctx:loadFromConfig(config.loadConfigFile()) +ctx:loadFromConfig(conf) return function() ctx:onRequest() diff --git a/scripts/webfinger_allowlist.lua b/scripts/webfinger_allowlist.lua index 41d5c80..b5f5c7e 100644 --- a/scripts/webfinger_allowlist.lua +++ b/scripts/webfinger_allowlist.lua @@ -37,8 +37,8 @@ return { }, config={ ['accounts'] = { - type='list', - schema={ + type='table', + value={ type='string', description='ap id' }, diff --git a/test.lua b/test.lua index 968117b..63d3a08 100644 --- a/test.lua +++ b/test.lua @@ -1,6 +1,5 @@ lu = require('luaunit') local rex = require('rex_pcre2') -require('util') function createNgx() local ngx = { @@ -55,18 +54,12 @@ function setupFakeRequest(path, options) end local ctx = require('ctx') -local config = require('config') -function setupTest(module_require_path, input_config) +function setupTest(module_require_path, config) resetNgx() local module = require(module_require_path) - - local schema_errors = config.validateSchema(module.config, input_config) - local count = table.pprint(schema_errors) - lu.assertIs(count, 0) - - state = module.init(input_config) + state = module.init(config) ctx.compiled_chain = { - {module, input_config, state} + {module, config, state} } return module end @@ -81,5 +74,4 @@ function onRequest() end require('tests.webfinger_allowlist') -require('tests.schema_validation') os.exit(lu.LuaUnit.run()) diff --git a/tests/schema_validation.lua b/tests/schema_validation.lua deleted file mode 100644 index 7834930..0000000 --- a/tests/schema_validation.lua +++ /dev/null @@ -1,49 +0,0 @@ -TestSchemaValidator = {} - -local config = require('config') - -function TestSchemaValidator:testBasicFields() - local errors = config.validateSchema({a={type='string'}}, {a='test'}) - lu.assertIs(table.len(errors), 0) - local errors = config.validateSchema({a={type='number'}}, {a=123}) - lu.assertIs(table.len(errors), 0) - local errors = config.validateSchema({a={type='string'}}, {a=123}) - lu.assertIs(table.len(errors), 1) -end - -function TestSchemaValidator:testList() - local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3}}) - lu.assertIs(table.len(errors), 0) - - local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3,'asd'}}) - lu.assertIs(table.len(errors), 1) -end - -function TestSchemaValidator:testTable() - local TEST_SCHEMA = { - a={ - type='table', - schema={ - b={ - type='number' - } - } - } - } - - local errors = config.validateSchema( - TEST_SCHEMA, - {a= - {b=2} - } - ) - lu.assertIs(table.len(errors), 0) - - local errors = config.validateSchema( - TEST_SCHEMA, - {a= - {b='sex'} - } - ) - lu.assertIs(table.len(errors), 1) -end diff --git a/util.lua b/util.lua deleted file mode 100644 index 4db2229..0000000 --- a/util.lua +++ /dev/null @@ -1,28 +0,0 @@ -function table.len(t) - local count = 0 - for _ in pairs(t) do count = count + 1 end - return count -end - -function table.pprint(t, options, ident, total_count) - local ident = ident or 0 - local total_count = total_count or 0 - - local options = options or {} - local print_function = options.call or print - if type(t) == 'table' then - local count = 0 - for k, v in pairs(t) do - print_function(string.rep('\t', ident) .. k) - count = count + 1 - total_count = table.pprint(v, options, ident + 1, total_count) - end - if count == 0 then - --print('') - end - else - print_function(string.rep('\t', ident) .. tostring(t)) - total_count = total_count + 1 - end - return total_count -end