refactoring for parallel processing

This commit is contained in:
Michael Peter Christen
2025-04-29 21:33:41 +02:00
parent 5e61e1667a
commit 88ac851380
5 changed files with 103 additions and 46 deletions

View File

@@ -80,6 +80,8 @@ def process_markdown_files(model_name, language):
def main(): def main():
parser = ArgumentParser(description="Extract code blocks from Markdown files.") parser = ArgumentParser(description="Extract code blocks from Markdown files.")
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory') parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
@@ -88,6 +90,7 @@ def main():
language = args.language language = args.language
endpoint_name = args.endpoint endpoint_name = args.endpoint
# in case no model name is given but an endpoint name, read the model name from the endpoint file
if endpoint_name: if endpoint_name:
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json") endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
print(f"Using endpoint file {endpoint_path}") print(f"Using endpoint file {endpoint_path}")
@@ -97,6 +100,10 @@ def main():
endpoint = json.load(file) endpoint = json.load(file)
model_name = endpoint.get('name', model_name) model_name = endpoint.get('name', model_name)
# modify the model name in case soft thinking switches are given
if args.think: model_name += "-think"
if args.no_think: model_name += "-no_think"
languages = args.language.split(',') languages = args.language.split(',')
for language in languages: for language in languages:
print(f"Processing language: {language}") print(f"Processing language: {language}")

View File

@@ -407,6 +407,8 @@ def main():
parser = ArgumentParser(description="Execute solutions and store results in a JSON file.") parser = ArgumentParser(description="Execute solutions and store results in a JSON file.")
parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them') parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them')
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python')
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory') parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default
@@ -432,6 +434,10 @@ def main():
endpoint = json.load(file) endpoint = json.load(file)
model_name = endpoint.get('name', model_name) model_name = endpoint.get('name', model_name)
# modify the model name in case soft thinking switches are given
if args.think: model_name += "-think"
if args.no_think: model_name += "-no_think"
with open('solutions.json', 'r', encoding='utf-8') as json_file: with open('solutions.json', 'r', encoding='utf-8') as json_file:
expected_solutions = json.load(json_file) expected_solutions = json.load(json_file)

View File

