|
|
@@ -233,10 +233,24 @@ class QueryGenerationAgent:
|
|
|
# 提取所有扩展的查询词
|
|
|
expanded = data.get("expanded_queries", {})
|
|
|
aggregated: List[str] = []
|
|
|
+ invalid_keywords = ["无关", "超出", "不相关", "不属于", "无法生成"]
|
|
|
|
|
|
- # 收集粗颗粒度查询
|
|
|
+ # 收集粗颗粒度查询并检测是否不符合创作领域
|
|
|
for item in expanded.get("coarse_grained", []) or []:
|
|
|
q = str(item.get("query", "")).strip()
|
|
|
+ reason = str(item.get("reason", "")).strip()
|
|
|
+
|
|
|
+ # 检测是否表明问题不符合创作领域
|
|
|
+ if q and any(keyword in q for keyword in invalid_keywords):
|
|
|
+ error_msg = q if len(q) <= 100 else reason[:100] if reason else "问题不符合内容创作领域"
|
|
|
+ logger.info(f"检测到不符合创作领域的问题: {error_msg}")
|
|
|
+ if state.get("task_id", 0) > 0:
|
|
|
+ self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
+ state["result_queries"] = []
|
|
|
+ state["initial_queries"] = []
|
|
|
+ state["refined_queries"] = []
|
|
|
+ return state
|
|
|
+
|
|
|
if q:
|
|
|
aggregated.append(q)
|
|
|
|
|
|
@@ -252,6 +266,17 @@ class QueryGenerationAgent:
|
|
|
if q:
|
|
|
aggregated.append(q)
|
|
|
|
|
|
+ # 如果所有查询词都为空,可能表示无法生成有效查询
|
|
|
+ if not aggregated:
|
|
|
+ error_msg = "无法生成有效的内容创作查询词"
|
|
|
+ logger.info(error_msg)
|
|
|
+ if state.get("task_id", 0) > 0:
|
|
|
+ self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
+ state["result_queries"] = []
|
|
|
+ state["initial_queries"] = []
|
|
|
+ state["refined_queries"] = []
|
|
|
+ return state
|
|
|
+
|
|
|
# 去重,保持顺序
|
|
|
seen = set()
|
|
|
deduped: List[str] = []
|