Module src.core.conversation

Expand source code
import json
from typing import List

from llm_api.iassistant import IAssistant
from models.models import ConversationRolesInternalEnum, LLMType, Message
from prompt_manager.ipromptmanager import IPromptManager
from runtime.iruntime import IRuntime
from core.utils import Colors, print_message


class Conversation:
    """Conversation class that handles the conversation flow and stores the conversation history."""

    def __init__(
        self,
        runtime: IRuntime,
        code_assistant: IAssistant,
        analysis_assistant: IAssistant,
        prompt: IPromptManager,
        conversation: List[Message] = None,
    ):
        self._conversation: List[Message] = conversation
        self._runtime: IRuntime = runtime
        self._code_assistant: IAssistant = code_assistant
        self._analysis_assistant: IAssistant = analysis_assistant
        self._prompt: IPromptManager = prompt
        self.code_messages_missing_snippets: int = 0

    @staticmethod
    def format_code_assistant_message(message: str, code_output: str) -> str:
        """Format the code assistant message."""
        return f"{message}\n\nHere is the output of the provided code:\n```{code_output}```"

    @staticmethod
    def _extract_code_snippets_from_message(message: str) -> list[str]:
        """Extract code snippets from message."""
        try:
            code_blocks = message.split("```")[1::2]
            return code_blocks
        except Exception as e:
            raise Exception("No code snippets found in response") from e

    def get_conversation(self) -> List[Message]:
        """Get the conversation."""
        return self._conversation

    def _add_to_conversation(
        self, role: ConversationRolesInternalEnum, content: str
    ) -> None:
        """Add message to the conversation."""
        self._conversation.append(Message(role=role, content=content))

    def _get_last_message(self) -> Message:
        """Get the last message in the conversation."""
        return self._conversation[-1]

    def _send_message_analysis(self) -> None:
        """
        Generates output from the analysis assistant and adds it to the conversation history.
        """
        analysis_conv = self._prompt.generate_conversation_context(
            self._conversation, ConversationRolesInternalEnum.ANALYSIS, LLMType.GPT4
        )
        analysis_response = self._analysis_assistant.generate_response(analysis_conv)
        self._add_to_conversation(
            ConversationRolesInternalEnum.ANALYSIS, analysis_response
        )
        self._runtime.add_description(analysis_response)

    def _execute_python_snippet(self, code: str) -> int:
        """Execute python code snippet in the runtime."""
        cell_idx = self._runtime.add_code(code)
        self._runtime.execute_cell(cell_idx)
        return cell_idx

    def _send_message_code(self) -> None:
        """
        Generates output from the code assistant and executes the code it generates.

        If the code assistant generates multiple code snippets, it executes them one by one.
        Output from each code snippet is stored in the conversation history and added to the report.
        In case snippet execution fails, the further execution is stopped.
        If traceback is longer than 20 lines, it is shortened to 20 lines.
        If plot was generated successfully, it is mentioned in the text output.
        """

        code_conv = self._prompt.generate_conversation_context(
            self._conversation, ConversationRolesInternalEnum.CODE, LLMType.GPT4
        )

        code_response = self._code_assistant.generate_response(
            code_conv,
            temperature=0.5,
        )
        code_snippets = self._extract_code_snippets_from_message(code_response)
        output = []
        first_snippet_idx = -1
        containsPythonSnippet = False
        for code_snippet in code_snippets:
            if not code_snippet.startswith("python"):
                continue  # Skip code snippets that are not in python
            containsPythonSnippet = True
            code = code_snippet[6:]  # Remove 'python' from the code snippet
            try:
                cell_idx = self._execute_python_snippet(code)
            except Exception as e:
                print("Error executing code snippet:\n")
                print(code)
                raise e

            if first_snippet_idx == -1:
                first_snippet_idx = cell_idx

            output.append(self._runtime.get_cell_output_stream(cell_idx))

            # Stop further code execution if the code snippet contains errors
            if output and ("Traceback" in output[-1] or "Error" in output[-1]):
                if "Traceback" in output[-1]:
                    pos = output[-1].find("Traceback")
                    traceback = output[-1][pos:].split("\n")
                    if len(traceback) > 20:
                        traceback = (
                            traceback[0] + "\n...\n" + "\n".join(traceback[-19:])
                        )
                        output[-1] = output[-1][:pos] + traceback

                break

            if self._runtime.check_if_plot_in_output(cell_idx):
                output[-1] += "\n\nPlot was generated successfully."
        if not containsPythonSnippet:
            self.code_messages_missing_snippets += 1

        if len(output) > 0:
            code_response = self.format_code_assistant_message(
                code_response, "\n".join(output)
            )

        if first_snippet_idx != -1:
            self._last_msg_first_cell_idx = first_snippet_idx

        self._add_to_conversation(
            role=ConversationRolesInternalEnum.CODE, content=code_response
        )

    def last_msg_contains_execution_errors(self) -> bool:
        """Check if the last step in the conversation contains errors."""
        last_message = self._get_last_message()
        if (
            last_message.role != ConversationRolesInternalEnum.CODE
            or "\n\nHere is the output of the provided code:\n```"
            not in last_message.content
        ):
            return False

        code_output = last_message.content.split(
            "\n\nHere is the output of the provided code:\n```"
        )[-1]
        if last_message.role == ConversationRolesInternalEnum.CODE and (
            "Traceback" in code_output or "Error" in code_output
        ):
            return True

        return False

    def perform_next_step(self) -> Message:
        """Perform the next step in the conversation."""
        # Generate response
        last_message = self._get_last_message()
        if last_message.role == ConversationRolesInternalEnum.CODE:
            self._send_message_analysis()
        elif last_message.role == ConversationRolesInternalEnum.ANALYSIS:
            self._send_message_code()
        else:
            raise Exception(f"Invalid conversation role: {last_message.role}")

        return self._get_last_message()

    def fix_last_code_message(self) -> Message:
        """
        Fix the last message in the conversation.
        Only code messages can be fixed.
        It impersonates the analysis assistant and sends the last message to the code assistant asking for a fix.
        """

        last_message = self._get_last_message()
        if last_message.role != ConversationRolesInternalEnum.CODE:
            raise Exception("Only code messages can be fixed")

        if not self.last_msg_contains_execution_errors():
            raise Exception("Last message does not contain errors")

        fix_request_msg = Message(
            role=ConversationRolesInternalEnum.ANALYSIS,
            content="Error during code execution occurred. Please fix it.",
        )

        self._conversation.append(fix_request_msg)
        print_message(fix_request_msg, Colors.BLUE)

        previous_msg_first_cell_idx = self._last_msg_first_cell_idx

        self.perform_next_step()

        # # Cleaning up previous code and fix request
        self._conversation.pop(-3)
        self._conversation.pop(-2)
        for _ in range(previous_msg_first_cell_idx, self._last_msg_first_cell_idx):
            self._runtime.remove_cell(previous_msg_first_cell_idx)

        self._last_msg_first_cell_idx = previous_msg_first_cell_idx

        return self._get_last_message()

    def get_conversation_json(self) -> str:
        """Get the conversation in json format."""
        return json.dumps([message.model_dump_json() for message in self._conversation])

