193 lines
6.9 KiB
Python
193 lines
6.9 KiB
Python
"""
|
|
Google Gemini provider using httpx (REST API, no SDK)
|
|
"""
|
|
|
|
import httpx
|
|
import json
|
|
import uuid
|
|
from typing import Dict, List, Any, Optional
|
|
from .base import BaseProvider
|
|
|
|
|
|
class GeminiProvider(BaseProvider):
|
|
"""Google Gemini provider using generateContent REST API"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model: str = "gemini-2.5-flash",
|
|
):
|
|
super().__init__(api_key, model)
|
|
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
|
|
|
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Convert Anthropic-style tool defs to Gemini function declarations"""
|
|
declarations = []
|
|
for tool in tools:
|
|
schema = tool.get("input_schema", {})
|
|
# Gemini doesn't support 'additionalProperties' in some cases
|
|
clean_schema = {k: v for k, v in schema.items() if k != "additionalProperties"}
|
|
declarations.append({
|
|
"name": tool["name"],
|
|
"description": tool.get("description", ""),
|
|
"parameters": clean_schema,
|
|
})
|
|
return [{"functionDeclarations": declarations}]
|
|
|
|
def _convert_messages(
|
|
self, messages: List[Dict[str, Any]], system_prompt: str
|
|
) -> tuple[str, List[Dict[str, Any]]]:
|
|
"""Convert Anthropic-style messages to Gemini contents format.
|
|
Returns (system_instruction, contents).
|
|
"""
|
|
contents: List[Dict[str, Any]] = []
|
|
|
|
for msg in messages:
|
|
role = msg.get("role")
|
|
|
|
if role == "user":
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
contents.append({
|
|
"role": "user",
|
|
"parts": [{"text": content}],
|
|
})
|
|
elif isinstance(content, list):
|
|
# Check for tool_result blocks
|
|
tool_results = [b for b in content if b.get("type") == "tool_result"]
|
|
if tool_results:
|
|
parts = []
|
|
for tr in tool_results:
|
|
parts.append({
|
|
"functionResponse": {
|
|
"name": tr.get("tool_name", "unknown"),
|
|
"response": {"result": tr.get("content", "")},
|
|
}
|
|
})
|
|
contents.append({"role": "user", "parts": parts})
|
|
else:
|
|
text = " ".join(
|
|
b.get("text", "") for b in content if b.get("type") == "text"
|
|
)
|
|
if text:
|
|
contents.append({
|
|
"role": "user",
|
|
"parts": [{"text": text}],
|
|
})
|
|
|
|
elif role == "assistant":
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
contents.append({
|
|
"role": "model",
|
|
"parts": [{"text": content}],
|
|
})
|
|
elif isinstance(content, list):
|
|
parts = []
|
|
for block in content:
|
|
if block.get("type") == "text":
|
|
text = block.get("text", "")
|
|
if text:
|
|
parts.append({"text": text})
|
|
elif block.get("type") == "tool_use":
|
|
parts.append({
|
|
"functionCall": {
|
|
"name": block.get("name", ""),
|
|
"args": block.get("input", {}),
|
|
}
|
|
})
|
|
if parts:
|
|
contents.append({"role": "model", "parts": parts})
|
|
|
|
return system_prompt, contents
|
|
|
|
async def generate_response(
|
|
self,
|
|
messages: List[Dict[str, Any]],
|
|
system_prompt: str,
|
|
tools: Optional[List[Dict[str, Any]]] = None,
|
|
max_tokens: int = 8192,
|
|
) -> Dict[str, Any]:
|
|
"""Generate response using Gemini generateContent API"""
|
|
|
|
system_instruction, contents = self._convert_messages(messages, system_prompt)
|
|
|
|
payload: Dict[str, Any] = {
|
|
"contents": contents,
|
|
"generationConfig": {
|
|
"maxOutputTokens": max_tokens,
|
|
},
|
|
}
|
|
|
|
if system_instruction:
|
|
payload["systemInstruction"] = {
|
|
"parts": [{"text": system_instruction}]
|
|
}
|
|
|
|
if tools:
|
|
payload["tools"] = self._convert_tools(tools)
|
|
|
|
url = f"{self.base_url}/models/{self.model}:generateContent?key={self.api_key}"
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
url,
|
|
json=payload,
|
|
timeout=120.0,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(
|
|
f"Gemini API error {response.status_code}: {response.text}"
|
|
)
|
|
|
|
result = response.json()
|
|
|
|
# Parse candidates
|
|
candidates = result.get("candidates", [])
|
|
if not candidates:
|
|
return {"content": "", "usage": {}, "model": self.model}
|
|
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
|
|
text_parts = []
|
|
tool_calls = []
|
|
raw_content = []
|
|
|
|
for part in parts:
|
|
if "text" in part:
|
|
text_parts.append(part["text"])
|
|
raw_content.append({"type": "text", "text": part["text"]})
|
|
elif "functionCall" in part:
|
|
fc = part["functionCall"]
|
|
call_id = f"call_{uuid.uuid4().hex[:12]}"
|
|
tool_calls.append({
|
|
"id": call_id,
|
|
"name": fc.get("name", ""),
|
|
"input": fc.get("args", {}),
|
|
})
|
|
raw_content.append({
|
|
"type": "tool_use",
|
|
"id": call_id,
|
|
"name": fc.get("name", ""),
|
|
"input": fc.get("args", {}),
|
|
})
|
|
|
|
# Usage info
|
|
usage_meta = result.get("usageMetadata", {})
|
|
|
|
response_data: Dict[str, Any] = {
|
|
"content": "\n".join(text_parts) if text_parts else "",
|
|
"usage": {
|
|
"input_tokens": usage_meta.get("promptTokenCount", 0),
|
|
"output_tokens": usage_meta.get("candidatesTokenCount", 0),
|
|
},
|
|
"model": self.model,
|
|
}
|
|
|
|
if tool_calls:
|
|
response_data["tool_calls"] = tool_calls
|
|
response_data["raw_content"] = raw_content
|
|
|
|
return response_data
|