yangxiaohui hai 6 horas
pai
achega
078f06d219
Modificáronse 1 ficheiros con 35 adicións e 13 borrados
  1. 35 13
      lib/text_embedding_api.py

+ 35 - 13
lib/text_embedding_api.py

@@ -267,7 +267,8 @@ def compare_phrases_batch(
 
 def compare_phrases_cartesian(
     phrases_a: List[str],
-    phrases_b: List[str]
+    phrases_b: List[str],
+    batch_size: int = 450
 ) -> List[List[Dict[str, Any]]]:
     """
     计算笛卡尔积相似度(M×N矩阵)
@@ -278,6 +279,7 @@ def compare_phrases_cartesian(
     Args:
         phrases_a: 第一组短语列表 (M个)
         phrases_b: 第二组短语列表 (N个)
+        batch_size: 每批处理的最大数量(API限制500,默认450留余量)
 
     Returns:
         M×N的结果矩阵(嵌套列表)
@@ -302,22 +304,42 @@ def compare_phrases_cartesian(
     if not phrases_a or not phrases_b:
         return [[]]
 
+    M = len(phrases_a)
+    N = len(phrases_b)
+
     try:
-        # 调用API计算笛卡尔积(一次性批量调用,不受max_concurrent限制)
         client = _get_api_client()
-        api_response = client.cartesian_similarity(phrases_a, phrases_b, model_name=None)
-        api_results = api_response["results"]
-
-        M = len(phrases_a)
-        N = len(phrases_b)
 
-        # 返回嵌套列表(带完整说明)
+        # 初始化结果矩阵
         results = [[None for _ in range(N)] for _ in range(M)]
-        for idx, api_result in enumerate(api_results):
-            i = idx // N
-            j = idx % N
-            score = float(api_result["score"])
-            results[i][j] = _format_result(score)
+
+        # 如果 phrases_b 超过 batch_size,分批处理
+        if N <= batch_size:
+            # 不需要分批,直接调用
+            api_response = client.cartesian_similarity(phrases_a, phrases_b, model_name=None)
+            api_results = api_response["results"]
+
+            for idx, api_result in enumerate(api_results):
+                i = idx // N
+                j = idx % N
+                score = float(api_result["score"])
+                results[i][j] = _format_result(score)
+        else:
+            # 需要分批处理 phrases_b
+            for batch_start in range(0, N, batch_size):
+                batch_end = min(batch_start + batch_size, N)
+                batch_b = phrases_b[batch_start:batch_end]
+                batch_n = len(batch_b)
+
+                api_response = client.cartesian_similarity(phrases_a, batch_b, model_name=None)
+                api_results = api_response["results"]
+
+                for idx, api_result in enumerate(api_results):
+                    i = idx // batch_n
+                    j = batch_start + (idx % batch_n)
+                    score = float(api_result["score"])
+                    results[i][j] = _format_result(score)
+
         return results
 
     except Exception as e: