|
|
@@ -0,0 +1,234 @@
|
|
|
+"""
|
|
|
+Claude Code OAuth Provider
|
|
|
+
|
|
|
+通过 claude-agent-sdk 复用 `claude` CLI 的 OAuth 登录态调用 Claude(Max 订阅额度)。
|
|
|
+
|
|
|
+实现方式:使用 `ClaudeSDKClient`(双向 session)+ AsyncIterable[dict] 形式发送
|
|
|
+用户消息。这种模式同时满足:
|
|
|
+ 1. 协议正确(client 内部管 stdin 生命周期,不会卡死)
|
|
|
+ 2. 支持多模态(content blocks 可带 image 节点)
|
|
|
+
|
|
|
+Auth:依赖 `~/.claude/.credentials.json` 的 OAuth token;如父进程有
|
|
|
+ ANTHROPIC_API_KEY / ANTHROPIC_BASE_URL,会从子进程 env 中剥离,让 CLI
|
|
|
+ 回落到 OAuth。父进程 os.environ 不变。
|
|
|
+
|
|
|
+输出契约(与现有 llm_call 一致):
|
|
|
+ {"content": str, "usage": {"input_tokens": int, "output_tokens": int}}
|
|
|
+"""
|
|
|
+
|
|
|
+import logging
|
|
|
+import os
|
|
|
+from typing import Any, Dict, List, Optional, Tuple
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def _flatten_messages_to_string(
|
|
|
+ messages: List[Dict[str, Any]],
|
|
|
+) -> Tuple[Optional[str], str]:
|
|
|
+ """
|
|
|
+ 把 OpenAI 风格 messages 折叠成 (system_prompt, user_text)。
|
|
|
+
|
|
|
+ - role=system 拼接为 system_prompt
|
|
|
+ - role=user/assistant 的 content 全部拍平为字符串
|
|
|
+ - image_url 类型块降级为 `[图片URL: ...]` 文本占位(模型看到 URL 字符串而非画面)
|
|
|
+
|
|
|
+ 使用 string 模式而非 AsyncIterable[dict],是为了走 SDK 中被生产验证的稳定路径。
|
|
|
+ 多模态真图传输需要切到 AsyncIterable + Anthropic content block 协议,单独迭代。
|
|
|
+ """
|
|
|
+ system_parts: List[str] = []
|
|
|
+ user_parts: List[str] = []
|
|
|
+
|
|
|
+ for msg in messages:
|
|
|
+ role = msg.get("role")
|
|
|
+ content = msg.get("content")
|
|
|
+
|
|
|
+ if role == "system":
|
|
|
+ if isinstance(content, str):
|
|
|
+ system_parts.append(content)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if isinstance(content, str):
|
|
|
+ user_parts.append(content)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if isinstance(content, list):
|
|
|
+ for block in content:
|
|
|
+ if not isinstance(block, dict):
|
|
|
+ user_parts.append(str(block))
|
|
|
+ continue
|
|
|
+ btype = block.get("type")
|
|
|
+ if btype == "text":
|
|
|
+ user_parts.append(block.get("text", ""))
|
|
|
+ elif btype == "image_url":
|
|
|
+ url = (block.get("image_url") or {}).get("url", "")
|
|
|
+ if url:
|
|
|
+ user_parts.append(f"[图片URL: {url}]")
|
|
|
+ elif btype == "image":
|
|
|
+ src = block.get("source") or {}
|
|
|
+ url = src.get("url") or src.get("data", "")[:60]
|
|
|
+ user_parts.append(f"[图片: {url}]")
|
|
|
+
|
|
|
+ system_prompt = "\n\n".join(system_parts).strip() or None
|
|
|
+ user_text = "\n\n".join(p for p in user_parts if p).strip()
|
|
|
+ return system_prompt, user_text
|
|
|
+
|
|
|
+
|
|
|
+def create_claude_code_oauth_llm_call(model: str = "claude-sonnet-4-5"):
|
|
|
+ """
|
|
|
+ 工厂:返回兼容 pipeline llm_call 契约的异步函数(基于 ClaudeSDKClient)。
|
|
|
+
|
|
|
+ 返回函数签名:
|
|
|
+ async (messages, model=..., temperature=..., max_tokens=...,
|
|
|
+ response_schema=None, tools=None, **kwargs) -> dict
|
|
|
+ 其中 temperature / max_tokens / response_schema / tools 静默忽略
|
|
|
+ (SDK 不透传这些参数,CLI 用自己的默认值)。
|
|
|
+ """
|
|
|
+ from claude_agent_sdk import (
|
|
|
+ AssistantMessage,
|
|
|
+ ClaudeAgentOptions,
|
|
|
+ ClaudeSDKClient,
|
|
|
+ ClaudeSDKError,
|
|
|
+ RateLimitEvent,
|
|
|
+ ResultMessage,
|
|
|
+ TextBlock,
|
|
|
+ )
|
|
|
+
|
|
|
+ # 从子进程 env 中剥离 API key 相关变量,让 CLI 回落到 OAuth;
|
|
|
+ # 父进程 os.environ 不变(其他 LLM provider 仍可用 API key)。
|
|
|
+ _stripped_env = {
|
|
|
+ k: v
|
|
|
+ for k, v in os.environ.items()
|
|
|
+ if k not in ("ANTHROPIC_API_KEY", "ANTHROPIC_BASE_URL", "ANTHROPIC_AUTH_TOKEN")
|
|
|
+ }
|
|
|
+ if "ANTHROPIC_API_KEY" in os.environ or "ANTHROPIC_BASE_URL" in os.environ:
|
|
|
+ logger.info(
|
|
|
+ "[claude_code_oauth] Stripping ANTHROPIC_API_KEY/ANTHROPIC_BASE_URL "
|
|
|
+ "from SDK subprocess env so CLI falls back to OAuth credentials."
|
|
|
+ )
|
|
|
+
|
|
|
+ default_model = model
|
|
|
+
|
|
|
+ async def llm_call(
|
|
|
+ messages: List[Dict[str, Any]],
|
|
|
+ model: Optional[str] = None,
|
|
|
+ tools: Optional[List[Dict]] = None,
|
|
|
+ **kwargs: Any,
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
+ actual_model = (model or default_model).split("/")[-1]
|
|
|
+
|
|
|
+ system_prompt, user_text = _flatten_messages_to_string(messages)
|
|
|
+ if not user_text:
|
|
|
+ user_text = " "
|
|
|
+
|
|
|
+ stderr_lines: List[str] = []
|
|
|
+
|
|
|
+ def _capture_stderr(line: str) -> None:
|
|
|
+ if line:
|
|
|
+ stderr_lines.append(line)
|
|
|
+
|
|
|
+ options = ClaudeAgentOptions(
|
|
|
+ model=actual_model,
|
|
|
+ system_prompt=system_prompt,
|
|
|
+ allowed_tools=[],
|
|
|
+ max_turns=1,
|
|
|
+ env=_stripped_env,
|
|
|
+ stderr=_capture_stderr,
|
|
|
+ # 关键:屏蔽 CLI 加载用户级 ~/.claude/ 配置(output_style/skills/plugins 等)
|
|
|
+ # 否则这些会被注入 system prompt,浪费 token + 影响输出格式
|
|
|
+ setting_sources=[],
|
|
|
+ )
|
|
|
+
|
|
|
+ text_parts: List[str] = []
|
|
|
+ usage: Dict[str, Any] = {}
|
|
|
+ is_error = False
|
|
|
+ api_error_status: Optional[int] = None
|
|
|
+ result_subtype: Optional[str] = None
|
|
|
+ result_errors: List[str] = []
|
|
|
+ rate_limit_signal: Optional[str] = None
|
|
|
+
|
|
|
+ def _emit(line: str) -> None:
|
|
|
+ print(f"[claude] {line}", flush=True)
|
|
|
+
|
|
|
+ try:
|
|
|
+ async with ClaudeSDKClient(options=options) as client:
|
|
|
+ await client.query(user_text)
|
|
|
+ async for msg in client.receive_response():
|
|
|
+ msg_type = type(msg).__name__
|
|
|
+
|
|
|
+ if isinstance(msg, AssistantMessage):
|
|
|
+ for block in msg.content:
|
|
|
+ if hasattr(block, "thinking"):
|
|
|
+ _emit(f"[think] {block.thinking}")
|
|
|
+ elif isinstance(block, TextBlock):
|
|
|
+ _emit(f"[text] {block.text}")
|
|
|
+ text_parts.append(block.text)
|
|
|
+ elif hasattr(block, "name") and hasattr(block, "input"):
|
|
|
+ _emit(f"[tool_use] {block.name}({block.input})")
|
|
|
+ else:
|
|
|
+ _emit(f"[{type(block).__name__}] {block!r}")
|
|
|
+ if msg.usage and not usage:
|
|
|
+ usage = dict(msg.usage)
|
|
|
+ elif isinstance(msg, ResultMessage):
|
|
|
+ if msg.usage:
|
|
|
+ usage = dict(msg.usage)
|
|
|
+ _emit(
|
|
|
+ f"[result] subtype={msg.subtype} "
|
|
|
+ f"is_error={msg.is_error} turns={msg.num_turns} "
|
|
|
+ f"duration={msg.duration_ms}ms "
|
|
|
+ f"in={msg.usage.get('input_tokens', 0) if msg.usage else 0} "
|
|
|
+ f"out={msg.usage.get('output_tokens', 0) if msg.usage else 0}"
|
|
|
+ )
|
|
|
+ if msg.is_error:
|
|
|
+ is_error = True
|
|
|
+ api_error_status = msg.api_error_status
|
|
|
+ result_subtype = msg.subtype
|
|
|
+ result_errors = list(msg.errors or [])
|
|
|
+ elif isinstance(msg, RateLimitEvent):
|
|
|
+ # RateLimitEvent 是 SDK 定期播报 quota 状态,不等于被限流。
|
|
|
+ # 只有 rate_limit_info.status != 'allowed' 才算真限流。
|
|
|
+ info = getattr(msg, "rate_limit_info", None)
|
|
|
+ info_status = getattr(info, "status", None) if info else None
|
|
|
+ _emit(f"[rate_limit] status={info_status!r} type={getattr(info, 'rate_limit_type', None)!r}")
|
|
|
+ if info_status and info_status != "allowed":
|
|
|
+ rate_limit_signal = f"status={info_status!r}"
|
|
|
+ else:
|
|
|
+ # SystemMessage 或其他未知类型
|
|
|
+ _emit(f"[{msg_type}] {msg!r}")
|
|
|
+ except ClaudeSDKError as e:
|
|
|
+ stderr_tail = "\n".join(stderr_lines[-20:])
|
|
|
+ raise RuntimeError(
|
|
|
+ f"claude_agent_sdk error: {type(e).__name__}: {e}\n"
|
|
|
+ f"--- CLI stderr (last 20 lines) ---\n{stderr_tail}"
|
|
|
+ ) from e
|
|
|
+
|
|
|
+ if rate_limit_signal or api_error_status == 429:
|
|
|
+ raise RuntimeError(
|
|
|
+ "Claude Code OAuth rate-limited (429). "
|
|
|
+ "Max subscription quota may be exhausted in current 5-hour window. "
|
|
|
+ "Run `claude /status` to check remaining."
|
|
|
+ )
|
|
|
+
|
|
|
+ if is_error:
|
|
|
+ stderr_tail = "\n".join(stderr_lines[-20:])
|
|
|
+ errors_str = "; ".join(result_errors) or "(empty errors[])"
|
|
|
+ raise RuntimeError(
|
|
|
+ f"claude_agent_sdk is_error=True "
|
|
|
+ f"subtype={result_subtype!r} status={api_error_status} "
|
|
|
+ f"errors={errors_str}\n"
|
|
|
+ f"--- CLI stderr (last 20 lines) ---\n{stderr_tail}"
|
|
|
+ )
|
|
|
+
|
|
|
+ content = "".join(text_parts)
|
|
|
+
|
|
|
+ normalized_usage = {
|
|
|
+ "input_tokens": int(usage.get("input_tokens", 0) or 0),
|
|
|
+ "output_tokens": int(usage.get("output_tokens", 0) or 0),
|
|
|
+ }
|
|
|
+ for k in ("cache_creation_input_tokens", "cache_read_input_tokens"):
|
|
|
+ if k in usage:
|
|
|
+ normalized_usage[k] = int(usage[k] or 0)
|
|
|
+
|
|
|
+ return {"content": content, "usage": normalized_usage}
|
|
|
+
|
|
|
+ return llm_call
|