accountServer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. """
  2. @author: luojunhui
  3. """
  4. import json
  5. import aiohttp
  6. from applications.articleTools import ArticleDBTools
  7. from applications.nlp_task import get_nlp_similarity_score
  8. from applications.config import port
  9. from applications.textSimilarity import NLPFunction
  10. class AccountServer(object):
  11. """
  12. 获取标题和公众号文章的相关性
  13. """
  14. def __init__(self, mysql_client, params, model, embedding_manager):
  15. self.account_name_list = None
  16. self.gh_id_list = None
  17. self.sim_type = None
  18. self.interest_type = None
  19. self.min_time = None
  20. self.max_time = None
  21. self.rate = None
  22. self.title_list = None
  23. self.view_count_filter = None
  24. self.use_cache = True
  25. self.params = params
  26. self.AT = ArticleDBTools(mysql_client)
  27. self.nlp = NLPFunction(model=model, embedding_manager=embedding_manager)
  28. async def request_for_nlp(self, title_list, account_interest, interest_weight):
  29. """
  30. nlp process
  31. """
  32. headers = {"Content-Type": "application/json"}
  33. url = "http://localhost:{}/nlp".format(port)
  34. body = {
  35. "data": {
  36. "text_list_a": [i.replace("'", "") for i in title_list],
  37. "text_list_b": [i.replace("'", "") for i in account_interest],
  38. "score_list_b": interest_weight,
  39. "symbol": 1,
  40. },
  41. "function": "similarities_cross_mean" if self.sim_type == "mean" else "similarities_cross_avg"
  42. }
  43. async with aiohttp.ClientSession() as session:
  44. async with session.post(url, headers=headers, json=body) as response:
  45. response_text = await response.text()
  46. # print("结果:\t", response_text)
  47. if response_text:
  48. return await response.json()
  49. else:
  50. print("Received empty response")
  51. return {}
  52. def check_params(self):
  53. """
  54. 校验传参
  55. :return:
  56. """
  57. try:
  58. self.title_list = self.params["text_list"]
  59. self.account_name_list = self.params.get("account_nickname_list", [])
  60. self.gh_id_list = self.params.get("gh_id_list", [])
  61. self.rate = self.params.get("rate", 0.1)
  62. self.max_time = self.params.get("max_time")
  63. self.min_time = self.params.get("min_time")
  64. self.interest_type = self.params.get("interest_type", "top")
  65. self.sim_type = self.params.get("sim_type", "mean")
  66. self.view_count_filter = self.params.get("view_count_filter", None)
  67. self.use_cache = self.params.get("use_cache", True)
  68. return None
  69. except Exception as e:
  70. response = {"error": "Params error", "detail": str(e)}
  71. return response
  72. async def get_account_interest(
  73. self,
  74. gh_id,
  75. interest_type,
  76. view_count_filter,
  77. rate=None,
  78. msg_type=None,
  79. index_list=None,
  80. min_time=None,
  81. max_time=None,
  82. ):
  83. """
  84. 获取账号的兴趣类型
  85. :param gh_id:
  86. :param max_time:
  87. :param min_time:
  88. :param index_list:
  89. :param msg_type:
  90. :param rate:
  91. :param interest_type:
  92. :param view_count_filter:
  93. :return:
  94. """
  95. good_df, bad_df = await self.AT.get_good_bad_articles(
  96. gh_id=gh_id,
  97. interest_type=interest_type,
  98. msg_type=msg_type,
  99. index_list=index_list,
  100. min_time=min_time,
  101. max_time=max_time,
  102. rate=rate,
  103. view_count_filter=view_count_filter,
  104. )
  105. extend_dicts = {
  106. 'view_count': good_df["show_view_count"].values.tolist(),
  107. }
  108. if 'view_count_avg' in good_df.columns:
  109. extend_dicts['view_count_rate'] = \
  110. (good_df["show_view_count"] / good_df["view_count_avg"]).values.tolist()
  111. account_interest = good_df["title"].values.tolist()
  112. return account_interest, extend_dicts
  113. async def get_each_account_score_list(self, gh_id):
  114. """
  115. 获取和单个账号的相关性分数
  116. :return:
  117. """
  118. try:
  119. account_interest, extend_dicts = await self.get_account_interest(
  120. gh_id=gh_id,
  121. interest_type=self.interest_type,
  122. rate=self.rate,
  123. view_count_filter=self.view_count_filter,
  124. min_time=self.min_time,
  125. max_time=self.max_time,
  126. )
  127. interest_weight = extend_dicts['view_count']
  128. if self.sim_type == "weighted_by_view_count_rate":
  129. interest_weight = extend_dicts['view_count_rate']
  130. data = {
  131. "text_list_a": [i.replace("'", "") for i in self.title_list],
  132. "text_list_b": [i.replace("'", "") for i in account_interest],
  133. "score_list_b": interest_weight,
  134. "symbol": 1,
  135. },
  136. function = "similarities_cross_mean" if self.sim_type == "mean" else "similarities_cross_avg"
  137. response = await get_nlp_similarity_score(
  138. nlp=self.nlp,
  139. function=function,
  140. data=data,
  141. use_cache=self.use_cache
  142. )
  143. score_list_key = "score_list_mean" if self.sim_type == "mean" else "score_list_avg"
  144. return {
  145. "score_list": response[score_list_key],
  146. "text_list_max": response["text_list_max"],
  147. }
  148. except Exception as e:
  149. print(e)
  150. return {
  151. "score_list": [0] * len(self.title_list),
  152. "text_list_max": self.title_list,
  153. }
  154. async def get_account_list_score_list(self):
  155. """
  156. 获取AccountList中每一个账号的相关性分数
  157. :return:
  158. """
  159. response = {}
  160. for gh_id in self.gh_id_list:
  161. if response.get(gh_id):
  162. continue
  163. else:
  164. response[gh_id] = await self.get_each_account_score_list(gh_id=gh_id)
  165. return response
  166. async def deal(self):
  167. """
  168. Deal Function
  169. :return:
  170. """
  171. return (
  172. self.check_params() if self.check_params() else await self.get_account_list_score_list()
  173. )