471 lines
19 KiB
Python
471 lines
19 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import queue
|
|
import base64
|
|
import urllib3
|
|
import requests
|
|
import threading
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
from enum import Enum, auto
|
|
from dataclasses import dataclass
|
|
from argparse import ArgumentParser
|
|
from typing import Callable, List, Optional
|
|
|
|
@dataclass
|
|
class Endpoint:
|
|
"""
|
|
Dataclass for endpoint.
|
|
This dataclass is used to represent an endpoint for the API.
|
|
Each endpoint has a name, a model name, a key, and a list of URLs.
|
|
"""
|
|
store_name: str # Name of the endpoint
|
|
api_name: str # Model name that is used in the api request
|
|
key: str # API key (if required)
|
|
url: str # URL of the endpoint
|
|
|
|
def get_ollama_url_stub(self) -> str:
|
|
"""Get the base URL for the ollama API"""
|
|
return urllib3.util.url.parse_url(self.url)._replace(path='').url
|
|
|
|
def ollama_pull(endpoint: Endpoint) -> bool:
|
|
api_base = endpoint.get_ollama_url_stub()
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
response = requests.request("POST", f"{api_base}/api/pull", verify=False,
|
|
headers={'Accept': 'application/json', 'Content-Type': 'application/json'},
|
|
json={"model": endpoint.api_name, "stream": False})
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return not data.get("error", False)
|
|
|
|
def ollama_delete(endpoint: Endpoint) -> bool:
|
|
api_base = endpoint.get_ollama_url_stub()
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
response = requests.request("DELETE", f"{api_base}/api/delete", verify=False,
|
|
headers={'Accept': 'application/json', 'Content-Type': 'application/json'},
|
|
json={"model": endpoint.api_name})
|
|
return response.status_code == 200
|
|
|
|
def ollama_list(endpoint: Endpoint) -> dict:
|
|
api_base = endpoint.get_ollama_url_stub()
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
response = requests.get(f"{api_base}/api/tags", verify=False)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return {
|
|
entry['model']: {
|
|
'parameter_size': entry['details']['parameter_size'][:-1],
|
|
'quantization_level': entry['details']['quantization_level'][1:2]
|
|
}
|
|
for entry in data['models']
|
|
}
|
|
|
|
def ollama_pull_endpoint(endpoint: Endpoint) -> Endpoint:
|
|
# check if the endpoint servers are online and the model is available
|
|
# we do not catch exceptions here, because that shall be done in calling code
|
|
list = ollama_list(endpoint)
|
|
if endpoint.api_name in list: return endpoint
|
|
|
|
# pull the model if it is not available
|
|
api_base = endpoint.get_ollama_url_stub()
|
|
print(f"Model {endpoint.api_name} is not available on server {api_base}. Pulling the model...")
|
|
ollama_pull(endpoint)
|
|
print(f"Model {endpoint.api_name} is now available on server {api_base}.")
|
|
return endpoint
|
|
|
|
def hex2base64(hex_string) -> str:
|
|
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
|
|
|
def ollama_chat(
|
|
endpoint: Endpoint,
|
|
prompt: str = 'Hello World',
|
|
base64_image: str = None,
|
|
temperature: float = 0.0,
|
|
max_tokens: int = 8192
|
|
) -> tuple[str, int, float]:
|
|
"""
|
|
Function to interact with the Ollama API for chat completions.
|
|
|
|
Args:
|
|
endpoint (dict): Dictionary containing endpoint information.
|
|
prompt (str): The prompt to send to the model.
|
|
base64_image (str): Base64 encoded image string (optional).
|
|
temperature (float): Temperature for randomness in response.
|
|
max_tokens (int): Maximum number of tokens for the response.
|
|
|
|
Returns:
|
|
tuple: A tuple containing the model's response, total tokens used, and tokens per second.
|
|
"""
|
|
|
|
# Disable SSL warnings
|
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
|
|
# Prepare the API endpoint URL
|
|
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"]
|
|
|
|
# Set headers and payload
|
|
headers = {
|
|
'Content-Type': 'application/json',
|
|
'Accept': 'application/json'
|
|
}
|
|
if endpoint.key:
|
|
headers['Authorization'] = 'Bearer ' + endpoint.key
|
|
|
|
modelname = endpoint.api_name
|
|
messages = []
|
|
messages.append({"role": "system", "content": "You are a helpful assistant"})
|
|
|
|
# special requirements of certain models
|
|
if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0
|
|
if modelname.startswith("qwen3"): temperature = 0.6
|
|
|
|
if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"):
|
|
# reduce number of stoptokes to 4
|
|
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>"]
|
|
|
|
if base64_image:
|
|
image_type = "jpeg"
|
|
#base64_magic = {"/9j/": "jpeg", "iVBO": "png", "Qk": "bmp", "R0lG": "gif", "SUkq": "tiff", "SUkr": "tiff", "TU0A": "tiff", "GkXf": "webp", "UklG": "webp"}
|
|
base64_magic = {"/9j/": "jpeg", "iVBO": "png", "R0lG": "gif"} # only jpeg and png are allowed as data type; however all of the types above (but gif!) are supported by the API
|
|
for magic, itype in base64_magic.items():
|
|
if base64_image.startswith(magic):
|
|
#print(f"Detected {itype} image")
|
|
image_type = itype
|
|
break
|
|
# If this is a gif we must convert it to png
|
|
if image_type == "gif":
|
|
#print("Converting gif to png")
|
|
image = Image.open(BytesIO(base64.b64decode(base64_image)))
|
|
png_image = BytesIO()
|
|
image.save(png_image, format="PNG")
|
|
base64_image = base64.b64encode(png_image.getvalue()).decode('utf-8')
|
|
image_type = "png"
|
|
|
|
# Add the image to the message
|
|
image_url_object = {"url": f"data:image/{image_type};base64,{base64_image}"}
|
|
usermessage = {"role": "user", "content": [
|
|
{"type": "text", "text": prompt},
|
|
{"type": "image_url", "image_url": image_url_object}
|
|
]}
|
|
else:
|
|
usermessage = {"role": "user", "content": prompt}
|
|
messages.append(usermessage)
|
|
|
|
if modelname.startswith("o1") or modelname.startswith("4o"):
|
|
stoptokens = []
|
|
|
|
payload = {
|
|
"model": modelname,
|
|
"messages": messages,
|
|
"response_format": { "type": "text" },
|
|
"temperature": temperature, # ollama default: 0.8
|
|
"top_k": 20, # reduces the probability of generating nonsense: high = more diverse, low = more focused; ollama default: 40
|
|
"top_p": 0.95, # works together with top_k: high = more diverse, low = more focused; ollama default: 0.9
|
|
"min_p": 0, # alternative to top_p: p is minimum probability for a token to be considered; ollama default: 0.0
|
|
"stream": False
|
|
}
|
|
if len(stoptokens) > 0 and not modelname.startswith("o4"):
|
|
payload["stop"] = stoptokens
|
|
if modelname.startswith("o1") or modelname.startswith("o4"):
|
|
payload["max_completion_tokens"] = max_tokens
|
|
else:
|
|
payload["max_tokens"] = max_tokens
|
|
|
|
# use the endpoints array as failover mechanism
|
|
try:
|
|
#print(payload)
|
|
t0 = time.time()
|
|
response = requests.post(endpoint.url, headers=headers, json=payload, verify=False)
|
|
t1 = time.time()
|
|
#print(response)
|
|
response.raise_for_status()
|
|
except requests.exceptions.RequestException as e:
|
|
# print(f"Failed to access api: {e}")
|
|
# Get the error message from the response
|
|
if response:
|
|
try:
|
|
#print(response.text)
|
|
data = response.json()
|
|
message = data.get('message', {})
|
|
content = message.get('content', '')
|
|
raise Exception(f"API request failed: {content}")
|
|
except json.JSONDecodeError:
|
|
raise Exception(f"API request failed: {e}")
|
|
|
|
# Parse the response
|
|
try:
|
|
#print(response.text)
|
|
data = response.json()
|
|
usage = data.get('usage', {})
|
|
total_tokens = usage.get('total_tokens', 0)
|
|
token_per_second = total_tokens / (t1 - t0)
|
|
#print(f"Total tokens: {total_tokens}, tokens per second: {token_per_second:.2f}")
|
|
choices = data.get('choices', [])
|
|
if len(choices) == 0:
|
|
raise Exception("No response from the API: " + str(data))
|
|
message = choices[0].get('message', {})
|
|
answer = message.get('content', '')
|
|
return answer, total_tokens, token_per_second
|
|
except json.JSONDecodeError as e:
|
|
raise Exception(f"Failed to parse JSON response from the API: {e}")
|
|
|
|
multimodal_cache = {}
|
|
|
|
def test_multimodal(endpoint: Endpoint) -> bool:
|
|
|
|
cached_result = multimodal_cache.get(endpoint.api_name, None)
|
|
if cached_result is not None:
|
|
return cached_result
|
|
|
|
image_path = "llmtest/testimage.png"
|
|
with open(image_path, "rb") as image_file:
|
|
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
try:
|
|
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
|
result = "42" in answer
|
|
if result:
|
|
print(f"Model {endpoint.api_name} is multimodal.")
|
|
multimodal_cache[endpoint.api_name] = result
|
|
return result
|
|
except Exception as e:
|
|
return False
|
|
|
|
|
|
busy_waiting_time = 1 # seconds
|
|
|
|
|
|
@dataclass
|
|
class Task:
|
|
"""
|
|
Dataclass for task.
|
|
This dataclass is used to represent a task that needs to be processed by a server.
|
|
Each task has an ID and a dictionary of data that contains the actual task data.
|
|
"""
|
|
id: str # Unique identifier for the task
|
|
description: str # a short description of the task (for logging)
|
|
prompt: str # the prompt to be sent to the model
|
|
base64_image: str # the base64 encoded image to be sent to the model
|
|
response_processing: Callable[['Response'], None] # a function to process the result
|
|
|
|
@dataclass
|
|
class Response:
|
|
"""
|
|
Dataclass for response.
|
|
This dataclass is used to represent a response from the server.
|
|
Each response has a task ID and the result of the task processing.
|
|
"""
|
|
task: Task
|
|
result: str # The result of the task processing
|
|
total_tokens: int # Total tokens used in the response
|
|
token_per_second: float # Tokens per second used in the response
|
|
|
|
@dataclass
|
|
class Server:
|
|
"""
|
|
Dataclass for server.
|
|
This dataclass is used to represent a server that can process tasks.
|
|
Each server has an ID, an endpoint (URL), and a status that indicates whether the server is available or busy.
|
|
The status is used to track the current task being processed by the server.
|
|
"""
|
|
endpoint: Endpoint # The endpoint of the server
|
|
current_task: Task = None # Track the current task being processed
|
|
|
|
class LoadBalancer:
|
|
"""
|
|
LoadBalancer class for managing task distribution across multiple servers.
|
|
This class is responsible for distributing tasks to available servers and managing their status.
|
|
- It uses a queue to manage tasks and a list of servers to distribute the load.
|
|
- It will only assign tasks to servers that are AVAILABLE.
|
|
- It implements backpressure to prevent overloading the servers.
|
|
- It will wait for a server to become available before assigning a new task.
|
|
- It will also retry failed tasks after a short delay.
|
|
- The status of each server is updated as tasks are assigned and completed.
|
|
"""
|
|
def __init__(self, max_queue_size: int = 1000):
|
|
self.servers = []
|
|
self.task_queue = queue.Queue[Task](maxsize=max_queue_size)
|
|
self.available_servers = queue.Queue[Server]()
|
|
self.lock = threading.Lock()
|
|
|
|
def add_server(self, server: Server):
|
|
"""Add a server to the load balancer"""
|
|
self.servers.append(server)
|
|
self.available_servers.put(server)
|
|
print(f"Server {server.endpoint.get_ollama_url_stub()} added to load balancer.")
|
|
|
|
def add_task(self, task: Task):
|
|
"""Add a task to the processing queue with backpressure"""
|
|
try:
|
|
self.task_queue.put(task, block=True, timeout=1)
|
|
return True
|
|
except queue.Full:
|
|
print("Task queue full - applying backpressure")
|
|
return False
|
|
|
|
def mark_server_available(self, server: Server):
|
|
"""Mark a server as available for new tasks"""
|
|
with self.lock:
|
|
server.current_task = None
|
|
self.available_servers.put(server)
|
|
|
|
def get_available_server(self, timeout: float = 10.0) -> Optional[Server]:
|
|
"""Get the next available server with timeout"""
|
|
try:
|
|
# Remove and return the next available server from the queue.
|
|
# The only way the server gets available is when the task is finished
|
|
# and the task assignes its server back to the available_servers.
|
|
return self.available_servers.get(timeout=timeout)
|
|
except queue.Empty:
|
|
return None
|
|
|
|
def assign_task_to_server(self, task: Task, server: Server):
|
|
"""Assign task to server and mark it as busy"""
|
|
with self.lock:
|
|
server.current_task = task
|
|
|
|
threading.Thread(
|
|
target=self.process_task_remote,
|
|
args=(server,),
|
|
daemon=True
|
|
).start()
|
|
|
|
def process_task_remote(self, server: Server):
|
|
"""Process task on remote server"""
|
|
task = server.current_task
|
|
endpoint = server.endpoint
|
|
try:
|
|
#print(f"Processing task ID {task.id} on server {server.endpoint} with model {task.model}")
|
|
t0 = time.time()
|
|
answer, total_tokens, token_per_second = ollama_chat(
|
|
endpoint,
|
|
task.prompt,
|
|
base64_image=task.base64_image
|
|
)
|
|
t1 = time.time()
|
|
# Call the response processing function
|
|
response = Response(task, answer, total_tokens, token_per_second)
|
|
task.response_processing(response)
|
|
print(f"Processed {task.description}, on {server.endpoint.url} with model {endpoint.api_name} in {t1 - t0:.2f} seconds with {total_tokens} tokens ({token_per_second:.2f} tokens/sec)")
|
|
|
|
# mark server available
|
|
self.mark_server_available(server)
|
|
|
|
except Exception as e:
|
|
# write a stack trace to std out
|
|
import traceback
|
|
traceback.print_exc()
|
|
# Log the error and mark server available
|
|
error_msg = f"Failed to process task ID {task.id} on {server.endpoint}: {str(e)}"
|
|
if hasattr(e, 'response'):
|
|
try:
|
|
error_details = e.response.json()
|
|
error_msg += f" | API Response: {error_details}"
|
|
except:
|
|
error_msg += f" | Raw Response: {e.response.text}"
|
|
print(error_msg)
|
|
# make server available again
|
|
self.mark_server_available(server)
|
|
|
|
def start_distribution(self):
|
|
"""Start the task distribution process"""
|
|
def distributor():
|
|
while True:
|
|
task = self.task_queue.get()
|
|
assigned = False
|
|
|
|
while not assigned:
|
|
server = self.get_available_server()
|
|
if server:
|
|
self.assign_task_to_server(task, server)
|
|
assigned = True
|
|
else:
|
|
# All servers busy, wait and try again
|
|
time.sleep(busy_waiting_time)
|
|
|
|
self.task_queue.task_done()
|
|
|
|
# Start distributor thread
|
|
threading.Thread(target=distributor, daemon=True).start()
|
|
|
|
def wait_completion(self):
|
|
"""Wait for all tasks to be processed"""
|
|
self.task_queue.join()
|
|
# Wait for all servers to finish their current tasks
|
|
print("Waiting for all servers to finish processing...")
|
|
while any(s.current_task != None for s in self.servers):
|
|
time.sleep(busy_waiting_time)
|
|
print("Still waiting for servers to finish...")
|
|
# print out the current status of all servers
|
|
for server in self.servers:
|
|
if server.current_task:
|
|
print(f"Server {server.endpoint.url} - Current task ID: {server.current_task.id}")
|
|
|
|
print("All servers finished processing.")
|
|
|
|
def main():
|
|
parser = ArgumentParser(description="Testing the ollama API.")
|
|
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
|
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
|
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
|
parser.add_argument('--image', required=False, default=None, help='path to an image that shall be processed')
|
|
|
|
# parse the arguments
|
|
args = parser.parse_args()
|
|
api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
|
|
endpoint_name = args.endpoint
|
|
model_name = args.model
|
|
image_path = args.image
|
|
|
|
# load the endpoint file
|
|
endpoints:List[Endpoint] = []
|
|
if endpoint_name:
|
|
print(f"Using endpoint {endpoint_name}")
|
|
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
|
print(f"Using endpoint file {endpoint_path}")
|
|
if not os.path.exists(endpoint_path):
|
|
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
|
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
|
endpoint_dict = json.load(file)
|
|
endpoints = [
|
|
Endpoint(
|
|
store_name=endpoint_dict["name"],
|
|
api_name=endpoint_dict["model"],
|
|
key=endpoint_dict["key"],
|
|
url=endpoint_dict["endpoint"]
|
|
)
|
|
]
|
|
else:
|
|
endpoints = [
|
|
Endpoint(store_name=model_name, api_name=model_name, key="",
|
|
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
|
]
|
|
|
|
# test if the endpoint is a multimodal model
|
|
if test_multimodal(endpoints[0]):
|
|
print("Endpoint is a multimodal model.")
|
|
else:
|
|
print("Endpoint is not a multimodal model.")
|
|
|
|
# load the image, if a path is given
|
|
base64_image = None
|
|
if image_path:
|
|
with open(image_path, "rb") as image_file:
|
|
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
|
# access the ollama API
|
|
models_dict = ollama_list(endpoints[0])
|
|
for (model, attr) in models_dict.items():
|
|
print(f"Model: {model}: {attr}")
|
|
try:
|
|
if base64_image:
|
|
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
|
else:
|
|
answer, total_tokens, token_per_second = ollama_chat(endpoints)
|
|
except Exception as e:
|
|
answer = f"Error: {str(e)}"
|
|
print(answer)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|