redesign of endpoint and task classes

This commit is contained in:
Michael Peter Christen
2025-05-18 22:30:50 +02:00
parent b080928020
commit c049b908dd
3 changed files with 128 additions and 121 deletions

View File

@@ -2,35 +2,29 @@ import os
import json
import time
import base64
from typing import List
from argparse import ArgumentParser
from benchmark import read_benchmark
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, LoadBalancer, Server, Task, Response
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, Endpoint, LoadBalancer, Server, Task, Response
def read_template(template_path):
with open(template_path, 'r', encoding='utf-8') as file:
return file.read()
def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999,
def process_problem_files(problems_dir, template_content, endpoints: List[Endpoint], language, max_problem_number=9999,
overwrite_existing=False, overwrite_failed=False, expected_solutions={},
think=False, no_think=False):
model_name = endpoint["name"]
model_store_name = endpoints[0].store_name
model_api_name = endpoints[0].api_name
if think: model_name += "-think"
if no_think: model_name += "-no_think"
solutions_dir = os.path.join('solutions', model_name, language)
solutions_dir = os.path.join('solutions', model_store_name, language)
os.makedirs(solutions_dir, exist_ok=True)
# get the results as computed so far (may be none at first run)
results_json_path = os.path.join(solutions_dir, 'results.json') # may not exist yet
solutions = {}
if os.path.exists(results_json_path):
with open(results_json_path, 'r', encoding='utf-8') as json_file:
solutions = json.load(json_file)
# Create load balancer with all available endpoints
server_urls = endpoint["endpoints"]
servers = [Server(endpoint=url) for i, url in enumerate(server_urls)]
lb = LoadBalancer(servers)
servers = [Server(endpoint=endpoint) for endpoint in endpoints]
lb = LoadBalancer()
for server in servers: lb.add_server(server)
lb.start_distribution()
# iterate over all problem files and process them
@@ -44,14 +38,6 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
if not overwrite_existing and not overwrite_failed and os.path.exists(result_file_path):
print(f"Skipping problem {problem_number} as it already has a solution.")
continue
# check if the problem is already solved
actual_solution = expected_solutions.get(problem_number, {}).get('solution', None)
problem_is_solved = problem_number in solutions and solutions[problem_number] == actual_solution
if overwrite_failed and problem_is_solved:
print(f"Skipping problem {problem_number} as it is already solved and overwrite_failed is set.")
continue
# read problem content
with open(problem_path, 'r', encoding='utf-8') as file:
@@ -69,7 +55,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
# check if the endpoint is multimodal if we have an image
if base64_image:
is_multimodal = test_multimodal(endpoint) # this is cached
is_multimodal = test_multimodal(endpoints[0]) # this is cached
if is_multimodal:
print(f"Problem {problem_number} is handled with multimodal model.")
else:
@@ -92,9 +78,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
# Create task and add to load balancer
task = Task(
id = problem_number,
description = f"problem {problem_number}, language {language}, model {endpoint.get("name", model_name)}",
model_name = model_name, # the model storage name
model = endpoint.get("name", model_name), # the actual model name
description = f"problem {problem_number}, language {language}, model {model_api_name}",
prompt = prompt,
base64_image = base64_image,
response_processing = save_solution
@@ -102,17 +86,13 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
while not lb.add_task(task):
print(f"Waiting to add task {problem_number} - queue full")
time.sleep(1)
print(f"Added problem {problem_number}, language {language}, model {model_name} to processing queue")
print(f"Added problem {problem_number}, language {language}, model {model_api_name} to processing queue")
# Wait for all tasks to complete
print("Waiting for all problems to be processed...")
lb.wait_completion()
print("All problems processed!")
# Save solutions to JSON
with open(results_json_path, 'w', encoding='utf-8') as json_file:
json.dump(lb.results, json_file, indent=2)
def main():
parser = ArgumentParser(description="Process Euler problems and send them to an LLM.")

View File

@@ -13,6 +13,22 @@ from dataclasses import dataclass
from argparse import ArgumentParser
from typing import Callable, List, Optional
@dataclass
class Endpoint:
"""
Dataclass for endpoint.
This dataclass is used to represent an endpoint for the API.
Each endpoint has a name, a model name, a key, and a list of URLs.
"""
store_name: str # Name of the endpoint
api_name: str # Model name that is used in the api request
key: str # API key (if required)
url: str # URL of the endpoint
def get_ollama_ls_url(self) -> str:
"""Get the URL for the ollama ls command"""
return urllib3.util.url.parse_url(self.url)._replace(path='/api/tags').url
def ollama_pull(api_base='http://localhost:11434', model='llama3.2:latest') -> bool:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
response = requests.request("POST", f"{api_base}/api/pull", verify=False,
@@ -43,7 +59,7 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
for entry in data['models']
}
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> List[Endpoint]:
if isinstance(api_base, str): api_base = [api_base]
# check if the endpoint servers are online and the model is available
@@ -57,24 +73,21 @@ def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.
ollama_pull(api_stub, model_name)
print(f"Model {model_name} is now available on server {api_stub}.")
except Exception as e:
# the server is not available
# remove the server from the list
# the server is not available, remove the server from the list
print(f"Server {api_stub} is not available: {e}")
api_base.remove(api_stub)
# return the endpoint object with the model name
return {
"name": model_name,
"model": model_name,
"key": "",
"endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base],
}
# create the endpoint list
return [
Endpoint(store_name=model_name, api_name=model_name, key="",
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
]
def hex2base64(hex_string) -> str:
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
def ollama_chat(
endpoint,
endpoint: Endpoint,
prompt: str = 'Hello World',
base64_image: str = None,
temperature: float = 0.0,
@@ -105,10 +118,10 @@ def ollama_chat(
'Content-Type': 'application/json',
'Accept': 'application/json'
}
if endpoint.get("key", ""):
headers['Authorization'] = 'Bearer ' + endpoint["key"]
if endpoint.key:
headers['Authorization'] = 'Bearer ' + endpoint.key
modelname = endpoint["model"]
modelname = endpoint.api_name
messages = []
messages.append({"role": "system", "content": "You are a helpful assistant"})
@@ -169,58 +182,48 @@ def ollama_chat(
payload["max_tokens"] = max_tokens
# use the endpoints array as failover mechanism
endpoint_url_list = endpoint["endpoints"]
failure_exception = Exception("No endpoint available or failure without exception")
for failover_count in range(len(endpoint_url_list)):
try:
#print(payload)
t0 = time.time()
response = requests.post(endpoint["endpoints"][failover_count], headers=headers, json=payload, verify=False)
t1 = time.time()
#print(response)
response.raise_for_status()
except requests.exceptions.RequestException as e:
# print(f"Failed to access api: {e}")
# Get the error message from the response
if response:
try:
#print(response.text)
data = response.json()
message = data.get('message', {})
content = message.get('content', '')
failure_exception = Exception(f"API request failed: {content}")
continue
except json.JSONDecodeError:
failure_exception = Exception(f"API request failed: {e}")
continue
try:
#print(payload)
t0 = time.time()
response = requests.post(endpoint.url, headers=headers, json=payload, verify=False)
t1 = time.time()
#print(response)
response.raise_for_status()
except requests.exceptions.RequestException as e:
# print(f"Failed to access api: {e}")
# Get the error message from the response
if response:
try:
#print(response.text)
data = response.json()
message = data.get('message', {})
content = message.get('content', '')
raise Exception(f"API request failed: {content}")
except json.JSONDecodeError:
raise Exception(f"API request failed: {e}")
# Parse the response
try:
#print(response.text)
data = response.json()
usage = data.get('usage', {})
total_tokens = usage.get('total_tokens', 0)
token_per_second = total_tokens / (t1 - t0)
#print(f"Total tokens: {total_tokens}, tokens per second: {token_per_second:.2f}")
choices = data.get('choices', [])
if len(choices) == 0:
failure_exception = Exception("No response from the API: " + str(data))
continue
message = choices[0].get('message', {})
answer = message.get('content', '')
return answer, total_tokens, token_per_second
except json.JSONDecodeError as e:
failure_exception = Exception(f"Failed to parse JSON response from the API: {e}")
continue
# If we reach here, all endpoints failed
raise failure_exception
# Parse the response
try:
#print(response.text)
data = response.json()
usage = data.get('usage', {})
total_tokens = usage.get('total_tokens', 0)
token_per_second = total_tokens / (t1 - t0)
#print(f"Total tokens: {total_tokens}, tokens per second: {token_per_second:.2f}")
choices = data.get('choices', [])
if len(choices) == 0:
raise Exception("No response from the API: " + str(data))
message = choices[0].get('message', {})
answer = message.get('content', '')
return answer, total_tokens, token_per_second
except json.JSONDecodeError as e:
raise Exception(f"Failed to parse JSON response from the API: {e}")
multimodal_cache = {}
def test_multimodal(endpoints) -> bool:
modelname = endpoints["model"]
cached_result = multimodal_cache.get(modelname, None)
def test_multimodal(endpoint: Endpoint) -> bool:
cached_result = multimodal_cache.get(endpoint.api_name, None)
if cached_result is not None:
return cached_result
@@ -228,11 +231,11 @@ def test_multimodal(endpoints) -> bool:
with open(image_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
try:
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
result = "42" in answer
if result:
print(f"Model {modelname} is multimodal.")
multimodal_cache[modelname] = result
print(f"Model {endpoint.api_name} is multimodal.")
multimodal_cache[endpoint.api_name] = result
return result
except Exception as e:
return False
@@ -248,12 +251,10 @@ class Task:
This dataclass is used to represent a task that needs to be processed by a server.
Each task has an ID and a dictionary of data that contains the actual task data.
"""
id: str # Unique identifier for the task
description: str # a short description of the task (for logging)
model_name: str # the model storage name
model: str # the actual model name
prompt: str # the prompt to be sent to the model
base64_image: str # the base64 encoded image to be sent to the model
id: str # Unique identifier for the task
description: str # a short description of the task (for logging)
prompt: str # the prompt to be sent to the model
base64_image: str # the base64 encoded image to be sent to the model
response_processing: Callable[['Response'], None] # a function to process the result
@dataclass
@@ -264,8 +265,8 @@ class Response:
Each response has a task ID and the result of the task processing.
"""
task: Task
result: str # The result of the task processing
total_tokens: int # Total tokens used in the response
result: str # The result of the task processing
total_tokens: int # Total tokens used in the response
token_per_second: float # Tokens per second used in the response
@dataclass
@@ -276,7 +277,7 @@ class Server:
Each server has an ID, an endpoint (URL), and a status that indicates whether the server is available or busy.
The status is used to track the current task being processed by the server.
"""
endpoint: str # The endpoint (URL) of the server
endpoint: Endpoint # The endpoint of the server
current_task: Task = None # Track the current task being processed
class LoadBalancer:
@@ -290,16 +291,18 @@ class LoadBalancer:
- It will also retry failed tasks after a short delay.
- The status of each server is updated as tasks are assigned and completed.
"""
def __init__(self, servers: List[Server], max_queue_size: int = 1000):
self.servers = servers
def __init__(self, max_queue_size: int = 1000):
self.servers = []
self.task_queue = queue.Queue[Task](maxsize=max_queue_size)
self.available_servers = queue.Queue[Server](maxsize=len(servers))
self.available_servers = queue.Queue[Server]()
self.lock = threading.Lock()
# Initialize available servers queue
for server in servers:
self.available_servers.put(server)
def add_server(self, server: Server):
"""Add a server to the load balancer"""
self.servers.append(server)
self.available_servers.put(server)
print(f"Server {server.endpoint} added to load balancer.")
def add_task(self, task: Task):
"""Add a task to the processing queue with backpressure"""
try:
@@ -318,6 +321,9 @@ class LoadBalancer:
def get_available_server(self, timeout: float = 10.0) -> Optional[Server]:
"""Get the next available server with timeout"""
try:
# Remove and return the next available server from the queue.
# The only way the server gets available is when the task is finished
# and the task assignes its server back to the available_servers.
return self.available_servers.get(timeout=timeout)
except queue.Empty:
return None
@@ -336,11 +342,12 @@ class LoadBalancer:
def process_task_remote(self, server: Server):
"""Process task on remote server"""
task = server.current_task
endpoint = server.endpoint
try:
#print(f"Processing task ID {task.id} on server {server.endpoint} with model {task.model}")
t0 = time.time()
answer, total_tokens, token_per_second = ollama_chat(
{"name": task.model_name, "model": task.model, "key": "", "endpoints": [server.endpoint]},
endpoint,
task.prompt,
base64_image=task.base64_image
)
@@ -348,7 +355,7 @@ class LoadBalancer:
# Call the response processing function
response = Response(task, answer, total_tokens, token_per_second)
task.response_processing(response)
print(f"Processed {task.description}, on {server.endpoint} with model {task.model} in {t1 - t0:.2f} seconds with {total_tokens} tokens ({token_per_second:.2f} tokens/sec)")
print(f"Processed {task.description}, on {server.endpoint.url} with model {endpoint.api_name} in {t1 - t0:.2f} seconds with {total_tokens} tokens ({token_per_second:.2f} tokens/sec)")
# mark server available
self.mark_server_available(server)
@@ -420,7 +427,7 @@ def main():
image_path = args.image
# load the endpoint file
endpoints = {}
endpoints:List[Endpoint] = []
if endpoint_name:
print(f"Using endpoint {endpoint_name}")
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
@@ -428,7 +435,15 @@ def main():
if not os.path.exists(endpoint_path):
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
with open(endpoint_path, 'r', encoding='utf-8') as file:
endpoints = json.load(file)
endpoint_dict = json.load(file)
endpoints = [
Endpoint(
name=endpoint_dict["name"],
model=endpoint_dict["model"],
key=endpoint_dict["key"],
url=endpoint_dict["endpoint"]
)
]
else:
endpoints = ollama_chat_endpoints(api_base, model_name)

20
test.py
View File

@@ -100,10 +100,22 @@ def main():
entry = benchmark.get(model_benchmark_name, {})
# check if attributes parameter_size and quantization_level are present in benchmark.json
if not '_parameter_size' in entry and model_dict.get(model,{}).get('parameter_size', None):
entry['_parameter_size'] = model_dict.get(model,{}).get('parameter_size', None)
if not '_quantization_level' in entry and model_dict.get(model,{}).get('quantization_level', None):
entry['_quantization_level'] = model_dict.get(model,{}).get('quantization_level', None)
parameter_size = model_dict.get(model,{}).get('parameter_size', None)
if parameter_size:
try:
parameter_size = float(parameter_size)
except ValueError:
print(f"Warning: Could not convert parameter_size '{parameter_size}' to float for model {model}")
parameter_size = None
quantization_level = model_dict.get(model,{}).get('parameter_size', None)
if quantization_level:
try:
quantization_level = int(quantization_level)
except ValueError:
print(f"Warning: Could not convert quantization_level '{quantization_level}' to int for model {model}")
quantization_level = None
if not '_parameter_size' in entry and parameter_size: entry['_parameter_size'] = parameter_size
if not '_quantization_level' in entry and quantization_level: entry['_quantization_level'] = quantization_level
entry = dict(sorted(entry.items(), key=lambda item: item[0]))
benchmark[model_benchmark_name] = entry