BertDemo.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import time
  2. from odps import ODPS
  3. from transformers import BertModel, BertTokenizer
  4. import dashvector
  5. from dashvector import Doc
  6. from typing import List
  7. from pandas import DataFrame
  8. import pandas as pd
  9. # 阿里云向量数据库连接
  10. client = dashvector.Client(
  11. api_key='sk-TbWSOiwIcp9FZkx0fyM9JRomTxmOtD796E4626C1411EEB3525A6F9FFB919B')
  12. # 索引集合
  13. collection = client.get('video_title_performance_01')
  14. assert collection
  15. # 阿里云odps连接
  16. access_id = 'LTAIWYUujJAm7CbH'
  17. access_key = 'RfSjdiWwED1sGFlsjXv0DlfTnZTG1P'
  18. endpoint = 'http://service.cn.maxcompute.aliyun.com/api'
  19. project_name = 'loghubods'
  20. odps = ODPS(
  21. access_id=access_id,
  22. secret_access_key=access_key,
  23. project=project_name,
  24. endpoint=endpoint
  25. )
  26. # 加载预训练的BERT模型和对应的tokenizer
  27. model_name = 'bert-base-chinese'
  28. model = BertModel.from_pretrained(model_name)
  29. tokenizer = BertTokenizer.from_pretrained(model_name)
  30. def insert_vector(docs: List[Doc]):
  31. if len(docs) == 0:
  32. return
  33. # 通过dashvector.Doc对象,批量插入
  34. resp = collection.insert(docs=docs)
  35. print(resp)
  36. def text_to_vector(text) -> List[float]:
  37. # 使用tokenizer将文本转化为模型需要的格式,这里我们只取一个文本所以使用encode而非batch_encode
  38. inputs = tokenizer(text, return_tensors='pt')
  39. # 用BERT模型处理输入数据
  40. outputs = model(**inputs)
  41. # 提取嵌入向量
  42. embeddings = outputs.last_hidden_state # 最后一层的隐藏状态
  43. # 将嵌入向量转为NumPy数组
  44. embeddings = embeddings.detach().numpy().tolist()[0][0]
  45. return embeddings
  46. # 查询视频标题的表现(从阿里云odps中查询)
  47. def query_video_title_perfermance(start_idx, limit):
  48. sql = f"SELECT * FROM video_perfermance_info_3 WHERE title is not NULL AND title != '' ORDER BY videoid LIMIT {start_idx}, {limit};"
  49. result = []
  50. with odps.execute_sql(sql).open_reader() as reader:
  51. for record in reader:
  52. # 处理查询结果
  53. result.append(record)
  54. return result
  55. # 将标题表现拼接为向量数据库的向量对象
  56. def video_title_perfermance_to_vector(startIdx, limit) -> List[Doc]:
  57. records = query_video_title_perfermance(startIdx, limit)
  58. docs = []
  59. for record in records:
  60. # 获取字段值
  61. videoid = str(record.videoid)
  62. title = record.title
  63. if title is None:
  64. continue
  65. rntCount = record.回流次数
  66. rntHeadCount = record.回流人数
  67. shareCount = record.分享次数
  68. shareHeadCount = record.分享人数
  69. exposureCount = record.曝光次数
  70. exposureHeadCount = record.曝光人数
  71. playCount = record.播放次数
  72. playHeadCount = record.播放人数
  73. # 将文本转化为向量
  74. vector = text_to_vector(title)
  75. # 将向量和标题表现拼接为Doc对象
  76. doc = Doc(id=videoid, vector=vector, fields={
  77. 'title': title, 'rntCount': rntCount, 'rntHeadCount': rntHeadCount,
  78. 'shareCount': shareCount, 'shareHeadCount': shareHeadCount,
  79. 'exposureCount': exposureCount, 'exposureHeadCount': exposureHeadCount, 'playCount': playCount, 'playHeadCount': playHeadCount
  80. })
  81. docs.append(doc)
  82. return docs
  83. def batchInsert():
  84. for i in range(100000, 185000, 500):
  85. print(i)
  86. # 计算耗时
  87. start = time.time()
  88. docs = video_title_perfermance_to_vector(i, 500)
  89. insert_vector(docs)
  90. end = time.time()
  91. print(f'{i} done in {end - start}s')