@@ -3,14 +3,18 @@ import json
import base64 import base64
from argparse import ArgumentParser from argparse import ArgumentParser
from benchmark import read_benchmark, write_benchmark from benchmark import read_benchmark, write_benchmark
from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat, test_multimodal from ollama_client import ollama_list, ollama_chat_endpoints, ollama_chat, test_multimodal
def read_template(template_path): def read_template(template_path):
with open(template_path, 'r', encoding='utf-8') as file: with open(template_path, 'r', encoding='utf-8') as file:
return file.read() return file.read()
def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999, overwrite_existing=False, overwrite_failed=False, expected_solutions={}): def process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number=9999,
model_name = endpoint["name"] overwrite_existing=False, overwrite_failed=False, expected_solutions={},
think=False, no_think=False):
model_name = endpoints["name"]
if think: model_name += "-think"
if no_think: model_name += "-no_think"
results_dir = os.path.join('solutions', model_name, language) results_dir = os.path.join('solutions', model_name, language)
os.makedirs(results_dir, exist_ok=True) os.makedirs(results_dir, exist_ok=True)
@@ -51,6 +55,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
# Construct the prompt using the template # Construct the prompt using the template
prompt = template_content.replace('$$$PROBLEM$$$', problem_content) prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
# attach soft thinking switches if asked
if think: prompt += " /think"
if no_think: prompt += " /no_think"
# check if there is also an image in the problem. We take the problem_file, remove the extension ".txt" # check if there is also an image in the problem. We take the problem_file, remove the extension ".txt"
# and add either "-0.png", "-0.jpg" or "-0.gif" # and add either "-0.png", "-0.jpg" or "-0.gif"
base64_image = None base64_image = None
@@ -64,7 +72,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
# check if the endpoint is multimodal, we do this only if we have an image # check if the endpoint is multimodal, we do this only if we have an image
if base64_image: if base64_image:
is_multimodal = test_multimodal(endpoint) # this is cached is_multimodal = test_multimodal(endpoints) # this is cached
if is_multimodal: if is_multimodal:
print(f"Problem {problem_number} is handled with multimodal model.") print(f"Problem {problem_number} is handled with multimodal model.")
else: else:
@@ -72,7 +80,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
base64_image = None base64_image = None
try: try:
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt, base64_image=base64_image) answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt, base64_image=base64_image)
# Save the response to a file # Save the response to a file
with open(result_file_path, 'w', encoding='utf-8') as result_file: with open(result_file_path, 'w', encoding='utf-8') as result_file:
@@ -89,6 +97,8 @@ def main():
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434') parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json') parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer') parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers') parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
@@ -100,7 +110,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
api_base = args.api_base api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
model_name = args.model model_name = args.model
language = args.language language = args.language
max_problem_number = 100 max_problem_number = 100
@@ -149,8 +159,10 @@ def main():
# add metadata to benchmark.json # add metadata to benchmark.json
if not model in benchmark or not bench_name in benchmark[model]: if not model in benchmark or not bench_name in benchmark[model]:
print(f"Inference: Using model {model} and language {language}") print(f"Inference: Using model {model} and language {language}")
endpoint = ollama_chat_endpoint(api_base, model) endpoints = ollama_chat_endpoints(api_base, model)
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions) process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
think = args.think, no_think = args.no_think)
else: else:
# construct the endpoint object # construct the endpoint object
endpoint = {} endpoint = {}
@@ -165,10 +177,12 @@ def main():
else: else:
print(f"Inference: Using model {model_name} and language {language}") print(f"Inference: Using model {model_name} and language {language}")
# construct the endpoint object from command line arguments considering that ollama is the endpoint # construct the endpoint object from command line arguments considering that ollama is the endpoint
endpoint = ollama_chat_endpoint(api_base, model_name) endpoints = ollama_chat_endpoints(api_base, model_name)
# run the inference # run the inference
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions) process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
think = args.think, no_think = args.no_think)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -40,19 +40,31 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
models_dict[model] = attr models_dict[model] = attr
return models_dict return models_dict
def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict: def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
endpoint = { # check if api_base is a string
"name": model_name, if isinstance(api_base, str):
"model": model_name, endpoint = {
"key": "", "name": model_name,
"endpoint": f"{api_base}/v1/chat/completions", "model": model_name,
} "key": "",
return endpoint "endpoints": [f"{api_base}/v1/chat/completions"],
}
return endpoint
# check if api_base is a list of strings
if isinstance(api_base, list):
endpoint = {
"name": model_name,
"model": model_name,
"key": "",
"endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base],
}
return endpoint
return {}
def hex2base64(hex_string) -> str: def hex2base64(hex_string) -> str:
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8') return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple: def ollama_chat(endpoints, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple:
# Disable SSL warnings # Disable SSL warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -65,16 +77,16 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
'Content-Type': 'application/json', 'Content-Type': 'application/json',
'Accept': 'application/json' 'Accept': 'application/json'
} }
if endpoint.get("key", ""): if endpoints.get("key", ""):
headers['Authorization'] = 'Bearer ' + endpoint["key"] headers['Authorization'] = 'Bearer ' + endpoints["key"]
modelname = endpoint["model"] modelname = endpoints["model"]
messages = [] messages = []
messages.append({"role": "system", "content": "You are a helpful assistant"}) messages.append({"role": "system", "content": "You are a helpful assistant"})
# o1 has special requirements # special requirements of certain models
if modelname.startswith("o1") or modelname.startswith("gpt-o1"): if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0
temperature = 1.0 # o1 models need temperature 1.0 if modelname.startswith("qwen3"): temperature = 0.6
if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"): if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"):
# reduce number of stoptokes to 4 # reduce number of stoptokes to 4
@@ -115,10 +127,12 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
"model": modelname, "model": modelname,
"messages": messages, "messages": messages,
"response_format": { "type": "text" }, "response_format": { "type": "text" },
"temperature": temperature, # ollama default: 0.8
"top_k": 20, # reduces the probability of generating nonsense: high = more diverse, low = more focused; ollama default: 40
"top_p": 0.95, # works together with top_k: high = more diverse, low = more focused; ollama default: 0.9
"min_p": 0, # alternative to top_p: p is minimum probability for a token to be considered; ollama default: 0.0
"stream": False "stream": False
} }
if not modelname.startswith("o4"):
payload["temperature"] = temperature
if len(stoptokens) > 0 and not modelname.startswith("o4"): if len(stoptokens) > 0 and not modelname.startswith("o4"):
payload["stop"] = stoptokens payload["stop"] = stoptokens
if modelname.startswith("o1") or modelname.startswith("o4"): if modelname.startswith("o1") or modelname.startswith("o4"):
@@ -129,7 +143,7 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
try: try:
#print(payload) #print(payload)
t0 = time.time() t0 = time.time()
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False) response = requests.post(endpoints["endpoints"][0], headers=headers, json=payload, verify=False)
t1 = time.time() t1 = time.time()
#print(response) #print(response)
response.raise_for_status() response.raise_for_status()
@@ -165,8 +179,8 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
multimodal_cache = {} multimodal_cache = {}
def test_multimodal(endpoint) -> bool: def test_multimodal(endpoints) -> bool:
modelname = endpoint["model"] modelname = endpoints["model"]
cached_result = multimodal_cache.get(modelname, None) cached_result = multimodal_cache.get(modelname, None)
if cached_result is not None: if cached_result is not None:
return cached_result return cached_result
@@ -175,7 +189,7 @@ def test_multimodal(endpoint) -> bool:
with open(image_path, "rb") as image_file: with open(image_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode('utf-8') base64_image = base64.b64encode(image_file.read()).decode('utf-8')
try: try:
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
result = "42" in answer result = "42" in answer
if result: if result:
print(f"Model {modelname} is multimodal.") print(f"Model {modelname} is multimodal.")
@@ -193,13 +207,13 @@ def main():
# parse the arguments # parse the arguments
args = parser.parse_args() args = parser.parse_args()
api_base = args.api_base api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
endpoint_name = args.endpoint endpoint_name = args.endpoint
model_name = args.model model_name = args.model
image_path = args.image image_path = args.image
# load the endpoint file # load the endpoint file
endpoint = {} endpoints = {}
if endpoint_name: if endpoint_name:
print(f"Using endpoint {endpoint_name}") print(f"Using endpoint {endpoint_name}")
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json") endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
@@ -207,12 +221,12 @@ def main():
if not os.path.exists(endpoint_path): if not os.path.exists(endpoint_path):
raise Exception(f"Endpoint file {endpoint_path} does not exist.") raise Exception(f"Endpoint file {endpoint_path} does not exist.")
with open(endpoint_path, 'r', encoding='utf-8') as file: with open(endpoint_path, 'r', encoding='utf-8') as file:
endpoint = json.load(file) endpoints = json.load(file)
else: else:
endpoint = ollama_chat_endpoint(api_base, model_name) endpoints = ollama_chat_endpoints(api_base, model_name)
# test if the endpoint is a multimodal model # test if the endpoint is a multimodal model
if test_multimodal(endpoint): if test_multimodal(endpoints):
print("Endpoint is a multimodal model.") print("Endpoint is a multimodal model.")
else: else:
print("Endpoint is not a multimodal model.") print("Endpoint is not a multimodal model.")
@@ -224,14 +238,14 @@ def main():
base64_image = base64.b64encode(image_file.read()).decode('utf-8') base64_image = base64.b64encode(image_file.read()).decode('utf-8')
# access the ollama API # access the ollama API
models_dict = ollama_list() models_dict = ollama_list(api_base[0])
for (model, attr) in models_dict.items(): for (model, attr) in models_dict.items():
print(f"Model: {model}: {attr}") print(f"Model: {model}: {attr}")
try: try:
if base64_image: if base64_image:
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image) answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
else: else:
answer, total_tokens, token_per_second = ollama_chat(endpoint) answer, total_tokens, token_per_second = ollama_chat(endpoints)
except Exception as e: except Exception as e:
answer = f"Error: {str(e)}" answer = f"Error: {str(e)}"
print(answer) print(answer)

