claude_code_oauth.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. Claude Code OAuth Provider
  3. 通过 claude-agent-sdk 复用 `claude` CLI 的 OAuth 登录态调用 Claude(Max 订阅额度)。
  4. 实现方式:使用 `ClaudeSDKClient`(双向 session)+ AsyncIterable[dict] 形式发送
  5. 用户消息。这种模式同时满足:
  6. 1. 协议正确(client 内部管 stdin 生命周期,不会卡死)
  7. 2. 支持多模态(content blocks 可带 image 节点)
  8. Auth:依赖 `~/.claude/.credentials.json` 的 OAuth token;如父进程有
  9. ANTHROPIC_API_KEY / ANTHROPIC_BASE_URL,会从子进程 env 中剥离,让 CLI
  10. 回落到 OAuth。父进程 os.environ 不变。
  11. 输出契约(与现有 llm_call 一致):
  12. {"content": str, "usage": {"input_tokens": int, "output_tokens": int}}
  13. """
  14. import logging
  15. import os
  16. from typing import Any, Dict, List, Optional, Tuple
  17. logger = logging.getLogger(__name__)
  18. def _convert_messages(
  19. messages: List[Dict[str, Any]],
  20. ) -> Tuple[Optional[str], List[Dict[str, Any]], bool]:
  21. """
  22. 把 OpenAI 风格 messages 拆为 (system_prompt, anthropic_content_blocks, has_image)。
  23. - role=system 拼接为 system_prompt
  24. - role=user/assistant 的 content 转为 Anthropic content blocks (text/image)
  25. - OpenAI {"type":"image_url","image_url":{"url":...}} 转为
  26. Anthropic {"type":"image","source":{"type":"url","url":...}}
  27. - has_image:是否包含图片块,用于决定走 string 还是 AsyncIterable 模式
  28. """
  29. system_parts: List[str] = []
  30. blocks: List[Dict[str, Any]] = []
  31. has_image = False
  32. for msg in messages:
  33. role = msg.get("role")
  34. content = msg.get("content")
  35. if role == "system":
  36. if isinstance(content, str):
  37. system_parts.append(content)
  38. continue
  39. if isinstance(content, str):
  40. blocks.append({"type": "text", "text": content})
  41. continue
  42. if isinstance(content, list):
  43. for block in content:
  44. if not isinstance(block, dict):
  45. blocks.append({"type": "text", "text": str(block)})
  46. continue
  47. btype = block.get("type")
  48. if btype == "text":
  49. blocks.append({"type": "text", "text": block.get("text", "")})
  50. elif btype == "image_url":
  51. url = (block.get("image_url") or {}).get("url", "")
  52. if url:
  53. blocks.append(
  54. {"type": "image", "source": {"type": "url", "url": url}}
  55. )
  56. has_image = True
  57. elif btype == "image":
  58. blocks.append(block)
  59. has_image = True
  60. system_prompt = "\n\n".join(system_parts).strip() or None
  61. return system_prompt, blocks, has_image
  62. def _blocks_to_string(blocks: List[Dict[str, Any]]) -> str:
  63. """把 content blocks 拍平成字符串(图片降级为 [图片URL: ...] 占位)— string 模式用"""
  64. parts: List[str] = []
  65. for block in blocks:
  66. btype = block.get("type")
  67. if btype == "text":
  68. parts.append(block.get("text", ""))
  69. elif btype == "image":
  70. src = block.get("source") or {}
  71. url = src.get("url") or src.get("data", "")[:60]
  72. parts.append(f"[图片URL: {url}]")
  73. return "\n\n".join(p for p in parts if p).strip()
  74. def create_claude_code_oauth_llm_call(model: str = "claude-sonnet-4-5"):
  75. """
  76. 工厂:返回兼容 pipeline llm_call 契约的异步函数(基于 ClaudeSDKClient)。
  77. 返回函数签名:
  78. async (messages, model=..., temperature=..., max_tokens=...,
  79. response_schema=None, tools=None, **kwargs) -> dict
  80. 其中 temperature / max_tokens / response_schema / tools 静默忽略
  81. (SDK 不透传这些参数,CLI 用自己的默认值)。
  82. """
  83. from claude_agent_sdk import (
  84. AssistantMessage,
  85. ClaudeAgentOptions,
  86. ClaudeSDKClient,
  87. ClaudeSDKError,
  88. RateLimitEvent,
  89. ResultMessage,
  90. TextBlock,
  91. )
  92. # 让 SDK 子进程看不到 API key 相关变量,回落到 OAuth。
  93. # SDK 内部把 options.env 当作"覆盖层"叠在父进程 os.environ 之上,
  94. # 所以从 dict 里"移除"这些 key 没用 — 必须显式以空串覆盖父值。
  95. # 父进程 os.environ 不变(其他 LLM provider 继续可用 API key)。
  96. _override_env: Dict[str, str] = {
  97. "ANTHROPIC_API_KEY": "",
  98. "ANTHROPIC_BASE_URL": "",
  99. "ANTHROPIC_AUTH_TOKEN": "",
  100. }
  101. if "ANTHROPIC_API_KEY" in os.environ or "ANTHROPIC_BASE_URL" in os.environ:
  102. logger.info(
  103. "[claude_code_oauth] Overriding ANTHROPIC_API_KEY/ANTHROPIC_BASE_URL "
  104. "with empty values in SDK subprocess env so CLI falls back to OAuth."
  105. )
  106. default_model = model
  107. async def llm_call(
  108. messages: List[Dict[str, Any]],
  109. model: Optional[str] = None,
  110. tools: Optional[List[Dict]] = None,
  111. **kwargs: Any,
  112. ) -> Dict[str, Any]:
  113. actual_model = (model or default_model).split("/")[-1]
  114. system_prompt, content_blocks, has_image = _convert_messages(messages)
  115. if not content_blocks:
  116. content_blocks = [{"type": "text", "text": " "}]
  117. stderr_lines: List[str] = []
  118. def _capture_stderr(line: str) -> None:
  119. if line:
  120. stderr_lines.append(line)
  121. options = ClaudeAgentOptions(
  122. model=actual_model,
  123. system_prompt=system_prompt,
  124. allowed_tools=[],
  125. max_turns=1,
  126. env=_override_env,
  127. stderr=_capture_stderr,
  128. # 关键:屏蔽 CLI 加载用户级 ~/.claude/ 配置(output_style/skills/plugins 等)
  129. # 否则这些会被注入 system prompt,浪费 token + 影响输出格式
  130. setting_sources=[],
  131. )
  132. text_parts: List[str] = []
  133. usage: Dict[str, Any] = {}
  134. is_error = False
  135. api_error_status: Optional[int] = None
  136. result_subtype: Optional[str] = None
  137. result_errors: List[str] = []
  138. rate_limit_signal: Optional[str] = None
  139. def _emit(line: str) -> None:
  140. print(f"[claude] {line}", flush=True)
  141. try:
  142. async with ClaudeSDKClient(options=options) as client:
  143. if has_image:
  144. # 多模态:用 AsyncIterable[dict] 模式发送 Anthropic content blocks
  145. async def _input_stream():
  146. yield {
  147. "type": "user",
  148. "message": {"role": "user", "content": content_blocks},
  149. "parent_tool_use_id": None,
  150. "session_id": "default",
  151. }
  152. await client.query(_input_stream())
  153. else:
  154. # 纯文本:走 SDK string 模式(已验证稳定路径)
  155. await client.query(_blocks_to_string(content_blocks))
  156. async for msg in client.receive_response():
  157. msg_type = type(msg).__name__
  158. if isinstance(msg, AssistantMessage):
  159. for block in msg.content:
  160. if hasattr(block, "thinking"):
  161. # thinking 内容太多,跳过
  162. continue
  163. elif isinstance(block, TextBlock):
  164. _emit(f"[text] {block.text}")
  165. text_parts.append(block.text)
  166. elif hasattr(block, "name") and hasattr(block, "input"):
  167. _emit(f"[tool_use] {block.name}({block.input})")
  168. else:
  169. _emit(f"[{type(block).__name__}] {block!r}")
  170. if msg.usage and not usage:
  171. usage = dict(msg.usage)
  172. elif isinstance(msg, ResultMessage):
  173. if msg.usage:
  174. usage = dict(msg.usage)
  175. _emit(
  176. f"[result] subtype={msg.subtype} "
  177. f"is_error={msg.is_error} turns={msg.num_turns} "
  178. f"duration={msg.duration_ms}ms "
  179. f"in={msg.usage.get('input_tokens', 0) if msg.usage else 0} "
  180. f"out={msg.usage.get('output_tokens', 0) if msg.usage else 0}"
  181. )
  182. if msg.is_error:
  183. is_error = True
  184. api_error_status = msg.api_error_status
  185. result_subtype = msg.subtype
  186. result_errors = list(msg.errors or [])
  187. elif isinstance(msg, RateLimitEvent):
  188. # RateLimitEvent 是 SDK 定期播报 quota 状态,不等于被限流。
  189. # 只有 rate_limit_info.status != 'allowed' 才算真限流。
  190. info = getattr(msg, "rate_limit_info", None)
  191. info_status = getattr(info, "status", None) if info else None
  192. _emit(f"[rate_limit] status={info_status!r} type={getattr(info, 'rate_limit_type', None)!r}")
  193. if info_status and info_status != "allowed":
  194. rate_limit_signal = f"status={info_status!r}"
  195. else:
  196. # SystemMessage 简化为关键字段;其他未知类型 fallback
  197. if msg_type == "SystemMessage":
  198. data = getattr(msg, "data", {}) or {}
  199. subtype = getattr(msg, "subtype", "?")
  200. if subtype == "init":
  201. _emit(
  202. f"[init] model={data.get('model')!r} "
  203. f"apiKeySource={data.get('apiKeySource')!r} "
  204. f"session={data.get('session_id', '')[:8]}"
  205. )
  206. else:
  207. _emit(f"[system] subtype={subtype}")
  208. else:
  209. _emit(f"[{msg_type}] {msg!r}")
  210. except ClaudeSDKError as e:
  211. stderr_tail = "\n".join(stderr_lines[-20:])
  212. raise RuntimeError(
  213. f"claude_agent_sdk error: {type(e).__name__}: {e}\n"
  214. f"--- CLI stderr (last 20 lines) ---\n{stderr_tail}"
  215. ) from e
  216. if rate_limit_signal or api_error_status == 429:
  217. raise RuntimeError(
  218. "Claude Code OAuth rate-limited (429). "
  219. "Max subscription quota may be exhausted in current 5-hour window. "
  220. "Run `claude /status` to check remaining."
  221. )
  222. if is_error:
  223. stderr_tail = "\n".join(stderr_lines[-20:])
  224. errors_str = "; ".join(result_errors) or "(empty errors[])"
  225. raise RuntimeError(
  226. f"claude_agent_sdk is_error=True "
  227. f"subtype={result_subtype!r} status={api_error_status} "
  228. f"errors={errors_str}\n"
  229. f"--- CLI stderr (last 20 lines) ---\n{stderr_tail}"
  230. )
  231. content = "".join(text_parts)
  232. normalized_usage = {
  233. "input_tokens": int(usage.get("input_tokens", 0) or 0),
  234. "output_tokens": int(usage.get("output_tokens", 0) or 0),
  235. }
  236. for k in ("cache_creation_input_tokens", "cache_read_input_tokens"):
  237. if k in usage:
  238. normalized_usage[k] = int(usage[k] or 0)
  239. return {"content": content, "usage": normalized_usage}
  240. return llm_call