CrewAI에서 커스텀 LLM 구현을 만드는 방법을 알아보세요.
BaseLLM
추상 기반 클래스를 통해 커스텀 LLM 구현을 지원합니다. 이를 통해 LiteLLM에 내장 지원이 없는 모든 LLM 제공자를 통합하거나, 커스텀 인증 메커니즘을 구현할 수 있습니다.
from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union
import requests
class CustomLLM(BaseLLM):
def __init__(self, model: str, api_key: str, endpoint: str, temperature: Optional[float] = None):
# IMPORTANT: Call super().__init__() with required parameters
super().__init__(model=model, temperature=temperature)
self.api_key = api_key
self.endpoint = endpoint
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages."""
# Convert string to message format if needed
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Prepare request
payload = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
}
# Add tools if provided and supported
if tools and self.supports_function_calling():
payload["tools"] = tools
# Make API call
response = requests.post(
self.endpoint,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
},
json=payload,
timeout=30
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["message"]["content"]
def supports_function_calling(self) -> bool:
"""Override if your LLM supports function calling."""
return True # Change to False if your LLM doesn't support tools
def get_context_window_size(self) -> int:
"""Return the context window size of your LLM."""
return 8192 # Adjust based on your model's actual context window
from crewai import Agent, Task, Crew
# Assuming you have the CustomLLM class defined above
# Create your custom LLM
custom_llm = CustomLLM(
model="my-custom-model",
api_key="your-api-key",
endpoint="https://api.example.com/v1/chat/completions",
temperature=0.7
)
# Use with an agent
agent = Agent(
role="Research Assistant",
goal="Find and analyze information",
backstory="You are a research assistant.",
llm=custom_llm
)
# Create and execute tasks
task = Task(
description="Research the latest developments in AI",
expected_output="A comprehensive summary",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
__init__()
super().__init__(model, temperature)
을 호출해야 합니다:
def __init__(self, model: str, api_key: str, temperature: Optional[float] = None):
# 필수: 부모 생성자를 model과 temperature로 호출
super().__init__(model=model, temperature=temperature)
# 사용자 정의 초기화
self.api_key = api_key
call()
call()
메서드는 LLM 구현의 핵심입니다. 반드시 다음을 수행해야 합니다:
def supports_function_calling(self) -> bool:
"""Return True if your LLM supports function calling."""
return True # Default is True
def supports_stop_words(self) -> bool:
"""Return True if your LLM supports stop sequences."""
return True # Default is True
def get_context_window_size(self) -> int:
"""Return the context window size."""
return 4096 # Default is 4096
import requests
def call(self, messages, tools=None, callbacks=None, available_functions=None):
try:
response = requests.post(
self.endpoint,
headers={"Authorization": f"Bearer {self.api_key}"},
json=payload,
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except requests.Timeout:
raise TimeoutError("LLM request timed out")
except requests.RequestException as e:
raise RuntimeError(f"LLM request failed: {str(e)}")
except (KeyError, IndexError) as e:
raise ValueError(f"Invalid response format: {str(e)}")
from crewai import BaseLLM
from typing import Optional
class CustomAuthLLM(BaseLLM):
def __init__(self, model: str, auth_token: str, endpoint: str, temperature: Optional[float] = None):
super().__init__(model=model, temperature=temperature)
self.auth_token = auth_token
self.endpoint = endpoint
def call(self, messages, tools=None, callbacks=None, available_functions=None):
headers = {
"Authorization": f"Custom {self.auth_token}", # Custom auth format
"Content-Type": "application/json"
}
# Rest of implementation...
"\nObservation:"
를 스톱 워드로 자동 추가합니다. 만약 사용 중인 LLM이 스톱 워드를 지원한다면:
def call(self, messages, tools=None, callbacks=None, available_functions=None):
payload = {
"model": self.model,
"messages": messages,
"stop": self.stop # Include stop words in API call
}
# Make API call...
def supports_stop_words(self) -> bool:
return True # Your LLM supports stop sequences
def call(self, messages, tools=None, callbacks=None, available_functions=None):
response = self._make_api_call(messages, tools)
content = response["choices"][0]["message"]["content"]
# Manually truncate at stop words
if self.stop:
for stop_word in self.stop:
if stop_word in content:
content = content.split(stop_word)[0]
break
return content
def supports_stop_words(self) -> bool:
return False # Tell CrewAI we handle stop words manually
import json
def call(self, messages, tools=None, callbacks=None, available_functions=None):
# Convert string to message format
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Make API call
response = self._make_api_call(messages, tools)
message = response["choices"][0]["message"]
# Check for function calls
if "tool_calls" in message and available_functions:
return self._handle_function_calls(
message["tool_calls"], messages, tools, available_functions
)
return message["content"]
def _handle_function_calls(self, tool_calls, messages, tools, available_functions):
"""Handle function calling with proper message flow."""
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
if function_name in available_functions:
# Parse and execute function
function_args = json.loads(tool_call["function"]["arguments"])
function_result = available_functions[function_name](**function_args)
# Add function call and result to message history
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [tool_call]
})
messages.append({
"role": "tool",
"tool_call_id": tool_call["id"],
"name": function_name,
"content": str(function_result)
})
# Call LLM again with updated context
return self.call(messages, tools, None, available_functions)
return "Function call failed"
# ❌ Wrong - missing required parameters
def __init__(self, api_key: str):
super().__init__()
# ✅ Correct
def __init__(self, model: str, api_key: str, temperature: Optional[float] = None):
super().__init__(model=model, temperature=temperature)
supports_function_calling()
이 True
를 반환하는지 확인하세요tool_calls
를 처리하는지 확인하세요available_functions
매개변수가 올바르게 사용되는지 검증하세요from crewai import Agent, Task, Crew
def test_custom_llm():
llm = CustomLLM(
model="test-model",
api_key="test-key",
endpoint="https://api.test.com"
)
# Test basic call
result = llm.call("Hello, world!")
assert isinstance(result, str)
assert len(result) > 0
# Test with CrewAI agent
agent = Agent(
role="Test Agent",
goal="Test custom LLM",
backstory="A test agent.",
llm=llm
)
task = Task(
description="Say hello",
expected_output="A greeting",
agent=agent
)
crew = Crew(agents=[agent], tasks=[task])
result = crew.kickoff()
assert "hello" in result.raw.lower()
이 페이지가 도움이 되었나요?