Classes

class Conversation (runtime: runtime.iruntime.IRuntime, code_assistant: llm_api.iassistant.IAssistant, analysis_assistant: llm_api.iassistant.IAssistant, prompt: prompt_manager.ipromptmanager.IPromptManager, conversation: List[models.models.Message] = None)

Conversation class that handles the conversation flow and stores the conversation history.

Expand source code
class Conversation:
    """Conversation class that handles the conversation flow and stores the conversation history."""

    def __init__(
        self,
        runtime: IRuntime,
        code_assistant: IAssistant,
        analysis_assistant: IAssistant,
        prompt: IPromptManager,
        conversation: List[Message] = None,
    ):
        self._conversation: List[Message] = conversation
        self._runtime: IRuntime = runtime
        self._code_assistant: IAssistant = code_assistant
        self._analysis_assistant: IAssistant = analysis_assistant
        self._prompt: IPromptManager = prompt
        self.code_messages_missing_snippets: int = 0

    @staticmethod
    def format_code_assistant_message(message: str, code_output: str) -> str:
        """Format the code assistant message."""
        return f"{message}\n\nHere is the output of the provided code:\n```{code_output}```"

    @staticmethod
    def _extract_code_snippets_from_message(message: str) -> list[str]:
        """Extract code snippets from message."""
        try:
            code_blocks = message.split("```")[1::2]
            return code_blocks
        except Exception as e:
            raise Exception("No code snippets found in response") from e

    def get_conversation(self) -> List[Message]:
        """Get the conversation."""
        return self._conversation

    def _add_to_conversation(
        self, role: ConversationRolesInternalEnum, content: str
    ) -> None:
        """Add message to the conversation."""
        self._conversation.append(Message(role=role, content=content))

    def _get_last_message(self) -> Message:
        """Get the last message in the conversation."""
        return self._conversation[-1]

    def _send_message_analysis(self) -> None:
        """
        Generates output from the analysis assistant and adds it to the conversation history.
        """
        analysis_conv = self._prompt.generate_conversation_context(
            self._conversation, ConversationRolesInternalEnum.ANALYSIS, LLMType.GPT4
        )
        analysis_response = self._analysis_assistant.generate_response(analysis_conv)
        self._add_to_conversation(
            ConversationRolesInternalEnum.ANALYSIS, analysis_response
        )
        self._runtime.add_description(analysis_response)

    def _execute_python_snippet(self, code: str) -> int:
        """Execute python code snippet in the runtime."""
        cell_idx = self._runtime.add_code(code)
        self._runtime.execute_cell(cell_idx)
        return cell_idx

    def _send_message_code(self) -> None:
        """
        Generates output from the code assistant and executes the code it generates.

        If the code assistant generates multiple code snippets, it executes them one by one.
        Output from each code snippet is stored in the conversation history and added to the report.
        In case snippet execution fails, the further execution is stopped.
        If traceback is longer than 20 lines, it is shortened to 20 lines.
        If plot was generated successfully, it is mentioned in the text output.
        """

        code_conv = self._prompt.generate_conversation_context(
            self._conversation, ConversationRolesInternalEnum.CODE, LLMType.GPT4
        )

        code_response = self._code_assistant.generate_response(
            code_conv,
            temperature=0.5,
        )
        code_snippets = self._extract_code_snippets_from_message(code_response)
        output = []
        first_snippet_idx = -1
        containsPythonSnippet = False
        for code_snippet in code_snippets:
            if not code_snippet.startswith("python"):
                continue  # Skip code snippets that are not in python
            containsPythonSnippet = True
            code = code_snippet[6:]  # Remove 'python' from the code snippet
            try:
                cell_idx = self._execute_python_snippet(code)
            except Exception as e:
                print("Error executing code snippet:\n")
                print(code)
                raise e

            if first_snippet_idx == -1:
                first_snippet_idx = cell_idx

            output.append(self._runtime.get_cell_output_stream(cell_idx))

            # Stop further code execution if the code snippet contains errors
            if output and ("Traceback" in output[-1] or "Error" in output[-1]):
                if "Traceback" in output[-1]:
                    pos = output[-1].find("Traceback")
                    traceback = output[-1][pos:].split("\n")
                    if len(traceback) > 20:
                        traceback = (
                            traceback[0] + "\n...\n" + "\n".join(traceback[-19:])
                        )
                        output[-1] = output[-1][:pos] + traceback

                break

            if self._runtime.check_if_plot_in_output(cell_idx):
                output[-1] += "\n\nPlot was generated successfully."
        if not containsPythonSnippet:
            self.code_messages_missing_snippets += 1

        if len(output) > 0:
            code_response = self.format_code_assistant_message(
                code_response, "\n".join(output)
            )

        if first_snippet_idx != -1:
            self._last_msg_first_cell_idx = first_snippet_idx

        self._add_to_conversation(
            role=ConversationRolesInternalEnum.CODE, content=code_response
        )

    def last_msg_contains_execution_errors(self) -> bool:
        """Check if the last step in the conversation contains errors."""
        last_message = self._get_last_message()
        if (
            last_message.role != ConversationRolesInternalEnum.CODE
            or "\n\nHere is the output of the provided code:\n```"
            not in last_message.content
        ):
            return False

        code_output = last_message.content.split(
            "\n\nHere is the output of the provided code:\n```"
        )[-1]
        if last_message.role == ConversationRolesInternalEnum.CODE and (
            "Traceback" in code_output or "Error" in code_output
        ):
            return True

        return False

    def perform_next_step(self) -> Message:
        """Perform the next step in the conversation."""
        # Generate response
        last_message = self._get_last_message()
        if last_message.role == ConversationRolesInternalEnum.CODE:
            self._send_message_analysis()
        elif last_message.role == ConversationRolesInternalEnum.ANALYSIS:
            self._send_message_code()
        else:
            raise Exception(f"Invalid conversation role: {last_message.role}")

        return self._get_last_message()

    def fix_last_code_message(self) -> Message:
        """
        Fix the last message in the conversation.
        Only code messages can be fixed.
        It impersonates the analysis assistant and sends the last message to the code assistant asking for a fix.
        """

        last_message = self._get_last_message()
        if last_message.role != ConversationRolesInternalEnum.CODE:
            raise Exception("Only code messages can be fixed")

        if not self.last_msg_contains_execution_errors():
            raise Exception("Last message does not contain errors")

        fix_request_msg = Message(
            role=ConversationRolesInternalEnum.ANALYSIS,
            content="Error during code execution occurred. Please fix it.",
        )

        self._conversation.append(fix_request_msg)
        print_message(fix_request_msg, Colors.BLUE)

        previous_msg_first_cell_idx = self._last_msg_first_cell_idx

        self.perform_next_step()

        # # Cleaning up previous code and fix request
        self._conversation.pop(-3)
        self._conversation.pop(-2)
        for _ in range(previous_msg_first_cell_idx, self._last_msg_first_cell_idx):
            self._runtime.remove_cell(previous_msg_first_cell_idx)

        self._last_msg_first_cell_idx = previous_msg_first_cell_idx

        return self._get_last_message()

    def get_conversation_json(self) -> str:
        """Get the conversation in json format."""
        return json.dumps([message.model_dump_json() for message in self._conversation])

