|  | @@ -5,7 +5,9 @@ import json
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import aiohttp
 |  |  import aiohttp
 | 
											
												
													
														|  |  from applications.articleTools import ArticleDBTools
 |  |  from applications.articleTools import ArticleDBTools
 | 
											
												
													
														|  | 
 |  | +from applications.nlp_task import get_nlp_similarity_score
 | 
											
												
													
														|  |  from applications.config import port
 |  |  from applications.config import port
 | 
											
												
													
														|  | 
 |  | +from applications.textSimilarity import NLPFunction
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  class AccountServer(object):
 |  |  class AccountServer(object):
 | 
											
										
											
												
													
														|  | @@ -13,7 +15,7 @@ class AccountServer(object):
 | 
											
												
													
														|  |      获取标题和公众号文章的相关性
 |  |      获取标题和公众号文章的相关性
 | 
											
												
													
														|  |      """
 |  |      """
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    def __init__(self, mysql_client, params):
 |  | 
 | 
											
												
													
														|  | 
 |  | +    def __init__(self, mysql_client, params, model, embedding_manager):
 | 
											
												
													
														|  |          self.account_name_list = None
 |  |          self.account_name_list = None
 | 
											
												
													
														|  |          self.gh_id_list = None
 |  |          self.gh_id_list = None
 | 
											
												
													
														|  |          self.sim_type = None
 |  |          self.sim_type = None
 | 
											
										
											
												
													
														|  | @@ -23,8 +25,10 @@ class AccountServer(object):
 | 
											
												
													
														|  |          self.rate = None
 |  |          self.rate = None
 | 
											
												
													
														|  |          self.title_list = None
 |  |          self.title_list = None
 | 
											
												
													
														|  |          self.view_count_filter = None
 |  |          self.view_count_filter = None
 | 
											
												
													
														|  | 
 |  | +        self.use_cache = True
 | 
											
												
													
														|  |          self.params = params
 |  |          self.params = params
 | 
											
												
													
														|  |          self.AT = ArticleDBTools(mysql_client)
 |  |          self.AT = ArticleDBTools(mysql_client)
 | 
											
												
													
														|  | 
 |  | +        self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      async def request_for_nlp(self, title_list, account_interest, interest_weight):
 |  |      async def request_for_nlp(self, title_list, account_interest, interest_weight):
 | 
											
												
													
														|  |          """
 |  |          """
 | 
											
										
											
												
													
														|  | @@ -66,6 +70,7 @@ class AccountServer(object):
 | 
											
												
													
														|  |              self.interest_type = self.params.get("interest_type", "top")
 |  |              self.interest_type = self.params.get("interest_type", "top")
 | 
											
												
													
														|  |              self.sim_type = self.params.get("sim_type", "mean")
 |  |              self.sim_type = self.params.get("sim_type", "mean")
 | 
											
												
													
														|  |              self.view_count_filter = self.params.get("view_count_filter", None)
 |  |              self.view_count_filter = self.params.get("view_count_filter", None)
 | 
											
												
													
														|  | 
 |  | +            self.use_cache = self.params.get("use_cache", True)
 | 
											
												
													
														|  |              return None
 |  |              return None
 | 
											
												
													
														|  |          except Exception as e:
 |  |          except Exception as e:
 | 
											
												
													
														|  |              response = {"error": "Params error", "detail": str(e)}
 |  |              response = {"error": "Params error", "detail": str(e)}
 | 
											
										
											
												
													
														|  | @@ -131,11 +136,21 @@ class AccountServer(object):
 | 
											
												
													
														|  |              interest_weight = extend_dicts['view_count']
 |  |              interest_weight = extend_dicts['view_count']
 | 
											
												
													
														|  |              if self.sim_type == "weighted_by_view_count_rate":
 |  |              if self.sim_type == "weighted_by_view_count_rate":
 | 
											
												
													
														|  |                  interest_weight = extend_dicts['view_count_rate']
 |  |                  interest_weight = extend_dicts['view_count_rate']
 | 
											
												
													
														|  | -            response = await self.request_for_nlp(
 |  | 
 | 
											
												
													
														|  | -                title_list=self.title_list,
 |  | 
 | 
											
												
													
														|  | -                account_interest=account_interest,
 |  | 
 | 
											
												
													
														|  | -                interest_weight=interest_weight
 |  | 
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            data = {
 | 
											
												
													
														|  | 
 |  | +                "text_list_a": [i.replace("'", "") for i in self.title_list],
 | 
											
												
													
														|  | 
 |  | +                "text_list_b": [i.replace("'", "") for i in account_interest],
 | 
											
												
													
														|  | 
 |  | +                "score_list_b": interest_weight,
 | 
											
												
													
														|  | 
 |  | +                "symbol": 1,
 | 
											
												
													
														|  | 
 |  | +            },
 | 
											
												
													
														|  | 
 |  | +            function = "similarities_cross_mean" if self.sim_type == "mean" else "similarities_cross_avg"
 | 
											
												
													
														|  | 
 |  | +            response = await get_nlp_similarity_score(
 | 
											
												
													
														|  | 
 |  | +                nlp=self.nlp,
 | 
											
												
													
														|  | 
 |  | +                function=function,
 | 
											
												
													
														|  | 
 |  | +                data=data,
 | 
											
												
													
														|  | 
 |  | +                use_cache=self.use_cache
 | 
											
												
													
														|  |              )
 |  |              )
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |              score_list_key = "score_list_mean" if self.sim_type == "mean" else "score_list_avg"
 |  |              score_list_key = "score_list_mean" if self.sim_type == "mean" else "score_list_avg"
 | 
											
												
													
														|  |              return {
 |  |              return {
 | 
											
												
													
														|  |                  "score_list": response[score_list_key],
 |  |                  "score_list": response[score_list_key],
 |