added rust code generation

This commit is contained in:
Michael Peter Christen
2025-01-05 18:47:17 +01:00
parent f1dc63a0d6
commit a3cbf9ed68
3 changed files with 172 additions and 93 deletions

View File

@@ -2,6 +2,7 @@ import os
import json
import requests
import urllib3
from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat
from argparse import ArgumentParser
def read_template(template_path):
@@ -29,7 +30,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
try:
content = ollama_client(endpoint, prompt)
content = ollama_chat(endpoint, prompt)
# Save the response to a file
with open(result_file_path, 'w', encoding='utf-8') as result_file:
@@ -38,79 +39,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
except Exception as e:
print(f"Failed to process problem {problem_number}: {e}")
def ollama_client(endpoint, prompt='Hello World', temperature=0.0, max_tokens=4096):
# Disable SSL warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Prepare the API endpoint URL
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"]
# Set headers and payload
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
if endpoint.get("key", ""):
headers['Authorization'] = 'Bearer ' + endpoint["key"]
modelname = endpoint["model"]
messages = []
# o1 has special requirements
if not modelname.startswith("o1"):
messages.append({"content": "You are a helpful assistant", "role": "system"})
else:
temperature = 1.0 # o1 models need temperature 1.0
messages.append({"role": "user", "content": prompt})
if modelname.startswith("o1") or modelname.startswith("4o"):
stoptokens = []
payload = {
"model": modelname,
"messages": messages,
"temperature": temperature,
"response_format": { "type": "text" },
"stream": False
}
if len(stoptokens) > 0:
payload["stop"] = stoptokens
if modelname.startswith("o1"):
payload["max_completion_tokens"] = max_tokens
else:
payload["max_tokens"] = max_tokens
try:
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
response.raise_for_status()
except requests.exceptions.RequestException as e:
# print(f"Failed to access api: {e}")
# Get the error message from the response
if response:
try:
data = response.json()
message = data.get('message', {})
content = message.get('content', '')
raise Exception(f"API request failed: {content}")
except json.JSONDecodeError:
raise Exception(f"API request failed: {e}")
# Parse the response
try:
data = response.json()
#print(data)
choices = data.get('choices', [])
if len(choices) == 0:
raise Exception("No response from the API: " + str(data))
message = choices[0].get('message', {})
content = message.get('content', '')
return content
except json.JSONDecodeError:
raise Exception("Failed to parse JSON response from the API.")
def main():
parser = ArgumentParser(description="Process Euler problems and send them to an LLM.")
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
parser.add_argument('--language', required=False, default='python', help='Name of the programming language to use, default is python')
parser.add_argument('--skip_existing', action='store_true', help='if set, skip problems that already have a solution')
@@ -130,28 +62,9 @@ def main():
if args.n200: max_problem_number = 200
if args.n400: max_problem_number = 400
if args.nall: max_problem_number = 9999
print(f"Inference: Using model {model_name} and language {language}")
# construct the endpoint object
bench_name = f"{language}-{max_problem_number}"
endpoint_name = args.endpoint
endpoint = {}
if endpoint_name:
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
print(f"Using endpoint file {endpoint_path}")
if not os.path.exists(endpoint_path):
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
with open(endpoint_path, 'r', encoding='utf-8') as file:
endpoint = json.load(file)
else:
# construct the endpoint object from command line arguments considering that ollama is the endpoint
api_base='http://localhost:11434'
endpoint = {
"name": model_name,
"model": model_name,
"key": "",
"endpoint": f"{api_base}/v1/chat/completions",
}
problems_dir = 'problems'
template_path = os.path.join('templates', 'template_' + language + '.md')
@@ -162,7 +75,43 @@ def main():
raise Exception(f"Template file {template_path} does not exist.")
template_content = read_template(template_path)
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
if args.allmodels:
if endpoint_name:
raise Exception("The --allmodels option cannot be used in combination with --endpoint.")
# loop over all models provided by ollama and run those which are missing in benchmark.json
models = ollama_list()
print(f"Found {len(models)} models in ollama.")
for model in models:
# in every loop we load the benchmark.json again because it might have been updated
with open('benchmark.json', 'r', encoding='utf-8') as json_file:
benchmark = json.load(json_file)
entry = benchmark.get(model, {})
# add metadata to benchmark.json
if not model in benchmark or not bench_name in benchmark[model]:
print(f"Inference: Using model {model} and language {language}")
endpoint = ollama_chat_endpoint('http://localhost:11434', model)
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
else:
# construct the endpoint object
endpoint = {}
if endpoint_name:
print(f"Inference: Using endpoint {endpoint_name} and language {language}")
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
print(f"Using endpoint file {endpoint_path}")
if not os.path.exists(endpoint_path):
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
with open(endpoint_path, 'r', encoding='utf-8') as file:
endpoint = json.load(file)
else:
print(f"Inference: Using model {model_name} and language {language}")
# construct the endpoint object from command line arguments considering that ollama is the endpoint
endpoint = ollama_chat_endpoint('http://localhost:11434', model_name)
# run the inference
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
if __name__ == "__main__":
main()

