Module src.prompt_manager.tests.test_few_shot
Expand source code
import pytest
from unittest.mock import MagicMock
from prompt_manager.few_shot import (
FewShot,
ConversationRolesInternalEnum,
ConversationRolesEnum,
LLMType,
Message,
)
from unittest.mock import patch
# Mocking the Message class
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
def mock_message():
MagicMock(spec=Message)
@pytest.fixture
def fewshot_instance():
return FewShot()
@pytest.mark.parametrize(
"conversation",
[
[Message(role=ConversationRolesInternalEnum.CODE, content="Sample message")],
[
Message(
role=ConversationRolesInternalEnum.ANALYSIS, content="Sample message"
)
],
],
)
@pytest.mark.parametrize(
"agent",
[
(ConversationRolesInternalEnum.CODE, FewShot._CODE_GENERATION_PROMPT),
(
ConversationRolesInternalEnum.ANALYSIS,
FewShot._ANALYSIS_SUGGESTION_INTERPRETATION_PROMPT,
),
],
)
def test_generate_conversation_context(fewshot_instance, conversation, agent):
agent_type, expected_prompt = agent
result = fewshot_instance.generate_conversation_context(
conversation=conversation, agent_type=agent_type, llm_type=LLMType.GPT4
)
assert isinstance(result, list)
assert len(result) == 2
assert result[0].role == ConversationRolesEnum.SYSTEM
assert result[0].content == expected_prompt
@pytest.mark.parametrize(
"conversation",
[[Message(role=ConversationRolesInternalEnum.CODE, content="Sample message")]],
)
def test_generate_conversation_context_not_implemented(fewshot_instance, conversation):
with pytest.raises(NotImplementedError):
fewshot_instance.generate_conversation_context(
conversation=conversation,
agent_type="invalid_agent_type", # This should be an invalid agent type
llm_type=LLMType.GPT4,
)
Functions
def fewshot_instance()
-
Expand source code
@pytest.fixture def fewshot_instance(): return FewShot()
def mock_message()
-
Expand source code
@pytest.mark.filterwarnings("ignore::DeprecationWarning") def mock_message(): MagicMock(spec=Message)
def test_generate_conversation_context(fewshot_instance, conversation, agent)
-
Expand source code
@pytest.mark.parametrize( "conversation", [ [Message(role=ConversationRolesInternalEnum.CODE, content="Sample message")], [ Message( role=ConversationRolesInternalEnum.ANALYSIS, content="Sample message" ) ], ], ) @pytest.mark.parametrize( "agent", [ (ConversationRolesInternalEnum.CODE, FewShot._CODE_GENERATION_PROMPT), ( ConversationRolesInternalEnum.ANALYSIS, FewShot._ANALYSIS_SUGGESTION_INTERPRETATION_PROMPT, ), ], ) def test_generate_conversation_context(fewshot_instance, conversation, agent): agent_type, expected_prompt = agent result = fewshot_instance.generate_conversation_context( conversation=conversation, agent_type=agent_type, llm_type=LLMType.GPT4 ) assert isinstance(result, list) assert len(result) == 2 assert result[0].role == ConversationRolesEnum.SYSTEM assert result[0].content == expected_prompt
def test_generate_conversation_context_not_implemented(fewshot_instance, conversation)
-
Expand source code
@pytest.mark.parametrize( "conversation", [[Message(role=ConversationRolesInternalEnum.CODE, content="Sample message")]], ) def test_generate_conversation_context_not_implemented(fewshot_instance, conversation): with pytest.raises(NotImplementedError): fewshot_instance.generate_conversation_context( conversation=conversation, agent_type="invalid_agent_type", # This should be an invalid agent type llm_type=LLMType.GPT4, )