LLM-Self-Detection-Research/tests/testing.PY
buzz-lightsnack-2007 f5c6380b77 add: testing program
This script contains prompt generation and LLM testing.
2024-12-07 21:37:30 +08:00

273 lines
8.7 KiB
Python
Executable file
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#! /Library/Frameworks/Python.framework/Versions/3.12/bin/python3
# describe.py
# To fulfill the test of asking the LLMs to describe
# Import modules.
import ollama;
import json;
import os;
import datetime;
# Add source files.
IMPORTED = {'Strings': "data/datasets/strings.JSON", 'Prompts': "tests/config/prompts.json", 'Models': 'tests/config/models.JSON'}
# Set up the main variables.
RESPONSES = {};
PROMPTS = {};
RESULTS = {};
# This is the testing configuration.
TEST_CONFIG = {};
# Read the files.
for NAME in list(IMPORTED.keys()):
try:
DATA = json.load(open(IMPORTED[NAME]))
except:
DATA = open(IMPORTED[NAME]).read()
IMPORTED[NAME] = DATA;
# Download the models.
def download_models():
for MODEL_NAME in IMPORTED["Models"].keys():
MODEL_ID = IMPORTED["Models"][MODEL_NAME];
ollama.pull(MODEL_ID);
# Let the user choose the testing type.
def select_testing_type():
if not('CoT' in list(TEST_CONFIG.keys())):
RESPONSE = abs(int(input("Type “1” for a chain-of-thought test or “0” for a classic test: ")));
TEST_CONFIG['CoT'] = RESPONSE > 0;
if not('multi-shot' in list(TEST_CONFIG.keys())):
RESPONSE = abs(int(input("Type “1” for a multi-shot test or “0” for a zero-shot test: ")));
TEST_CONFIG['multi-shot'] = RESPONSE > 0;
for CONFIGURATION_TYPE in list(TEST_CONFIG.keys()):
if (TEST_CONFIG[CONFIGURATION_TYPE]):
print(f"{CONFIGURATION_TYPE}:\tEnabled")
else:
print(f"{CONFIGURATION_TYPE}:\tDisabled")
return (TEST_CONFIG);
'''
Check for questions cache.
Returns: cache validity
'''
def test_questions_cache():
CACHE_VALID = False;
# Check for the validity of the cache files
SOURCES = ["tests/cache/config.JSON", "tests/cache/prompts.JSON"];
for FILE_NAME in SOURCES:
CACHE_VALID = os.path.isfile(FILE_NAME);
if (not(CACHE_VALID)):
break;
if (CACHE_VALID):
try:
# Open the last configuration file and check if the used testing configurations are equal.
CACHED_DATA = json.load(open('tests/cache/config.JSON'));
CACHE_VALID = CACHED_DATA == TEST_CONFIG;
except:
CACHE_VALID = False;
if (not(CACHE_VALID)):
for FILE_NAME in ["tests/cache/config.JSON", "tests/cache/prompts.JSON", "tests/cache/responses.JSON"]:
if (os.path.isfile(FILE_NAME)):
# Remove the invalid caches.
os.remove(FILE_NAME);
return (CACHE_VALID);
'''
Format the questions.
Returns: (dict) the prompts
'''
def format_questions():
CACHE_USABLE = test_questions_cache();
if (CACHE_USABLE):
print("Using cache.");
IMPORTED['Cached Prompts'] = json.load(open("tests/cache/prompts.JSON"));
for TYPE_NAME in list(IMPORTED['Strings']['testing'].keys()):
if (type(TYPE_NAME) is str):
PROMPTS[TYPE_NAME.strip()] = IMPORTED['Cached Prompts'][TYPE_NAME];
else:
PROMPTS[TYPE_NAME] = IMPORTED['Cached Prompts'][TYPE_NAME];
else:
print("Generating questions…")
# Loop through each author type
for TYPE_NAME in list(IMPORTED['Strings']['testing'].keys()):
PROMPTS[TYPE_NAME] = {};
# Loop through each source
for SOURCE_NAME in list(IMPORTED['Strings']['testing'][TYPE_NAME].keys()):
PROMPTS[TYPE_NAME][SOURCE_NAME] = [];
for PASSAGE in IMPORTED['Strings']['testing'][TYPE_NAME][SOURCE_NAME]:
PROMPT = "";
if TEST_CONFIG['multi-shot']:
PROMPT = f"{IMPORTED['Prompts']["sample"]}"
for GENERATION_TYPE in IMPORTED['Strings']['training'].keys():
for TEXT_NUMBER in range(len(IMPORTED['Strings']['training'][GENERATION_TYPE])):
PROMPT = f"{PROMPT}\n\n{GENERATION_TYPE}-written #{str(TEXT_NUMBER + 1)}: \n{'\n\n\t'.join(IMPORTED['Strings']['training'][GENERATION_TYPE][TEXT_NUMBER].strip().split("\n\n"))}";
PROMPT = f"{PROMPT}\n\n{IMPORTED['Prompts']['bridge']}\n\n";
PROMPT = f"{PROMPT}{IMPORTED['Prompts']["introduction"]}\n\n{'\n\t'.join(PASSAGE.strip().split("\n"))}\n\n{IMPORTED['Prompts']["classify"]}"
PROMPTS[TYPE_NAME][SOURCE_NAME].append(PROMPT);
create_cache(exclude=['responses']);
return(IMPORTED['Prompts']);
def asking_execution():
if (screen_asking()):
ask_AI();
save_responses();
'''
This function will request the users final review before running the LLMs. This function also ensures that testing conditions are satisfied before executing.
Returns: (bool) user's proceed state
'''
def screen_asking():
RESPONSE = '';
# Check the testing conditions.
CONTINUE = len(list(IMPORTED['Models'].keys())) > 0;
if (CONTINUE):
try:
RESPONSE = input("\n\nDo you now want to begin interaction with the LLMs? \nThis process will take about 20 minutes. \n");
except KeyboardInterrupt:
CONTINUE = False;
else:
if ("n" in RESPONSE.lower().strip().rstrip('.').rstrip('!')):
CONTINUE = False;
else:
print("No testing models configured. Change that configuration and run this script again once you're ready.")
return CONTINUE;
'''
Ask the AI.
Parameters:
models (array): the models to test
Returns: (dict) the responses
'''
def ask_AI():
# Loop through each author type
for TYPE_NAME in list(PROMPTS.keys()):
RESPONSES[TYPE_NAME] = {};
# Loop through each source
for SOURCE_NAME in list(PROMPTS[TYPE_NAME].keys()):
RESPONSES[TYPE_NAME][SOURCE_NAME] = [];
print("\n");
TARGET_LENGTH = len(PROMPTS[TYPE_NAME][SOURCE_NAME]);
for PROMPT_NUMBER in range(TARGET_LENGTH - 1):
print(f"\033[FAnswering prompt {PROMPT_NUMBER + 1} of {TARGET_LENGTH} from {TYPE_NAME} work in {SOURCE_NAME}");
PROMPT = PROMPTS[TYPE_NAME][SOURCE_NAME][PROMPT_NUMBER];
MODEL_RESPONSES = {};
for MODEL_NAME in list(IMPORTED['Models'].keys()):
# Get the model ID.
MODEL_ID = IMPORTED['Models'][MODEL_NAME];
# Send update log.
print(f"\033[FAnswering prompt {PROMPT_NUMBER + 1} of {TARGET_LENGTH} from {TYPE_NAME} work in {SOURCE_NAME} using {MODEL_NAME}");
# Prepare the messages.
MESSAGES = {};
# Set the messages.
MESSAGES['User'] = [];
MESSAGES['Model'] = [];
MESSAGES['History'] = [];
# Add the order of the messages.
MESSAGES['User'].append(PROMPT);
(MESSAGES['User'].append(IMPORTED['Prompts']['judge'])) if (TEST_CONFIG['CoT']) else False;
MESSAGES['User'].append(IMPORTED['Prompts']['answer format']);
for MESSAGE in (MESSAGES['User']):
# Add the message.
MESSAGES['History'].append({'role': 'user', 'content': MESSAGE});
MESSAGE_LAST = ((ollama.chat(model=MODEL_ID, messages=MESSAGES['History']))['message']['content']).strip("\t\n").strip();
MESSAGES['Model'].append(MESSAGE_LAST);
MESSAGES['History'].append({'role': 'assistant', 'content': MESSAGE_LAST});
# Associate with the correct LLM model.
del MESSAGES['User'];
del MESSAGES['History'];
MODEL_RESPONSES[MODEL_NAME] = MESSAGES['Model'];
# Append the messages.
RESPONSES[TYPE_NAME][SOURCE_NAME].append(MODEL_RESPONSES);
# Cache the responses.
create_cache(include=['responses']);
# Update the status.
print(f"\033[FAnswered prompt {PROMPT_NUMBER + 1} of {TARGET_LENGTH} from {TYPE_NAME} work in {SOURCE_NAME}.");
print(f"\033[FFinished answering all {TARGET_LENGTH} prompts consisting of {TYPE_NAME} work in {SOURCE_NAME}.");
return False;
'''
Generate a cache.
Parameters:
exclude (list): Exclude items
include (list): Include certain items
'''
def create_cache(**params):
if (not('test config' in params['exclude']) if ('exclude' in list(params.keys())) else (('test config' in params['include']) if ('include' in list(params.keys())) else True)):
save_data(dictionary=TEST_CONFIG, filename='tests/cache/config.JSON');
if (not('prompts' in params['exclude']) if ('exclude' in list(params.keys())) else (('prompts' in params['include']) if ('include' in list(params.keys())) else True)):
save_data(dictionary=PROMPTS, filename='tests/cache/prompts.JSON');
if (not('responses' in params['exclude']) if ('exclude' in list(params.keys())) else (('responses' in params['include']) if ('include' in list(params.keys())) else True)):
save_data(dictionary=RESPONSES, filename="tests/cache/responses.JSON");
'''
Export the responses.
'''
def save_responses():
save_data(dictionary=RESPONSES, filename=f"tests/outputs/responses{' multi-shot' if (TEST_CONFIG['multi-shot'] if 'multi-shot' in list(TEST_CONFIG.keys()) else False) else ''}{' CoT' if (TEST_CONFIG['CoT'] if 'CoT' in list(TEST_CONFIG.keys()) else False) else ''} {str(datetime.datetime.now().time())}.JSON");
os.remove("tests/cache/responses.JSON");
'''
Save the data.
Parameters:
filename (str): The file name
'''
def save_data(**parameters):
if (parameters['filename'].strip()):
with open(parameters['filename'], 'w') as file:
# print(f"Saving {parameters['filename']}…");
json.dump(parameters['dictionary'], file);
# Run the code.
def main():
select_testing_type();
download_models();
format_questions();
asking_execution();
main();