122
ollama_client.py Normal file
View File

@@ -0,0 +1,122 @@
import requests
import urllib3
def ollama_list(api_base='http://localhost:11434'):
# call api http://localhost:11434/api/tags with http get request
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
endpoint = f"{api_base}/api/tags"
response = requests.get(endpoint, verify=False)
response.raise_for_status()
data = response.json()
models_list = data['models']
models_dict = {}
for entry in models_list:
# get parameter_size and quantization_level from data
model = entry['model']
details = entry['details']
attr = {}
parameter_size = details['parameter_size']
quantization_level = details['quantization_level']
parameter_size = parameter_size[:-1]
try:
parameter_size = float(parameter_size)
attr['parameter_size'] = parameter_size
except ValueError:
pass
quantization_level_char = quantization_level[1:2]
try:
quantization_level = int(quantization_level_char)
attr['quantization_level'] = quantization_level
except ValueError:
pass
models_dict[model] = attr
return models_dict
def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest'):
endpoint = {
"name": model_name,
"model": model_name,
"key": "",
"endpoint": f"{api_base}/v1/chat/completions",
}
return endpoint
def ollama_chat(endpoint, prompt='Hello World', temperature=0.0, max_tokens=4096):
# Disable SSL warnings
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# Prepare the API endpoint URL
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"]
# Set headers and payload
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
if endpoint.get("key", ""):
headers['Authorization'] = 'Bearer ' + endpoint["key"]
modelname = endpoint["model"]
messages = []
# o1 has special requirements
if not modelname.startswith("o1"):
messages.append({"content": "You are a helpful assistant", "role": "system"})
else:
temperature = 1.0 # o1 models need temperature 1.0
messages.append({"role": "user", "content": prompt})
if modelname.startswith("o1") or modelname.startswith("4o"):
stoptokens = []
payload = {
"model": modelname,
"messages": messages,
"temperature": temperature,
"response_format": { "type": "text" },
"stream": False
}
if len(stoptokens) > 0:
payload["stop"] = stoptokens
if modelname.startswith("o1"):
payload["max_completion_tokens"] = max_tokens
else:
payload["max_tokens"] = max_tokens
try:
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
response.raise_for_status()
except requests.exceptions.RequestException as e:
# print(f"Failed to access api: {e}")
# Get the error message from the response
if response:
try:
data = response.json()
message = data.get('message', {})
content = message.get('content', '')
raise Exception(f"API request failed: {content}")
except json.JSONDecodeError:
raise Exception(f"API request failed: {e}")
# Parse the response
try:
data = response.json()
#print(data)
choices = data.get('choices', [])
if len(choices) == 0:
raise Exception("No response from the API: " + str(data))
message = choices[0].get('message', {})
content = message.get('content', '')
return content
except json.JSONDecodeError:
raise Exception("Failed to parse JSON response from the API.")
def main():
models_dict = ollama_list()
print(models_dict)
answer = ollama_chat(ollama_chat_endpoint())
print(answer)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,8 @@
Write a rust program for the following problem description:
$$$PROBLEM$$$
The rust program must not use any std::io module. It should also not have any user inputs. Do not use assert.
The rust program should not use any imports unless absolutely necessary for the solution.
The rust program must not output any comment lines or logging or debugging lines,
it must just output a single value which is the solution of the problem.