188 lines
6.9 KiB
Python
188 lines
6.9 KiB
Python
"""
|
|
OpenAI-compatible provider (ChatGPT, xAI/Grok, etc.) using httpx
|
|
"""
|
|
|
|
import httpx
|
|
from typing import Dict, List, Any, Optional
|
|
from .base import BaseProvider
|
|
|
|
|
|
class OpenAIProvider(BaseProvider):
|
|
"""OpenAI-compatible provider using direct API calls.
|
|
Works with: OpenAI, xAI (Grok), Azure OpenAI, any OpenAI-compatible API.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model: str = "gpt-4o",
|
|
base_url: str = "https://api.openai.com/v1",
|
|
):
|
|
super().__init__(api_key, model)
|
|
self.base_url = base_url.rstrip("/")
|
|
|
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Convert Anthropic-style tool defs to OpenAI function calling format"""
|
|
openai_tools = []
|
|
for tool in tools:
|
|
openai_tools.append({
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool["name"],
|
|
"description": tool.get("description", ""),
|
|
"parameters": tool.get("input_schema", {}),
|
|
},
|
|
})
|
|
return openai_tools
|
|
|
|
def _convert_messages(
|
|
self, messages: List[Dict[str, Any]], system_prompt: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Convert Anthropic-style messages to OpenAI format"""
|
|
oai_messages: List[Dict[str, Any]] = []
|
|
|
|
# System prompt as first message
|
|
if system_prompt:
|
|
oai_messages.append({"role": "system", "content": system_prompt})
|
|
|
|
for msg in messages:
|
|
role = msg.get("role")
|
|
|
|
# --- user message ---
|
|
if role == "user":
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
oai_messages.append({"role": "user", "content": content})
|
|
elif isinstance(content, list):
|
|
# Could contain tool_result blocks (Anthropic format)
|
|
tool_results = [b for b in content if b.get("type") == "tool_result"]
|
|
if tool_results:
|
|
for tr in tool_results:
|
|
oai_messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_use_id", ""),
|
|
"content": tr.get("content", ""),
|
|
})
|
|
else:
|
|
# Plain text blocks
|
|
text = " ".join(
|
|
b.get("text", "") for b in content if b.get("type") == "text"
|
|
)
|
|
if text:
|
|
oai_messages.append({"role": "user", "content": text})
|
|
|
|
# --- assistant message ---
|
|
elif role == "assistant":
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
oai_messages.append({"role": "assistant", "content": content})
|
|
elif isinstance(content, list):
|
|
text_parts = []
|
|
tool_calls = []
|
|
for block in content:
|
|
if block.get("type") == "text":
|
|
text_parts.append(block.get("text", ""))
|
|
elif block.get("type") == "tool_use":
|
|
import json
|
|
tool_calls.append({
|
|
"id": block.get("id", ""),
|
|
"type": "function",
|
|
"function": {
|
|
"name": block.get("name", ""),
|
|
"arguments": json.dumps(block.get("input", {})),
|
|
},
|
|
})
|
|
assistant_msg: Dict[str, Any] = {
|
|
"role": "assistant",
|
|
"content": "\n".join(text_parts) if text_parts else None,
|
|
}
|
|
if tool_calls:
|
|
assistant_msg["tool_calls"] = tool_calls
|
|
oai_messages.append(assistant_msg)
|
|
|
|
return oai_messages
|
|
|
|
async def generate_response(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
system_prompt: str,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
max_tokens: int = 8192,
|
|
) -> Dict[str, Any]:
|
|
"""Generate response using OpenAI Chat Completions API"""
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
}
|
|
|
|
oai_messages = self._convert_messages(messages, system_prompt)
|
|
|
|
payload: Dict[str, Any] = {
|
|
"model": self.model,
|
|
"max_tokens": max_tokens,
|
|
"messages": oai_messages,
|
|
}
|
|
|
|
if tools:
|
|
payload["tools"] = self._convert_tools(tools)
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/chat/completions",
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=120.0,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(
|
|
f"OpenAI API error {response.status_code}: {response.text}"
|
|
)
|
|
|
|
result = response.json()
|
|
choice = result.get("choices", [{}])[0]
|
|
message = choice.get("message", {})
|
|
|
|
text_content = message.get("content", "") or ""
|
|
oai_tool_calls = message.get("tool_calls", [])
|
|
|
|
response_data: Dict[str, Any] = {
|
|
"content": text_content,
|
|
"usage": result.get("usage", {}),
|
|
"model": result.get("model", self.model),
|
|
}
|
|
|
|
if oai_tool_calls:
|
|
import json
|
|
# Convert OpenAI tool_calls → our standard format
|
|
tool_calls = []
|
|
# Also build Anthropic-style raw_content for session storage
|
|
raw_content = []
|
|
if text_content:
|
|
raw_content.append({"type": "text", "text": text_content})
|
|
|
|
for tc in oai_tool_calls:
|
|
fn = tc.get("function", {})
|
|
args_str = fn.get("arguments", "{}")
|
|
try:
|
|
args = json.loads(args_str)
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
tool_calls.append({
|
|
"id": tc.get("id", ""),
|
|
"name": fn.get("name", ""),
|
|
"input": args,
|
|
})
|
|
raw_content.append({
|
|
"type": "tool_use",
|
|
"id": tc.get("id", ""),
|
|
"name": fn.get("name", ""),
|
|
"input": args,
|
|
})
|
|
|
|
response_data["tool_calls"] = tool_calls
|
|
response_data["raw_content"] = raw_content
|
|
|
|
return response_data
|