Expand source code
import os
from main import main, get_runtime_kwargs
from dotenv import load_dotenv
from core.analysis import CodeRetryLimitExceeded
from datetime import datetime
if __name__ == "__main__":
load_dotenv()
together_token = os.getenv("TOGETHER_API_KEY")
openai_token = os.getenv("OPENAI_API_KEY")
prompting_techniques = [
"zero-shot",
"few-shot",
]
assistants = [
"openai",
"mixtral-8x7b",
"llama-chat",
]
analysis_message_limit = 8
runtime = "jupyter-notebook"
report_params_no = {
"zero-shot_llama-chat": 1,
"zero-shot_openai": 2,
"zero-shot_mixtral-8x7b": 3,
"few-shot_llama-chat": 4,
"few-shot_openai": 5,
"few-shot_mixtral-8x7b": 6,
}
dataset_path = "data/wine-quality.csv"
dataset_name = dataset_path.split("/")[-1].split(".")[0]
ITERATIONS = 100
for iteration in range(1, ITERATIONS):
print(f"Iteration: {iteration}")
for assistant in assistants:
for prompting_technique in prompting_techniques:
kwargs = get_runtime_kwargs(
runtime,
prompting_technique,
assistant,
)
kwargs["analysis_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
kwargs["code_assistant_kwargs"]["api_key"] = (
openai_token if assistant == "openai" else together_token
)
report_no = report_params_no[f"{prompting_technique}_{assistant}"]
try:
output_pdf_path, error_count, code_messages_missing_snippets = main(
dataset_name,
dataset_path,
runtime,
assistant,
assistant,
prompting_technique,
analysis_message_limit=analysis_message_limit,
output_pdf_path=f"../{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.pdf",
**kwargs,
)
except Exception as e:
print(e)
continue
print(output_pdf_path)
print("Error Count:", error_count)
print("Code Messages Missing Snippets:", code_messages_missing_snippets)
# create text file with error count and code messages missing snippets
with open(
f"data/{dataset_name}_{report_no}_{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}.txt",
"w",
) as f:
f.write(f"Error Count: {error_count}\n")
f.write(
f"Code Messages Missing Snippets: {code_messages_missing_snippets}\n"
)