|
|
@@ -6,7 +6,7 @@
|
|
|
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, List, Optional
|
|
|
+from typing import Dict, List, Optional, Set
|
|
|
import re
|
|
|
import sys
|
|
|
|
|
|
@@ -15,6 +15,7 @@ project_root = Path(__file__).parent.parent.parent
|
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
|
|
from script.detail import get_xiaohongshu_detail
|
|
|
+from script.data_processing.path_config import PathConfig
|
|
|
|
|
|
|
|
|
def extract_post_id_from_filename(filename: str) -> str:
|
|
|
@@ -228,14 +229,20 @@ def merge_results(all_results: List[Dict]) -> Dict:
|
|
|
return merged
|
|
|
|
|
|
|
|
|
-def convert_to_array_format(merged_dict: Dict, fetch_details: bool = True, time_filter: Optional[str] = None) -> Dict:
|
|
|
+def convert_to_array_format(
|
|
|
+ merged_dict: Dict,
|
|
|
+ fetch_details: bool = True,
|
|
|
+ time_filter: Optional[str] = None,
|
|
|
+ exclude_post_ids: Optional[Set[str]] = None
|
|
|
+) -> Dict:
|
|
|
"""
|
|
|
将字典格式转换为数组格式,并添加帖子详情
|
|
|
|
|
|
Args:
|
|
|
merged_dict: 字典格式的结果
|
|
|
fetch_details: 是否获取帖子详情,默认为True
|
|
|
- time_filter: 时间过滤阈值,只保留发布时间>=该时间的帖子,格式为 "YYYY-MM-DD HH:MM:SS"
|
|
|
+ time_filter: 时间过滤阈值,只保留发布时间<该时间的帖子,格式为 "YYYY-MM-DD HH:MM:SS"
|
|
|
+ exclude_post_ids: 要排除的帖子ID集合
|
|
|
|
|
|
Returns:
|
|
|
数组格式的结果
|
|
|
@@ -265,11 +272,23 @@ def convert_to_array_format(merged_dict: Dict, fetch_details: bool = True, time_
|
|
|
|
|
|
print(f"成功获取 {len(post_details)} 个帖子详情")
|
|
|
|
|
|
- # 如果启用时间过滤,过滤帖子(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
|
|
|
- if time_filter:
|
|
|
+ # 应用过滤规则
|
|
|
+ filtered_count = 0
|
|
|
+
|
|
|
+ # 1. 如果启用帖子ID过滤
|
|
|
+ if exclude_post_ids:
|
|
|
+ print(f"\n正在应用帖子ID过滤,排除 {len(exclude_post_ids)} 个当前帖子...")
|
|
|
+ before_count = len(post_details)
|
|
|
+ post_details = {pid: detail for pid, detail in post_details.items() if pid not in exclude_post_ids}
|
|
|
+ filtered_count = before_count - len(post_details)
|
|
|
+ if filtered_count > 0:
|
|
|
+ print(f" ⚠️ 过滤掉 {filtered_count} 个当前帖子")
|
|
|
+ print(f"保留 {len(post_details)} 个帖子")
|
|
|
+
|
|
|
+ # 2. 如果启用时间过滤(过滤掉发布时间晚于等于阈值的帖子,避免穿越)
|
|
|
+ elif time_filter:
|
|
|
print(f"\n正在应用时间过滤 (< {time_filter}),避免使用晚于当前帖子的数据...")
|
|
|
filtered_post_ids = set()
|
|
|
- filtered_count = 0
|
|
|
for post_id, detail in post_details.items():
|
|
|
publish_time = detail.get('publish_time', '')
|
|
|
if publish_time < time_filter:
|
|
|
@@ -288,8 +307,8 @@ def convert_to_array_format(merged_dict: Dict, fetch_details: bool = True, time_
|
|
|
# 为每个来源添加帖子详情
|
|
|
enhanced_sources = []
|
|
|
for source in data["来源"]:
|
|
|
- # 如果启用时间过滤,跳过不符合时间条件的帖子
|
|
|
- if fetch_details and time_filter and source["帖子id"] not in post_details:
|
|
|
+ # 如果启用过滤,跳过不符合条件的帖子
|
|
|
+ if fetch_details and (time_filter or exclude_post_ids) and source["帖子id"] not in post_details:
|
|
|
continue
|
|
|
|
|
|
enhanced_source = source.copy()
|
|
|
@@ -307,6 +326,38 @@ def convert_to_array_format(merged_dict: Dict, fetch_details: bool = True, time_
|
|
|
return result
|
|
|
|
|
|
|
|
|
+def get_current_post_ids(current_posts_dir: Path) -> Set[str]:
|
|
|
+ """
|
|
|
+ 获取当前帖子目录中的所有帖子ID
|
|
|
+
|
|
|
+ Args:
|
|
|
+ current_posts_dir: 当前帖子目录路径
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 当前帖子ID集合
|
|
|
+ """
|
|
|
+ if not current_posts_dir.exists():
|
|
|
+ print(f"警告: 当前帖子目录不存在: {current_posts_dir}")
|
|
|
+ return set()
|
|
|
+
|
|
|
+ json_files = list(current_posts_dir.glob("*.json"))
|
|
|
+ if not json_files:
|
|
|
+ print(f"警告: 当前帖子目录为空: {current_posts_dir}")
|
|
|
+ return set()
|
|
|
+
|
|
|
+ print(f"\n正在获取当前帖子ID...")
|
|
|
+ print(f"找到 {len(json_files)} 个当前帖子")
|
|
|
+
|
|
|
+ post_ids = set()
|
|
|
+ for file_path in json_files:
|
|
|
+ post_id = extract_post_id_from_filename(file_path.name)
|
|
|
+ if post_id:
|
|
|
+ post_ids.add(post_id)
|
|
|
+
|
|
|
+ print(f"提取到 {len(post_ids)} 个帖子ID")
|
|
|
+ return post_ids
|
|
|
+
|
|
|
+
|
|
|
def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
|
|
|
"""
|
|
|
获取当前帖子目录中最早的发布时间
|
|
|
@@ -354,17 +405,23 @@ def get_earliest_publish_time(current_posts_dir: Path) -> Optional[str]:
|
|
|
|
|
|
|
|
|
def main():
|
|
|
- # 输入输出路径(默认使用项目根目录下的 data/data_1117 目录)
|
|
|
- script_dir = Path(__file__).parent
|
|
|
- project_root = script_dir.parent.parent
|
|
|
- data_dir = project_root / "data" / "data_1118"
|
|
|
+ # 使用路径配置
|
|
|
+ config = PathConfig()
|
|
|
+
|
|
|
+ # 确保输出目录存在
|
|
|
+ config.ensure_dirs()
|
|
|
|
|
|
- input_dir = data_dir / "过去帖子_what解构结果"
|
|
|
- current_posts_dir = data_dir / "当前帖子_what解构结果"
|
|
|
- output_file = data_dir / "特征名称_帖子来源.json"
|
|
|
+ # 获取路径
|
|
|
+ input_dir = config.historical_posts_dir
|
|
|
+ current_posts_dir = config.current_posts_dir
|
|
|
+ output_file = config.feature_source_mapping_file
|
|
|
|
|
|
- # 获取当前帖子的最早发布时间
|
|
|
- earliest_time = get_earliest_publish_time(current_posts_dir)
|
|
|
+ print(f"账号: {config.account_name}")
|
|
|
+ print(f"过滤模式: {config.filter_mode}")
|
|
|
+ print(f"过去帖子目录: {input_dir}")
|
|
|
+ print(f"当前帖子目录: {current_posts_dir}")
|
|
|
+ print(f"输出文件: {output_file}")
|
|
|
+ print()
|
|
|
|
|
|
print(f"\n正在扫描目录: {input_dir}")
|
|
|
|
|
|
@@ -383,15 +440,35 @@ def main():
|
|
|
print("\n正在合并结果...")
|
|
|
merged_result = merge_results(all_results)
|
|
|
|
|
|
- # 转换为数组格式(带时间过滤)
|
|
|
+ # 根据配置的过滤模式应用过滤
|
|
|
+ filter_mode = config.filter_mode
|
|
|
+ time_filter = None
|
|
|
+ exclude_post_ids = None
|
|
|
+
|
|
|
+ if filter_mode == "exclude_current_posts":
|
|
|
+ # 新规则:排除当前帖子ID
|
|
|
+ print("\n应用过滤规则: 排除当前帖子ID")
|
|
|
+ exclude_post_ids = get_current_post_ids(current_posts_dir)
|
|
|
+ elif filter_mode == "time_based":
|
|
|
+ # 旧规则:基于发布时间
|
|
|
+ print("\n应用过滤规则: 基于发布时间")
|
|
|
+ time_filter = get_earliest_publish_time(current_posts_dir)
|
|
|
+ elif filter_mode == "none":
|
|
|
+ print("\n过滤模式: none,不应用任何过滤")
|
|
|
+ else:
|
|
|
+ print(f"\n警告: 未知的过滤模式 '{filter_mode}',不应用过滤")
|
|
|
+
|
|
|
+ # 转换为数组格式(带过滤)
|
|
|
print("正在转换为数组格式...")
|
|
|
- final_result = convert_to_array_format(merged_result, fetch_details=True, time_filter=earliest_time)
|
|
|
+ final_result = convert_to_array_format(
|
|
|
+ merged_result,
|
|
|
+ fetch_details=True,
|
|
|
+ time_filter=time_filter,
|
|
|
+ exclude_post_ids=exclude_post_ids
|
|
|
+ )
|
|
|
|
|
|
# 统计信息
|
|
|
- if earliest_time:
|
|
|
- print(f"\n提取统计 (已过滤掉发布时间 >= {earliest_time} 的帖子):")
|
|
|
- else:
|
|
|
- print(f"\n提取统计:")
|
|
|
+ print(f"\n提取统计:")
|
|
|
for category in ["灵感点", "目的点", "关键点"]:
|
|
|
feature_count = len(final_result[category])
|
|
|
source_count = sum(len(item["特征来源"]) for item in final_result[category])
|