""" Edit Tool - 文件编辑工具 参考 OpenCode 的 edit.ts 完整实现。 核心功能: - 精确字符串替换 - 9 种智能匹配策略(按优先级依次尝试) - 生成 diff 预览 """ from pathlib import Path from typing import Optional, Generator import difflib import re from agent.tools import tool, ToolResult, ToolContext @tool(description="编辑文件,使用精确字符串替换。支持多种智能匹配策略。") async def edit_file( file_path: str, old_string: str, new_string: str, replace_all: bool = False, context: Optional[ToolContext] = None ) -> ToolResult: """ 编辑文件内容 使用 9 种智能匹配策略,按优先级尝试: 1. SimpleReplacer - 精确匹配 2. LineTrimmedReplacer - 忽略行首尾空白 3. BlockAnchorReplacer - 基于首尾锚点的块匹配(使用 Levenshtein 相似度) 4. WhitespaceNormalizedReplacer - 空白归一化 5. IndentationFlexibleReplacer - 灵活缩进匹配 6. EscapeNormalizedReplacer - 转义序列归一化 7. TrimmedBoundaryReplacer - 边界空白裁剪 8. ContextAwareReplacer - 上下文感知匹配 9. MultiOccurrenceReplacer - 多次出现匹配 Args: file_path: 文件路径 old_string: 要替换的文本 new_string: 替换后的文本 replace_all: 是否替换所有匹配(默认 False,只替换唯一匹配) context: 工具上下文 Returns: ToolResult: 编辑结果和 diff """ if old_string == new_string: return ToolResult( title="无需编辑", output="old_string 和 new_string 相同", error="Strings are identical" ) # 解析路径 path = Path(file_path) if not path.is_absolute(): path = Path.cwd() / path # 检查文件 if not path.exists(): return ToolResult( title="文件未找到", output=f"文件不存在: {file_path}", error="File not found" ) if path.is_dir(): return ToolResult( title="路径错误", output=f"路径是目录,不是文件: {file_path}", error="Path is a directory" ) # 读取文件 try: with open(path, 'r', encoding='utf-8') as f: content_old = f.read() except Exception as e: return ToolResult( title="读取失败", output=f"无法读取文件: {str(e)}", error=str(e) ) # 执行替换 try: content_new = replace(content_old, old_string, new_string, replace_all) except ValueError as e: return ToolResult( title="替换失败", output=str(e), error=str(e) ) # 生成 diff diff = _create_diff(file_path, content_old, content_new) # 写入文件 try: with open(path, 'w', encoding='utf-8') as f: f.write(content_new) except Exception as e: return ToolResult( title="写入失败", output=f"无法写入文件: {str(e)}", error=str(e) ) # 统计变更 old_lines = content_old.count('\n') new_lines = content_new.count('\n') return ToolResult( title=path.name, output=f"编辑成功\n\n{diff}", metadata={ "diff": diff, "old_lines": old_lines, "new_lines": new_lines, "additions": max(0, new_lines - old_lines), "deletions": max(0, old_lines - new_lines) }, long_term_memory=f"已编辑文件 {path.name}" ) # ============================================================================ # 替换策略(Replacers) # ============================================================================ def replace(content: str, old_string: str, new_string: str, replace_all: bool = False) -> str: """ 使用多种策略尝试替换 按优先级尝试所有策略,直到找到匹配 """ if old_string == new_string: raise ValueError("old_string 和 new_string 必须不同") not_found = True # 按优先级尝试策略 for replacer in [ simple_replacer, line_trimmed_replacer, block_anchor_replacer, whitespace_normalized_replacer, indentation_flexible_replacer, escape_normalized_replacer, trimmed_boundary_replacer, context_aware_replacer, multi_occurrence_replacer, ]: for search in replacer(content, old_string): index = content.find(search) if index == -1: continue not_found = False if replace_all: return content.replace(search, new_string) # 检查唯一性 last_index = content.rfind(search) if index != last_index: continue return content[:index] + new_string + content[index + len(search):] if not_found: raise ValueError("在文件中未找到匹配的文本") raise ValueError( "找到多个匹配。请在 old_string 中提供更多上下文以唯一标识," "或使用 replace_all=True 替换所有匹配。" ) # 1. SimpleReplacer - 精确匹配 def simple_replacer(content: str, find: str) -> Generator[str, None, None]: """精确匹配""" yield find # 2. LineTrimmedReplacer - 忽略行首尾空白 def line_trimmed_replacer(content: str, find: str) -> Generator[str, None, None]: """忽略行首尾空白进行匹配""" content_lines = content.split('\n') find_lines = find.rstrip('\n').split('\n') for i in range(len(content_lines) - len(find_lines) + 1): # 检查所有行是否匹配(忽略首尾空白) matches = all( content_lines[i + j].strip() == find_lines[j].strip() for j in range(len(find_lines)) ) if matches: # 计算原始文本位置 match_start = sum(len(content_lines[k]) + 1 for k in range(i)) match_end = match_start + sum( len(content_lines[i + k]) + (1 if k < len(find_lines) - 1 else 0) for k in range(len(find_lines)) ) yield content[match_start:match_end] # 3. BlockAnchorReplacer - 基于首尾锚点的块匹配 def block_anchor_replacer(content: str, find: str) -> Generator[str, None, None]: """ 基于首尾行作为锚点进行块匹配 使用 Levenshtein 距离计算中间行的相似度 """ content_lines = content.split('\n') find_lines = find.rstrip('\n').split('\n') if len(find_lines) < 3: return first_line_find = find_lines[0].strip() last_line_find = find_lines[-1].strip() find_block_size = len(find_lines) # 收集所有候选位置(首尾行都匹配) candidates = [] for i in range(len(content_lines)): if content_lines[i].strip() != first_line_find: continue # 查找匹配的尾行 for j in range(i + 2, len(content_lines)): if content_lines[j].strip() == last_line_find: candidates.append((i, j)) break if not candidates: return # 单个候选:使用宽松阈值 if len(candidates) == 1: start_line, end_line = candidates[0] actual_block_size = end_line - start_line + 1 similarity = _calculate_block_similarity( content_lines[start_line:end_line + 1], find_lines ) if similarity >= 0.0: # SINGLE_CANDIDATE_SIMILARITY_THRESHOLD match_start = sum(len(content_lines[k]) + 1 for k in range(start_line)) match_end = match_start + sum( len(content_lines[k]) + (1 if k < end_line else 0) for k in range(start_line, end_line + 1) ) yield content[match_start:match_end] return # 多个候选:选择相似度最高的 best_match = None max_similarity = -1 for start_line, end_line in candidates: similarity = _calculate_block_similarity( content_lines[start_line:end_line + 1], find_lines ) if similarity > max_similarity: max_similarity = similarity best_match = (start_line, end_line) if max_similarity >= 0.3 and best_match: # MULTIPLE_CANDIDATES_SIMILARITY_THRESHOLD start_line, end_line = best_match match_start = sum(len(content_lines[k]) + 1 for k in range(start_line)) match_end = match_start + sum( len(content_lines[k]) + (1 if k < end_line else 0) for k in range(start_line, end_line + 1) ) yield content[match_start:match_end] def _calculate_block_similarity(content_block: list, find_block: list) -> float: """计算块相似度(使用 Levenshtein 距离)""" actual_size = len(content_block) find_size = len(find_block) lines_to_check = min(find_size - 2, actual_size - 2) if lines_to_check <= 0: return 1.0 similarity = 0.0 for j in range(1, min(find_size - 1, actual_size - 1)): content_line = content_block[j].strip() find_line = find_block[j].strip() max_len = max(len(content_line), len(find_line)) if max_len == 0: continue distance = _levenshtein(content_line, find_line) similarity += (1 - distance / max_len) / lines_to_check return similarity def _levenshtein(a: str, b: str) -> int: """Levenshtein 距离算法""" if not a: return len(b) if not b: return len(a) matrix = [[0] * (len(b) + 1) for _ in range(len(a) + 1)] for i in range(len(a) + 1): matrix[i][0] = i for j in range(len(b) + 1): matrix[0][j] = j for i in range(1, len(a) + 1): for j in range(1, len(b) + 1): cost = 0 if a[i - 1] == b[j - 1] else 1 matrix[i][j] = min( matrix[i - 1][j] + 1, # 删除 matrix[i][j - 1] + 1, # 插入 matrix[i - 1][j - 1] + cost # 替换 ) return matrix[len(a)][len(b)] # 4. WhitespaceNormalizedReplacer - 空白归一化 def whitespace_normalized_replacer(content: str, find: str) -> Generator[str, None, None]: """空白归一化匹配(所有空白序列归一化为单个空格)""" def normalize_ws(text: str) -> str: return re.sub(r'\s+', ' ', text).strip() normalized_find = normalize_ws(find) lines = content.split('\n') # 单行匹配 for line in lines: if normalize_ws(line) == normalized_find: yield line continue # 子串匹配 if normalized_find in normalize_ws(line): words = find.strip().split() if words: pattern = r'\s+'.join(re.escape(word) for word in words) match = re.search(pattern, line) if match: yield match.group(0) # 多行匹配 find_lines = find.split('\n') if len(find_lines) > 1: for i in range(len(lines) - len(find_lines) + 1): block = lines[i:i + len(find_lines)] if normalize_ws('\n'.join(block)) == normalized_find: yield '\n'.join(block) # 5. IndentationFlexibleReplacer - 灵活缩进匹配 def indentation_flexible_replacer(content: str, find: str) -> Generator[str, None, None]: """移除缩进后匹配""" def remove_indentation(text: str) -> str: lines = text.split('\n') non_empty = [line for line in lines if line.strip()] if not non_empty: return text min_indent = min(len(line) - len(line.lstrip()) for line in non_empty) return '\n'.join( line[min_indent:] if line.strip() else line for line in lines ) normalized_find = remove_indentation(find) content_lines = content.split('\n') find_lines = find.split('\n') for i in range(len(content_lines) - len(find_lines) + 1): block = '\n'.join(content_lines[i:i + len(find_lines)]) if remove_indentation(block) == normalized_find: yield block # 6. EscapeNormalizedReplacer - 转义序列归一化 def escape_normalized_replacer(content: str, find: str) -> Generator[str, None, None]: """反转义后匹配""" def unescape_string(s: str) -> str: replacements = { r'\n': '\n', r'\t': '\t', r'\r': '\r', r"\'": "'", r'\"': '"', r'\`': '`', r'\\': '\\', r'\$': '$', } result = s for escaped, unescaped in replacements.items(): result = result.replace(escaped, unescaped) return result unescaped_find = unescape_string(find) # 直接匹配 if unescaped_find in content: yield unescaped_find # 尝试反转义后匹配 lines = content.split('\n') find_lines = unescaped_find.split('\n') for i in range(len(lines) - len(find_lines) + 1): block = '\n'.join(lines[i:i + len(find_lines)]) if unescape_string(block) == unescaped_find: yield block # 7. TrimmedBoundaryReplacer - 边界空白裁剪 def trimmed_boundary_replacer(content: str, find: str) -> Generator[str, None, None]: """裁剪边界空白后匹配""" trimmed_find = find.strip() if trimmed_find == find: return # 已经是 trimmed,无需尝试 # 尝试匹配 trimmed 版本 if trimmed_find in content: yield trimmed_find # 尝试块匹配 lines = content.split('\n') find_lines = find.split('\n') for i in range(len(lines) - len(find_lines) + 1): block = '\n'.join(lines[i:i + len(find_lines)]) if block.strip() == trimmed_find: yield block # 8. ContextAwareReplacer - 上下文感知匹配 def context_aware_replacer(content: str, find: str) -> Generator[str, None, None]: """基于上下文(首尾行)匹配,允许中间行有差异""" find_lines = find.split('\n') if find_lines and find_lines[-1] == '': find_lines.pop() if len(find_lines) < 3: return content_lines = content.split('\n') first_line = find_lines[0].strip() last_line = find_lines[-1].strip() # 查找首尾匹配的块 for i in range(len(content_lines)): if content_lines[i].strip() != first_line: continue for j in range(i + 2, len(content_lines)): if content_lines[j].strip() == last_line: block_lines = content_lines[i:j + 1] # 检查块长度是否匹配 if len(block_lines) == len(find_lines): # 计算中间行匹配率 matching_lines = 0 total_non_empty = 0 for k in range(1, len(block_lines) - 1): block_line = block_lines[k].strip() find_line = find_lines[k].strip() if block_line or find_line: total_non_empty += 1 if block_line == find_line: matching_lines += 1 # 至少 50% 的中间行匹配 if total_non_empty == 0 or matching_lines / total_non_empty >= 0.5: yield '\n'.join(block_lines) break break # 9. MultiOccurrenceReplacer - 多次出现匹配 def multi_occurrence_replacer(content: str, find: str) -> Generator[str, None, None]: """yield 所有精确匹配,用于 replace_all""" start_index = 0 while True: index = content.find(find, start_index) if index == -1: break yield find start_index = index + len(find) # ============================================================================ # 辅助函数 # ============================================================================ def _create_diff(filepath: str, old_content: str, new_content: str) -> str: """生成 unified diff""" old_lines = old_content.splitlines(keepends=True) new_lines = new_content.splitlines(keepends=True) diff_lines = list(difflib.unified_diff( old_lines, new_lines, fromfile=f"a/{filepath}", tofile=f"b/{filepath}", lineterm='' )) if not diff_lines: return "(无变更)" return ''.join(diff_lines)