added rust code generation
This commit is contained in:
135
inference.py
135
inference.py
@@ -2,6 +2,7 @@ import os
|
||||
import json
|
||||
import requests
|
||||
import urllib3
|
||||
from ollama_client import ollama_list, ollama_chat_endpoint, ollama_chat
|
||||
from argparse import ArgumentParser
|
||||
|
||||
def read_template(template_path):
|
||||
@@ -29,7 +30,7 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
prompt = template_content.replace('$$$PROBLEM$$$', problem_content)
|
||||
|
||||
try:
|
||||
content = ollama_client(endpoint, prompt)
|
||||
content = ollama_chat(endpoint, prompt)
|
||||
|
||||
# Save the response to a file
|
||||
with open(result_file_path, 'w', encoding='utf-8') as result_file:
|
||||
@@ -38,79 +39,10 @@ def process_problem_files(problems_dir, template_content, endpoint, language, ma
|
||||
except Exception as e:
|
||||
print(f"Failed to process problem {problem_number}: {e}")
|
||||
|
||||
def ollama_client(endpoint, prompt='Hello World', temperature=0.0, max_tokens=4096):
|
||||
|
||||
# Disable SSL warnings
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# Prepare the API endpoint URL
|
||||
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"]
|
||||
|
||||
# Set headers and payload
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if endpoint.get("key", ""):
|
||||
headers['Authorization'] = 'Bearer ' + endpoint["key"]
|
||||
|
||||
modelname = endpoint["model"]
|
||||
messages = []
|
||||
# o1 has special requirements
|
||||
if not modelname.startswith("o1"):
|
||||
messages.append({"content": "You are a helpful assistant", "role": "system"})
|
||||
else:
|
||||
temperature = 1.0 # o1 models need temperature 1.0
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if modelname.startswith("o1") or modelname.startswith("4o"):
|
||||
stoptokens = []
|
||||
|
||||
payload = {
|
||||
"model": modelname,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": { "type": "text" },
|
||||
"stream": False
|
||||
}
|
||||
if len(stoptokens) > 0:
|
||||
payload["stop"] = stoptokens
|
||||
if modelname.startswith("o1"):
|
||||
payload["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
# print(f"Failed to access api: {e}")
|
||||
# Get the error message from the response
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
message = data.get('message', {})
|
||||
content = message.get('content', '')
|
||||
raise Exception(f"API request failed: {content}")
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"API request failed: {e}")
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
data = response.json()
|
||||
#print(data)
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) == 0:
|
||||
raise Exception("No response from the API: " + str(data))
|
||||
message = choices[0].get('message', {})
|
||||
content = message.get('content', '')
|
||||
return content
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Failed to parse JSON response from the API.")
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description="Process Euler problems and send them to an LLM.")
|
||||
parser.add_argument('--api_base', required=False, default='http://localhost:11434', help='API base URL for the LLM, default is http://localhost:11434')
|
||||
parser.add_argument('--allmodels', action='store_true', help='loop over all models provided by ollama and run those which are missing in benchmark.json')
|
||||
parser.add_argument('--model', required=False, default='llama3.2:latest', help='Name of the model to use, default is llama3.2:latest')
|
||||
parser.add_argument('--language', required=False, default='python', help='Name of the programming language to use, default is python')
|
||||
parser.add_argument('--skip_existing', action='store_true', help='if set, skip problems that already have a solution')
|
||||
@@ -130,28 +62,9 @@ def main():
|
||||
if args.n200: max_problem_number = 200
|
||||
if args.n400: max_problem_number = 400
|
||||
if args.nall: max_problem_number = 9999
|
||||
|
||||
print(f"Inference: Using model {model_name} and language {language}")
|
||||
|
||||
# construct the endpoint object
|
||||
bench_name = f"{language}-{max_problem_number}"
|
||||
endpoint_name = args.endpoint
|
||||
endpoint = {}
|
||||
if endpoint_name:
|
||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||
print(f"Using endpoint file {endpoint_path}")
|
||||
if not os.path.exists(endpoint_path):
|
||||
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
||||
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
||||
endpoint = json.load(file)
|
||||
else:
|
||||
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
||||
api_base='http://localhost:11434'
|
||||
endpoint = {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoint": f"{api_base}/v1/chat/completions",
|
||||
}
|
||||
|
||||
problems_dir = 'problems'
|
||||
template_path = os.path.join('templates', 'template_' + language + '.md')
|
||||
|
||||
@@ -162,7 +75,43 @@ def main():
|
||||
raise Exception(f"Template file {template_path} does not exist.")
|
||||
|
||||
template_content = read_template(template_path)
|
||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
|
||||
|
||||
if args.allmodels:
|
||||
if endpoint_name:
|
||||
raise Exception("The --allmodels option cannot be used in combination with --endpoint.")
|
||||
|
||||
# loop over all models provided by ollama and run those which are missing in benchmark.json
|
||||
models = ollama_list()
|
||||
print(f"Found {len(models)} models in ollama.")
|
||||
for model in models:
|
||||
# in every loop we load the benchmark.json again because it might have been updated
|
||||
with open('benchmark.json', 'r', encoding='utf-8') as json_file:
|
||||
benchmark = json.load(json_file)
|
||||
entry = benchmark.get(model, {})
|
||||
|
||||
# add metadata to benchmark.json
|
||||
if not model in benchmark or not bench_name in benchmark[model]:
|
||||
print(f"Inference: Using model {model} and language {language}")
|
||||
endpoint = ollama_chat_endpoint('http://localhost:11434', model)
|
||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
|
||||
else:
|
||||
# construct the endpoint object
|
||||
endpoint = {}
|
||||
if endpoint_name:
|
||||
print(f"Inference: Using endpoint {endpoint_name} and language {language}")
|
||||
endpoint_path = os.path.join('endpoints', f"{endpoint_name}.json")
|
||||
print(f"Using endpoint file {endpoint_path}")
|
||||
if not os.path.exists(endpoint_path):
|
||||
raise Exception(f"Endpoint file {endpoint_path} does not exist.")
|
||||
with open(endpoint_path, 'r', encoding='utf-8') as file:
|
||||
endpoint = json.load(file)
|
||||
else:
|
||||
print(f"Inference: Using model {model_name} and language {language}")
|
||||
# construct the endpoint object from command line arguments considering that ollama is the endpoint
|
||||
endpoint = ollama_chat_endpoint('http://localhost:11434', model_name)
|
||||
|
||||
# run the inference
|
||||
process_problem_files(problems_dir, template_content, endpoint, language, max_problem_number = max_problem_number, skip_existing = args.skip_existing)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
122
ollama_client.py
Normal file
122
ollama_client.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
def ollama_list(api_base='http://localhost:11434'):
|
||||
# call api http://localhost:11434/api/tags with http get request
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
endpoint = f"{api_base}/api/tags"
|
||||
response = requests.get(endpoint, verify=False)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
models_list = data['models']
|
||||
models_dict = {}
|
||||
for entry in models_list:
|
||||
# get parameter_size and quantization_level from data
|
||||
model = entry['model']
|
||||
details = entry['details']
|
||||
attr = {}
|
||||
parameter_size = details['parameter_size']
|
||||
quantization_level = details['quantization_level']
|
||||
parameter_size = parameter_size[:-1]
|
||||
try:
|
||||
parameter_size = float(parameter_size)
|
||||
attr['parameter_size'] = parameter_size
|
||||
except ValueError:
|
||||
pass
|
||||
quantization_level_char = quantization_level[1:2]
|
||||
try:
|
||||
quantization_level = int(quantization_level_char)
|
||||
attr['quantization_level'] = quantization_level
|
||||
except ValueError:
|
||||
pass
|
||||
models_dict[model] = attr
|
||||
return models_dict
|
||||
|
||||
def ollama_chat_endpoint(api_base='http://localhost:11434', model_name='llama3.2:latest'):
|
||||
endpoint = {
|
||||
"name": model_name,
|
||||
"model": model_name,
|
||||
"key": "",
|
||||
"endpoint": f"{api_base}/v1/chat/completions",
|
||||
}
|
||||
return endpoint
|
||||
|
||||
def ollama_chat(endpoint, prompt='Hello World', temperature=0.0, max_tokens=4096):
|
||||
|
||||
# Disable SSL warnings
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
# Prepare the API endpoint URL
|
||||
stoptokens = ["[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "<EOS_TOKEN>", "</s>", "<|end|>"]
|
||||
|
||||
# Set headers and payload
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
if endpoint.get("key", ""):
|
||||
headers['Authorization'] = 'Bearer ' + endpoint["key"]
|
||||
|
||||
modelname = endpoint["model"]
|
||||
messages = []
|
||||
# o1 has special requirements
|
||||
if not modelname.startswith("o1"):
|
||||
messages.append({"content": "You are a helpful assistant", "role": "system"})
|
||||
else:
|
||||
temperature = 1.0 # o1 models need temperature 1.0
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
if modelname.startswith("o1") or modelname.startswith("4o"):
|
||||
stoptokens = []
|
||||
|
||||
payload = {
|
||||
"model": modelname,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"response_format": { "type": "text" },
|
||||
"stream": False
|
||||
}
|
||||
if len(stoptokens) > 0:
|
||||
payload["stop"] = stoptokens
|
||||
if modelname.startswith("o1"):
|
||||
payload["max_completion_tokens"] = max_tokens
|
||||
else:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint["endpoint"], headers=headers, json=payload, verify=False)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.RequestException as e:
|
||||
# print(f"Failed to access api: {e}")
|
||||
# Get the error message from the response
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
message = data.get('message', {})
|
||||
content = message.get('content', '')
|
||||
raise Exception(f"API request failed: {content}")
|
||||
except json.JSONDecodeError:
|
||||
raise Exception(f"API request failed: {e}")
|
||||
|
||||
# Parse the response
|
||||
try:
|
||||
data = response.json()
|
||||
#print(data)
|
||||
choices = data.get('choices', [])
|
||||
if len(choices) == 0:
|
||||
raise Exception("No response from the API: " + str(data))
|
||||
message = choices[0].get('message', {})
|
||||
content = message.get('content', '')
|
||||
return content
|
||||
except json.JSONDecodeError:
|
||||
raise Exception("Failed to parse JSON response from the API.")
|
||||
|
||||
|
||||
def main():
|
||||
models_dict = ollama_list()
|
||||
print(models_dict)
|
||||
answer = ollama_chat(ollama_chat_endpoint())
|
||||
print(answer)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
8
templates/template_rust.md
Normal file
8
templates/template_rust.md
Normal file
@@ -0,0 +1,8 @@
|
||||
Write a rust program for the following problem description:
|
||||
|
||||
$$$PROBLEM$$$
|
||||
|
||||
The rust program must not use any std::io module. It should also not have any user inputs. Do not use assert.
|
||||
The rust program should not use any imports unless absolutely necessary for the solution.
|
||||
The rust program must not output any comment lines or logging or debugging lines,
|
||||
it must just output a single value which is the solution of the problem.
|
||||
Reference in New Issue
Block a user