utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. from typing import List, Dict, Any
  2. import json
  3. from .my_trace import get_current_time
  4. import re
  5. import uuid
  6. import datetime
  7. def parse_json_from_text(text: str) -> dict:
  8. """
  9. 从文本中解析JSON,支持多种格式的JSON代码块
  10. Args:
  11. text (str): 包含JSON的文本
  12. Returns:
  13. dict: 解析后的JSON数据,解析失败返回空字典
  14. """
  15. if not text or not isinstance(text, str):
  16. return {}
  17. # 去除首尾空白字符
  18. text = text.strip()
  19. # 定义可能的JSON代码块标记
  20. json_markers = [
  21. ("'''json", "'''"),
  22. ('"""json', '"""'),
  23. ("```json", "```"),
  24. ("```", "```")
  25. ]
  26. # 尝试提取JSON代码块
  27. json_content = text
  28. for start_marker, end_marker in json_markers:
  29. if text.startswith(start_marker):
  30. # 找到开始标记,查找结束标记
  31. start_pos = len(start_marker)
  32. end_pos = text.find(end_marker, start_pos)
  33. if end_pos != -1:
  34. json_content = text[start_pos:end_pos].strip()
  35. break
  36. # 如果没有找到代码块标记,检查是否以结束标记结尾并移除
  37. if json_content == text:
  38. for _, end_marker in json_markers:
  39. if text.endswith(end_marker):
  40. json_content = text[:-len(end_marker)].strip()
  41. break
  42. # 尝试解析JSON
  43. try:
  44. return json.loads(json_content)
  45. except json.JSONDecodeError as e:
  46. print(f"JSON解析失败: {e}")
  47. # 如果直接解析失败,尝试查找第一个{到最后一个}的内容
  48. try:
  49. first_brace = json_content.find('{')
  50. last_brace = json_content.rfind('}')
  51. if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
  52. json_part = json_content[first_brace:last_brace + 1]
  53. return json.loads(json_part)
  54. except json.JSONDecodeError:
  55. pass
  56. return {}
  57. def get_safe_filename(filename: str) -> str:
  58. """
  59. 生成安全的文件名,移除不安全字符
  60. Args:
  61. filename: 原始文件名
  62. Returns:
  63. str: 安全的文件名
  64. """
  65. # 移除不安全的字符,只保留字母、数字、下划线、连字符和点
  66. return re.sub(r'[^\w\-\./]', '_', filename)
  67. def generate_image_filename(mime_type: str, prefix: str = "gemini_img") -> str:
  68. """
  69. 生成合理的图片文件名
  70. Args:
  71. mime_type: 文件MIME类型
  72. prefix: 文件名前缀
  73. Returns:
  74. str: 生成的文件名
  75. """
  76. # 获取当前时间戳
  77. timestamp = datetime.datetime.now().strftime("%Y%m%d/%H%M%S")
  78. # 获取文件扩展名
  79. extension = mime_type.split('/')[-1]
  80. if extension == "jpeg":
  81. extension = "jpg"
  82. # 生成唯一ID (短UUID)
  83. unique_id = str(uuid.uuid4())[:4]
  84. # 组合文件名
  85. filename = f"{prefix}/{timestamp}_{unique_id}.{extension}"
  86. # 确保文件名安全
  87. return get_safe_filename(filename)
  88. def parse_multimodal_content(content: str) -> List[Dict[str, Any]]:
  89. """解析多模态内容,保持上下文顺序,适用于AI参数传递 """
  90. result = []
  91. lines = content.split('\n')
  92. role = ''
  93. for line in lines:
  94. line = line.strip()
  95. if not line:
  96. continue
  97. # 分割前缀和内容
  98. if ':' in line:
  99. prefix, content = line.split(':', 1)
  100. prefix = prefix.strip().lower()
  101. content = content.strip()
  102. row = {}
  103. if prefix == 'image':
  104. row = {
  105. "type": "image_url",
  106. "image_url": {
  107. "url": content
  108. }
  109. }
  110. elif prefix == 'text':
  111. row = {
  112. "type": "text",
  113. "text": content
  114. }
  115. elif prefix == 'role':
  116. role = content
  117. if row:
  118. if role:
  119. row['role'] = role
  120. role = ''
  121. result.append(row)
  122. return result
  123. def read_json(file_path):
  124. """
  125. 读取JSON文件并返回解析后的数据
  126. Args:
  127. file_path: JSON文件路径
  128. Returns:
  129. 解析后的JSON数据
  130. """
  131. try:
  132. with open(file_path, 'r', encoding='utf-8') as f:
  133. return json.load(f)
  134. except Exception as e:
  135. print(f"读取JSON文件时出错: {e}")
  136. return None
  137. def save_json(data, file_path):
  138. """
  139. 保存数据到JSON文件
  140. Args:
  141. data: 要保存的数据
  142. file_path: 保存路径
  143. """
  144. with open(file_path, 'w', encoding='utf-8') as f:
  145. json.dump(data, f, ensure_ascii=False, indent=2)
  146. def get_script_data(file_path):
  147. """
  148. 读取JSON文件并返回解析后的数据
  149. Args:
  150. file_path: JSON文件路径
  151. """
  152. return read_json(file_path)['脚本']
  153. import os
  154. import xml.etree.ElementTree as ET
  155. from typing import Dict, List, Any
  156. import re
  157. import unicodedata
  158. def get_model(model_name):
  159. # return 'gemini/gemini-2.5-flash'
  160. # return 'litellm/gemini/gemini-2.5-flash'
  161. if model_name.startswith('litellm'):
  162. return model_name
  163. else:
  164. from openai import AsyncOpenAI
  165. from agents import OpenAIChatCompletionsModel
  166. BASE_URL = os.getenv("EXAMPLE_BASE_URL") or "https://openrouter.ai/api/v1"
  167. API_KEY = os.getenv("OPENROUTER_API_KEY") or ""
  168. client = AsyncOpenAI(
  169. base_url=BASE_URL,
  170. api_key=API_KEY,
  171. )
  172. return OpenAIChatCompletionsModel(
  173. # model='google/gemini-2.5-pro-preview',
  174. # model='google/gemini-2.5-flash-preview-05-20',
  175. # model='google/gemini-2.5-flash-preview-05-20',
  176. # model='google/gemini-2.5-flash',
  177. # model='google/gemini-2.5-flash',
  178. # model='google/gemini-2.5-flash-preview-05-20:thinking',
  179. # model='google/gemini-2.0-flash-001',
  180. model=model_name,
  181. openai_client=client,
  182. )
  183. def read_file_as_string(file_path):
  184. """读取文件内容并返回字符串"""
  185. try:
  186. with open(file_path, 'r', encoding='utf-8') as file:
  187. content = file.read().strip()
  188. return content
  189. except Exception as e:
  190. print(f"读取文件时出错: {e}")
  191. return None
  192. def save_file_as_string(file_path, content):
  193. """将字符串内容写入文件"""
  194. with open(file_path, 'w', encoding='utf-8') as f:
  195. f.write(content)
  196. def extract_html_from_markdown(text):
  197. """
  198. 从可能包含markdown或其他代码块的文本中提取HTML内容
  199. 参数:
  200. text: 可能包含各种格式的文本
  201. 返回:
  202. 提取出的纯HTML内容
  203. """
  204. # 处理```html```格式(反引号)
  205. backtick_pattern = r"```(?:html)?\s*([\s\S]*?)```"
  206. backtick_matches = re.findall(backtick_pattern, text)
  207. # 处理'''html'''格式(单引号)
  208. single_quote_pattern = r"'''(?:html)?\s*([\s\S]*?)'''"
  209. single_quote_matches = re.findall(single_quote_pattern, text)
  210. # 处理"""html"""格式(双引号)
  211. double_quote_pattern = r'"""(?:html)?\s*([\s\S]*?)"""'
  212. double_quote_matches = re.findall(double_quote_pattern, text)
  213. if backtick_matches:
  214. # 优先使用反引号格式
  215. return backtick_matches[0].strip()
  216. elif single_quote_matches:
  217. # 其次使用单引号格式
  218. return single_quote_matches[0].strip()
  219. elif double_quote_matches:
  220. # 再次使用双引号格式
  221. return double_quote_matches[0].strip()
  222. else:
  223. # 如果没有代码块格式,直接返回原get_current_time始文本
  224. return text
  225. def create_workspace_dir(current_time=None, make_dir=True):
  226. if not current_time:
  227. current_time = get_current_time()
  228. task_dir = f"result/{current_time}"
  229. if make_dir:
  230. os.makedirs(task_dir, exist_ok=True)
  231. task_dir_absolute = os.path.abspath(task_dir)
  232. # print(f"任务目录的绝对路径: {task_dir_absolute}")
  233. return task_dir_absolute, str(current_time)
  234. def extract_tag_content(text, tag_name):
  235. """
  236. 从文本中提取指定标签内的内容
  237. 参数:
  238. text (str): 要处理的文本
  239. tag_name (str): 要提取的标签名称
  240. 返回:
  241. str: 标签内的内容,如果未找到则返回空字符串
  242. """
  243. import re
  244. pattern = f"<{tag_name}>(.*?)</{tag_name}>"
  245. match = re.search(pattern, text, re.DOTALL)
  246. if match:
  247. return match.group(1).strip()
  248. return ""
  249. from typing import Dict, List, Optional
  250. def parse_tasks(tasks_xml: str) -> List[Dict]:
  251. """Parse XML tasks into a list of task dictionaries."""
  252. tasks = []
  253. current_task = {}
  254. for line in tasks_xml.split('\n'):
  255. line = line.strip()
  256. if not line:
  257. continue
  258. if line.startswith("<task>"):
  259. current_task = {}
  260. elif line.startswith("<name>"):
  261. current_task["name"] = line[6:-7].strip()
  262. elif line.startswith("<output>"):
  263. current_task["output"] = line[12:-13].strip()
  264. elif line.startswith("</task>"):
  265. if "description" in current_task:
  266. if "type" not in current_task:
  267. current_task["type"] = "default"
  268. tasks.append(current_task)
  269. return tasks
  270. def parse_xml_content(xml_string: str) -> Dict[str, Any]:
  271. """
  272. 将XML字符串解析成字典,提取main_task、thoughts、tasks和resources
  273. 参数:
  274. xml_string: 包含任务信息的XML字符串
  275. 返回:
  276. 包含所有解析信息的字典
  277. """
  278. # 创建结果字典
  279. result = {
  280. "main_task": {},
  281. "thoughts": "",
  282. "tasks": [],
  283. "resources": []
  284. }
  285. try:
  286. # 提取thoughts内容
  287. thoughts_match = re.search(r'<thoughts>(.*?)</thoughts>', xml_string, re.DOTALL)
  288. if thoughts_match:
  289. result["thoughts"] = thoughts_match.group(1).strip()
  290. # 提取main_task内容
  291. main_task_match = re.search(r'<main_task>(.*?)</main_task>', xml_string, re.DOTALL)
  292. if main_task_match:
  293. main_task_content = main_task_match.group(1)
  294. main_task = {}
  295. # 获取主任务名称
  296. name_match = re.search(r'<name>(.*?)</name>', main_task_content, re.DOTALL)
  297. if name_match:
  298. main_task['name'] = name_match.group(1).strip()
  299. # 获取主任务输出
  300. output_match = re.search(r'<output>(.*?)</output>', main_task_content, re.DOTALL)
  301. if output_match:
  302. main_task['output'] = output_match.group(1).strip()
  303. # 获取主任务描述
  304. description_match = re.search(r'<description>(.*?)</description>', main_task_content, re.DOTALL)
  305. if description_match:
  306. main_task['description'] = description_match.group(1).strip()
  307. result["main_task"] = main_task
  308. # 提取<tasks>...</tasks>部分
  309. tasks_pattern = re.compile(r'<tasks>(.*?)</tasks>', re.DOTALL)
  310. tasks_match = tasks_pattern.search(xml_string)
  311. if tasks_match:
  312. tasks_content = tasks_match.group(1)
  313. # 提取每个task块
  314. task_pattern = re.compile(r'<task>(.*?)</task>', re.DOTALL)
  315. task_matches = task_pattern.finditer(tasks_content)
  316. for task_match in task_matches:
  317. task_content = task_match.group(1)
  318. task_dict = {}
  319. # 获取任务名称
  320. name_match = re.search(r'<name>(.*?)</name>', task_content, re.DOTALL)
  321. if not name_match:
  322. continue # 跳过没有名称的任务
  323. name = name_match.group(1).strip()
  324. task_dict['name'] = name
  325. # 获取输出信息
  326. output_match = re.search(r'<output>(.*?)</output>', task_content, re.DOTALL)
  327. task_dict['output'] = output_match.group(1).strip() if output_match else ""
  328. # 获取描述信息
  329. description_match = re.search(r'<description>(.*?)</description>', task_content, re.DOTALL)
  330. task_dict['description'] = description_match.group(1).strip() if description_match else ""
  331. # 获取依赖任务列表
  332. depend_tasks = []
  333. depend_tasks_section = re.search(r'<depend_tasks>(.*?)</depend_tasks>', task_content, re.DOTALL)
  334. if depend_tasks_section:
  335. depend_task_matches = re.finditer(r'<depend_task>(.*?)</depend_task>',
  336. depend_tasks_section.group(1), re.DOTALL)
  337. for dt_match in depend_task_matches:
  338. if dt_match.group(1).strip():
  339. depend_tasks.append(dt_match.group(1).strip())
  340. task_dict['depend_tasks'] = depend_tasks
  341. # 获取依赖资源列表
  342. depend_resources = []
  343. resources_match = re.search(r'<depend_resources>(.*?)</depend_resources>', task_content, re.DOTALL)
  344. if resources_match and resources_match.group(1).strip():
  345. resources_text = resources_match.group(1).strip()
  346. depend_resources = [res.strip() for res in resources_text.split(',') if res.strip()]
  347. task_dict['depend_resources'] = depend_resources
  348. # 将任务添加到结果字典
  349. result["tasks"].append(task_dict)
  350. # 提取resources内容
  351. resources_pattern = re.compile(r'<resources>(.*?)</resources>', re.DOTALL)
  352. resources_match = resources_pattern.search(xml_string)
  353. if resources_match:
  354. resources_content = resources_match.group(1).strip()
  355. result["resources"] = resources_content
  356. return result
  357. except Exception as e:
  358. raise ValueError(f"处理XML数据时发生错误: {e}")
  359. def parse_planner_result(result):
  360. """
  361. 解析规划结果,并为每个任务添加任务目录名
  362. 参数:
  363. result: 包含thoughts、main_task、tasks和resources的规划结果字符串
  364. 返回:
  365. 解析后的完整规划信息字典
  366. """
  367. # 使用parse_xml_content解析完整内容
  368. parsed_result = parse_xml_content(result)
  369. task_name_to_index = {}
  370. task_dict = {
  371. 'tasks': {},
  372. 'max_index': 1,
  373. }
  374. # 为每个任务添加task_dir字段
  375. for i, task_info in enumerate(parsed_result["tasks"]):
  376. # 使用sanitize_filename生成目录名
  377. task_name = task_info.get("name", "task")
  378. depend_tasks_dir = []
  379. task_info['task_dir'] = get_task_dir(task_name, task_dict)
  380. for depend_task in task_info.get("depend_tasks", []):
  381. depend_tasks_dir.append(get_task_dir(depend_task, task_dict))
  382. task_info['depend_tasks_dir'] = depend_tasks_dir
  383. task_info['status'] = 'todo' # 任务状态,todo: 未开始,doing: 进行中,success: 已完成,fail: 失败
  384. task_name_to_index[task_name] = i
  385. # 为主任务也添加task_dir字段
  386. if parsed_result["main_task"]:
  387. main_task_name = parsed_result["main_task"].get("name", "main_task")
  388. parsed_result["main_task"]["task_dir"] = sanitize_filename(main_task_name)
  389. return parsed_result, task_name_to_index
  390. def get_task_dir(task_name, task_dict, append_index=True):
  391. max_index = task_dict.get('max_index', 1)
  392. if task_name in task_dict['tasks']:
  393. return task_dict['tasks'][task_name]
  394. max_index_str = f"{max_index:02d}"
  395. task_dir_raw = sanitize_filename(task_name)
  396. if append_index:
  397. task_dir = f"{max_index_str}_{task_dir_raw}"
  398. else:
  399. task_dir = task_dir_raw
  400. task_dict['tasks'][task_name] = task_dir
  401. task_dict['max_index'] = max_index + 1
  402. return task_dir
  403. def sanitize_filename(task_name: str, max_length: int = 20) -> str:
  404. """
  405. 将任务名称转换为适合作为文件夹名称的字符串
  406. 参数:
  407. task_name: 需要转换的任务名称
  408. max_length: 文件名最大长度限制,默认80个字符
  409. 返回:
  410. 处理后适合作为文件名/文件夹名的字符串
  411. """
  412. # 替换Windows和Unix系统中不允许的文件名字符
  413. # 替换 / \ : * ? " < > | 等字符为下划线
  414. sanitized = re.sub(r'[\\/*?:"<>|]', '_', task_name)
  415. # 替换连续的空白字符为单个下划线
  416. sanitized = re.sub(r'\s+', '_', sanitized)
  417. # 移除开头和结尾的点和空格
  418. sanitized = sanitized.strip('. ')
  419. # 如果名称过长,截断它
  420. if len(sanitized) > max_length:
  421. # 保留前面的部分和后面的部分,中间用...连接
  422. half_length = (max_length - 3) // 2
  423. sanitized = sanitized[:half_length] + '...' + sanitized[-half_length:]
  424. # 确保名称不为空
  425. if not sanitized:
  426. sanitized = "unnamed_task"
  427. return sanitized
  428. def write_json(data, file_path: str) -> None:
  429. """
  430. 将数据写入JSON文件
  431. 参数:
  432. data: 要写入的数据对象
  433. file_path: 目标文件路径
  434. 返回:
  435. """
  436. import json
  437. with open(file_path, 'w', encoding='utf-8') as f:
  438. json.dump(data, f, ensure_ascii=False, indent=2)
  439. def write_string_to_file(content: str, file_path: str) -> None:
  440. """
  441. 将字符串内容写入文件
  442. 参数:
  443. content: 要写入的字符串内容
  444. file_path: 目标文件路径
  445. 返回:
  446. """
  447. with open(file_path, 'w', encoding='utf-8') as f:
  448. f.write(content)
  449. def pretty_process(result):
  450. def format_output(in_str):
  451. return in_str.replace('\n\n', '\n').replace('\\"', '"')
  452. process_list = []
  453. i = 0
  454. call_dict = {}
  455. # 首先收集所有工具调用输出
  456. for row in result:
  457. if isinstance(row, list):
  458. # 处理列表:递归处理列表中的每个项目
  459. for item in row:
  460. if isinstance(item, dict) and item.get('type', '') == 'function_call_output':
  461. call_id = item['call_id']
  462. call_dict[call_id] = item['output']
  463. elif isinstance(row, dict) and row.get('type', '') == 'function_call_output':
  464. call_id = row['call_id']
  465. call_dict[call_id] = row['output']
  466. # 然后处理每一行
  467. for row in result:
  468. if isinstance(row, list):
  469. # 递归处理列表中的每个项目
  470. for item in row:
  471. if isinstance(item, dict):
  472. process_row(item, process_list, call_dict, i)
  473. i += 1
  474. else:
  475. # 直接处理字典项
  476. process_row(row, process_list, call_dict, i)
  477. i += 1
  478. process_str = '\n'.join(process_list)
  479. return process_str
  480. def process_row(row, process_list, call_dict, i):
  481. """处理单个行项目,添加到处理列表中"""
  482. def format_output(in_str):
  483. return in_str.replace('\n\n', '\n').replace('\\"', '"')
  484. if not isinstance(row, dict):
  485. return
  486. action = ''
  487. out = ''
  488. call_id = ''
  489. role_ = row.get('role', '')
  490. type_ = row.get('type', '')
  491. if type_ == 'function_call':
  492. action = f'工具调用-{row.get("name")}'
  493. out = row['arguments']
  494. call_id = row['call_id']
  495. elif type_ == 'function_call_output':
  496. return # 跳过函数调用输出,它们已经被收集到call_dict中
  497. elif role_ in ('user', 'assistant'):
  498. action = role_
  499. if isinstance(row['content'], str):
  500. out = row['content']
  501. else:
  502. content_text = ""
  503. for this_c in row['content']:
  504. if isinstance(this_c, dict) and 'text' in this_c:
  505. content_text += this_c['text']
  506. out = content_text
  507. process_list.append('\n\n' + f'{i+1}. ' + '## ' + action + ' ' * 4 + '-' * 32 + '\n')
  508. process_list.append(format_output(str(out)))
  509. # 如果存在对应的工具输出,添加它
  510. if call_id and call_id in call_dict:
  511. process_list.append('\n\n' + f'{i+2}. ' + '## ' + '工具输出' + ' ' * 4 + '-' * 32 + '\n')
  512. process_list.append(format_output(call_dict[call_id]))