diff --git a/llmtest/testimage.png b/llmtest/testimage.png new file mode 100644 index 0000000..7084653 Binary files /dev/null and b/llmtest/testimage.png differ diff --git a/ollama_client.py b/ollama_client.py index 4be279a..da21227 100644 --- a/ollama_client.py +++ b/ollama_client.py @@ -1,10 +1,14 @@ import os import json import time +import base64 import urllib3 import requests +from PIL import Image +from io import BytesIO from argparse import ArgumentParser + def ollama_list(api_base='http://localhost:11434'): # call api http://localhost:11434/api/tags with http get request urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -45,7 +49,10 @@ def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2 } return endpoint -def ollama_chat(endpoint, prompt='Hello World', temperature=0.0, max_tokens=32768): +def hex2base64(hex_string): + return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8') + +def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=32768): # Disable SSL warnings urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -68,12 +75,38 @@ def ollama_chat(endpoint, prompt='Hello World', temperature=0.0, max_tokens=3276 if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0 # o1 models need temperature 1.0 else: - messages.append({"content": "You are a helpful assistant", "role": "system"}) + messages.append({"role": "system", "content": "You are a helpful assistant"}) if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"): # reduce number of stoptokes to 4 stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>"] - messages.append({"role": "user", "content": prompt}) + if base64_image: + image_type = "jpeg" + #base64_magic = {"/9j/": "jpeg", "iVBO": "png", "Qk": "bmp", "R0lG": "gif", "SUkq": "tiff", "SUkr": "tiff", "TU0A": "tiff", "GkXf": "webp", "UklG": "webp"} + base64_magic = {"/9j/": "jpeg", "iVBO": "png", "R0lG": "gif"} # only jpeg and png are allowed as data type; however all of the types above (but gif!) are supported by the API + for magic, itype in base64_magic.items(): + if base64_image.startswith(magic): + #print(f"Detected {itype} image") + image_type = itype + break + # If this is a gif we must convert it to png + if image_type == "gif": + #print("Converting gif to png") + image = Image.open(BytesIO(base64.b64decode(base64_image))) + png_image = BytesIO() + image.save(png_image, format="PNG") + base64_image = base64.b64encode(png_image.getvalue()).decode('utf-8') + image_type = "png" + + # Add the image to the message + image_url_object = {"url": f"data:image/{image_type};base64,{base64_image}"} + usermessage = {"role": "user", "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": image_url_object} + ]} + else: + usermessage = {"role": "user", "content": prompt} + messages.append(usermessage) if modelname.startswith("o1") or modelname.startswith("4o"): stoptokens = [] @@ -129,17 +162,31 @@ def ollama_chat(endpoint, prompt='Hello World', temperature=0.0, max_tokens=3276 except json.JSONDecodeError as e: raise Exception(f"Failed to parse JSON response from the API: {e}") +def test_multimodal(endpoint): + image_path = "llmtest/testimage.png" + with open(image_path, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode('utf-8') + try: + answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) + if "42" in answer: + return True + return False + except Exception as e: + return False + def main(): parser = ArgumentParser(description="Testing the ollama API.") parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434') parser.add_argument('--endpoint', required=False, default='', help='Name of an .json file in the endpoints directory') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') + parser.add_argument('--image', required=False, default=None, help='path to an image that shall be processed') # parse the arguments args = parser.parse_args() api_base = args.api_base endpoint_name = args.endpoint model_name = args.model + image_path = args.image # load the endpoint file endpoint = {} @@ -154,11 +201,29 @@ def main(): else: endpoint = ollama_chat_endpoint(api_base, model_name) + # test if the endpoint is a multimodal model + if test_multimodal(endpoint): + print("Endpoint is a multimodal model.") + else: + print("Endpoint is not a multimodal model.") + + # load the image, if a path is given + base64_image = None + if image_path: + with open(image_path, "rb") as image_file: + base64_image = base64.b64encode(image_file.read()).decode('utf-8') + # access the ollama API models_dict = ollama_list() for (model, attr) in models_dict.items(): print(f"Model: {model}") - answer, total_tokens, token_per_second = ollama_chat(endpoint) + try: + if base64_image: + answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) + else: + answer, total_tokens, token_per_second = ollama_chat(endpoint) + except Exception as e: + answer = f"Error: {str(e)}" print(answer) if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index c56b0b8..e19dd67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +PIL sympy urllib3 requests