picogent-py/picogent/providers/openai.py

188 lines
6.9 KiB
Python

"""
OpenAI-compatible provider (ChatGPT, xAI/Grok, etc.) using httpx
"""
import httpx
from typing import Dict, List, Any, Optional
from .base import BaseProvider
class OpenAIProvider(BaseProvider):
"""OpenAI-compatible provider using direct API calls.
Works with: OpenAI, xAI (Grok), Azure OpenAI, any OpenAI-compatible API.
"""
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
base_url: str = "https://api.openai.com/v1",
):
super().__init__(api_key, model)
self.base_url = base_url.rstrip("/")
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic-style tool defs to OpenAI function calling format"""
openai_tools = []
for tool in tools:
openai_tools.append({
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": tool.get("input_schema", {}),
},
})
return openai_tools
def _convert_messages(
self, messages: List[Dict[str, Any]], system_prompt: str
) -> List[Dict[str, Any]]:
"""Convert Anthropic-style messages to OpenAI format"""
oai_messages: List[Dict[str, Any]] = []
# System prompt as first message
if system_prompt:
oai_messages.append({"role": "system", "content": system_prompt})
for msg in messages:
role = msg.get("role")
# --- user message ---
if role == "user":
content = msg.get("content")
if isinstance(content, str):
oai_messages.append({"role": "user", "content": content})
elif isinstance(content, list):
# Could contain tool_result blocks (Anthropic format)
tool_results = [b for b in content if b.get("type") == "tool_result"]
if tool_results:
for tr in tool_results:
oai_messages.append({
"role": "tool",
"tool_call_id": tr.get("tool_use_id", ""),
"content": tr.get("content", ""),
})
else:
# Plain text blocks
text = " ".join(
b.get("text", "") for b in content if b.get("type") == "text"
)
if text:
oai_messages.append({"role": "user", "content": text})
# --- assistant message ---
elif role == "assistant":
content = msg.get("content")
if isinstance(content, str):
oai_messages.append({"role": "assistant", "content": content})
elif isinstance(content, list):
text_parts = []
tool_calls = []
for block in content:
if block.get("type") == "text":
text_parts.append(block.get("text", ""))
elif block.get("type") == "tool_use":
import json
tool_calls.append({
"id": block.get("id", ""),
"type": "function",
"function": {
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {})),
},
})
assistant_msg: Dict[str, Any] = {
"role": "assistant",
"content": "\n".join(text_parts) if text_parts else None,
}
if tool_calls:
assistant_msg["tool_calls"] = tool_calls
oai_messages.append(assistant_msg)
return oai_messages
async def generate_response(
self,
messages: List[Dict[str, Any]],
system_prompt: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 8192,
) -> Dict[str, Any]:
"""Generate response using OpenAI Chat Completions API"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
oai_messages = self._convert_messages(messages, system_prompt)
payload: Dict[str, Any] = {
"model": self.model,
"max_tokens": max_tokens,
"messages": oai_messages,
}
if tools:
payload["tools"] = self._convert_tools(tools)
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/chat/completions",
headers=headers,
json=payload,
timeout=120.0,
)
if response.status_code != 200:
raise Exception(
f"OpenAI API error {response.status_code}: {response.text}"
)
result = response.json()
choice = result.get("choices", [{}])[0]
message = choice.get("message", {})
text_content = message.get("content", "") or ""
oai_tool_calls = message.get("tool_calls", [])
response_data: Dict[str, Any] = {
"content": text_content,
"usage": result.get("usage", {}),
"model": result.get("model", self.model),
}
if oai_tool_calls:
import json
# Convert OpenAI tool_calls → our standard format
tool_calls = []
# Also build Anthropic-style raw_content for session storage
raw_content = []
if text_content:
raw_content.append({"type": "text", "text": text_content})
for tc in oai_tool_calls:
fn = tc.get("function", {})
args_str = fn.get("arguments", "{}")
try:
args = json.loads(args_str)
except json.JSONDecodeError:
args = {}
tool_calls.append({
"id": tc.get("id", ""),
"name": fn.get("name", ""),
"input": args,
})
raw_content.append({
"type": "tool_use",
"id": tc.get("id", ""),
"name": fn.get("name", ""),
"input": args,
})
response_data["tool_calls"] = tool_calls
response_data["raw_content"] = raw_content
return response_data