From fd5905910176d11307e43570ba81ac317b08863d Mon Sep 17 00:00:00 2001 From: Luna Date: Wed, 7 Dec 2022 14:57:07 -0300 Subject: [PATCH] add draft for config schema validation --- config.lua | 94 ++++++++++++++++++++++++++++++++- main.lua | 4 +- scripts/webfinger_allowlist.lua | 4 +- test.lua | 1 + tests/schema_validation.lua | 43 +++++++++++++++ 5 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 tests/schema_validation.lua diff --git a/config.lua b/config.lua index f83449b..cc148bd 100644 --- a/config.lua +++ b/config.lua @@ -29,10 +29,100 @@ local function findConfigFile() return nil end +local function insertError(errors, key, message) + errors[key] = errors[key] or {} + table.insert(errors[key], message) + return errors +end + +local function fakeTableSchema(value_schema) + return setmetatable({ __list = true }, { + __index = function (self, key) + return value_schema + end + }) +end + +local SPECIAL_KEYS = {__list = true} + +local function validateSchema(schema, input, errors) + local errors = errors or {} + if schema.__list then + -- generate full schema for lists that are same size as input + for k, _ in pairs(input) do schema[k] = schema.__schema_value end + end + + for key, field_schema in pairs(schema) do + if not SPECIAL_KEYS[key] then + assert(field_schema, 'schema not provided') + local input_value = input[key] + local input_type = type(input_value) + local wanted_type = field_schema.type + + local actual_wanted_type = wanted_type + if wanted_type == 'list' then + actual_wanted_type = 'table' + end + + if input_type == actual_wanted_type then + if wanted_type == 'table' then + -- recursive schema validation for generic tables + errors[key] = errors[key] or {} + validateSchema(field_schema.schema, input_value, errors[key]) + if not errors[key] then + errors[key] = nil + end + elseif wanted_type == 'list' then + -- for lists (which only have schemas for values), interpret + -- it differently + errors[key] = errors[key] or {} + validateSchema(fakeTableSchema(field_schema.schema), input_value, errors[key]) + if next(errors[key]) == nil then + errors[key] = nil + end + end + else + insertError( + errors, key, + string.format('wanted %s but got %s', tostring(wanted_type), tostring(input_type)) + ) + end + end + end + return errors +end + +local function validateConfigFile(config_object) + local all_schema_errors = {} + for module_name, module_config in pairs(config_object.wantedScripts) do + local module_manifest = require('scripts.' .. module_name) + local config_schema = module_manifest.config + local schema_errors = validateSchema(config_schema, module_config) + if schema_errors then + all_schema_errors[module_name] = schema_errors + end + end + return all_schema_errors +end + +local function writeSchemaErrors(errors, out) + out('sex') +end + local function loadConfigFile() local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path) local config_file_function = assert(loadstring(config_file_data)) - return config_file_function() + local config_object = config_file_function() + local schema_errors = validateConfigFile(config_object) + if schema_errors then + for name, errors in pairs(schema_errors) do + log('CONFIG ERROR' .. writeSchemaErrors(errors, log)) + end + end + return config_object end -return loadConfigFile() +return { + loadConfigFile=loadConfigFile, + validateSchema=validateSchema, +} diff --git a/main.lua b/main.lua index c2e8e40..2be3b01 100644 --- a/main.lua +++ b/main.lua @@ -7,9 +7,9 @@ -- local config = loadConfig() local ctx = require('ctx') -local conf = require('config') +local config = require('config') -ctx:loadFromConfig(conf) +ctx:loadFromConfig(config.loadConfigFile()) return function() ctx:onRequest() diff --git a/scripts/webfinger_allowlist.lua b/scripts/webfinger_allowlist.lua index b5f5c7e..41d5c80 100644 --- a/scripts/webfinger_allowlist.lua +++ b/scripts/webfinger_allowlist.lua @@ -37,8 +37,8 @@ return { }, config={ ['accounts'] = { - type='table', - value={ + type='list', + schema={ type='string', description='ap id' }, diff --git a/test.lua b/test.lua index 63d3a08..76e17fa 100644 --- a/test.lua +++ b/test.lua @@ -74,4 +74,5 @@ function onRequest() end require('tests.webfinger_allowlist') +require('tests.schema_validation') os.exit(lu.LuaUnit.run()) diff --git a/tests/schema_validation.lua b/tests/schema_validation.lua new file mode 100644 index 0000000..7b4d8f8 --- /dev/null +++ b/tests/schema_validation.lua @@ -0,0 +1,43 @@ +TestConfigSchemaValidator = {} + +local config = require('config') + +local function len(t) + local count = 0 + for _ in pairs(t) do count = count + 1 end + return count +end + +function pprint(t, ident) + local ident = ident or 0 + if type(t) == 'table' then + local count = 0 + for k, v in pairs(t) do + print(string.rep('\t', ident) .. k) + pprint(v, ident + 1) + count = count + 1 + end + if count == 0 then + print('') + end + else + print(string.rep('\t', ident) .. tostring(t)) + end +end + +function TestConfigSchemaValidator:testBasicFields() + local errors = config.validateSchema({a={type='string'}}, {a='test'}) + lu.assertIs(len(errors), 0) + local errors = config.validateSchema({a={type='number'}}, {a=123}) + lu.assertIs(len(errors), 0) + local errors = config.validateSchema({a={type='string'}}, {a=123}) + lu.assertIs(len(errors), 1) +end + +function TestConfigSchemaValidator:testList() + local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3}}) + lu.assertIs(len(errors), 0) + + local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3,'asd'}}) + lu.assertIs(len(errors), 1) +end