asynchronous loading of remote models

This commit is contained in:
Michael Peter Christen
2025-05-18 22:47:42 +02:00
parent c049b908dd
commit f67c4ce38f
2 changed files with 58 additions and 36 deletions

View File

@@ -2,10 +2,13 @@ import os
import json
import time
import base64
import threading
from typing import List
from argparse import ArgumentParser
from benchmark import read_benchmark
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, Endpoint, LoadBalancer, Server, Task, Response
from ollama_client import ollama_list, test_multimodal, ollama_pull_endpoint, Endpoint, LoadBalancer, Server, Task, Response
def read_template(template_path):
with open(template_path, 'r', encoding='utf-8') as file:
@@ -16,16 +19,21 @@ def process_problem_files(problems_dir, template_content, endpoints: List[Endpoi
think=False, no_think=False):
model_store_name = endpoints[0].store_name
model_api_name = endpoints[0].api_name
if think: model_name += "-think"
if no_think: model_name += "-no_think"
if think: model_store_name += "-think"
if no_think: model_store_name += "-no_think"
solutions_dir = os.path.join('solutions', model_store_name, language)
os.makedirs(solutions_dir, exist_ok=True)
# Create load balancer with all available endpoints
servers = [Server(endpoint=endpoint) for endpoint in endpoints]
lb = LoadBalancer()
for server in servers: lb.add_server(server)
lb.start_distribution()
# ensure that the first endpoint is loaded:
ollama_pull_endpoint(endpoints[0])
# load server concurrently; they will download a model if that is not present so far.
loading_thread = threading.Thread(
target = lambda: [lb.add_server(Server(endpoint=ollama_pull_endpoint(endpoint))) for endpoint in endpoints]
)
loading_thread.start()
# iterate over all problem files and process them
for problem_file in sorted(os.listdir(problems_dir)):
@@ -162,7 +170,10 @@ def main():
# add metadata to benchmark.json
if not model in benchmark or not bench_name in benchmark[model]:
print(f"Inference: Using model {model} and language {language}")
endpoints = ollama_chat_endpoints(api_base, model)
endpoints = [
Endpoint(store_name=model, api_name=model, key="",
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
]
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
think = args.think, no_think = args.no_think)
@@ -180,7 +191,10 @@ def main():
else:
print(f"Inference: Using model {model_name} and language {language}")
# construct the endpoint object from command line arguments considering that ollama is the endpoint
endpoints = ollama_chat_endpoints(api_base, model_name)
endpoints = [
Endpoint(store_name=model_name, api_name=model_name, key="",
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
]
# run the inference
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,

View File

@@ -25,9 +25,25 @@ class Endpoint:
key: str # API key (if required)
url: str # URL of the endpoint
def get_ollama_url_stub(self) -> str:
"""Get the base URL for the ollama API"""
return urllib3.util.url.parse_url(self.url)._replace(path='').url
def get_ollama_pull_url(self) -> str:
"""Get the URL for the ollama pull command"""
return urllib3.util.url.parse_url(self.url)._replace(path='/api/pull').url
def get_ollama_delete_url(self) -> str:
"""Get the URL for the ollama delete command"""
return urllib3.util.url.parse_url(self.url)._replace(path='/api/delete').url
def get_ollama_ls_url(self) -> str:
"""Get the URL for the ollama ls command"""
"""Get the URL for the ollama list command"""
return urllib3.util.url.parse_url(self.url)._replace(path='/api/tags').url
def get_ollama_ps_url(self) -> str:
"""Get the URL for the ollama ps command"""
return urllib3.util.url.parse_url(self.url)._replace(path='/api/ps').url
def ollama_pull(api_base='http://localhost:11434', model='llama3.2:latest') -> bool:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -59,29 +75,18 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
for entry in data['models']
}
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> List[Endpoint]:
if isinstance(api_base, str): api_base = [api_base]
def ollama_pull_endpoint(endpoint: Endpoint) -> Endpoint:
# check if the endpoint servers are online and the model is available
for api_stub in api_base:
try:
print(f"Loading model list from server {api_stub} to check for model {model_name}...")
list = ollama_list(api_stub)
if model_name not in list:
# pull the model if it is not available
print(f"Model {model_name} is not available on server {api_stub}. Pulling the model...")
ollama_pull(api_stub, model_name)
print(f"Model {model_name} is now available on server {api_stub}.")
except Exception as e:
# the server is not available, remove the server from the list
print(f"Server {api_stub} is not available: {e}")
api_base.remove(api_stub)
# we do not catch exceptions here, because that shall be done in calling code
api_base = endpoint.get_ollama_url_stub()
list = ollama_list(api_base)
if endpoint.api_name in list: return endpoint
# create the endpoint list
return [
Endpoint(store_name=model_name, api_name=model_name, key="",
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
]
# pull the model if it is not available
print(f"Model {endpoint.api_name} is not available on server {api_base}. Pulling the model...")
ollama_pull(api_base, endpoint.api_name)
print(f"Model {endpoint.api_name} is now available on server {api_base}.")
return endpoint
def hex2base64(hex_string) -> str:
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
@@ -301,7 +306,7 @@ class LoadBalancer:
"""Add a server to the load balancer"""
self.servers.append(server)
self.available_servers.put(server)
print(f"Server {server.endpoint} added to load balancer.")
print(f"Server {server.endpoint.get_ollama_url_stub()} added to load balancer.")
def add_task(self, task: Task):
"""Add a task to the processing queue with backpressure"""
@@ -408,7 +413,7 @@ class LoadBalancer:
# print out the current status of all servers
for server in self.servers:
if server.current_task:
print(f"Server {server.endpoint} - Current task ID: {server.current_task.id}")
print(f"Server {server.endpoint.url} - Current task ID: {server.current_task.id}")
print("All servers finished processing.")
@@ -438,17 +443,20 @@ def main():
endpoint_dict = json.load(file)
endpoints = [
Endpoint(
name=endpoint_dict["name"],
model=endpoint_dict["model"],
store_name=endpoint_dict["name"],
api_name=endpoint_dict["model"],
key=endpoint_dict["key"],
url=endpoint_dict["endpoint"]
)
]
else:
endpoints = ollama_chat_endpoints(api_base, model_name)
endpoints = [
Endpoint(store_name=model_name, api_name=model_name, key="",
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
]
# test if the endpoint is a multimodal model
if test_multimodal(endpoints):
if test_multimodal(endpoints[0]):
print("Endpoint is a multimodal model.")
else:
print("Endpoint is not a multimodal model.")