asynchronous loading of remote models
This commit is contained in:
28
inference.py
28
inference.py
@@ -2,10 +2,13 @@ import os
|
||||
import json
|
||||
import time
|
||||
import base64
|
||||
import threading
|
||||
from typing import List
|
||||
from argparse import ArgumentParser
|
||||
from benchmark import read_benchmark
|
||||
from ollama_client import ollama_list, ollama_chat_endpoints, test_multimodal, Endpoint, LoadBalancer, Server, Task, Response
|
||||
from ollama_client import ollama_list, test_multimodal, ollama_pull_endpoint, Endpoint, LoadBalancer, Server, Task, Response
|
||||
|
||||
|
||||
|
||||
def read_template(template_path):
|
||||
with open(template_path, 'r', encoding='utf-8') as file:
|
||||
@@ -16,16 +19,21 @@ def process_problem_files(problems_dir, template_content, endpoints: List[Endpoi
|
||||
think=False, no_think=False):
|
||||
model_store_name = endpoints[0].store_name
|
||||
model_api_name = endpoints[0].api_name
|
||||
if think: model_name += "-think"
|
||||
if no_think: model_name += "-no_think"
|
||||
if think: model_store_name += "-think"
|
||||
if no_think: model_store_name += "-no_think"
|
||||
solutions_dir = os.path.join('solutions', model_store_name, language)
|
||||
os.makedirs(solutions_dir, exist_ok=True)
|
||||
|
||||
# Create load balancer with all available endpoints
|
||||
servers = [Server(endpoint=endpoint) for endpoint in endpoints]
|
||||
lb = LoadBalancer()
|
||||
for server in servers: lb.add_server(server)
|
||||
lb.start_distribution()
|
||||
# ensure that the first endpoint is loaded:
|
||||
ollama_pull_endpoint(endpoints[0])
|
||||
# load server concurrently; they will download a model if that is not present so far.
|
||||
loading_thread = threading.Thread(
|
||||
target = lambda: [lb.add_server(Server(endpoint=ollama_pull_endpoint(endpoint))) for endpoint in endpoints]
|
||||
)
|
||||
loading_thread.start()
|
||||
|
||||
# iterate over all problem files and process them
|
||||
for problem_file in sorted(os.listdir(problems_dir)):
|
||||
@@ -162,7 +170,10 @@ def main():
|
||||
# add metadata to benchmark.json
|
||||
if not model in benchmark or not bench_name in benchmark[model]:
|
||||
print(f"Inference: Using model {model} and language {language}")
|
||||
endpoints = ollama_chat_endpoints(api_base, model)
|
||||
endpoints = [
|
||||
Endpoint(store_name=model, api_name=model, key="",
|
||||
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
||||
]
|
||||
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
|
||||
think = args.think, no_think = args.no_think)
|
||||
@@ -180,7 +191,10 @@ def main():
|
||||
else:
|
||||
print(f"Inference: Using model {model_name} and language {language}")
|
||||
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
||||
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||
endpoints = [
|
||||
Endpoint(store_name=model_name, api_name=model_name, key="",
|
||||
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
||||
]
|
||||
|
||||
# run the inference
|
||||
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||
|
||||
@@ -25,9 +25,25 @@ class Endpoint:
|
||||
key: str # API key (if required)
|
||||
url: str # URL of the endpoint
|
||||
|
||||
def get_ollama_url_stub(self) -> str:
|
||||
"""Get the base URL for the ollama API"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='').url
|
||||
|
||||
def get_ollama_pull_url(self) -> str:
|
||||
"""Get the URL for the ollama pull command"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='/api/pull').url
|
||||
|
||||
def get_ollama_delete_url(self) -> str:
|
||||
"""Get the URL for the ollama delete command"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='/api/delete').url
|
||||
|
||||
def get_ollama_ls_url(self) -> str:
|
||||
"""Get the URL for the ollama ls command"""
|
||||
"""Get the URL for the ollama list command"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='/api/tags').url
|
||||
|
||||
def get_ollama_ps_url(self) -> str:
|
||||
"""Get the URL for the ollama ps command"""
|
||||
return urllib3.util.url.parse_url(self.url)._replace(path='/api/ps').url
|
||||
|
||||
def ollama_pull(api_base='http://localhost:11434', model='llama3.2:latest') -> bool:
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
@@ -59,29 +75,18 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
|
||||
for entry in data['models']
|
||||
}
|
||||
|
||||
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> List[Endpoint]:
|
||||
if isinstance(api_base, str): api_base = [api_base]
|
||||
|
||||
def ollama_pull_endpoint(endpoint: Endpoint) -> Endpoint:
|
||||
# check if the endpoint servers are online and the model is available
|
||||
for api_stub in api_base:
|
||||
try:
|
||||
print(f"Loading model list from server {api_stub} to check for model {model_name}...")
|
||||
list = ollama_list(api_stub)
|
||||
if model_name not in list:
|
||||
# pull the model if it is not available
|
||||
print(f"Model {model_name} is not available on server {api_stub}. Pulling the model...")
|
||||
ollama_pull(api_stub, model_name)
|
||||
print(f"Model {model_name} is now available on server {api_stub}.")
|
||||
except Exception as e:
|
||||
# the server is not available, remove the server from the list
|
||||
print(f"Server {api_stub} is not available: {e}")
|
||||
api_base.remove(api_stub)
|
||||
# we do not catch exceptions here, because that shall be done in calling code
|
||||
api_base = endpoint.get_ollama_url_stub()
|
||||
list = ollama_list(api_base)
|
||||
if endpoint.api_name in list: return endpoint
|
||||
|
||||
# create the endpoint list
|
||||
return [
|
||||
Endpoint(store_name=model_name, api_name=model_name, key="",
|
||||
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
||||
]
|
||||
# pull the model if it is not available
|
||||
print(f"Model {endpoint.api_name} is not available on server {api_base}. Pulling the model...")
|
||||
ollama_pull(api_base, endpoint.api_name)
|
||||
print(f"Model {endpoint.api_name} is now available on server {api_base}.")
|
||||
return endpoint
|
||||
|
||||
def hex2base64(hex_string) -> str:
|
||||
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
||||
@@ -301,7 +306,7 @@ class LoadBalancer:
|
||||
"""Add a server to the load balancer"""
|
||||
self.servers.append(server)
|
||||
self.available_servers.put(server)
|
||||
print(f"Server {server.endpoint} added to load balancer.")
|
||||
print(f"Server {server.endpoint.get_ollama_url_stub()} added to load balancer.")
|
||||
|
||||
def add_task(self, task: Task):
|
||||
"""Add a task to the processing queue with backpressure"""
|
||||
@@ -408,7 +413,7 @@ class LoadBalancer:
|
||||
# print out the current status of all servers
|
||||
for server in self.servers:
|
||||
if server.current_task:
|
||||
print(f"Server {server.endpoint} - Current task ID: {server.current_task.id}")
|
||||
print(f"Server {server.endpoint.url} - Current task ID: {server.current_task.id}")
|
||||
|
||||
print("All servers finished processing.")
|
||||
|
||||
@@ -438,17 +443,20 @@ def main():
|
||||
endpoint_dict = json.load(file)
|
||||
endpoints = [
|
||||
Endpoint(
|
||||
name=endpoint_dict["name"],
|
||||
model=endpoint_dict["model"],
|
||||
store_name=endpoint_dict["name"],
|
||||
api_name=endpoint_dict["model"],
|
||||
key=endpoint_dict["key"],
|
||||
url=endpoint_dict["endpoint"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||
|
||||
endpoints = [
|
||||
Endpoint(store_name=model_name, api_name=model_name, key="",
|
||||
url=f"{api_stub}/v1/chat/completions") for api_stub in api_base
|
||||
]
|
||||
|
||||
# test if the endpoint is a multimodal model
|
||||
if test_multimodal(endpoints):
|
||||
if test_multimodal(endpoints[0]):
|
||||
print("Endpoint is a multimodal model.")
|
||||
else:
|
||||
print("Endpoint is not a multimodal model.")
|
||||
|
||||
Reference in New Issue
Block a user