redesign of endpoint and task classes
This commit is contained in:
44
inference.py
44
inference.py
@@ -2,35 +2,29 @@ import os
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
from typing import List
|
||||
from argparse import ArgumentParser
|
||||
from benchmark import read_benchmark
|
||||
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, LoadBalancer, Server, Task, Response
|
||||
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, Endpoint, LoadBalancer, Server, Task, Response
|
||||
|
||||
def read_template(template_path):
|
||||
with open(template_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999,
|
||||
def process_problem_files(problems_dir, template_content, endpoints: List[Endpoint], language, max_problem_number=9999,
|
||||
overwrite_existing=False, overwrite_failed=False, expected_solutions={},
|
||||
think=False, no_think=False):
|
||||
model_name = endpoint["name"]
|
||||
model_store_name = endpoints[0].store_name
|
||||
model_api_name = endpoints[0].api_name
|
||||
if think: model_name += "-think"
|
||||
if no_think: model_name += "-no_think"
|
||||
solutions_dir = os.path.join('solutions', model_name, language)
|
||||
solutions_dir = os.path.join('solutions', model_store_name, language)
|
||||
os.makedirs(solutions_dir, exist_ok=True)
|
||||
|
||||
# get the results as computed so far (may be none at first run)
|
||||
results_json_path = os.path.join(solutions_dir, 'results.json') # may not exist yet
|
||||
solutions = {}
|
||||
if os.path.exists(results_json_path):
|
||||
with open(results_json_path, 'r', encoding='utf-8') as json_file:
|
||||
solutions = json.load(json_file)
|
||||
|
||||
# Create load balancer with all available endpoints
|
||||
server_urls = endpoint["endpoints"]
|
||||
servers = [Server(endpoint=url) for i, url in enumerate(server_urls)]
|
||||
|
||||
lb = LoadBalancer(servers)
|
||||
servers = [Server(endpoint=endpoint) for endpoint in endpoints]
|
||||
lb = LoadBalancer()
|
||||
for server in servers: lb.add_server(server)
|
||||
lb.start_distribution()
|
||||
|
||||
# iterate over all problem files and process them
|
||||
@@ -44,14 +38,6 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
if not overwrite_existing and not overwrite_failed and os.path.exists(result_file_path):
|
||||
print(f"Skipping problem {problem_number} as it already has a solution.")
|
||||
continue
|
||||
|
||||
# check if the problem is already solved
|
||||
actual_solution = expected_solutions.get(problem_number, {}).get('solution', None)
|
||||
problem_is_solved = problem_number in solutions and solutions[problem_number] == actual_solution
|
||||
|
||||
if overwrite_failed and problem_is_solved:
|
||||
print(f"Skipping problem {problem_number} as it is already solved and overwrite_failed is set.")
|
||||
continue
|
||||
|
||||
# read problem content
|
||||
with open(problem_path, 'r', encoding='utf-8') as file:
|
||||
@@ -69,7 +55,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
|
||||
# check if the endpoint is multimodal if we have an image
|
||||
if base64_image:
|
||||
is_multimodal = test_multimodal(endpoint) # this is cached
|
||||
is_multimodal = test_multimodal(endpoints[0]) # this is cached
|
||||
if is_multimodal:
|
||||
print(f"Problem {problem_number} is handled with multimodal model.")
|
||||
else:
|
||||
@@ -92,9 +78,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
# Create task and add to load balancer
|
||||
task = Task(
|
||||
id = problem_number,
|
||||
description = f"problem {problem_number}, language {language}, model {endpoint.get("name", model_name)}",
|
||||
model_name = model_name, # the model storage name
|
||||
model = endpoint.get("name", model_name), # the actual model name
|
||||
description = f"problem {problem_number}, language {language}, model {model_api_name}",
|
||||
prompt = prompt,
|
||||
base64_image = base64_image,
|
||||
response_processing = save_solution
|
||||
@@ -102,17 +86,13 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
while not lb.add_task(task):
|
||||
print(f"Waiting to add task {problem_number} - queue full")
|
||||
time.sleep(1)
|
||||
print(f"Added problem {problem_number}, language {language}, model {model_name} to processing queue")
|
||||
print(f"Added problem {problem_number}, language {language}, model {model_api_name} to processing queue")
|
||||
|
||||
# Wait for all tasks to complete
|
||||
print("Waiting for all problems to be processed...")
|
||||
lb.wait_completion()
|
||||
print("All problems processed!")
|
||||
|
||||
# Save solutions to JSON
|
||||
with open(results_json_path, 'w', encoding='utf-8') as json_file:
|
||||
json.dump(lb.results, json_file, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description="Process Euler problems and send them to an LLM.")
|
||||
|
||||
185
ollama_client.py
185
ollama_client.py
@@ -13,6 +13,22 @@ from dataclasses import dataclass
|
||||
from argparse import ArgumentParser
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
@dataclass
|
||||
class Endpoint:
|
||||
"""
|
||||
Dataclass for endpoint.
|
||||
This dataclass is used to represent an endpoint for the API.
|
||||
Each endpoint has a name, a model name, a key, and a list of URLs.
|
||||
"""
|
||||
store_name: str # Name of the endpoint
|
||||
api_name: str # Model name that is used in the api request
|
||||
key: str # API key (if required)
|
||||
url: str # URL of the endpoint
|
||||
|
||||
def get_ollama_ls_url(self) -> str:
|
||||
"""Get the URL for the ollama ls command"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='/api/tags').url
|
||||
|
||||
def ollama_pull(api_base='http://localhost:11434', model='llama3.2:latest') -> bool:
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
response = requests.request("POST", f"{api_base}/api/pull", verify=False,
|
||||
@@ -43,7 +59,7 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
|
||||
for entry in data['models']
|
||||
}
|
||||
|
||||
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
|
||||
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> List[Endpoint]:
|
||||
if isinstance(api_base, str): api_base = [api_base]
|
||||
|
||||
# check if the endpoint servers are online and the model is available
|
||||
@@ -57,24 +73,21 @@ def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.
|
||||
ollama_pull(api_stub, model_name)
|
||||
print(f"Model {model_name} is now available on server {api_stub}.")
|
||||
except Exception as e:
|
||||
# the server is not available
|
||||
# remove the server from the list
|
||||
# the server is not available, remove the server from the list
|
||||
print(f"Server {api_stub} is not available: {e}")
|
||||
api_base.remove(api_stub)
|
||||
|
||||
# return the endpoint object with the model name
|
||||
return {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base],
|
||||
}
|
||||
# create the endpoint list
|
||||
return [
|
||||
Endpoint(store_name=model_name, api_name=model_name, key="",
|
||||
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
||||
]
|
||||
|
||||
def hex2base64(hex_string) -> str:
|
||||
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
||||
|
||||
def ollama_chat(
|
||||
endpoint,
|
||||
endpoint: Endpoint,
|
||||
prompt: str = 'Hello World',
|
||||
base64_image: str = None,
|
||||
temperature: float = 0.0,
|
||||
@@ -105,10 +118,10 @@ def ollama_chat(
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if endpoint.get("key", ""):
|
||||
headers['Authorization'] = 'Bearer ' + endpoint["key"]
|
||||
if endpoint.key:
|
||||
headers['Authorization'] = 'Bearer ' + endpoint.key
|
||||
|
||||
modelname = endpoint["model"]
|
||||
modelname = endpoint.api_name
|
||||
messages = []
|
||||
messages.append({"role": "system", "content": "You are a helpful assistant"})
|
||||
|
||||
@@ -169,58 +182,48 @@ def ollama_chat(
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
# use the endpoints array as failover mechanism
|
||||
endpoint_url_list = endpoint["endpoints"]
|
||||
failure_exception = Exception("No endpoint available or failure without exception")
|
||||
for failover_count in range(len(endpoint_url_list)):
|
||||
try:
|
||||
#print(payload)
|
||||
t0 = time.time()
|
||||
response = requests.post(endpoint["endpoints"][failover_count], headers=headers, json=payload, verify=False)
|
||||
t1 = time.time()
|
||||
#print(response)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
# print(f"Failed to access api: {e}")
|
||||
# Get the error message from the response
|
||||
if response:
|
||||
try:
|
||||
#print(response.text)
|
||||
data = response.json()
|
||||
message = data.get('message', {})
|
||||
content = message.get('content', '')
|
||||
failure_exception = Exception(f"API request failed: {content}")
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
failure_exception = Exception(f"API request failed: {e}")
|
||||
continue
|
||||
try:
|
||||
#print(payload)
|
||||
t0 = time.time()
|
||||
response = requests.post(endpoint.url, headers=headers, json=payload, verify=False)
|
||||
t1 = time.time()
|
||||
#print(response)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
# print(f"Failed to access api: {e}")
|
||||
# Get the error message from the response
|
||||
if response:
|
||||
try:
|
||||
#print(response.text)
|
||||
data = response.json()
|
||||
message = data.get('message', {})
|
||||
content = message.get('content', '')
|
||||
raise Exception(f"API request failed: {content}")
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"API request failed: {e}")
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
#print(response.text)
|
||||
data = response.json()
|
||||
usage = data.get('usage', {})
|
||||
total_tokens = usage.get('total_tokens', 0)
|
||||
token_per_second = total_tokens / (t1 - t0)
|
||||
#print(f"Total tokens: {total_tokens}, tokens per second: {token_per_second:.2f}")
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) == 0:
|
||||
failure_exception = Exception("No response from the API: " + str(data))
|
||||
continue
|
||||
message = choices[0].get('message', {})
|
||||
answer = message.get('content', '')
|
||||
return answer, total_tokens, token_per_second
|
||||
except json.JSONDecodeError as e:
|
||||
failure_exception = Exception(f"Failed to parse JSON response from the API: {e}")
|
||||
continue
|
||||
|
||||
# If we reach here, all endpoints failed
|
||||
raise failure_exception
|
||||
# Parse the response
|
||||
try:
|
||||
#print(response.text)
|
||||
data = response.json()
|
||||
usage = data.get('usage', {})
|
||||
total_tokens = usage.get('total_tokens', 0)
|
||||
token_per_second = total_tokens / (t1 - t0)
|
||||
#print(f"Total tokens: {total_tokens}, tokens per second: {token_per_second:.2f}")
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) == 0:
|
||||
raise Exception("No response from the API: " + str(data))
|
||||
message = choices[0].get('message', {})
|
||||
answer = message.get('content', '')
|
||||
return answer, total_tokens, token_per_second
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"Failed to parse JSON response from the API: {e}")
|
||||
|
||||
multimodal_cache = {}
|
||||
|
||||
def test_multimodal(endpoints) -> bool:
|
||||
modelname = endpoints["model"]
|
||||
cached_result = multimodal_cache.get(modelname, None)
|
||||
def test_multimodal(endpoint: Endpoint) -> bool:
|
||||
|
||||
cached_result = multimodal_cache.get(endpoint.api_name, None)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
@@ -228,11 +231,11 @@ def test_multimodal(endpoints) -> bool:
|
||||
with open(image_path, "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
try:
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
||||
result = "42" in answer
|
||||
if result:
|
||||
print(f"Model {modelname} is multimodal.")
|
||||
multimodal_cache[modelname] = result
|
||||
print(f"Model {endpoint.api_name} is multimodal.")
|
||||
multimodal_cache[endpoint.api_name] = result
|
||||
return result
|
||||
except Exception as e:
|
||||
return False
|
||||
@@ -248,12 +251,10 @@ class Task:
|
||||
This dataclass is used to represent a task that needs to be processed by a server.
|
||||
Each task has an ID and a dictionary of data that contains the actual task data.
|
||||
"""
|
||||
id: str # Unique identifier for the task
|
||||
description: str # a short description of the task (for logging)
|
||||
model_name: str # the model storage name
|
||||
model: str # the actual model name
|
||||
prompt: str # the prompt to be sent to the model
|
||||
base64_image: str # the base64 encoded image to be sent to the model
|
||||
id: str # Unique identifier for the task
|
||||
description: str # a short description of the task (for logging)
|
||||
prompt: str # the prompt to be sent to the model
|
||||
base64_image: str # the base64 encoded image to be sent to the model
|
||||
response_processing: Callable[['Response'], None] # a function to process the result
|
||||
|
||||
@dataclass
|
||||
@@ -264,8 +265,8 @@ class Response:
|
||||
Each response has a task ID and the result of the task processing.
|
||||
"""
|
||||
task: Task
|
||||
result: str # The result of the task processing
|
||||
total_tokens: int # Total tokens used in the response
|
||||
result: str # The result of the task processing
|
||||
total_tokens: int # Total tokens used in the response
|
||||
token_per_second: float # Tokens per second used in the response
|
||||
|
||||
@dataclass
|
||||
@@ -276,7 +277,7 @@ class Server:
|
||||
Each server has an ID, an endpoint (URL), and a status that indicates whether the server is available or busy.
|
||||
The status is used to track the current task being processed by the server.
|
||||
"""
|
||||
endpoint: str # The endpoint (URL) of the server
|
||||
endpoint: Endpoint # The endpoint of the server
|
||||
current_task: Task = None # Track the current task being processed
|
||||
|
||||
class LoadBalancer:
|
||||
@@ -290,16 +291,18 @@ class LoadBalancer:
|
||||
- It will also retry failed tasks after a short delay.
|
||||
- The status of each server is updated as tasks are assigned and completed.
|
||||
"""
|
||||
def __init__(self, servers: List[Server], max_queue_size: int = 1000):
|
||||
self.servers = servers
|
||||
def __init__(self, max_queue_size: int = 1000):
|
||||
self.servers = []
|
||||
self.task_queue = queue.Queue[Task](maxsize=max_queue_size)
|
||||
self.available_servers = queue.Queue[Server](maxsize=len(servers))
|
||||
self.available_servers = queue.Queue[Server]()
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# Initialize available servers queue
|
||||
for server in servers:
|
||||
self.available_servers.put(server)
|
||||
|
||||
def add_server(self, server: Server):
|
||||
"""Add a server to the load balancer"""
|
||||
self.servers.append(server)
|
||||
self.available_servers.put(server)
|
||||
print(f"Server {server.endpoint} added to load balancer.")
|
||||
|
||||
def add_task(self, task: Task):
|
||||
"""Add a task to the processing queue with backpressure"""
|
||||
try:
|
||||
@@ -318,6 +321,9 @@ class LoadBalancer:
|
||||
def get_available_server(self, timeout: float = 10.0) -> Optional[Server]:
|
||||
"""Get the next available server with timeout"""
|
||||
try:
|
||||
# Remove and return the next available server from the queue.
|
||||
# The only way the server gets available is when the task is finished
|
||||
# and the task assignes its server back to the available_servers.
|
||||
return self.available_servers.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
return None
|
||||
@@ -336,11 +342,12 @@ class LoadBalancer:
|
||||
def process_task_remote(self, server: Server):
|
||||
"""Process task on remote server"""
|
||||
task = server.current_task
|
||||
endpoint = server.endpoint
|
||||
try:
|
||||
#print(f"Processing task ID {task.id} on server {server.endpoint} with model {task.model}")
|
||||
t0 = time.time()
|
||||
answer, total_tokens, token_per_second = ollama_chat(
|
||||
{"name": task.model_name, "model": task.model, "key": "", "endpoints": [server.endpoint]},
|
||||
endpoint,
|
||||
task.prompt,
|
||||
base64_image=task.base64_image
|
||||
)
|
||||
@@ -348,7 +355,7 @@ class LoadBalancer:
|
||||
# Call the response processing function
|
||||
response = Response(task, answer, total_tokens, token_per_second)
|
||||
task.response_processing(response)
|
||||
print(f"Processed {task.description}, on {server.endpoint} with model {task.model} in {t1 - t0:.2f} seconds with {total_tokens} tokens ({token_per_second:.2f} tokens/sec)")
|
||||
print(f"Processed {task.description}, on {server.endpoint.url} with model {endpoint.api_name} in {t1 - t0:.2f} seconds with {total_tokens} tokens ({token_per_second:.2f} tokens/sec)")
|
||||
|
||||
# mark server available
|
||||
self.mark_server_available(server)
|
||||
@@ -420,7 +427,7 @@ def main():
|
||||
image_path = args.image
|
||||
|
||||
# load the endpoint file
|
||||
endpoints = {}
|
||||
endpoints:List[Endpoint] = []
|
||||
if endpoint_name:
|
||||
print(f"Using endpoint {endpoint_name}")
|
||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||
@@ -428,7 +435,15 @@ def main():
|
||||
if not os.path.exists(endpoint_path):
|
||||
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
||||
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
||||
endpoints = json.load(file)
|
||||
endpoint_dict = json.load(file)
|
||||
endpoints = [
|
||||
Endpoint(
|
||||
name=endpoint_dict["name"],
|
||||
model=endpoint_dict["model"],
|
||||
key=endpoint_dict["key"],
|
||||
url=endpoint_dict["endpoint"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||
|
||||
|
||||
20
test.py
20
test.py
@@ -100,10 +100,22 @@ def main():
|
||||
entry = benchmark.get(model_benchmark_name, {})
|
||||
|
||||
# check if attributes parameter_size and quantization_level are present in benchmark.json
|
||||
if not '_parameter_size' in entry and model_dict.get(model,{}).get('parameter_size', None):
|
||||
entry['_parameter_size'] = model_dict.get(model,{}).get('parameter_size', None)
|
||||
if not '_quantization_level' in entry and model_dict.get(model,{}).get('quantization_level', None):
|
||||
entry['_quantization_level'] = model_dict.get(model,{}).get('quantization_level', None)
|
||||
parameter_size = model_dict.get(model,{}).get('parameter_size', None)
|
||||
if parameter_size:
|
||||
try:
|
||||
parameter_size = float(parameter_size)
|
||||
except ValueError:
|
||||
print(f"Warning: Could not convert parameter_size '{parameter_size}' to float for model {model}")
|
||||
parameter_size = None
|
||||
quantization_level = model_dict.get(model,{}).get('parameter_size', None)
|
||||
if quantization_level:
|
||||
try:
|
||||
quantization_level = int(quantization_level)
|
||||
except ValueError:
|
||||
print(f"Warning: Could not convert quantization_level '{quantization_level}' to int for model {model}")
|
||||
quantization_level = None
|
||||
if not '_parameter_size' in entry and parameter_size: entry['_parameter_size'] = parameter_size
|
||||
if not '_quantization_level' in entry and quantization_level: entry['_quantization_level'] = quantization_level
|
||||
entry = dict(sorted(entry.items(), key=lambda item: item[0]))
|
||||
benchmark[model_benchmark_name] = entry
|
||||
|
||||
|
||||
Reference in New Issue
Block a user