Compare commits

..

5 commits

6 changed files with 186 additions and 9 deletions

View file

@ -29,10 +29,101 @@ local function findConfigFile()
return nil return nil
end end
local function insertError(errors, key, message)
errors[key] = errors[key] or {}
table.insert(errors[key], message)
return errors
end
local function fakeTableSchema(value_schema)
return setmetatable({ __list = true }, {
__index = function (self, key)
return value_schema
end
})
end
local SPECIAL_KEYS = {__list = true}
local function validateSchema(schema, input, errors)
local errors = errors or {}
if schema.__list then
-- generate full schema for lists that are same size as input
for k, _ in pairs(input) do schema[k] = schema.__schema_value end
end
for key, field_schema in pairs(schema) do
if not SPECIAL_KEYS[key] then
assert(field_schema, 'schema not provided')
local input_value = input[key]
local input_type = type(input_value)
local wanted_type = field_schema.type
local actual_wanted_type = wanted_type
if wanted_type == 'list' then
actual_wanted_type = 'table'
end
if input_type == actual_wanted_type then
if wanted_type == 'table' then
-- recursive schema validation for generic tables
errors[key] = errors[key] or {}
validateSchema(field_schema.schema, input_value, errors[key])
if next(errors[key]) == nil then
errors[key] = nil
end
elseif wanted_type == 'list' then
-- for lists (which only have schemas for values), interpret
-- it differently
errors[key] = errors[key] or {}
validateSchema(fakeTableSchema(field_schema.schema), input_value, errors[key])
if next(errors[key]) == nil then
errors[key] = nil
end
end
else
insertError(
errors, key,
string.format('wanted %s but got %s', tostring(wanted_type), tostring(input_type))
)
end
end
end
return errors
end
local function validateConfigFile(config_object)
local all_schema_errors = {}
for module_name, module_config in pairs(config_object.wantedScripts) do
local module_manifest = require('scripts.' .. module_name)
local config_schema = module_manifest.config
local schema_errors = validateSchema(config_schema, module_config)
if schema_errors then
all_schema_errors[module_name] = schema_errors
end
end
return all_schema_errors
end
local function writeSchemaErrors(errors, out)
out('sex')
end
local function loadConfigFile() local function loadConfigFile()
local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path) local config_file_data = assert(findConfigFile(), 'no config file found, config path: ' .. config_path)
local config_file_function = assert(loadstring(config_file_data)) local config_file_function = assert(loadstring(config_file_data))
return config_file_function() local config_object = config_file_function()
local schema_errors = validateConfigFile(config_object)
local total_count = table.pprint(schema_errors, {call=function() end})
if total_count > 0 then
log('CONFIG ERROR')
table.pprint(schema_errors, {call=log})
end
return config_object
end end
return loadConfigFile() return {
loadConfigFile=loadConfigFile,
validateSchema=validateSchema,
}

View file

@ -7,9 +7,10 @@
-- local config = loadConfig() -- local config = loadConfig()
local ctx = require('ctx') local ctx = require('ctx')
local conf = require('config') local config = require('config')
require('util')
ctx:loadFromConfig(conf) ctx:loadFromConfig(config.loadConfigFile())
return function() return function()
ctx:onRequest() ctx:onRequest()

View file

@ -37,8 +37,8 @@ return {
}, },
config={ config={
['accounts'] = { ['accounts'] = {
type='table', type='list',
value={ schema={
type='string', type='string',
description='ap id' description='ap id'
}, },

View file

@ -1,5 +1,6 @@
lu = require('luaunit') lu = require('luaunit')
local rex = require('rex_pcre2') local rex = require('rex_pcre2')
require('util')
function createNgx() function createNgx()
local ngx = { local ngx = {
@ -54,12 +55,18 @@ function setupFakeRequest(path, options)
end end
local ctx = require('ctx') local ctx = require('ctx')
function setupTest(module_require_path, config) local config = require('config')
function setupTest(module_require_path, input_config)
resetNgx() resetNgx()
local module = require(module_require_path) local module = require(module_require_path)
state = module.init(config)
local schema_errors = config.validateSchema(module.config, input_config)
local count = table.pprint(schema_errors)
lu.assertIs(count, 0)
state = module.init(input_config)
ctx.compiled_chain = { ctx.compiled_chain = {
{module, config, state} {module, input_config, state}
} }
return module return module
end end
@ -74,4 +81,5 @@ function onRequest()
end end
require('tests.webfinger_allowlist') require('tests.webfinger_allowlist')
require('tests.schema_validation')
os.exit(lu.LuaUnit.run()) os.exit(lu.LuaUnit.run())

View file

@ -0,0 +1,49 @@
TestSchemaValidator = {}
local config = require('config')
function TestSchemaValidator:testBasicFields()
local errors = config.validateSchema({a={type='string'}}, {a='test'})
lu.assertIs(table.len(errors), 0)
local errors = config.validateSchema({a={type='number'}}, {a=123})
lu.assertIs(table.len(errors), 0)
local errors = config.validateSchema({a={type='string'}}, {a=123})
lu.assertIs(table.len(errors), 1)
end
function TestSchemaValidator:testList()
local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3}})
lu.assertIs(table.len(errors), 0)
local errors = config.validateSchema({a={type='list', schema={type='number'}}}, {a={1,2,3,'asd'}})
lu.assertIs(table.len(errors), 1)
end
function TestSchemaValidator:testTable()
local TEST_SCHEMA = {
a={
type='table',
schema={
b={
type='number'
}
}
}
}
local errors = config.validateSchema(
TEST_SCHEMA,
{a=
{b=2}
}
)
lu.assertIs(table.len(errors), 0)
local errors = config.validateSchema(
TEST_SCHEMA,
{a=
{b='sex'}
}
)
lu.assertIs(table.len(errors), 1)
end

28
util.lua Normal file
View file

@ -0,0 +1,28 @@
function table.len(t)
local count = 0
for _ in pairs(t) do count = count + 1 end
return count
end
function table.pprint(t, options, ident, total_count)
local ident = ident or 0
local total_count = total_count or 0
local options = options or {}
local print_function = options.call or print
if type(t) == 'table' then
local count = 0
for k, v in pairs(t) do
print_function(string.rep('\t', ident) .. k)
count = count + 1
total_count = table.pprint(v, options, ident + 1, total_count)
end
if count == 0 then
--print('<empty table>')
end
else
print_function(string.rep('\t', ident) .. tostring(t))
total_count = total_count + 1
end
return total_count
end