Static methods

def format_code_assistant_message(message: str, code_output: str) ‑> str

Format the code assistant message.

Expand source code
@staticmethod
def format_code_assistant_message(message: str, code_output: str) -> str:
    """Format the code assistant message."""
    return f"{message}\n\nHere is the output of the provided code:\n```{code_output}```"

Methods

def fix_last_code_message(self) ‑> models.models.Message

Fix the last message in the conversation. Only code messages can be fixed. It impersonates the analysis assistant and sends the last message to the code assistant asking for a fix.

Expand source code
def fix_last_code_message(self) -> Message:
    """
    Fix the last message in the conversation.
    Only code messages can be fixed.
    It impersonates the analysis assistant and sends the last message to the code assistant asking for a fix.
    """

    last_message = self._get_last_message()
    if last_message.role != ConversationRolesInternalEnum.CODE:
        raise Exception("Only code messages can be fixed")

    if not self.last_msg_contains_execution_errors():
        raise Exception("Last message does not contain errors")

    fix_request_msg = Message(
        role=ConversationRolesInternalEnum.ANALYSIS,
        content="Error during code execution occurred. Please fix it.",
    )

    self._conversation.append(fix_request_msg)
    print_message(fix_request_msg, Colors.BLUE)

    previous_msg_first_cell_idx = self._last_msg_first_cell_idx

    self.perform_next_step()

    # # Cleaning up previous code and fix request
    self._conversation.pop(-3)
    self._conversation.pop(-2)
    for _ in range(previous_msg_first_cell_idx, self._last_msg_first_cell_idx):
        self._runtime.remove_cell(previous_msg_first_cell_idx)

    self._last_msg_first_cell_idx = previous_msg_first_cell_idx

    return self._get_last_message()
