refactoring for parallel processing
This commit is contained in:
@@ -80,6 +80,8 @@ def process_markdown_files(model_name, language):
|
|||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser(description="Extract code blocks from Markdown files.")
|
parser = ArgumentParser(description="Extract code blocks from Markdown files.")
|
||||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||||
|
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||||
|
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||||
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
||||||
|
|
||||||
@@ -88,6 +90,7 @@ def main():
|
|||||||
language = args.language
|
language = args.language
|
||||||
endpoint_name = args.endpoint
|
endpoint_name = args.endpoint
|
||||||
|
|
||||||
|
# in case no model name is given but an endpoint name, read the model name from the endpoint file
|
||||||
if endpoint_name:
|
if endpoint_name:
|
||||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||||
print(f"Using endpoint file {endpoint_path}")
|
print(f"Using endpoint file {endpoint_path}")
|
||||||
@@ -97,6 +100,10 @@ def main():
|
|||||||
endpoint = json.load(file)
|
endpoint = json.load(file)
|
||||||
model_name = endpoint.get('name', model_name)
|
model_name = endpoint.get('name', model_name)
|
||||||
|
|
||||||
|
# modify the model name in case soft thinking switches are given
|
||||||
|
if args.think: model_name += "-think"
|
||||||
|
if args.no_think: model_name += "-no_think"
|
||||||
|
|
||||||
languages = args.language.split(',')
|
languages = args.language.split(',')
|
||||||
for language in languages:
|
for language in languages:
|
||||||
print(f"Processing language: {language}")
|
print(f"Processing language: {language}")
|
||||||
|
|||||||
@@ -407,6 +407,8 @@ def main():
|
|||||||
parser = ArgumentParser(description="Execute solutions and store results in a JSON file.")
|
parser = ArgumentParser(description="Execute solutions and store results in a JSON file.")
|
||||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them')
|
parser.add_argument('--allmodels', action='store_true', help='loop over all models as provided by benchmark.json and run all of them')
|
||||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||||
|
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||||
|
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python')
|
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the programming language to use, default is python')
|
||||||
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
parser.add_argument('--endpoint', required=False, default='', help='Name of an <endpoint>.json file in the endpoints directory')
|
||||||
parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default
|
parser.add_argument('--n100', action='store_true', help='only 100 problems') # this is the default
|
||||||
@@ -432,6 +434,10 @@ def main():
|
|||||||
endpoint = json.load(file)
|
endpoint = json.load(file)
|
||||||
model_name = endpoint.get('name', model_name)
|
model_name = endpoint.get('name', model_name)
|
||||||
|
|
||||||
|
# modify the model name in case soft thinking switches are given
|
||||||
|
if args.think: model_name += "-think"
|
||||||
|
if args.no_think: model_name += "-no_think"
|
||||||
|
|
||||||
with open('solutions.json', 'r', encoding='utf-8') as json_file:
|
with open('solutions.json', 'r', encoding='utf-8') as json_file:
|
||||||
expected_solutions = json.load(json_file)
|
expected_solutions = json.load(json_file)
|
||||||
|
|
||||||
|
|||||||
34
inference.py
34
inference.py
@@ -3,14 +3,18 @@ import json
|
|||||||
import base64
|
import base64
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from benchmark import read_benchmark, write_benchmark
|
from benchmark import read_benchmark, write_benchmark
|
||||||
from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat, test_multimodal
|
from ollama_client import ollama_list, ollama_chat_endpoints, ollama_chat, test_multimodal
|
||||||
|
|
||||||
def read_template(template_path):
|
def read_template(template_path):
|
||||||
with open(template_path, 'r', encoding='utf-8') as file:
|
with open(template_path, 'r', encoding='utf-8') as file:
|
||||||
return file.read()
|
return file.read()
|
||||||
|
|
||||||
def process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number=9999, overwrite_existing=False, overwrite_failed=False, expected_solutions={}):
|
def process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number=9999,
|
||||||
model_name = endpoint["name"]
|
overwrite_existing=False, overwrite_failed=False, expected_solutions={},
|
||||||
|
think=False, no_think=False):
|
||||||
|
model_name = endpoints["name"]
|
||||||
|
if think: model_name += "-think"
|
||||||
|
if no_think: model_name += "-no_think"
|
||||||
results_dir = os.path.join('solutions', model_name, language)
|
results_dir = os.path.join('solutions', model_name, language)
|
||||||
os.makedirs(results_dir, exist_ok=True)
|
os.makedirs(results_dir, exist_ok=True)
|
||||||
|
|
||||||
@@ -51,6 +55,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
|||||||
# Construct the prompt using the template
|
# Construct the prompt using the template
|
||||||
prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
|
prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
|
||||||
|
|
||||||
|
# attach soft thinking switches if asked
|
||||||
|
if think: prompt += " /think"
|
||||||
|
if no_think: prompt += " /no_think"
|
||||||
|
|
||||||
# check if there is also an image in the problem. We take the problem_file, remove the extension ".txt"
|
# check if there is also an image in the problem. We take the problem_file, remove the extension ".txt"
|
||||||
# and add either "-0.png", "-0.jpg" or "-0.gif"
|
# and add either "-0.png", "-0.jpg" or "-0.gif"
|
||||||
base64_image = None
|
base64_image = None
|
||||||
@@ -64,7 +72,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
|||||||
|
|
||||||
# check if the endpoint is multimodal, we do this only if we have an image
|
# check if the endpoint is multimodal, we do this only if we have an image
|
||||||
if base64_image:
|
if base64_image:
|
||||||
is_multimodal = test_multimodal(endpoint) # this is cached
|
is_multimodal = test_multimodal(endpoints) # this is cached
|
||||||
if is_multimodal:
|
if is_multimodal:
|
||||||
print(f"Problem {problem_number} is handled with multimodal model.")
|
print(f"Problem {problem_number} is handled with multimodal model.")
|
||||||
else:
|
else:
|
||||||
@@ -72,7 +80,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
|||||||
base64_image = None
|
base64_image = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt, base64_image=base64_image)
|
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt, base64_image=base64_image)
|
||||||
|
|
||||||
# Save the response to a file
|
# Save the response to a file
|
||||||
with open(result_file_path, 'w', encoding='utf-8') as result_file:
|
with open(result_file_path, 'w', encoding='utf-8') as result_file:
|
||||||
@@ -89,6 +97,8 @@ def main():
|
|||||||
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
||||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
||||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||||
|
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||||
|
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||||
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
||||||
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
||||||
@@ -100,7 +110,7 @@ def main():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
api_base = args.api_base
|
api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
|
||||||
model_name = args.model
|
model_name = args.model
|
||||||
language = args.language
|
language = args.language
|
||||||
max_problem_number = 100
|
max_problem_number = 100
|
||||||
@@ -149,8 +159,10 @@ def main():
|
|||||||
# add metadata to benchmark.json
|
# add metadata to benchmark.json
|
||||||
if not model in benchmark or not bench_name in benchmark[model]:
|
if not model in benchmark or not bench_name in benchmark[model]:
|
||||||
print(f"Inference: Using model {model} and language {language}")
|
print(f"Inference: Using model {model} and language {language}")
|
||||||
endpoint = ollama_chat_endpoint(api_base, model)
|
endpoints = ollama_chat_endpoints(api_base, model)
|
||||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions)
|
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||||
|
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
|
||||||
|
think = args.think, no_think = args.no_think)
|
||||||
else:
|
else:
|
||||||
# construct the endpoint object
|
# construct the endpoint object
|
||||||
endpoint = {}
|
endpoint = {}
|
||||||
@@ -165,10 +177,12 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print(f"Inference: Using model {model_name} and language {language}")
|
print(f"Inference: Using model {model_name} and language {language}")
|
||||||
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
||||||
endpoint = ollama_chat_endpoint(api_base, model_name)
|
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||||
|
|
||||||
# run the inference
|
# run the inference
|
||||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions)
|
process_problem_files(problems_dir, template_content, endpoints, language, max_problem_number = max_problem_number,
|
||||||
|
overwrite_existing = args.overwrite_existing, overwrite_failed = args.overwrite_failed, expected_solutions = expected_solutions,
|
||||||
|
think = args.think, no_think = args.no_think)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -40,19 +40,31 @@ def ollama_list(api_base='http://localhost:11434') -> dict:
|
|||||||
models_dict[model] = attr
|
models_dict[model] = attr
|
||||||
return models_dict
|
return models_dict
|
||||||
|
|
||||||
def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
|
def ollama_chat_endpoints(api_base='http://localhost:11434', model_name='llama3.2:latest') -> dict:
|
||||||
endpoint = {
|
# check if api_base is a string
|
||||||
"name": model_name,
|
if isinstance(api_base, str):
|
||||||
"model": model_name,
|
endpoint = {
|
||||||
"key": "",
|
"name": model_name,
|
||||||
"endpoint": f"{api_base}/v1/chat/completions",
|
"model": model_name,
|
||||||
}
|
"key": "",
|
||||||
return endpoint
|
"endpoints": [f"{api_base}/v1/chat/completions"],
|
||||||
|
}
|
||||||
|
return endpoint
|
||||||
|
# check if api_base is a list of strings
|
||||||
|
if isinstance(api_base, list):
|
||||||
|
endpoint = {
|
||||||
|
"name": model_name,
|
||||||
|
"model": model_name,
|
||||||
|
"key": "",
|
||||||
|
"endpoints": [f"{api_stub}/v1/chat/completions" for api_stub in api_base],
|
||||||
|
}
|
||||||
|
return endpoint
|
||||||
|
return {}
|
||||||
|
|
||||||
def hex2base64(hex_string) -> str:
|
def hex2base64(hex_string) -> str:
|
||||||
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
return base64.b64encode(bytes.fromhex(hex_string)).decode('utf-8')
|
||||||
|
|
||||||
def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple:
|
def ollama_chat(endpoints, prompt='Hello World', base64_image=None, temperature=0.0, max_tokens=8192) -> tuple:
|
||||||
|
|
||||||
# Disable SSL warnings
|
# Disable SSL warnings
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
@@ -65,16 +77,16 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json'
|
'Accept': 'application/json'
|
||||||
}
|
}
|
||||||
if endpoint.get("key", ""):
|
if endpoints.get("key", ""):
|
||||||
headers['Authorization'] = 'Bearer ' + endpoint["key"]
|
headers['Authorization'] = 'Bearer ' + endpoints["key"]
|
||||||
|
|
||||||
modelname = endpoint["model"]
|
modelname = endpoints["model"]
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "system", "content": "You are a helpful assistant"})
|
messages.append({"role": "system", "content": "You are a helpful assistant"})
|
||||||
|
|
||||||
# o1 has special requirements
|
# special requirements of certain models
|
||||||
if modelname.startswith("o1") or modelname.startswith("gpt-o1"):
|
if modelname.startswith("o1") or modelname.startswith("gpt-o1"): temperature = 1.0
|
||||||
temperature = 1.0 # o1 models need temperature 1.0
|
if modelname.startswith("qwen3"): temperature = 0.6
|
||||||
|
|
||||||
if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"):
|
if modelname.startswith("4o") or modelname.startswith("gpt-4o") or modelname.startswith("gpt-3.5"):
|
||||||
# reduce number of stoptokes to 4
|
# reduce number of stoptokes to 4
|
||||||
@@ -115,10 +127,12 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
|||||||
"model": modelname,
|
"model": modelname,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"response_format": { "type": "text" },
|
"response_format": { "type": "text" },
|
||||||
|
"temperature": temperature, # ollama default: 0.8
|
||||||
|
"top_k": 20, # reduces the probability of generating nonsense: high = more diverse, low = more focused; ollama default: 40
|
||||||
|
"top_p": 0.95, # works together with top_k: high = more diverse, low = more focused; ollama default: 0.9
|
||||||
|
"min_p": 0, # alternative to top_p: p is minimum probability for a token to be considered; ollama default: 0.0
|
||||||
"stream": False
|
"stream": False
|
||||||
}
|
}
|
||||||
if not modelname.startswith("o4"):
|
|
||||||
payload["temperature"] = temperature
|
|
||||||
if len(stoptokens) > 0 and not modelname.startswith("o4"):
|
if len(stoptokens) > 0 and not modelname.startswith("o4"):
|
||||||
payload["stop"] = stoptokens
|
payload["stop"] = stoptokens
|
||||||
if modelname.startswith("o1") or modelname.startswith("o4"):
|
if modelname.startswith("o1") or modelname.startswith("o4"):
|
||||||
@@ -129,7 +143,7 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
|||||||
try:
|
try:
|
||||||
#print(payload)
|
#print(payload)
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
|
response = requests.post(endpoints["endpoints"][0], headers=headers, json=payload, verify=False)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
#print(response)
|
#print(response)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -165,8 +179,8 @@ def ollama_chat(endpoint, prompt='Hello World', base64_image=None, temperature=0
|
|||||||
|
|
||||||
multimodal_cache = {}
|
multimodal_cache = {}
|
||||||
|
|
||||||
def test_multimodal(endpoint) -> bool:
|
def test_multimodal(endpoints) -> bool:
|
||||||
modelname = endpoint["model"]
|
modelname = endpoints["model"]
|
||||||
cached_result = multimodal_cache.get(modelname, None)
|
cached_result = multimodal_cache.get(modelname, None)
|
||||||
if cached_result is not None:
|
if cached_result is not None:
|
||||||
return cached_result
|
return cached_result
|
||||||
@@ -175,7 +189,7 @@ def test_multimodal(endpoint) -> bool:
|
|||||||
with open(image_path, "rb") as image_file:
|
with open(image_path, "rb") as image_file:
|
||||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
try:
|
try:
|
||||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
||||||
result = "42" in answer
|
result = "42" in answer
|
||||||
if result:
|
if result:
|
||||||
print(f"Model {modelname} is multimodal.")
|
print(f"Model {modelname} is multimodal.")
|
||||||
@@ -193,13 +207,13 @@ def main():
|
|||||||
|
|
||||||
# parse the arguments
|
# parse the arguments
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
api_base = args.api_base
|
api_base = args.api_base.split(",") if "," in args.api_base else [args.api_base]
|
||||||
endpoint_name = args.endpoint
|
endpoint_name = args.endpoint
|
||||||
model_name = args.model
|
model_name = args.model
|
||||||
image_path = args.image
|
image_path = args.image
|
||||||
|
|
||||||
# load the endpoint file
|
# load the endpoint file
|
||||||
endpoint = {}
|
endpoints = {}
|
||||||
if endpoint_name:
|
if endpoint_name:
|
||||||
print(f"Using endpoint {endpoint_name}")
|
print(f"Using endpoint {endpoint_name}")
|
||||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||||
@@ -207,12 +221,12 @@ def main():
|
|||||||
if not os.path.exists(endpoint_path):
|
if not os.path.exists(endpoint_path):
|
||||||
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
||||||
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
||||||
endpoint = json.load(file)
|
endpoints = json.load(file)
|
||||||
else:
|
else:
|
||||||
endpoint = ollama_chat_endpoint(api_base, model_name)
|
endpoints = ollama_chat_endpoints(api_base, model_name)
|
||||||
|
|
||||||
# test if the endpoint is a multimodal model
|
# test if the endpoint is a multimodal model
|
||||||
if test_multimodal(endpoint):
|
if test_multimodal(endpoints):
|
||||||
print("Endpoint is a multimodal model.")
|
print("Endpoint is a multimodal model.")
|
||||||
else:
|
else:
|
||||||
print("Endpoint is not a multimodal model.")
|
print("Endpoint is not a multimodal model.")
|
||||||
@@ -224,14 +238,14 @@ def main():
|
|||||||
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
# access the ollama API
|
# access the ollama API
|
||||||
models_dict = ollama_list()
|
models_dict = ollama_list(api_base[0])
|
||||||
for (model, attr) in models_dict.items():
|
for (model, attr) in models_dict.items():
|
||||||
print(f"Model: {model}: {attr}")
|
print(f"Model: {model}: {attr}")
|
||||||
try:
|
try:
|
||||||
if base64_image:
|
if base64_image:
|
||||||
answer, total_tokens, token_per_second = ollama_chat(endpoint, prompt="what is in the image", base64_image=base64_image)
|
answer, total_tokens, token_per_second = ollama_chat(endpoints, prompt="what is in the image", base64_image=base64_image)
|
||||||
else:
|
else:
|
||||||
answer, total_tokens, token_per_second = ollama_chat(endpoint)
|
answer, total_tokens, token_per_second = ollama_chat(endpoints)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
answer = f"Error: {str(e)}"
|
answer = f"Error: {str(e)}"
|
||||||
print(answer)
|
print(answer)
|
||||||
|
|||||||
30
test.py
30
test.py
@@ -3,29 +3,41 @@ from argparse import ArgumentParser
|
|||||||
from benchmark import read_benchmark, write_benchmark
|
from benchmark import read_benchmark, write_benchmark
|
||||||
from ollama_client import ollama_list
|
from ollama_client import ollama_list
|
||||||
|
|
||||||
def test(endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100):
|
def test(api_base, endpoint_name, model_name, language, overwrite_existing, overwrite_failed, max_problem_number=100, think=False, no_think=False):
|
||||||
# call inference.py
|
# call inference.py
|
||||||
cmd = f"python3.12 inference.py --language {language}"
|
cmd = f"python3.12 inference.py --language {language} --api_base {api_base}"
|
||||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||||
if max_problem_number == 200: cmd += " --n200"
|
if max_problem_number == 200: cmd += " --n200"
|
||||||
if overwrite_existing: cmd += " --overwrite_existing"
|
if overwrite_existing: cmd += " --overwrite_existing"
|
||||||
if overwrite_failed: cmd += " --overwrite_failed"
|
if overwrite_failed: cmd += " --overwrite_failed"
|
||||||
|
if think: cmd += " --think"
|
||||||
|
if no_think: cmd += " --no_think"
|
||||||
|
print(f"Running command: {cmd}")
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
# call codeextraction.py
|
# call codeextraction.py
|
||||||
cmd = f"python3.12 codeextraction.py --language {language}"
|
cmd = f"python3.12 codeextraction.py --language {language}"
|
||||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||||
|
if think: cmd += " --think"
|
||||||
|
if no_think: cmd += " --no_think"
|
||||||
|
print(f"Running command: {cmd}")
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
# call execute.py
|
# call execute.py
|
||||||
cmd = f"python3.12 execute.py --language {language}"
|
cmd = f"python3.12 execute.py --language {language}"
|
||||||
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
cmd += f" --endpoint {endpoint_name}" if endpoint_name else f" --model {model_name}"
|
||||||
|
if think: cmd += " --think"
|
||||||
|
if no_think: cmd += " --no_think"
|
||||||
|
print(f"Running command: {cmd}")
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.")
|
parser = ArgumentParser(description="Run the complete pipeline to execute solutions and store results in a JSON file.")
|
||||||
|
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
||||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
||||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||||
|
parser.add_argument('--think', action='store_true', help='if set, the prompt will get an additional "/think" appended at the end')
|
||||||
|
parser.add_argument('--no_think', action='store_true', help='if set, the prompt will get an additional "/no_think" appended at the end')
|
||||||
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
parser.add_argument('--language', required=False, default='python,java,rust,clojure', help='Name of the languages to test, default is python,java,rust,clojure')
|
||||||
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
parser.add_argument('--overwrite_existing', action='store_true', help='if set, re-calculate all problems that already have an answer')
|
||||||
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
parser.add_argument('--overwrite_failed', action='store_true', help='if set, re-calculate those problems with wrong answers')
|
||||||
@@ -36,6 +48,7 @@ def main():
|
|||||||
parser.add_argument('--nall', action='store_true', help='all problems')
|
parser.add_argument('--nall', action='store_true', help='all problems')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
api_base = args.api_base
|
||||||
model_name = args.model
|
model_name = args.model
|
||||||
max_problem_number = 100
|
max_problem_number = 100
|
||||||
if args.n100: max_problem_number = 100
|
if args.n100: max_problem_number = 100
|
||||||
@@ -72,16 +85,19 @@ def main():
|
|||||||
|
|
||||||
# in every loop we load the benchmark.json again because it might have been updated
|
# in every loop we load the benchmark.json again because it might have been updated
|
||||||
benchmark = read_benchmark()
|
benchmark = read_benchmark()
|
||||||
entry = benchmark.get(model, {})
|
model_benchmark_name = model
|
||||||
|
if args.think: model_benchmark_name += "-think"
|
||||||
|
if args.no_think: model_benchmark_name += "-no_think"
|
||||||
|
entry = benchmark.get(model_benchmark_name, {})
|
||||||
|
|
||||||
# add metadata to benchmark.json
|
# add metadata to benchmark.json
|
||||||
if not model in benchmark or not bench_name in benchmark[model] or overwrite_existing or overwrite_failed:
|
if not model_benchmark_name in benchmark or not bench_name in benchmark[model_benchmark_name] or overwrite_existing or overwrite_failed:
|
||||||
# run the model; this writes a news entry to benchmark.json
|
# run the model; this writes a news entry to benchmark.json
|
||||||
test(endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number)
|
test(api_base, endpoint_name, model, language, overwrite_existing, overwrite_failed, max_problem_number, think = args.think, no_think = args.no_think)
|
||||||
# load benchmark.json again because the test has updated it
|
# load benchmark.json again because the test has updated it
|
||||||
benchmark = read_benchmark()
|
benchmark = read_benchmark()
|
||||||
# because testing can be interrupted, there is no guarantee that the entry is present
|
# because testing can be interrupted, there is no guarantee that the entry is present
|
||||||
entry = benchmark.get(model, {})
|
entry = benchmark.get(model_benchmark_name, {})
|
||||||
|
|
||||||
# check if attributes parameter_size and quantization_level are present in benchmark.json
|
# check if attributes parameter_size and quantization_level are present in benchmark.json
|
||||||
if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None):
|
if not '_parameter_size' in entry and model_dict[model].get('parameter_size', None):
|
||||||
@@ -89,7 +105,7 @@ def main():
|
|||||||
if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None):
|
if not '_quantization_level' in entry and model_dict[model].get('quantization_level', None):
|
||||||
entry['_quantization_level'] = model_dict[model].get('quantization_level', None)
|
entry['_quantization_level'] = model_dict[model].get('quantization_level', None)
|
||||||
entry = dict(sorted(entry.items(), key=lambda item: item[0]))
|
entry = dict(sorted(entry.items(), key=lambda item: item[0]))
|
||||||
benchmark[model] = entry
|
benchmark[model_benchmark_name] = entry
|
||||||
|
|
||||||
# write the updated benchmark file
|
# write the updated benchmark file
|
||||||
write_benchmark(benchmark)
|
write_benchmark(benchmark)
|
||||||
|
|||||||
Reference in New Issue
Block a user