Add OpenAI, Gemini, xAI providers (all via httpx, no SDKs)
This commit is contained in:
parent
cb7188d857
commit
4f07bf7bec
@ -1,7 +1,34 @@
|
|||||||
{
|
{
|
||||||
|
"_comment": "Provider: anthropic | openai | gemini | xai (or any OpenAI-compatible)",
|
||||||
|
|
||||||
"provider": "anthropic",
|
"provider": "anthropic",
|
||||||
"model": "claude-sonnet-4-20250514",
|
"model": "claude-sonnet-4-20250514",
|
||||||
"api_key": "env:ANTHROPIC_API_KEY",
|
"api_key": "env:ANTHROPIC_API_KEY",
|
||||||
|
|
||||||
|
"_examples": {
|
||||||
|
"openai": {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": "env:OPENAI_API_KEY"
|
||||||
|
},
|
||||||
|
"gemini": {
|
||||||
|
"provider": "gemini",
|
||||||
|
"model": "gemini-2.5-flash",
|
||||||
|
"api_key": "env:GEMINI_API_KEY"
|
||||||
|
},
|
||||||
|
"xai": {
|
||||||
|
"provider": "xai",
|
||||||
|
"model": "grok-3",
|
||||||
|
"api_key": "env:XAI_API_KEY"
|
||||||
|
},
|
||||||
|
"custom_openai_compatible": {
|
||||||
|
"provider": "openai",
|
||||||
|
"model": "my-model",
|
||||||
|
"api_key": "env:MY_API_KEY",
|
||||||
|
"base_url": "https://my-provider.com/v1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"max_iterations": 20,
|
"max_iterations": 20,
|
||||||
"workspace": ".",
|
"workspace": ".",
|
||||||
|
|||||||
@ -10,6 +10,8 @@ from .config import Config
|
|||||||
from .session import Session
|
from .session import Session
|
||||||
from .context import ContextBuilder
|
from .context import ContextBuilder
|
||||||
from .providers.anthropic import AnthropicProvider
|
from .providers.anthropic import AnthropicProvider
|
||||||
|
from .providers.openai import OpenAIProvider
|
||||||
|
from .providers.gemini import GeminiProvider
|
||||||
from .tools.registry import ToolRegistry
|
from .tools.registry import ToolRegistry
|
||||||
from .tools.read import ReadTool
|
from .tools.read import ReadTool
|
||||||
from .tools.write import WriteTool
|
from .tools.write import WriteTool
|
||||||
@ -26,10 +28,27 @@ class Agent:
|
|||||||
self.context_builder = ContextBuilder(config.workspace)
|
self.context_builder = ContextBuilder(config.workspace)
|
||||||
|
|
||||||
# Initialize provider
|
# Initialize provider
|
||||||
|
base_url = getattr(config, 'base_url', None)
|
||||||
if config.provider == "anthropic":
|
if config.provider == "anthropic":
|
||||||
self.provider = AnthropicProvider(config.api_key, config.model)
|
self.provider = AnthropicProvider(config.api_key, config.model)
|
||||||
|
elif config.provider == "openai":
|
||||||
|
self.provider = OpenAIProvider(
|
||||||
|
config.api_key, config.model,
|
||||||
|
base_url=base_url or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
elif config.provider == "xai":
|
||||||
|
self.provider = OpenAIProvider(
|
||||||
|
config.api_key, config.model,
|
||||||
|
base_url=base_url or "https://api.x.ai/v1"
|
||||||
|
)
|
||||||
|
elif config.provider == "gemini":
|
||||||
|
self.provider = GeminiProvider(config.api_key, config.model)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown provider: {config.provider}")
|
# Assume OpenAI-compatible for unknown providers
|
||||||
|
self.provider = OpenAIProvider(
|
||||||
|
config.api_key, config.model,
|
||||||
|
base_url=base_url or "https://api.openai.com/v1"
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize tools
|
# Initialize tools
|
||||||
self.tool_registry = ToolRegistry()
|
self.tool_registry = ToolRegistry()
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class Config:
|
|||||||
max_iterations: int = 20
|
max_iterations: int = 20
|
||||||
workspace: str = "."
|
workspace: str = "."
|
||||||
system_prompt: str = "You are a helpful coding assistant."
|
system_prompt: str = "You are a helpful coding assistant."
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, config_path: str) -> "Config":
|
def from_file(cls, config_path: str) -> "Config":
|
||||||
@ -43,7 +44,8 @@ class Config:
|
|||||||
max_tokens=data.get('max_tokens', 8192),
|
max_tokens=data.get('max_tokens', 8192),
|
||||||
max_iterations=data.get('max_iterations', 20),
|
max_iterations=data.get('max_iterations', 20),
|
||||||
workspace=data.get('workspace', '.'),
|
workspace=data.get('workspace', '.'),
|
||||||
system_prompt=data.get('system_prompt', 'You are a helpful coding assistant.')
|
system_prompt=data.get('system_prompt', 'You are a helpful coding assistant.'),
|
||||||
|
base_url=data.get('base_url', None)
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
"""
|
|
||||||
PicoGent Providers Package
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .base import BaseProvider
|
from .base import BaseProvider
|
||||||
from .anthropic import AnthropicProvider
|
from .anthropic import AnthropicProvider
|
||||||
|
from .openai import OpenAIProvider
|
||||||
|
from .gemini import GeminiProvider
|
||||||
|
|
||||||
__all__ = ["BaseProvider", "AnthropicProvider"]
|
__all__ = ["BaseProvider", "AnthropicProvider", "OpenAIProvider", "GeminiProvider"]
|
||||||
|
|||||||
192
picogent/providers/gemini.py
Normal file
192
picogent/providers/gemini.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
"""
|
||||||
|
Google Gemini provider using httpx (REST API, no SDK)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
from .base import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiProvider(BaseProvider):
|
||||||
|
"""Google Gemini provider using generateContent REST API"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "gemini-2.5-flash",
|
||||||
|
):
|
||||||
|
super().__init__(api_key, model)
|
||||||
|
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Convert Anthropic-style tool defs to Gemini function declarations"""
|
||||||
|
declarations = []
|
||||||
|
for tool in tools:
|
||||||
|
schema = tool.get("input_schema", {})
|
||||||
|
# Gemini doesn't support 'additionalProperties' in some cases
|
||||||
|
clean_schema = {k: v for k, v in schema.items() if k != "additionalProperties"}
|
||||||
|
declarations.append({
|
||||||
|
"name": tool["name"],
|
||||||
|
"description": tool.get("description", ""),
|
||||||
|
"parameters": clean_schema,
|
||||||
|
})
|
||||||
|
return [{"functionDeclarations": declarations}]
|
||||||
|
|
||||||
|
def _convert_messages(
|
||||||
|
self, messages: List[Dict[str, Any]], system_prompt: str
|
||||||
|
) -> tuple[str, List[Dict[str, Any]]]:
|
||||||
|
"""Convert Anthropic-style messages to Gemini contents format.
|
||||||
|
Returns (system_instruction, contents).
|
||||||
|
"""
|
||||||
|
contents: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
|
||||||
|
if role == "user":
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
contents.append({
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": content}],
|
||||||
|
})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Check for tool_result blocks
|
||||||
|
tool_results = [b for b in content if b.get("type") == "tool_result"]
|
||||||
|
if tool_results:
|
||||||
|
parts = []
|
||||||
|
for tr in tool_results:
|
||||||
|
parts.append({
|
||||||
|
"functionResponse": {
|
||||||
|
"name": tr.get("tool_name", "unknown"),
|
||||||
|
"response": {"result": tr.get("content", "")},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
contents.append({"role": "user", "parts": parts})
|
||||||
|
else:
|
||||||
|
text = " ".join(
|
||||||
|
b.get("text", "") for b in content if b.get("type") == "text"
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
contents.append({
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": text}],
|
||||||
|
})
|
||||||
|
|
||||||
|
elif role == "assistant":
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
contents.append({
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": content}],
|
||||||
|
})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
parts = []
|
||||||
|
for block in content:
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text = block.get("text", "")
|
||||||
|
if text:
|
||||||
|
parts.append({"text": text})
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
parts.append({
|
||||||
|
"functionCall": {
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"args": block.get("input", {}),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if parts:
|
||||||
|
contents.append({"role": "model", "parts": parts})
|
||||||
|
|
||||||
|
return system_prompt, contents
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
system_prompt: str,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
max_tokens: int = 8192,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate response using Gemini generateContent API"""
|
||||||
|
|
||||||
|
system_instruction, contents = self._convert_messages(messages, system_prompt)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"contents": contents,
|
||||||
|
"generationConfig": {
|
||||||
|
"maxOutputTokens": max_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if system_instruction:
|
||||||
|
payload["systemInstruction"] = {
|
||||||
|
"parts": [{"text": system_instruction}]
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = self._convert_tools(tools)
|
||||||
|
|
||||||
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url,
|
||||||
|
json=payload,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Gemini API error {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# Parse candidates
|
||||||
|
candidates = result.get("candidates", [])
|
||||||
|
if not candidates:
|
||||||
|
return {"content": "", "usage": {}, "model": self.model}
|
||||||
|
|
||||||
|
parts = candidates[0].get("content", {}).get("parts", [])
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
tool_calls = []
|
||||||
|
raw_content = []
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
if "text" in part:
|
||||||
|
text_parts.append(part["text"])
|
||||||
|
raw_content.append({"type": "text", "text": part["text"]})
|
||||||
|
elif "functionCall" in part:
|
||||||
|
fc = part["functionCall"]
|
||||||
|
call_id = f"call_{uuid.uuid4().hex[:12]}"
|
||||||
|
tool_calls.append({
|
||||||
|
"id": call_id,
|
||||||
|
"name": fc.get("name", ""),
|
||||||
|
"input": fc.get("args", {}),
|
||||||
|
})
|
||||||
|
raw_content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": call_id,
|
||||||
|
"name": fc.get("name", ""),
|
||||||
|
"input": fc.get("args", {}),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Usage info
|
||||||
|
usage_meta = result.get("usageMetadata", {})
|
||||||
|
|
||||||
|
response_data: Dict[str, Any] = {
|
||||||
|
"content": "\n".join(text_parts) if text_parts else "",
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": usage_meta.get("promptTokenCount", 0),
|
||||||
|
"output_tokens": usage_meta.get("candidatesTokenCount", 0),
|
||||||
|
},
|
||||||
|
"model": self.model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
response_data["tool_calls"] = tool_calls
|
||||||
|
response_data["raw_content"] = raw_content
|
||||||
|
|
||||||
|
return response_data
|
||||||
187
picogent/providers/openai.py
Normal file
187
picogent/providers/openai.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
"""
|
||||||
|
OpenAI-compatible provider (ChatGPT, xAI/Grok, etc.) using httpx
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
from .base import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
"""OpenAI-compatible provider using direct API calls.
|
||||||
|
Works with: OpenAI, xAI (Grok), Azure OpenAI, any OpenAI-compatible API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
base_url: str = "https://api.openai.com/v1",
|
||||||
|
):
|
||||||
|
super().__init__(api_key, model)
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
|
||||||
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
|
"""Convert Anthropic-style tool defs to OpenAI function calling format"""
|
||||||
|
openai_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
openai_tools.append({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool["name"],
|
||||||
|
"description": tool.get("description", ""),
|
||||||
|
"parameters": tool.get("input_schema", {}),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return openai_tools
|
||||||
|
|
||||||
|
def _convert_messages(
|
||||||
|
self, messages: List[Dict[str, Any]], system_prompt: str
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Convert Anthropic-style messages to OpenAI format"""
|
||||||
|
oai_messages: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# System prompt as first message
|
||||||
|
if system_prompt:
|
||||||
|
oai_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
|
||||||
|
# --- user message ---
|
||||||
|
if role == "user":
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
oai_messages.append({"role": "user", "content": content})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
# Could contain tool_result blocks (Anthropic format)
|
||||||
|
tool_results = [b for b in content if b.get("type") == "tool_result"]
|
||||||
|
if tool_results:
|
||||||
|
for tr in tool_results:
|
||||||
|
oai_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr.get("tool_use_id", ""),
|
||||||
|
"content": tr.get("content", ""),
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Plain text blocks
|
||||||
|
text = " ".join(
|
||||||
|
b.get("text", "") for b in content if b.get("type") == "text"
|
||||||
|
)
|
||||||
|
if text:
|
||||||
|
oai_messages.append({"role": "user", "content": text})
|
||||||
|
|
||||||
|
# --- assistant message ---
|
||||||
|
elif role == "assistant":
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
oai_messages.append({"role": "assistant", "content": content})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
tool_calls = []
|
||||||
|
for block in content:
|
||||||
|
if block.get("type") == "text":
|
||||||
|
text_parts.append(block.get("text", ""))
|
||||||
|
elif block.get("type") == "tool_use":
|
||||||
|
import json
|
||||||
|
tool_calls.append({
|
||||||
|
"id": block.get("id", ""),
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": block.get("name", ""),
|
||||||
|
"arguments": json.dumps(block.get("input", {})),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assistant_msg: Dict[str, Any] = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "\n".join(text_parts) if text_parts else None,
|
||||||
|
}
|
||||||
|
if tool_calls:
|
||||||
|
assistant_msg["tool_calls"] = tool_calls
|
||||||
|
oai_messages.append(assistant_msg)
|
||||||
|
|
||||||
|
return oai_messages
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
system_prompt: str,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
max_tokens: int = 8192,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate response using OpenAI Chat Completions API"""
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
}
|
||||||
|
|
||||||
|
oai_messages = self._convert_messages(messages, system_prompt)
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"model": self.model,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"messages": oai_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
payload["tools"] = self._convert_tools(tools)
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{self.base_url}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"OpenAI API error {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
choice = result.get("choices", [{}])[0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
|
||||||
|
text_content = message.get("content", "") or ""
|
||||||
|
oai_tool_calls = message.get("tool_calls", [])
|
||||||
|
|
||||||
|
response_data: Dict[str, Any] = {
|
||||||
|
"content": text_content,
|
||||||
|
"usage": result.get("usage", {}),
|
||||||
|
"model": result.get("model", self.model),
|
||||||
|
}
|
||||||
|
|
||||||
|
if oai_tool_calls:
|
||||||
|
import json
|
||||||
|
# Convert OpenAI tool_calls → our standard format
|
||||||
|
tool_calls = []
|
||||||
|
# Also build Anthropic-style raw_content for session storage
|
||||||
|
raw_content = []
|
||||||
|
if text_content:
|
||||||
|
raw_content.append({"type": "text", "text": text_content})
|
||||||
|
|
||||||
|
for tc in oai_tool_calls:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
args_str = fn.get("arguments", "{}")
|
||||||
|
try:
|
||||||
|
args = json.loads(args_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
args = {}
|
||||||
|
tool_calls.append({
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"name": fn.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
})
|
||||||
|
raw_content.append({
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": tc.get("id", ""),
|
||||||
|
"name": fn.get("name", ""),
|
||||||
|
"input": args,
|
||||||
|
})
|
||||||
|
|
||||||
|
response_data["tool_calls"] = tool_calls
|
||||||
|
response_data["raw_content"] = raw_content
|
||||||
|
|
||||||
|
return response_data
|
||||||
Loading…
Reference in New Issue
Block a user