|
|
@@ -267,7 +267,8 @@ def compare_phrases_batch(
|
|
|
|
|
|
def compare_phrases_cartesian(
|
|
|
phrases_a: List[str],
|
|
|
- phrases_b: List[str]
|
|
|
+ phrases_b: List[str],
|
|
|
+ batch_size: int = 450
|
|
|
) -> List[List[Dict[str, Any]]]:
|
|
|
"""
|
|
|
计算笛卡尔积相似度(M×N矩阵)
|
|
|
@@ -278,6 +279,7 @@ def compare_phrases_cartesian(
|
|
|
Args:
|
|
|
phrases_a: 第一组短语列表 (M个)
|
|
|
phrases_b: 第二组短语列表 (N个)
|
|
|
+ batch_size: 每批处理的最大数量(API限制500,默认450留余量)
|
|
|
|
|
|
Returns:
|
|
|
M×N的结果矩阵(嵌套列表)
|
|
|
@@ -302,22 +304,42 @@ def compare_phrases_cartesian(
|
|
|
if not phrases_a or not phrases_b:
|
|
|
return [[]]
|
|
|
|
|
|
+ M = len(phrases_a)
|
|
|
+ N = len(phrases_b)
|
|
|
+
|
|
|
try:
|
|
|
- # 调用API计算笛卡尔积(一次性批量调用,不受max_concurrent限制)
|
|
|
client = _get_api_client()
|
|
|
- api_response = client.cartesian_similarity(phrases_a, phrases_b, model_name=None)
|
|
|
- api_results = api_response["results"]
|
|
|
-
|
|
|
- M = len(phrases_a)
|
|
|
- N = len(phrases_b)
|
|
|
|
|
|
- # 返回嵌套列表(带完整说明)
|
|
|
+ # 初始化结果矩阵
|
|
|
results = [[None for _ in range(N)] for _ in range(M)]
|
|
|
- for idx, api_result in enumerate(api_results):
|
|
|
- i = idx // N
|
|
|
- j = idx % N
|
|
|
- score = float(api_result["score"])
|
|
|
- results[i][j] = _format_result(score)
|
|
|
+
|
|
|
+ # 如果 phrases_b 超过 batch_size,分批处理
|
|
|
+ if N <= batch_size:
|
|
|
+ # 不需要分批,直接调用
|
|
|
+ api_response = client.cartesian_similarity(phrases_a, phrases_b, model_name=None)
|
|
|
+ api_results = api_response["results"]
|
|
|
+
|
|
|
+ for idx, api_result in enumerate(api_results):
|
|
|
+ i = idx // N
|
|
|
+ j = idx % N
|
|
|
+ score = float(api_result["score"])
|
|
|
+ results[i][j] = _format_result(score)
|
|
|
+ else:
|
|
|
+ # 需要分批处理 phrases_b
|
|
|
+ for batch_start in range(0, N, batch_size):
|
|
|
+ batch_end = min(batch_start + batch_size, N)
|
|
|
+ batch_b = phrases_b[batch_start:batch_end]
|
|
|
+ batch_n = len(batch_b)
|
|
|
+
|
|
|
+ api_response = client.cartesian_similarity(phrases_a, batch_b, model_name=None)
|
|
|
+ api_results = api_response["results"]
|
|
|
+
|
|
|
+ for idx, api_result in enumerate(api_results):
|
|
|
+ i = idx // batch_n
|
|
|
+ j = batch_start + (idx % batch_n)
|
|
|
+ score = float(api_result["score"])
|
|
|
+ results[i][j] = _format_result(score)
|
|
|
+
|
|
|
return results
|
|
|
|
|
|
except Exception as e:
|