refactoring for parallel processing
This commit is contained in:
@@ -80,6 +80,8 @@ def process_markdown_files(model_name, language):
|
||||
def main():
|
||||
parser = ArgumentParser(description="Extract code blocks from Markdown files.")
|
||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
||||
|
||||
@@ -88,6 +90,7 @@ def main():
|
||||
language = args.language
|
||||
endpoint_name = args.endpoint
|
||||
|
||||
# in case no model name is given but an endpoint name, read the model name from the endpoint file
|
||||
if endpoint_name:
|
||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||
print(f"Using endpoint file {endpoint_path}")
|
||||
@@ -97,6 +100,10 @@ def main():
|
||||
endpoint = json.load(file)
|
||||
model_name = endpoint.get('name', model_name)
|
||||
|
||||
# modify the model name in case soft thinking switches are given
|
||||
if args.think: model_name += "-think"
|
||||
if args.no_think: model_name += "-no_think"
|
||||
|
||||
languages = args.language.split(',')
|
||||
for language in languages:
|
||||
print(f"Processing language: {language}")
|
||||
|
||||
@@ -407,6 +407,8 @@ def main():
|
||||
parser = ArgumentParser(description="Execute solutions and store results in a JSON file.")
|
||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them')
|
||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python')
|
||||
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
||||
parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default
|
||||
@@ -432,6 +434,10 @@ def main():
|
||||
endpoint = json.load(file)
|
||||
model_name = endpoint.get('name', model_name)
|
||||
|
||||
# modify the model name in case soft thinking switches are given
|
||||
if args.think: model_name += "-think"
|
||||
if args.no_think: model_name += "-no_think"
|
||||
|
||||
with open('solutions.json', 'r', encoding='utf-8') as json_file:
|
||||
expected_solutions = json.load(json_file)
|
||||
|
||||
|
||||
34
inference.py
34
inference.py
@@ -3,14 +3,18 @@ import json
|
||||
import base64
|
||||
from argparse import ArgumentParser
|
||||
from benchmark import read_benchmark, write_benchmark
|
||||
from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat, test_multimodal
|
||||
from ollama_client import ollama_list, ollama_chat_endpoints, ollama_chat, test_multimodal
|
||||
|
||||
def read_template(template_path):
|
||||
with open(template_path, 'r', encoding='utf-8') as file:
|
||||
return file.read()
|
||||
|
||||
def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999, overwrite_existing=False, overwrite_failed=False, expected_solutions={}):
|
||||
model_name = endpoint["name"]
|
||||
def process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number=9999,
|
||||
overwrite_existing=False, overwrite_failed=False, expected_solutions={},
|
||||
think=False, no_think=False):
|
||||
model_name = endpoints["name"]
|
||||
if think: model_name += "-think"
|
||||
if no_think: model_name += "-no_think"
|
||||
results_dir = os.path.join('solutions', model_name, language)
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
|
||||
@@ -51,6 +55,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
# Construct the prompt using the template
|
||||
prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
|
||||
|
||||
# attach soft thinking switches if asked
|
||||
if think: prompt += " /think"
|
||||
if no_think: prompt += " /no_think"
|
||||
|
||||
# check if there is also an image in the problem. We take the problem_file, remove the extension ".txt"
|
||||
# and add either "-0.png", "-0.jpg" or "-0.gif"
|
||||
base64_image = None
|
||||
@@ -64,7 +72,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
|
||||
# check if the endpoint is multimodal, we do this only if we have an image
|
||||
if base64_image:
|
||||
is_multimodal = test_multimodal(endpoint) # this is cached
|
||||
is_multimodal = test_multimodal(endpoints) # this is cached
|
||||
if is_multimodal:
|
||||
print(f"Problem {problem_number} is handled with multimodal model.")
|
||||
else:
|
||||
@@ -72,7 +80,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
base64_image = None
|
||||
|
||||
try:
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt, base64_image=base64_image)
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt, base64_image=base64_image)
|
||||
|
||||
# Save the response to a file
|
||||
with open(result_file_path, 'w', encoding='utf-8') as result_file:
|
||||
@@ -89,6 +97,8 @@ def main():
|
||||
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
||||
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
||||
@@ -100,7 +110,7 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
api_base = args.api_base
|
||||
api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
|
||||
model_name = args.model
|
||||
language = args.language
|
||||
max_problem_number = 100
|
||||
@@ -149,8 +159,10 @@ def main():
|
||||
# add metadata to benchmark.json
|
||||
if not model in benchmark or not bench_name in benchmark[model]:
|
||||
print(f"Inference: Using model {model} and language {language}")
|
||||
endpoint = ollama_chat_endpoint(api_base, model)
|
||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions)
|
||||
endpoints = ollama_chat_endpoints(api_base, model)
|
||||
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
|
||||
think = args.think, no_think = args.no_think)
|
||||
else:
|
||||
# construct the endpoint object
|
||||
endpoint = {}
|
||||
@@ -165,10 +177,12 @@ def main():
|
||||
else:
|
||||
print(f"Inference: Using model {model_name} and language {language}")
|
||||
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
||||
endpoint = ollama_chat_endpoint(api_base, model_name)
|
||||
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||
|
||||
# run the inference
|
||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions)
|
||||
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
|
||||
think = args.think, no_think = args.no_think)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -40,19 +40,31 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
|
||||
models_dict[model] = attr
|
||||
return models_dict
|
||||
|
||||
def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
|
||||
endpoint = {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoint": f"{api_base}/v1/chat/completions",
|
||||
}
|
||||
return endpoint
|
||||
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
|
||||
# check if api_base is a string
|
||||
if isinstance(api_base, str):
|
||||
endpoint = {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoints": [f"{api_base}/v1/chat/completions"],
|
||||
}
|
||||
return endpoint
|
||||
# check if api_base is a list of strings
|
||||
if isinstance(api_base, list):
|
||||
endpoint = {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base],
|
||||
}
|
||||
return endpoint
|
||||
return {}
|
||||
|
||||
def hex2base64(hex_string) -> str:
|
||||
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
||||
|
||||
def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple:
|
||||
def ollama_chat(endpoints, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple:
|
||||
|
||||
# Disable SSL warnings
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
@@ -65,16 +77,16 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if endpoint.get("key", ""):
|
||||
headers['Authorization'] = 'Bearer ' + endpoint["key"]
|
||||
if endpoints.get("key", ""):
|
||||
headers['Authorization'] = 'Bearer ' + endpoints["key"]
|
||||
|
||||
modelname = endpoint["model"]
|
||||
modelname = endpoints["model"]
|
||||
messages = []
|
||||
messages.append({"role": "system", "content": "You are a helpful assistant"})
|
||||
|
||||
# o1 has special requirements
|
||||
if modelname.startswith("o1") or modelname.startswith("gpt-o1"):
|
||||
temperature = 1.0 # o1 models need temperature 1.0
|
||||
# special requirements of certain models
|
||||
if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0
|
||||
if modelname.startswith("qwen3"): temperature = 0.6
|
||||
|
||||
if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"):
|
||||
# reduce number of stoptokes to 4
|
||||
@@ -115,10 +127,12 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
||||
"model": modelname,
|
||||
"messages": messages,
|
||||
"response_format": { "type": "text" },
|
||||
"temperature": temperature, # ollama default: 0.8
|
||||
"top_k": 20, # reduces the probability of generating nonsense: high = more diverse, low = more focused; ollama default: 40
|
||||
"top_p": 0.95, # works together with top_k: high = more diverse, low = more focused; ollama default: 0.9
|
||||
"min_p": 0, # alternative to top_p: p is minimum probability for a token to be considered; ollama default: 0.0
|
||||
"stream": False
|
||||
}
|
||||
if not modelname.startswith("o4"):
|
||||
payload["temperature"] = temperature
|
||||
if len(stoptokens) > 0 and not modelname.startswith("o4"):
|
||||
payload["stop"] = stoptokens
|
||||
if modelname.startswith("o1") or modelname.startswith("o4"):
|
||||
@@ -129,7 +143,7 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
||||
try:
|
||||
#print(payload)
|
||||
t0 = time.time()
|
||||
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
|
||||
response = requests.post(endpoints["endpoints"][0], headers=headers, json=payload, verify=False)
|
||||
t1 = time.time()
|
||||
#print(response)
|
||||
response.raise_for_status()
|
||||
@@ -165,8 +179,8 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
||||
|
||||
multimodal_cache = {}
|
||||
|
||||
def test_multimodal(endpoint) -> bool:
|
||||
modelname = endpoint["model"]
|
||||
def test_multimodal(endpoints) -> bool:
|
||||
modelname = endpoints["model"]
|
||||
cached_result = multimodal_cache.get(modelname, None)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
@@ -175,7 +189,7 @@ def test_multimodal(endpoint) -> bool:
|
||||
with open(image_path, "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
try:
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
||||
result = "42" in answer
|
||||
if result:
|
||||
print(f"Model {modelname} is multimodal.")
|
||||
@@ -193,13 +207,13 @@ def main():
|
||||
|
||||
# parse the arguments
|
||||
args = parser.parse_args()
|
||||
api_base = args.api_base
|
||||
api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
|
||||
endpoint_name = args.endpoint
|
||||
model_name = args.model
|
||||
image_path = args.image
|
||||
|
||||
# load the endpoint file
|
||||
endpoint = {}
|
||||
endpoints = {}
|
||||
if endpoint_name:
|
||||
print(f"Using endpoint {endpoint_name}")
|
||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||
@@ -207,12 +221,12 @@ def main():
|
||||
if not os.path.exists(endpoint_path):
|
||||
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
||||
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
||||
endpoint = json.load(file)
|
||||
endpoints = json.load(file)
|
||||
else:
|
||||
endpoint = ollama_chat_endpoint(api_base, model_name)
|
||||
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||
|
||||
# test if the endpoint is a multimodal model
|
||||
if test_multimodal(endpoint):
|
||||
if test_multimodal(endpoints):
|
||||
print("Endpoint is a multimodal model.")
|
||||
else:
|
||||
print("Endpoint is not a multimodal model.")
|
||||
@@ -224,14 +238,14 @@ def main():
|
||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
# access the ollama API
|
||||
models_dict = ollama_list()
|
||||
models_dict = ollama_list(api_base[0])
|
||||
for (model, attr) in models_dict.items():
|
||||
print(f"Model: {model}: {attr}")
|
||||
try:
|
||||
if base64_image:
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
||||
else:
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoint)
|
||||
answer, total_tokens, token_per_second = ollama_chat(endpoints)
|
||||
except Exception as e:
|
||||
answer = f"Error: {str(e)}"
|
||||
print(answer)
|
||||
|
||||
30
test.py
30
test.py
@@ -3,29 +3,41 @@ from argparse import ArgumentParser
|
||||
from benchmark import read_benchmark, write_benchmark
|
||||
from ollama_client import ollama_list
|
||||
|
||||
def test(endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100):
|
||||
def test(api_base, endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100, think=False, no_think=False):
|
||||
# call inference.py
|
||||
cmd = f"python3.12 inference.py --language {language}"
|
||||
cmd = f"python3.12 inference.py --language {language} --api_base {api_base}"
|
||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||
if max_problem_number == 200: cmd += " --n200"
|
||||
if overwrite_existing: cmd += " --overwrite_existing"
|
||||
if overwrite_failed: cmd += " --overwrite_failed"
|
||||
if think: cmd += " --think"
|
||||
if no_think: cmd += " --no_think"
|
||||
print(f"Running command: {cmd}")
|
||||
os.system(cmd)
|
||||
|
||||
# call codeextraction.py
|
||||
cmd = f"python3.12 codeextraction.py --language {language}"
|
||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||
if think: cmd += " --think"
|
||||
if no_think: cmd += " --no_think"
|
||||
print(f"Running command: {cmd}")
|
||||
os.system(cmd)
|
||||
|
||||
# call execute.py
|
||||
cmd = f"python3.12 execute.py --language {language}"
|
||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||
if think: cmd += " --think"
|
||||
if no_think: cmd += " --no_think"
|
||||
print(f"Running command: {cmd}")
|
||||
os.system(cmd)
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.")
|
||||
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
||||
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
||||
@@ -36,6 +48,7 @@ def main():
|
||||
parser.add_argument('--nall', action='store_true', help='all problems')
|
||||
|
||||
args = parser.parse_args()
|
||||
api_base = args.api_base
|
||||
model_name = args.model
|
||||
max_problem_number = 100
|
||||
if args.n100: max_problem_number = 100
|
||||
@@ -72,16 +85,19 @@ def main():
|
||||
|
||||
# in every loop we load the benchmark.json again because it might have been updated
|
||||
benchmark = read_benchmark()
|
||||
entry = benchmark.get(model, {})
|
||||
model_benchmark_name = model
|
||||
if args.think: model_benchmark_name += "-think"
|
||||
if args.no_think: model_benchmark_name += "-no_think"
|
||||
entry = benchmark.get(model_benchmark_name, {})
|
||||
|
||||
# add metadata to benchmark.json
|
||||
if not model in benchmark or not bench_name in benchmark[model] or overwrite_existing or overwrite_failed:
|
||||
if not model_benchmark_name in benchmark or not bench_name in benchmark[model_benchmark_name] or overwrite_existing or overwrite_failed:
|
||||
# run the model; this writes a news entry to benchmark.json
|
||||
test(endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number)
|
||||
test(api_base, endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number, think = args.think, no_think = args.no_think)
|
||||
# load benchmark.json again because the test has updated it
|
||||
benchmark = read_benchmark()
|
||||
# because testing can be interrupted, there is no guarantee that the entry is present
|
||||
entry = benchmark.get(model, {})
|
||||
entry = benchmark.get(model_benchmark_name, {})
|
||||
|
||||
# check if attributes parameter_size and quantization_level are present in benchmark.json
|
||||
if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None):
|
||||
@@ -89,7 +105,7 @@ def main():
|
||||
if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None):
|
||||
entry['_quantization_level'] = model_dict[model].get('quantization_level', None)
|
||||
entry = dict(sorted(entry.items(), key=lambda item: item[0]))
|
||||
benchmark[model] = entry
|
||||
benchmark[model_benchmark_name] = entry
|
||||
|
||||
# write the updated benchmark file
|
||||
write_benchmark(benchmark)
|
||||
|
||||
Reference in New Issue
Block a user