edit.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. """
  2. Edit Tool - 文件编辑工具
  3. 参考 OpenCode 的 edit.ts 完整实现。
  4. 核心功能:
  5. - 精确字符串替换
  6. - 9 种智能匹配策略(按优先级依次尝试)
  7. - 生成 diff 预览
  8. """
  9. from pathlib import Path
  10. from typing import Optional, Generator
  11. import difflib
  12. import re
  13. from agent.tools import tool, ToolResult, ToolContext
  14. @tool(description="编辑文件,使用精确字符串替换。支持多种智能匹配策略。")
  15. async def edit_file(
  16. file_path: str,
  17. old_string: str,
  18. new_string: str,
  19. replace_all: bool = False,
  20. uid: str = "",
  21. context: Optional[ToolContext] = None
  22. ) -> ToolResult:
  23. """
  24. 编辑文件内容
  25. 使用 9 种智能匹配策略,按优先级尝试:
  26. 1. SimpleReplacer - 精确匹配
  27. 2. LineTrimmedReplacer - 忽略行首尾空白
  28. 3. BlockAnchorReplacer - 基于首尾锚点的块匹配(使用 Levenshtein 相似度)
  29. 4. WhitespaceNormalizedReplacer - 空白归一化
  30. 5. IndentationFlexibleReplacer - 灵活缩进匹配
  31. 6. EscapeNormalizedReplacer - 转义序列归一化
  32. 7. TrimmedBoundaryReplacer - 边界空白裁剪
  33. 8. ContextAwareReplacer - 上下文感知匹配
  34. 9. MultiOccurrenceReplacer - 多次出现匹配
  35. Args:
  36. file_path: 文件路径
  37. old_string: 要替换的文本
  38. new_string: 替换后的文本
  39. replace_all: 是否替换所有匹配(默认 False,只替换唯一匹配)
  40. uid: 用户 ID
  41. context: 工具上下文
  42. Returns:
  43. ToolResult: 编辑结果和 diff
  44. """
  45. if old_string == new_string:
  46. return ToolResult(
  47. title="无需编辑",
  48. output="old_string 和 new_string 相同",
  49. error="Strings are identical"
  50. )
  51. # 解析路径
  52. path = Path(file_path)
  53. if not path.is_absolute():
  54. path = Path.cwd() / path
  55. # 检查文件
  56. if not path.exists():
  57. return ToolResult(
  58. title="文件未找到",
  59. output=f"文件不存在: {file_path}",
  60. error="File not found"
  61. )
  62. if path.is_dir():
  63. return ToolResult(
  64. title="路径错误",
  65. output=f"路径是目录,不是文件: {file_path}",
  66. error="Path is a directory"
  67. )
  68. # 读取文件
  69. try:
  70. with open(path, 'r', encoding='utf-8') as f:
  71. content_old = f.read()
  72. except Exception as e:
  73. return ToolResult(
  74. title="读取失败",
  75. output=f"无法读取文件: {str(e)}",
  76. error=str(e)
  77. )
  78. # 执行替换
  79. try:
  80. content_new = replace(content_old, old_string, new_string, replace_all)
  81. except ValueError as e:
  82. return ToolResult(
  83. title="替换失败",
  84. output=str(e),
  85. error=str(e)
  86. )
  87. # 生成 diff
  88. diff = _create_diff(file_path, content_old, content_new)
  89. # 写入文件
  90. try:
  91. with open(path, 'w', encoding='utf-8') as f:
  92. f.write(content_new)
  93. except Exception as e:
  94. return ToolResult(
  95. title="写入失败",
  96. output=f"无法写入文件: {str(e)}",
  97. error=str(e)
  98. )
  99. # 统计变更
  100. old_lines = content_old.count('\n')
  101. new_lines = content_new.count('\n')
  102. return ToolResult(
  103. title=path.name,
  104. output=f"编辑成功\n\n{diff}",
  105. metadata={
  106. "diff": diff,
  107. "old_lines": old_lines,
  108. "new_lines": new_lines,
  109. "additions": max(0, new_lines - old_lines),
  110. "deletions": max(0, old_lines - new_lines)
  111. },
  112. long_term_memory=f"已编辑文件 {path.name}"
  113. )
  114. # ============================================================================
  115. # 替换策略(Replacers)
  116. # ============================================================================
  117. def replace(content: str, old_string: str, new_string: str, replace_all: bool = False) -> str:
  118. """
  119. 使用多种策略尝试替换
  120. 按优先级尝试所有策略,直到找到匹配
  121. """
  122. if old_string == new_string:
  123. raise ValueError("old_string 和 new_string 必须不同")
  124. not_found = True
  125. # 按优先级尝试策略
  126. for replacer in [
  127. simple_replacer,
  128. line_trimmed_replacer,
  129. block_anchor_replacer,
  130. whitespace_normalized_replacer,
  131. indentation_flexible_replacer,
  132. escape_normalized_replacer,
  133. trimmed_boundary_replacer,
  134. context_aware_replacer,
  135. multi_occurrence_replacer,
  136. ]:
  137. for search in replacer(content, old_string):
  138. index = content.find(search)
  139. if index == -1:
  140. continue
  141. not_found = False
  142. if replace_all:
  143. return content.replace(search, new_string)
  144. # 检查唯一性
  145. last_index = content.rfind(search)
  146. if index != last_index:
  147. continue
  148. return content[:index] + new_string + content[index + len(search):]
  149. if not_found:
  150. raise ValueError("在文件中未找到匹配的文本")
  151. raise ValueError(
  152. "找到多个匹配。请在 old_string 中提供更多上下文以唯一标识,"
  153. "或使用 replace_all=True 替换所有匹配。"
  154. )
  155. # 1. SimpleReplacer - 精确匹配
  156. def simple_replacer(content: str, find: str) -> Generator[str, None, None]:
  157. """精确匹配"""
  158. yield find
  159. # 2. LineTrimmedReplacer - 忽略行首尾空白
  160. def line_trimmed_replacer(content: str, find: str) -> Generator[str, None, None]:
  161. """忽略行首尾空白进行匹配"""
  162. content_lines = content.split('\n')
  163. find_lines = find.rstrip('\n').split('\n')
  164. for i in range(len(content_lines) - len(find_lines) + 1):
  165. # 检查所有行是否匹配(忽略首尾空白)
  166. matches = all(
  167. content_lines[i + j].strip() == find_lines[j].strip()
  168. for j in range(len(find_lines))
  169. )
  170. if matches:
  171. # 计算原始文本位置
  172. match_start = sum(len(content_lines[k]) + 1 for k in range(i))
  173. match_end = match_start + sum(
  174. len(content_lines[i + k]) + (1 if k < len(find_lines) - 1 else 0)
  175. for k in range(len(find_lines))
  176. )
  177. yield content[match_start:match_end]
  178. # 3. BlockAnchorReplacer - 基于首尾锚点的块匹配
  179. def block_anchor_replacer(content: str, find: str) -> Generator[str, None, None]:
  180. """
  181. 基于首尾行作为锚点进行块匹配
  182. 使用 Levenshtein 距离计算中间行的相似度
  183. """
  184. content_lines = content.split('\n')
  185. find_lines = find.rstrip('\n').split('\n')
  186. if len(find_lines) < 3:
  187. return
  188. first_line_find = find_lines[0].strip()
  189. last_line_find = find_lines[-1].strip()
  190. find_block_size = len(find_lines)
  191. # 收集所有候选位置(首尾行都匹配)
  192. candidates = []
  193. for i in range(len(content_lines)):
  194. if content_lines[i].strip() != first_line_find:
  195. continue
  196. # 查找匹配的尾行
  197. for j in range(i + 2, len(content_lines)):
  198. if content_lines[j].strip() == last_line_find:
  199. candidates.append((i, j))
  200. break
  201. if not candidates:
  202. return
  203. # 单个候选:使用宽松阈值
  204. if len(candidates) == 1:
  205. start_line, end_line = candidates[0]
  206. actual_block_size = end_line - start_line + 1
  207. similarity = _calculate_block_similarity(
  208. content_lines[start_line:end_line + 1],
  209. find_lines
  210. )
  211. if similarity >= 0.0: # SINGLE_CANDIDATE_SIMILARITY_THRESHOLD
  212. match_start = sum(len(content_lines[k]) + 1 for k in range(start_line))
  213. match_end = match_start + sum(
  214. len(content_lines[k]) + (1 if k < end_line else 0)
  215. for k in range(start_line, end_line + 1)
  216. )
  217. yield content[match_start:match_end]
  218. return
  219. # 多个候选:选择相似度最高的
  220. best_match = None
  221. max_similarity = -1
  222. for start_line, end_line in candidates:
  223. similarity = _calculate_block_similarity(
  224. content_lines[start_line:end_line + 1],
  225. find_lines
  226. )
  227. if similarity > max_similarity:
  228. max_similarity = similarity
  229. best_match = (start_line, end_line)
  230. if max_similarity >= 0.3 and best_match: # MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD
  231. start_line, end_line = best_match
  232. match_start = sum(len(content_lines[k]) + 1 for k in range(start_line))
  233. match_end = match_start + sum(
  234. len(content_lines[k]) + (1 if k < end_line else 0)
  235. for k in range(start_line, end_line + 1)
  236. )
  237. yield content[match_start:match_end]
  238. def _calculate_block_similarity(content_block: list, find_block: list) -> float:
  239. """计算块相似度(使用 Levenshtein 距离)"""
  240. actual_size = len(content_block)
  241. find_size = len(find_block)
  242. lines_to_check = min(find_size - 2, actual_size - 2)
  243. if lines_to_check <= 0:
  244. return 1.0
  245. similarity = 0.0
  246. for j in range(1, min(find_size - 1, actual_size - 1)):
  247. content_line = content_block[j].strip()
  248. find_line = find_block[j].strip()
  249. max_len = max(len(content_line), len(find_line))
  250. if max_len == 0:
  251. continue
  252. distance = _levenshtein(content_line, find_line)
  253. similarity += (1 - distance / max_len) / lines_to_check
  254. return similarity
  255. def _levenshtein(a: str, b: str) -> int:
  256. """Levenshtein 距离算法"""
  257. if not a:
  258. return len(b)
  259. if not b:
  260. return len(a)
  261. matrix = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
  262. for i in range(len(a) + 1):
  263. matrix[i][0] = i
  264. for j in range(len(b) + 1):
  265. matrix[0][j] = j
  266. for i in range(1, len(a) + 1):
  267. for j in range(1, len(b) + 1):
  268. cost = 0 if a[i - 1] == b[j - 1] else 1
  269. matrix[i][j] = min(
  270. matrix[i - 1][j] + 1, # 删除
  271. matrix[i][j - 1] + 1, # 插入
  272. matrix[i - 1][j - 1] + cost # 替换
  273. )
  274. return matrix[len(a)][len(b)]
  275. # 4. WhitespaceNormalizedReplacer - 空白归一化
  276. def whitespace_normalized_replacer(content: str, find: str) -> Generator[str, None, None]:
  277. """空白归一化匹配(所有空白序列归一化为单个空格)"""
  278. def normalize_ws(text: str) -> str:
  279. return re.sub(r'\s+', ' ', text).strip()
  280. normalized_find = normalize_ws(find)
  281. lines = content.split('\n')
  282. # 单行匹配
  283. for line in lines:
  284. if normalize_ws(line) == normalized_find:
  285. yield line
  286. continue
  287. # 子串匹配
  288. if normalized_find in normalize_ws(line):
  289. words = find.strip().split()
  290. if words:
  291. pattern = r'\s+'.join(re.escape(word) for word in words)
  292. match = re.search(pattern, line)
  293. if match:
  294. yield match.group(0)
  295. # 多行匹配
  296. find_lines = find.split('\n')
  297. if len(find_lines) > 1:
  298. for i in range(len(lines) - len(find_lines) + 1):
  299. block = lines[i:i + len(find_lines)]
  300. if normalize_ws('\n'.join(block)) == normalized_find:
  301. yield '\n'.join(block)
  302. # 5. IndentationFlexibleReplacer - 灵活缩进匹配
  303. def indentation_flexible_replacer(content: str, find: str) -> Generator[str, None, None]:
  304. """移除缩进后匹配"""
  305. def remove_indentation(text: str) -> str:
  306. lines = text.split('\n')
  307. non_empty = [line for line in lines if line.strip()]
  308. if not non_empty:
  309. return text
  310. min_indent = min(len(line) - len(line.lstrip()) for line in non_empty)
  311. return '\n'.join(
  312. line[min_indent:] if line.strip() else line
  313. for line in lines
  314. )
  315. normalized_find = remove_indentation(find)
  316. content_lines = content.split('\n')
  317. find_lines = find.split('\n')
  318. for i in range(len(content_lines) - len(find_lines) + 1):
  319. block = '\n'.join(content_lines[i:i + len(find_lines)])
  320. if remove_indentation(block) == normalized_find:
  321. yield block
  322. # 6. EscapeNormalizedReplacer - 转义序列归一化
  323. def escape_normalized_replacer(content: str, find: str) -> Generator[str, None, None]:
  324. """反转义后匹配"""
  325. def unescape_string(s: str) -> str:
  326. replacements = {
  327. r'\n': '\n',
  328. r'\t': '\t',
  329. r'\r': '\r',
  330. r"\'": "'",
  331. r'\"': '"',
  332. r'\`': '`',
  333. r'\\': '\\',
  334. r'\$': '$',
  335. }
  336. result = s
  337. for escaped, unescaped in replacements.items():
  338. result = result.replace(escaped, unescaped)
  339. return result
  340. unescaped_find = unescape_string(find)
  341. # 直接匹配
  342. if unescaped_find in content:
  343. yield unescaped_find
  344. # 尝试反转义后匹配
  345. lines = content.split('\n')
  346. find_lines = unescaped_find.split('\n')
  347. for i in range(len(lines) - len(find_lines) + 1):
  348. block = '\n'.join(lines[i:i + len(find_lines)])
  349. if unescape_string(block) == unescaped_find:
  350. yield block
  351. # 7. TrimmedBoundaryReplacer - 边界空白裁剪
  352. def trimmed_boundary_replacer(content: str, find: str) -> Generator[str, None, None]:
  353. """裁剪边界空白后匹配"""
  354. trimmed_find = find.strip()
  355. if trimmed_find == find:
  356. return # 已经是 trimmed,无需尝试
  357. # 尝试匹配 trimmed 版本
  358. if trimmed_find in content:
  359. yield trimmed_find
  360. # 尝试块匹配
  361. lines = content.split('\n')
  362. find_lines = find.split('\n')
  363. for i in range(len(lines) - len(find_lines) + 1):
  364. block = '\n'.join(lines[i:i + len(find_lines)])
  365. if block.strip() == trimmed_find:
  366. yield block
  367. # 8. ContextAwareReplacer - 上下文感知匹配
  368. def context_aware_replacer(content: str, find: str) -> Generator[str, None, None]:
  369. """基于上下文(首尾行)匹配,允许中间行有差异"""
  370. find_lines = find.split('\n')
  371. if find_lines and find_lines[-1] == '':
  372. find_lines.pop()
  373. if len(find_lines) < 3:
  374. return
  375. content_lines = content.split('\n')
  376. first_line = find_lines[0].strip()
  377. last_line = find_lines[-1].strip()
  378. # 查找首尾匹配的块
  379. for i in range(len(content_lines)):
  380. if content_lines[i].strip() != first_line:
  381. continue
  382. for j in range(i + 2, len(content_lines)):
  383. if content_lines[j].strip() == last_line:
  384. block_lines = content_lines[i:j + 1]
  385. # 检查块长度是否匹配
  386. if len(block_lines) == len(find_lines):
  387. # 计算中间行匹配率
  388. matching_lines = 0
  389. total_non_empty = 0
  390. for k in range(1, len(block_lines) - 1):
  391. block_line = block_lines[k].strip()
  392. find_line = find_lines[k].strip()
  393. if block_line or find_line:
  394. total_non_empty += 1
  395. if block_line == find_line:
  396. matching_lines += 1
  397. # 至少 50% 的中间行匹配
  398. if total_non_empty == 0 or matching_lines / total_non_empty >= 0.5:
  399. yield '\n'.join(block_lines)
  400. break
  401. break
  402. # 9. MultiOccurrenceReplacer - 多次出现匹配
  403. def multi_occurrence_replacer(content: str, find: str) -> Generator[str, None, None]:
  404. """yield 所有精确匹配,用于 replace_all"""
  405. start_index = 0
  406. while True:
  407. index = content.find(find, start_index)
  408. if index == -1:
  409. break
  410. yield find
  411. start_index = index + len(find)
  412. # ============================================================================
  413. # 辅助函数
  414. # ============================================================================
  415. def _create_diff(filepath: str, old_content: str, new_content: str) -> str:
  416. """生成 unified diff"""
  417. old_lines = old_content.splitlines(keepends=True)
  418. new_lines = new_content.splitlines(keepends=True)
  419. diff_lines = list(difflib.unified_diff(
  420. old_lines,
  421. new_lines,
  422. fromfile=f"a/{filepath}",
  423. tofile=f"b/{filepath}",
  424. lineterm=''
  425. ))
  426. if not diff_lines:
  427. return "(无变更)"
  428. return ''.join(diff_lines)