def get_conversation(self) ‑> List[models.models.Message]

Get the conversation.

Expand source code
def get_conversation(self) -> List[Message]:
    """Get the conversation."""
    return self._conversation
def get_conversation_json(self) ‑> str

Get the conversation in json format.

Expand source code
def get_conversation_json(self) -> str:
    """Get the conversation in json format."""
    return json.dumps([message.model_dump_json() for message in self._conversation])
def last_msg_contains_execution_errors(self) ‑> bool

Check if the last step in the conversation contains errors.

Expand source code
def last_msg_contains_execution_errors(self) -> bool:
    """Check if the last step in the conversation contains errors."""
    last_message = self._get_last_message()
    if (
        last_message.role != ConversationRolesInternalEnum.CODE
        or "\n\nHere is the output of the provided code:\n```"
        not in last_message.content
    ):
        return False

    code_output = last_message.content.split(
        "\n\nHere is the output of the provided code:\n```"
    )[-1]
    if last_message.role == ConversationRolesInternalEnum.CODE and (
        "Traceback" in code_output or "Error" in code_output
    ):
        return True

    return False
def perform_next_step(self) ‑> models.models.Message

Perform the next step in the conversation.

Expand source code
def perform_next_step(self) -> Message:
    """Perform the next step in the conversation."""
    # Generate response
    last_message = self._get_last_message()
    if last_message.role == ConversationRolesInternalEnum.CODE:
        self._send_message_analysis()
    elif last_message.role == ConversationRolesInternalEnum.ANALYSIS:
        self._send_message_code()
    else:
        raise Exception(f"Invalid conversation role: {last_message.role}")

    return self._get_last_message()