BertDemo.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import time
  2. from odps import ODPS
  3. from transformers import BertModel, BertTokenizer
  4. import dashvector
  5. from dashvector import Doc
  6. from typing import List
  7. # 阿里云向量数据库连接
  8. client = dashvector.Client(
  9. api_key='sk-TbWSOiwIcp9FZkx0fyM9JRomTxmOtD796E4626C1411EEB3525A6F9FFB919B')
  10. # 索引集合
  11. collection = client.get('video_title_performance_01')
  12. assert collection
  13. # 阿里云odps连接
  14. access_id = 'LTAIWYUujJAm7CbH'
  15. access_key = 'RfSjdiWwED1sGFlsjXv0DlfTnZTG1P'
  16. endpoint = 'http://service.cn.maxcompute.aliyun.com/api'
  17. project_name = 'loghubods'
  18. odps = ODPS(
  19. access_id=access_id,
  20. secret_access_key=access_key,
  21. project=project_name,
  22. endpoint=endpoint
  23. )
  24. # 加载预训练的BERT模型和对应的tokenizer
  25. model_name = 'bert-base-chinese'
  26. model = BertModel.from_pretrained(model_name)
  27. tokenizer = BertTokenizer.from_pretrained(model_name)
  28. def insert_vector(docs: List[Doc]):
  29. if len(docs) == 0:
  30. return
  31. # 通过dashvector.Doc对象,批量插入
  32. resp = collection.insert(docs=docs)
  33. print(resp)
  34. def text_to_vector(text) -> List[float]:
  35. # 使用tokenizer将文本转化为模型需要的格式,这里我们只取一个文本所以使用encode而非batch_encode
  36. inputs = tokenizer(text, return_tensors='pt')
  37. # 用BERT模型处理输入数据
  38. outputs = model(**inputs)
  39. # 提取嵌入向量
  40. embeddings = outputs.last_hidden_state # 最后一层的隐藏状态
  41. # 将嵌入向量转为NumPy数组
  42. embeddings = embeddings.detach().numpy().tolist()[0][0]
  43. return embeddings
  44. # 查询视频标题的表现(从阿里云odps中查询)
  45. def query_video_title_perfermance(start_idx, limit):
  46. sql = f"SELECT * FROM video_perfermance_info_3 WHERE title is not NULL AND title != '' ORDER BY videoid LIMIT {start_idx}, {limit};"
  47. result = []
  48. with odps.execute_sql(sql).open_reader() as reader:
  49. for record in reader:
  50. # 处理查询结果
  51. result.append(record)
  52. return result
  53. # 将标题表现拼接为向量数据库的向量对象
  54. def video_title_perfermance_to_vector(startIdx, limit) -> List[Doc]:
  55. records = query_video_title_perfermance(startIdx, limit)
  56. docs = []
  57. for record in records:
  58. # 获取字段值
  59. videoid = str(record.videoid)
  60. title = record.title
  61. if title is None:
  62. continue
  63. rntCount = record.回流次数
  64. rntHeadCount = record.回流人数
  65. shareCount = record.分享次数
  66. shareHeadCount = record.分享人数
  67. exposureCount = record.曝光次数
  68. exposureHeadCount = record.曝光人数
  69. playCount = record.播放次数
  70. playHeadCount = record.播放人数
  71. # 将文本转化为向量
  72. vector = text_to_vector(title)
  73. # 将向量和标题表现拼接为Doc对象
  74. doc = Doc(id=videoid, vector=vector, fields={
  75. 'title': title, 'rntCount': rntCount, 'rntHeadCount': rntHeadCount,
  76. 'shareCount': shareCount, 'shareHeadCount': shareHeadCount,
  77. 'exposureCount': exposureCount, 'exposureHeadCount': exposureHeadCount, 'playCount': playCount, 'playHeadCount': playHeadCount
  78. })
  79. docs.append(doc)
  80. return docs
  81. def batchInsert():
  82. for i in range(100000, 185000, 500):
  83. print(i)
  84. # 计算耗时
  85. start = time.time()
  86. docs = video_title_perfermance_to_vector(i, 500)
  87. insert_vector(docs)
  88. end = time.time()
  89. print(f'{i} done in {end - start}s')