edit.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. """
  2. Edit Tool - 文件编辑工具
  3. 参考 OpenCode 的 edit.ts 完整实现。
  4. 核心功能:
  5. - 精确字符串替换
  6. - 9 种智能匹配策略(按优先级依次尝试)
  7. - 生成 diff 预览
  8. """
  9. from pathlib import Path
  10. from typing import Optional, Generator
  11. import difflib
  12. import re
  13. from agent.tools import tool, ToolResult, ToolContext
  14. @tool(description="编辑文件,使用精确字符串替换。支持多种智能匹配策略。", hidden_params=["context"])
  15. async def edit_file(
  16. file_path: str,
  17. old_string: str,
  18. new_string: str,
  19. replace_all: bool = False,
  20. context: Optional[ToolContext] = None
  21. ) -> ToolResult:
  22. """
  23. 编辑文件内容
  24. 使用 9 种智能匹配策略,按优先级尝试:
  25. 1. SimpleReplacer - 精确匹配
  26. 2. LineTrimmedReplacer - 忽略行首尾空白
  27. 3. BlockAnchorReplacer - 基于首尾锚点的块匹配(使用 Levenshtein 相似度)
  28. 4. WhitespaceNormalizedReplacer - 空白归一化
  29. 5. IndentationFlexibleReplacer - 灵活缩进匹配
  30. 6. EscapeNormalizedReplacer - 转义序列归一化
  31. 7. TrimmedBoundaryReplacer - 边界空白裁剪
  32. 8. ContextAwareReplacer - 上下文感知匹配
  33. 9. MultiOccurrenceReplacer - 多次出现匹配
  34. Args:
  35. file_path: 文件路径
  36. old_string: 要替换的文本
  37. new_string: 替换后的文本
  38. replace_all: 是否替换所有匹配(默认 False,只替换唯一匹配)
  39. context: 工具上下文
  40. Returns:
  41. ToolResult: 编辑结果和 diff
  42. """
  43. if old_string == new_string:
  44. return ToolResult(
  45. title="无需编辑",
  46. output="old_string 和 new_string 相同",
  47. error="Strings are identical"
  48. )
  49. # 解析路径
  50. path = Path(file_path)
  51. if not path.is_absolute():
  52. path = Path.cwd() / path
  53. # 检查文件
  54. if not path.exists():
  55. return ToolResult(
  56. title="文件未找到",
  57. output=f"文件不存在: {file_path}",
  58. error="File not found"
  59. )
  60. if path.is_dir():
  61. return ToolResult(
  62. title="路径错误",
  63. output=f"路径是目录,不是文件: {file_path}",
  64. error="Path is a directory"
  65. )
  66. # 读取文件
  67. try:
  68. with open(path, 'r', encoding='utf-8') as f:
  69. content_old = f.read()
  70. except Exception as e:
  71. return ToolResult(
  72. title="读取失败",
  73. output=f"无法读取文件: {str(e)}",
  74. error=str(e)
  75. )
  76. # 执行替换
  77. try:
  78. content_new = replace(content_old, old_string, new_string, replace_all)
  79. except ValueError as e:
  80. return ToolResult(
  81. title="替换失败",
  82. output=str(e),
  83. error=str(e)
  84. )
  85. # 生成 diff
  86. diff = _create_diff(file_path, content_old, content_new)
  87. # 写入文件
  88. try:
  89. with open(path, 'w', encoding='utf-8') as f:
  90. f.write(content_new)
  91. except Exception as e:
  92. return ToolResult(
  93. title="写入失败",
  94. output=f"无法写入文件: {str(e)}",
  95. error=str(e)
  96. )
  97. # 统计变更
  98. old_lines = content_old.count('\n')
  99. new_lines = content_new.count('\n')
  100. return ToolResult(
  101. title=path.name,
  102. output=f"编辑成功\n\n{diff}",
  103. metadata={
  104. "diff": diff,
  105. "old_lines": old_lines,
  106. "new_lines": new_lines,
  107. "additions": max(0, new_lines - old_lines),
  108. "deletions": max(0, old_lines - new_lines)
  109. },
  110. long_term_memory=f"已编辑文件 {path.name}"
  111. )
  112. # ============================================================================
  113. # 替换策略(Replacers)
  114. # ============================================================================
  115. def replace(content: str, old_string: str, new_string: str, replace_all: bool = False) -> str:
  116. """
  117. 使用多种策略尝试替换
  118. 按优先级尝试所有策略,直到找到匹配
  119. """
  120. if old_string == new_string:
  121. raise ValueError("old_string 和 new_string 必须不同")
  122. not_found = True
  123. # 按优先级尝试策略
  124. for replacer in [
  125. simple_replacer,
  126. line_trimmed_replacer,
  127. block_anchor_replacer,
  128. whitespace_normalized_replacer,
  129. indentation_flexible_replacer,
  130. escape_normalized_replacer,
  131. trimmed_boundary_replacer,
  132. context_aware_replacer,
  133. multi_occurrence_replacer,
  134. ]:
  135. for search in replacer(content, old_string):
  136. index = content.find(search)
  137. if index == -1:
  138. continue
  139. not_found = False
  140. if replace_all:
  141. return content.replace(search, new_string)
  142. # 检查唯一性
  143. last_index = content.rfind(search)
  144. if index != last_index:
  145. continue
  146. return content[:index] + new_string + content[index + len(search):]
  147. if not_found:
  148. raise ValueError("在文件中未找到匹配的文本")
  149. raise ValueError(
  150. "找到多个匹配。请在 old_string 中提供更多上下文以唯一标识,"
  151. "或使用 replace_all=True 替换所有匹配。"
  152. )
  153. # 1. SimpleReplacer - 精确匹配
  154. def simple_replacer(content: str, find: str) -> Generator[str, None, None]:
  155. """精确匹配"""
  156. yield find
  157. # 2. LineTrimmedReplacer - 忽略行首尾空白
  158. def line_trimmed_replacer(content: str, find: str) -> Generator[str, None, None]:
  159. """忽略行首尾空白进行匹配"""
  160. content_lines = content.split('\n')
  161. find_lines = find.rstrip('\n').split('\n')
  162. for i in range(len(content_lines) - len(find_lines) + 1):
  163. # 检查所有行是否匹配(忽略首尾空白)
  164. matches = all(
  165. content_lines[i + j].strip() == find_lines[j].strip()
  166. for j in range(len(find_lines))
  167. )
  168. if matches:
  169. # 计算原始文本位置
  170. match_start = sum(len(content_lines[k]) + 1 for k in range(i))
  171. match_end = match_start + sum(
  172. len(content_lines[i + k]) + (1 if k < len(find_lines) - 1 else 0)
  173. for k in range(len(find_lines))
  174. )
  175. yield content[match_start:match_end]
  176. # 3. BlockAnchorReplacer - 基于首尾锚点的块匹配
  177. def block_anchor_replacer(content: str, find: str) -> Generator[str, None, None]:
  178. """
  179. 基于首尾行作为锚点进行块匹配
  180. 使用 Levenshtein 距离计算中间行的相似度
  181. """
  182. content_lines = content.split('\n')
  183. find_lines = find.rstrip('\n').split('\n')
  184. if len(find_lines) < 3:
  185. return
  186. first_line_find = find_lines[0].strip()
  187. last_line_find = find_lines[-1].strip()
  188. find_block_size = len(find_lines)
  189. # 收集所有候选位置(首尾行都匹配)
  190. candidates = []
  191. for i in range(len(content_lines)):
  192. if content_lines[i].strip() != first_line_find:
  193. continue
  194. # 查找匹配的尾行
  195. for j in range(i + 2, len(content_lines)):
  196. if content_lines[j].strip() == last_line_find:
  197. candidates.append((i, j))
  198. break
  199. if not candidates:
  200. return
  201. # 单个候选:使用宽松阈值
  202. if len(candidates) == 1:
  203. start_line, end_line = candidates[0]
  204. actual_block_size = end_line - start_line + 1
  205. similarity = _calculate_block_similarity(
  206. content_lines[start_line:end_line + 1],
  207. find_lines
  208. )
  209. if similarity >= 0.0: # SINGLE_CANDIDATE_SIMILARITY_THRESHOLD
  210. match_start = sum(len(content_lines[k]) + 1 for k in range(start_line))
  211. match_end = match_start + sum(
  212. len(content_lines[k]) + (1 if k < end_line else 0)
  213. for k in range(start_line, end_line + 1)
  214. )
  215. yield content[match_start:match_end]
  216. return
  217. # 多个候选:选择相似度最高的
  218. best_match = None
  219. max_similarity = -1
  220. for start_line, end_line in candidates:
  221. similarity = _calculate_block_similarity(
  222. content_lines[start_line:end_line + 1],
  223. find_lines
  224. )
  225. if similarity > max_similarity:
  226. max_similarity = similarity
  227. best_match = (start_line, end_line)
  228. if max_similarity >= 0.3 and best_match: # MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD
  229. start_line, end_line = best_match
  230. match_start = sum(len(content_lines[k]) + 1 for k in range(start_line))
  231. match_end = match_start + sum(
  232. len(content_lines[k]) + (1 if k < end_line else 0)
  233. for k in range(start_line, end_line + 1)
  234. )
  235. yield content[match_start:match_end]
  236. def _calculate_block_similarity(content_block: list, find_block: list) -> float:
  237. """计算块相似度(使用 Levenshtein 距离)"""
  238. actual_size = len(content_block)
  239. find_size = len(find_block)
  240. lines_to_check = min(find_size - 2, actual_size - 2)
  241. if lines_to_check <= 0:
  242. return 1.0
  243. similarity = 0.0
  244. for j in range(1, min(find_size - 1, actual_size - 1)):
  245. content_line = content_block[j].strip()
  246. find_line = find_block[j].strip()
  247. max_len = max(len(content_line), len(find_line))
  248. if max_len == 0:
  249. continue
  250. distance = _levenshtein(content_line, find_line)
  251. similarity += (1 - distance / max_len) / lines_to_check
  252. return similarity
  253. def _levenshtein(a: str, b: str) -> int:
  254. """Levenshtein 距离算法"""
  255. if not a:
  256. return len(b)
  257. if not b:
  258. return len(a)
  259. matrix = [[0] * (len(b) + 1) for _ in range(len(a) + 1)]
  260. for i in range(len(a) + 1):
  261. matrix[i][0] = i
  262. for j in range(len(b) + 1):
  263. matrix[0][j] = j
  264. for i in range(1, len(a) + 1):
  265. for j in range(1, len(b) + 1):
  266. cost = 0 if a[i - 1] == b[j - 1] else 1
  267. matrix[i][j] = min(
  268. matrix[i - 1][j] + 1, # 删除
  269. matrix[i][j - 1] + 1, # 插入
  270. matrix[i - 1][j - 1] + cost # 替换
  271. )
  272. return matrix[len(a)][len(b)]
  273. # 4. WhitespaceNormalizedReplacer - 空白归一化
  274. def whitespace_normalized_replacer(content: str, find: str) -> Generator[str, None, None]:
  275. """空白归一化匹配(所有空白序列归一化为单个空格)"""
  276. def normalize_ws(text: str) -> str:
  277. return re.sub(r'\s+', ' ', text).strip()
  278. normalized_find = normalize_ws(find)
  279. lines = content.split('\n')
  280. # 单行匹配
  281. for line in lines:
  282. if normalize_ws(line) == normalized_find:
  283. yield line
  284. continue
  285. # 子串匹配
  286. if normalized_find in normalize_ws(line):
  287. words = find.strip().split()
  288. if words:
  289. pattern = r'\s+'.join(re.escape(word) for word in words)
  290. match = re.search(pattern, line)
  291. if match:
  292. yield match.group(0)
  293. # 多行匹配
  294. find_lines = find.split('\n')
  295. if len(find_lines) > 1:
  296. for i in range(len(lines) - len(find_lines) + 1):
  297. block = lines[i:i + len(find_lines)]
  298. if normalize_ws('\n'.join(block)) == normalized_find:
  299. yield '\n'.join(block)
  300. # 5. IndentationFlexibleReplacer - 灵活缩进匹配
  301. def indentation_flexible_replacer(content: str, find: str) -> Generator[str, None, None]:
  302. """移除缩进后匹配"""
  303. def remove_indentation(text: str) -> str:
  304. lines = text.split('\n')
  305. non_empty = [line for line in lines if line.strip()]
  306. if not non_empty:
  307. return text
  308. min_indent = min(len(line) - len(line.lstrip()) for line in non_empty)
  309. return '\n'.join(
  310. line[min_indent:] if line.strip() else line
  311. for line in lines
  312. )
  313. normalized_find = remove_indentation(find)
  314. content_lines = content.split('\n')
  315. find_lines = find.split('\n')
  316. for i in range(len(content_lines) - len(find_lines) + 1):
  317. block = '\n'.join(content_lines[i:i + len(find_lines)])
  318. if remove_indentation(block) == normalized_find:
  319. yield block
  320. # 6. EscapeNormalizedReplacer - 转义序列归一化
  321. def escape_normalized_replacer(content: str, find: str) -> Generator[str, None, None]:
  322. """反转义后匹配"""
  323. def unescape_string(s: str) -> str:
  324. replacements = {
  325. r'\n': '\n',
  326. r'\t': '\t',
  327. r'\r': '\r',
  328. r"\'": "'",
  329. r'\"': '"',
  330. r'\`': '`',
  331. r'\\': '\\',
  332. r'\$': '$',
  333. }
  334. result = s
  335. for escaped, unescaped in replacements.items():
  336. result = result.replace(escaped, unescaped)
  337. return result
  338. unescaped_find = unescape_string(find)
  339. # 直接匹配
  340. if unescaped_find in content:
  341. yield unescaped_find
  342. # 尝试反转义后匹配
  343. lines = content.split('\n')
  344. find_lines = unescaped_find.split('\n')
  345. for i in range(len(lines) - len(find_lines) + 1):
  346. block = '\n'.join(lines[i:i + len(find_lines)])
  347. if unescape_string(block) == unescaped_find:
  348. yield block
  349. # 7. TrimmedBoundaryReplacer - 边界空白裁剪
  350. def trimmed_boundary_replacer(content: str, find: str) -> Generator[str, None, None]:
  351. """裁剪边界空白后匹配"""
  352. trimmed_find = find.strip()
  353. if trimmed_find == find:
  354. return # 已经是 trimmed,无需尝试
  355. # 尝试匹配 trimmed 版本
  356. if trimmed_find in content:
  357. yield trimmed_find
  358. # 尝试块匹配
  359. lines = content.split('\n')
  360. find_lines = find.split('\n')
  361. for i in range(len(lines) - len(find_lines) + 1):
  362. block = '\n'.join(lines[i:i + len(find_lines)])
  363. if block.strip() == trimmed_find:
  364. yield block
  365. # 8. ContextAwareReplacer - 上下文感知匹配
  366. def context_aware_replacer(content: str, find: str) -> Generator[str, None, None]:
  367. """基于上下文(首尾行)匹配,允许中间行有差异"""
  368. find_lines = find.split('\n')
  369. if find_lines and find_lines[-1] == '':
  370. find_lines.pop()
  371. if len(find_lines) < 3:
  372. return
  373. content_lines = content.split('\n')
  374. first_line = find_lines[0].strip()
  375. last_line = find_lines[-1].strip()
  376. # 查找首尾匹配的块
  377. for i in range(len(content_lines)):
  378. if content_lines[i].strip() != first_line:
  379. continue
  380. for j in range(i + 2, len(content_lines)):
  381. if content_lines[j].strip() == last_line:
  382. block_lines = content_lines[i:j + 1]
  383. # 检查块长度是否匹配
  384. if len(block_lines) == len(find_lines):
  385. # 计算中间行匹配率
  386. matching_lines = 0
  387. total_non_empty = 0
  388. for k in range(1, len(block_lines) - 1):
  389. block_line = block_lines[k].strip()
  390. find_line = find_lines[k].strip()
  391. if block_line or find_line:
  392. total_non_empty += 1
  393. if block_line == find_line:
  394. matching_lines += 1
  395. # 至少 50% 的中间行匹配
  396. if total_non_empty == 0 or matching_lines / total_non_empty >= 0.5:
  397. yield '\n'.join(block_lines)
  398. break
  399. break
  400. # 9. MultiOccurrenceReplacer - 多次出现匹配
  401. def multi_occurrence_replacer(content: str, find: str) -> Generator[str, None, None]:
  402. """yield 所有精确匹配,用于 replace_all"""
  403. start_index = 0
  404. while True:
  405. index = content.find(find, start_index)
  406. if index == -1:
  407. break
  408. yield find
  409. start_index = index + len(find)
  410. # ============================================================================
  411. # 辅助函数
  412. # ============================================================================
  413. def _create_diff(filepath: str, old_content: str, new_content: str) -> str:
  414. """生成 unified diff"""
  415. old_lines = old_content.splitlines(keepends=True)
  416. new_lines = new_content.splitlines(keepends=True)
  417. diff_lines = list(difflib.unified_diff(
  418. old_lines,
  419. new_lines,
  420. fromfile=f"a/{filepath}",
  421. tofile=f"b/{filepath}",
  422. lineterm=''
  423. ))
  424. if not diff_lines:
  425. return "(无变更)"
  426. return ''.join(diff_lines)