picogent-py/picogent/providers/gemini.py

193 lines
6.9 KiB
Python

"""
Google Gemini provider using httpx (REST API, no SDK)
"""
import httpx
import json
import uuid
from typing import Dict, List, Any, Optional
from .base import BaseProvider
class GeminiProvider(BaseProvider):
"""Google Gemini provider using generateContent REST API"""
def __init__(
self,
api_key: str,
model: str = "gemini-2.5-flash",
):
super().__init__(api_key, model)
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic-style tool defs to Gemini function declarations"""
declarations = []
for tool in tools:
schema = tool.get("input_schema", {})
# Gemini doesn't support 'additionalProperties' in some cases
clean_schema = {k: v for k, v in schema.items() if k != "additionalProperties"}
declarations.append({
"name": tool["name"],
"description": tool.get("description", ""),
"parameters": clean_schema,
})
return [{"functionDeclarations": declarations}]
def _convert_messages(
self, messages: List[Dict[str, Any]], system_prompt: str
) -> tuple[str, List[Dict[str, Any]]]:
"""Convert Anthropic-style messages to Gemini contents format.
Returns (system_instruction, contents).
"""
contents: List[Dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
if role == "user":
content = msg.get("content")
if isinstance(content, str):
contents.append({
"role": "user",
"parts": [{"text": content}],
})
elif isinstance(content, list):
# Check for tool_result blocks
tool_results = [b for b in content if b.get("type") == "tool_result"]
if tool_results:
parts = []
for tr in tool_results:
parts.append({
"functionResponse": {
"name": tr.get("tool_name", "unknown"),
"response": {"result": tr.get("content", "")},
}
})
contents.append({"role": "user", "parts": parts})
else:
text = " ".join(
b.get("text", "") for b in content if b.get("type") == "text"
)
if text:
contents.append({
"role": "user",
"parts": [{"text": text}],
})
elif role == "assistant":
content = msg.get("content")
if isinstance(content, str):
contents.append({
"role": "model",
"parts": [{"text": content}],
})
elif isinstance(content, list):
parts = []
for block in content:
if block.get("type") == "text":
text = block.get("text", "")
if text:
parts.append({"text": text})
elif block.get("type") == "tool_use":
parts.append({
"functionCall": {
"name": block.get("name", ""),
"args": block.get("input", {}),
}
})
if parts:
contents.append({"role": "model", "parts": parts})
return system_prompt, contents
async def generate_response(
self,
messages: List[Dict[str, Any]],
system_prompt: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_tokens: int = 8192,
) -> Dict[str, Any]:
"""Generate response using Gemini generateContent API"""
system_instruction, contents = self._convert_messages(messages, system_prompt)
payload: Dict[str, Any] = {
"contents": contents,
"generationConfig": {
"maxOutputTokens": max_tokens,
},
}
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
if tools:
payload["tools"] = self._convert_tools(tools)
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
async with httpx.AsyncClient() as client:
response = await client.post(
url,
json=payload,
timeout=120.0,
)
if response.status_code != 200:
raise Exception(
f"Gemini API error {response.status_code}: {response.text}"
)
result = response.json()
# Parse candidates
candidates = result.get("candidates", [])
if not candidates:
return {"content": "", "usage": {}, "model": self.model}
parts = candidates[0].get("content", {}).get("parts", [])
text_parts = []
tool_calls = []
raw_content = []
for part in parts:
if "text" in part:
text_parts.append(part["text"])
raw_content.append({"type": "text", "text": part["text"]})
elif "functionCall" in part:
fc = part["functionCall"]
call_id = f"call_{uuid.uuid4().hex[:12]}"
tool_calls.append({
"id": call_id,
"name": fc.get("name", ""),
"input": fc.get("args", {}),
})
raw_content.append({
"type": "tool_use",
"id": call_id,
"name": fc.get("name", ""),
"input": fc.get("args", {}),
})
# Usage info
usage_meta = result.get("usageMetadata", {})
response_data: Dict[str, Any] = {
"content": "\n".join(text_parts) if text_parts else "",
"usage": {
"input_tokens": usage_meta.get("promptTokenCount", 0),
"output_tokens": usage_meta.get("candidatesTokenCount", 0),
},
"model": self.model,
}
if tool_calls:
response_data["tool_calls"] = tool_calls
response_data["raw_content"] = raw_content
return response_data