app_dssm_0329.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # encoding: utf-8
  2. import base64
  3. import json
  4. # from meinheld import server
  5. import flask
  6. from flask import request, Flask
  7. from flask import Flask
  8. from embedding_manager import EmbeddingManager
  9. from embedding_manager_user import EmbeddingManagerUser
  10. import time
  11. import logging
  12. from logging.handlers import TimedRotatingFileHandler
  13. app = Flask(__name__)
  14. def setLog():
  15. log_fmt = '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
  16. formatter = logging.Formatter(log_fmt)
  17. fh = TimedRotatingFileHandler(filename="log/run_dss_server" + str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + ".log", when="H", interval=1,
  18. backupCount=72)
  19. fh.setFormatter(formatter)
  20. logging.basicConfig(level=logging.INFO)
  21. log = logging.getLogger()
  22. log.addHandler(fh)
  23. setLog()
  24. print("load user embedding")
  25. mgr_user_embedding = EmbeddingManagerUser(
  26. "/work/xielixun/DeepMatch/DSSM/tensorflow_user_embedding-dssm-tzld-210327-2-bak.csv",
  27. "mid", "emb")
  28. print("load video embedding")
  29. mgr_video_embedding = EmbeddingManager(
  30. # "/root/xielixun/ant-learn-recsys/datas/tzld_video_embedding-1106-sort.csv",
  31. "/work/xielixun/DeepMatch/DSSM/tensorflow_video_embedding-dssm-tzld-210327-2-bak.csv",
  32. "videoid", "emb")
  33. @app.route("/")
  34. def index():
  35. return "test"
  36. """
  37. 健康检查
  38. """
  39. @app.route("/healthcheck", methods=['GET'])
  40. def index_health_check():
  41. logging.info("I'm ok")
  42. return "ok"
  43. # 定义路由
  44. @app.route("/ai/v1/user2video", methods=['POST'])
  45. def get_video_by_user_mid2vid():
  46. try:
  47. start_time = time.time()
  48. resParm = flask.request.data
  49. # 转字符串
  50. resParm = str(resParm, encoding="utf-8")
  51. resParm = eval(resParm)
  52. requestId = resParm.get('requestId')
  53. # 服务鉴权
  54. token = resParm.get('token')
  55. if not token:
  56. res = {'code': 3, 'msg': 'token fail'}
  57. logging.error("code: 3 msg: token fail ")
  58. return json.dumps(res)
  59. # 按照debase64进行处理
  60. mid = resParm.get("mid")
  61. vid = resParm.get("vid")
  62. page_size = resParm.get("pageSize")
  63. # 1. 获取该用户的embedding
  64. user_embedding_str = mgr_user_embedding.get_embedding(mid)
  65. user_str = "["
  66. target_video_ids = list()
  67. if user_embedding_str != "":
  68. user_list = user_embedding_str[1:-1].strip('\n').split()
  69. for idx, emb in enumerate(user_list):
  70. if idx < 31:
  71. user_str += emb + ","
  72. else:
  73. user_str += emb + "]"
  74. # 2. 获取该用户看过的电影ID列表
  75. # watch_ids = obj_user_rating.get_user_watched_ids(user_id)
  76. # 3. 使用近邻搜索获取用户可能喜欢的视频ID列表
  77. target_video_ids = mgr_video_embedding.search_ids_by_embedding(user_str,
  78. page_size)
  79. timeUsed = time.time() - start_time
  80. data = {'requestId': requestId, 'videoIds': str(target_video_ids), 'timeUsed': timeUsed, 'mid': mid}
  81. res = {'code': 0, 'msg': 'success', 'data': data}
  82. logging.info(f"code:0 msg:success user2video cost Time is: {str(timeUsed)} ")
  83. return json.dumps(res)
  84. except Exception as x:
  85. logging.exception(x)
  86. res = {'code': 6, 'msg': 'request exception', 'data': {}, 'mid': mid}
  87. return json.dumps(res)
  88. # 定义路由
  89. @app.route("/ai/v1/video2video", methods=['POST'])
  90. def get_video_by_video_vid2vid():
  91. try:
  92. start_time = time.time()
  93. resParm = flask.request.data
  94. # 转字符串
  95. resParm = str(resParm, encoding="utf-8")
  96. resParm = eval(resParm)
  97. requestId = resParm.get('requestId')
  98. # 服务鉴权
  99. token = resParm.get('token')
  100. if not token:
  101. res = {'code': 3, 'msg': 'token fail'}
  102. logging.error("code: 3 msg: token fail ")
  103. return json.dumps(res)
  104. # 按照debase64进行处理
  105. # mid = resParm.get("mid")
  106. vid = resParm.get("vid")
  107. page_size = resParm.get("pageSize")
  108. video_ids = list()
  109. video_str = "["
  110. # target_video_ids = list()
  111. # 查询自己的embedding
  112. video_embedding = mgr_video_embedding.get_embedding(vid)
  113. if video_embedding != "":
  114. video_list = video_embedding[1:-1].strip('\n').split()
  115. for idx, emb in enumerate(video_list):
  116. if idx < 31:
  117. video_str += emb + ","
  118. else:
  119. video_str += emb + "]"
  120. # 查询相似的视频
  121. video_ids = mgr_video_embedding.search_ids_by_embedding(video_str, page_size)
  122. timeUsed = time.time() - start_time
  123. data = {'requestId': requestId, 'videoIds': str(video_ids), 'timeUsed': timeUsed, 'vid': vid}
  124. res = {'code': 0, 'msg': 'success', 'data': data}
  125. logging.info(f"code:0 msg:success video2video cost Time is: {str(timeUsed)} ")
  126. return json.dumps(res)
  127. except Exception as x:
  128. logging.exception(x)
  129. res = {'code': 6, 'msg': 'request exception', 'data': {}, 'vid': vid}
  130. return json.dumps(res)
  131. @app.route("/ai/v1/videolist2video", methods=['POST'])
  132. def get_video_by_video_vidList2vid():
  133. try:
  134. start_time = time.time()
  135. resParm = flask.request.data
  136. # 转字符串
  137. resParm = str(resParm, encoding="utf-8")
  138. resParm = eval(resParm)
  139. requestId = resParm.get('requestId')
  140. # 服务鉴权
  141. token = resParm.get('token')
  142. if not token:
  143. res = {'code': 3, 'msg': 'token fail'}
  144. logging.error("code: 3 msg: token fail ")
  145. return json.dumps(res)
  146. # 按照debase64进行处理
  147. # mid = resParm.get("mid")
  148. vid_str = resParm.get("vidList")
  149. vid_list = vid_str.split(",")
  150. # vid_list = list(map(int, vid_list))
  151. # print("\n\nvid_list is: ")
  152. # print(vid_list)
  153. page_size = resParm.get("pageSize")
  154. video_embedding_list = list()
  155. # 查询自己的embedding
  156. for vid in vid_list:
  157. video_embedding = mgr_video_embedding.get_embedding(vid)
  158. if video_embedding == "":
  159. continue
  160. video_str = "["
  161. if video_embedding != "":
  162. video_list = video_embedding[1:-1].strip('\n').split()
  163. for idx, emb in enumerate(video_list):
  164. if idx < 31:
  165. video_str += emb + ","
  166. else:
  167. video_str += emb + "]"
  168. video_embedding_list.append(json.loads(video_str))
  169. # 查询相似的视频
  170. video_ids = list()
  171. if len(video_embedding_list) > 0:
  172. video_ids = mgr_video_embedding.search_ids_by_embedding_list(video_embedding_list, page_size)
  173. timeUsed = time.time() - start_time
  174. data = {'requestId': requestId, 'videoIds': str(video_ids), 'timeUsed': timeUsed, 'vid': vid}
  175. res = {'code': 0, 'msg': 'success', 'data': data}
  176. logging.info(f"code:0 msg:success Faiss videolist2video cost Time is: {str(timeUsed)} ")
  177. return json.dumps(res)
  178. except Exception as x:
  179. logging.exception(x)
  180. res = {'code': 6, 'msg': 'request exception', 'data': {}, 'vid': vid}
  181. return json.dumps(res)
  182. if __name__ == '__main__':
  183. # 启动服务
  184. # app.run(host="0.0.0.0", port=9996) # test
  185. # app.run(host="0.0.0.0", port=9996) # fm
  186. # app.run(host="0.0.0.0", port=9999) # item2vec
  187. app.run(host="0.0.0.0", port=9997) # dssm