From 88ac8513806215183bd235b25977fd1d3dd652b9 Mon Sep 17 00:00:00 2001 From: Michael Peter Christen Date: Tue, 29 Apr 2025 21:33:41 +0200 Subject: [PATCH] refactoring for parallel processing --- codeextraction.py | 7 +++++ execute.py | 6 ++++ inference.py | 34 +++++++++++++++------- ollama_client.py | 72 ++++++++++++++++++++++++++++------------------- test.py | 30 +++++++++++++++----- 5 files changed, 103 insertions(+), 46 deletions(-) diff --git a/codeextraction.py b/codeextraction.py index 0d6ee16..6c67377 100644 --- a/codeextraction.py +++ b/codeextraction.py @@ -80,6 +80,8 @@ def process_markdown_files(model_name, language): def main(): parser = ArgumentParser(description="Extract code blocks from Markdown files.") parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') + parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end') + parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--endpoint', required=False, default='', help='Name of an .json file in the endpoints directory') @@ -88,6 +90,7 @@ def main(): language = args.language endpoint_name = args.endpoint + # in case no model name is given but an endpoint name, read the model name from the endpoint file if endpoint_name: endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json") print(f"Using endpoint file {endpoint_path}") @@ -97,6 +100,10 @@ def main(): endpoint = json.load(file) model_name = endpoint.get('name', model_name) + # modify the model name in case soft thinking switches are given + if args.think: model_name += "-think" + if args.no_think: model_name += "-no_think" + languages = args.language.split(',') for language in languages: print(f"Processing language: {language}") diff --git a/execute.py b/execute.py index c14e9cd..96acebc 100644 --- a/execute.py +++ b/execute.py @@ -407,6 +407,8 @@ def main(): parser = ArgumentParser(description="Execute solutions and store results in a JSON file.") parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') + parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end') + parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python') parser.add_argument('--endpoint', required=False, default='', help='Name of an .json file in the endpoints directory') parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default @@ -432,6 +434,10 @@ def main(): endpoint = json.load(file) model_name = endpoint.get('name', model_name) + # modify the model name in case soft thinking switches are given + if args.think: model_name += "-think" + if args.no_think: model_name += "-no_think" + with open('solutions.json', 'r', encoding='utf-8') as json_file: expected_solutions = json.load(json_file) diff --git a/inference.py b/inference.py index 6b00657..7ddab58 100644 --- a/inference.py +++ b/inference.py @@ -3,14 +3,18 @@ import json import base64 from argparse import ArgumentParser from benchmark import read_benchmark, write_benchmark -from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat, test_multimodal +from ollama_client import ollama_list, ollama_chat_endpoints, ollama_chat, test_multimodal def read_template(template_path): with open(template_path, 'r', encoding='utf-8') as file: return file.read() -def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999, overwrite_existing=False, overwrite_failed=False, expected_solutions={}): - model_name = endpoint["name"] +def process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number=9999, + overwrite_existing=False, overwrite_failed=False, expected_solutions={}, + think=False, no_think=False): + model_name = endpoints["name"] + if think: model_name += "-think" + if no_think: model_name += "-no_think" results_dir = os.path.join('solutions', model_name, language) os.makedirs(results_dir, exist_ok=True) @@ -50,6 +54,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma # Construct the prompt using the template prompt = template_content.replace('$$$PROBLEM$$$', problem_content) + + # attach soft thinking switches if asked + if think: prompt += " /think" + if no_think: prompt += " /no_think" # check if there is also an image in the problem. We take the problem_file, remove the extension ".txt" # and add either "-0.png", "-0.jpg" or "-0.gif" @@ -64,7 +72,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma # check if the endpoint is multimodal, we do this only if we have an image if base64_image: - is_multimodal = test_multimodal(endpoint) # this is cached + is_multimodal = test_multimodal(endpoints) # this is cached if is_multimodal: print(f"Problem {problem_number} is handled with multimodal model.") else: @@ -72,7 +80,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma base64_image = None try: - answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt, base64_image=base64_image) + answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt, base64_image=base64_image) # Save the response to a file with open(result_file_path, 'w', encoding='utf-8') as result_file: @@ -89,6 +97,8 @@ def main(): parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434') parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') + parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end') + parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer') parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers') @@ -100,7 +110,7 @@ def main(): args = parser.parse_args() - api_base = args.api_base + api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base] model_name = args.model language = args.language max_problem_number = 100 @@ -149,8 +159,10 @@ def main(): # add metadata to benchmark.json if not model in benchmark or not bench_name in benchmark[model]: print(f"Inference: Using model {model} and language {language}") - endpoint = ollama_chat_endpoint(api_base, model) - process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions) + endpoints = ollama_chat_endpoints(api_base, model) + process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number, + overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions, + think = args.think, no_think = args.no_think) else: # construct the endpoint object endpoint = {} @@ -165,10 +177,12 @@ def main(): else: print(f"Inference: Using model {model_name} and language {language}") # construct the endpoint object from command line arguments considering that ollama is the endpoint - endpoint = ollama_chat_endpoint(api_base, model_name) + endpoints = ollama_chat_endpoints(api_base, model_name) # run the inference - process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions) + process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number, + overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions, + think = args.think, no_think = args.no_think) if __name__ == "__main__": main() diff --git a/ollama_client.py b/ollama_client.py index 3e78a9e..3c7294f 100644 --- a/ollama_client.py +++ b/ollama_client.py @@ -40,19 +40,31 @@ def ollama_list(api_base='http://localhost:11434') -> dict: models_dict[model] = attr return models_dict -def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict: - endpoint = { - "name": model_name, - "model": model_name, - "key": "", - "endpoint": f"{api_base}/v1/chat/completions", - } - return endpoint +def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict: + # check if api_base is a string + if isinstance(api_base, str): + endpoint = { + "name": model_name, + "model": model_name, + "key": "", + "endpoints": [f"{api_base}/v1/chat/completions"], + } + return endpoint + # check if api_base is a list of strings + if isinstance(api_base, list): + endpoint = { + "name": model_name, + "model": model_name, + "key": "", + "endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base], + } + return endpoint + return {} def hex2base64(hex_string) -> str: return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8') -def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple: +def ollama_chat(endpoints, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple: # Disable SSL warnings urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -65,16 +77,16 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0 'Content-Type': 'application/json', 'Accept': 'application/json' } - if endpoint.get("key", ""): - headers['Authorization'] = 'Bearer ' + endpoint["key"] + if endpoints.get("key", ""): + headers['Authorization'] = 'Bearer ' + endpoints["key"] - modelname = endpoint["model"] + modelname = endpoints["model"] messages = [] messages.append({"role": "system", "content": "You are a helpful assistant"}) - # o1 has special requirements - if modelname.startswith("o1") or modelname.startswith("gpt-o1"): - temperature = 1.0 # o1 models need temperature 1.0 + # special requirements of certain models + if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0 + if modelname.startswith("qwen3"): temperature = 0.6 if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"): # reduce number of stoptokes to 4 @@ -115,10 +127,12 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0 "model": modelname, "messages": messages, "response_format": { "type": "text" }, + "temperature": temperature, # ollama default: 0.8 + "top_k": 20, # reduces the probability of generating nonsense: high = more diverse, low = more focused; ollama default: 40 + "top_p": 0.95, # works together with top_k: high = more diverse, low = more focused; ollama default: 0.9 + "min_p": 0, # alternative to top_p: p is minimum probability for a token to be considered; ollama default: 0.0 "stream": False } - if not modelname.startswith("o4"): - payload["temperature"] = temperature if len(stoptokens) > 0 and not modelname.startswith("o4"): payload["stop"] = stoptokens if modelname.startswith("o1") or modelname.startswith("o4"): @@ -129,7 +143,7 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0 try: #print(payload) t0 = time.time() - response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False) + response = requests.post(endpoints["endpoints"][0], headers=headers, json=payload, verify=False) t1 = time.time() #print(response) response.raise_for_status() @@ -165,8 +179,8 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0 multimodal_cache = {} -def test_multimodal(endpoint) -> bool: - modelname = endpoint["model"] +def test_multimodal(endpoints) -> bool: + modelname = endpoints["model"] cached_result = multimodal_cache.get(modelname, None) if cached_result is not None: return cached_result @@ -175,7 +189,7 @@ def test_multimodal(endpoint) -> bool: with open(image_path, "rb") as image_file: base64_image = base64.b64encode(image_file.read()).decode('utf-8') try: - answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) + answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image) result = "42" in answer if result: print(f"Model {modelname} is multimodal.") @@ -193,13 +207,13 @@ def main(): # parse the arguments args = parser.parse_args() - api_base = args.api_base + api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base] endpoint_name = args.endpoint model_name = args.model image_path = args.image # load the endpoint file - endpoint = {} + endpoints = {} if endpoint_name: print(f"Using endpoint {endpoint_name}") endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json") @@ -207,12 +221,12 @@ def main(): if not os.path.exists(endpoint_path): raise Exception(f"Endpoint file {endpoint_path} does not exist.") with open(endpoint_path, 'r', encoding='utf-8') as file: - endpoint = json.load(file) + endpoints = json.load(file) else: - endpoint = ollama_chat_endpoint(api_base, model_name) + endpoints = ollama_chat_endpoints(api_base, model_name) # test if the endpoint is a multimodal model - if test_multimodal(endpoint): + if test_multimodal(endpoints): print("Endpoint is a multimodal model.") else: print("Endpoint is not a multimodal model.") @@ -224,14 +238,14 @@ def main(): base64_image = base64.b64encode(image_file.read()).decode('utf-8') # access the ollama API - models_dict = ollama_list() + models_dict = ollama_list(api_base[0]) for (model, attr) in models_dict.items(): print(f"Model: {model}: {attr}") try: if base64_image: - answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) + answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image) else: - answer, total_tokens, token_per_second = ollama_chat(endpoint) + answer, total_tokens, token_per_second = ollama_chat(endpoints) except Exception as e: answer = f"Error: {str(e)}" print(answer) diff --git a/test.py b/test.py index 7c38256..30ec2ad 100644 --- a/test.py +++ b/test.py @@ -3,29 +3,41 @@ from argparse import ArgumentParser from benchmark import read_benchmark, write_benchmark from ollama_client import ollama_list -def test(endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100): +def test(api_base, endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100, think=False, no_think=False): # call inference.py - cmd = f"python3.12 inference.py --language {language}" + cmd = f"python3.12 inference.py --language {language} --api_base {api_base}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" if max_problem_number == 200: cmd += " --n200" if overwrite_existing: cmd += " --overwrite_existing" if overwrite_failed: cmd += " --overwrite_failed" + if think: cmd += " --think" + if no_think: cmd += " --no_think" + print(f"Running command: {cmd}") os.system(cmd) # call codeextraction.py cmd = f"python3.12 codeextraction.py --language {language}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" + if think: cmd += " --think" + if no_think: cmd += " --no_think" + print(f"Running command: {cmd}") os.system(cmd) # call execute.py cmd = f"python3.12 execute.py --language {language}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" + if think: cmd += " --think" + if no_think: cmd += " --no_think" + print(f"Running command: {cmd}") os.system(cmd) def main(): parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.") + parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434') parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') + parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end') + parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer') parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers') @@ -36,6 +48,7 @@ def main(): parser.add_argument('--nall', action='store_true', help='all problems') args = parser.parse_args() + api_base = args.api_base model_name = args.model max_problem_number = 100 if args.n100: max_problem_number = 100 @@ -72,16 +85,19 @@ def main(): # in every loop we load the benchmark.json again because it might have been updated benchmark = read_benchmark() - entry = benchmark.get(model, {}) + model_benchmark_name = model + if args.think: model_benchmark_name += "-think" + if args.no_think: model_benchmark_name += "-no_think" + entry = benchmark.get(model_benchmark_name, {}) # add metadata to benchmark.json - if not model in benchmark or not bench_name in benchmark[model] or overwrite_existing or overwrite_failed: + if not model_benchmark_name in benchmark or not bench_name in benchmark[model_benchmark_name] or overwrite_existing or overwrite_failed: # run the model; this writes a news entry to benchmark.json - test(endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number) + test(api_base, endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number, think = args.think, no_think = args.no_think) # load benchmark.json again because the test has updated it benchmark = read_benchmark() # because testing can be interrupted, there is no guarantee that the entry is present - entry = benchmark.get(model, {}) + entry = benchmark.get(model_benchmark_name, {}) # check if attributes parameter_size and quantization_level are present in benchmark.json if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None): @@ -89,7 +105,7 @@ def main(): if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None): entry['_quantization_level'] = model_dict[model].get('quantization_level', None) entry = dict(sorted(entry.items(), key=lambda item: item[0])) - benchmark[model] = entry + benchmark[model_benchmark_name] = entry # write the updated benchmark file write_benchmark(benchmark)