30
test.py
View File

@@ -3,29 +3,41 @@ from argparse import ArgumentParser
from benchmark import read_benchmark, write_benchmark from benchmark import read_benchmark, write_benchmark
from ollama_client import ollama_list from ollama_client import ollama_list
def test(endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100): def test(api_base, endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100, think=False, no_think=False):
# call inference.py # call inference.py
cmd = f"python3.12 inference.py --language {language}" cmd = f"python3.12 inference.py --language {language} --api_base {api_base}"
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
if max_problem_number == 200: cmd += " --n200" if max_problem_number == 200: cmd += " --n200"
if overwrite_existing: cmd += " --overwrite_existing" if overwrite_existing: cmd += " --overwrite_existing"
if overwrite_failed: cmd += " --overwrite_failed" if overwrite_failed: cmd += " --overwrite_failed"
if think: cmd += " --think"
if no_think: cmd += " --no_think"
print(f"Running command: {cmd}")
os.system(cmd) os.system(cmd)
# call codeextraction.py # call codeextraction.py
cmd = f"python3.12 codeextraction.py --language {language}" cmd = f"python3.12 codeextraction.py --language {language}"
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
if think: cmd += " --think"
if no_think: cmd += " --no_think"
print(f"Running command: {cmd}")
os.system(cmd) os.system(cmd)
# call execute.py # call execute.py
cmd = f"python3.12 execute.py --language {language}" cmd = f"python3.12 execute.py --language {language}"
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}" cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
if think: cmd += " --think"
if no_think: cmd += " --no_think"
print(f"Running command: {cmd}")
os.system(cmd) os.system(cmd)
def main(): def main():
parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.") parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.")
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json') parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest') parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure') parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer') parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers') parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
@@ -36,6 +48,7 @@ def main():
parser.add_argument('--nall', action='store_true', help='all problems') parser.add_argument('--nall', action='store_true', help='all problems')
args = parser.parse_args() args = parser.parse_args()
api_base = args.api_base
model_name = args.model model_name = args.model
max_problem_number = 100 max_problem_number = 100
if args.n100: max_problem_number = 100 if args.n100: max_problem_number = 100
@@ -72,16 +85,19 @@ def main():
# in every loop we load the benchmark.json again because it might have been updated # in every loop we load the benchmark.json again because it might have been updated
benchmark = read_benchmark() benchmark = read_benchmark()
entry = benchmark.get(model, {}) model_benchmark_name = model
if args.think: model_benchmark_name += "-think"
if args.no_think: model_benchmark_name += "-no_think"
entry = benchmark.get(model_benchmark_name, {})
# add metadata to benchmark.json # add metadata to benchmark.json
if not model in benchmark or not bench_name in benchmark[model] or overwrite_existing or overwrite_failed: if not model_benchmark_name in benchmark or not bench_name in benchmark[model_benchmark_name] or overwrite_existing or overwrite_failed:
# run the model; this writes a news entry to benchmark.json # run the model; this writes a news entry to benchmark.json
test(endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number) test(api_base, endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number, think = args.think, no_think = args.no_think)
# load benchmark.json again because the test has updated it # load benchmark.json again because the test has updated it
benchmark = read_benchmark() benchmark = read_benchmark()
# because testing can be interrupted, there is no guarantee that the entry is present # because testing can be interrupted, there is no guarantee that the entry is present
entry = benchmark.get(model, {}) entry = benchmark.get(model_benchmark_name, {})
# check if attributes parameter_size and quantization_level are present in benchmark.json # check if attributes parameter_size and quantization_level are present in benchmark.json
if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None): if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None):
@@ -89,7 +105,7 @@ def main():
if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None): if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None):
entry['_quantization_level'] = model_dict[model].get('quantization_level', None) entry['_quantization_level'] = model_dict[model].get('quantization_level', None)
entry = dict(sorted(entry.items(), key=lambda item: item[0])) entry = dict(sorted(entry.items(), key=lambda item: item[0]))
benchmark[model] = entry benchmark[model_benchmark_name] = entry
# write the updated benchmark file # write the updated benchmark file
write_benchmark(benchmark) write_benchmark(benchmark)