lr_model.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. #coding utf-8
  2. import json
  3. import math
  4. def load_json(filename):
  5. with open(filename, 'r') as fin:
  6. json_data = json.load(fin)
  7. return json_data
  8. def wx(w_dict, kv):
  9. k, v = kv
  10. w = w_dict.get(k, 0.0)
  11. return w * v
  12. def sigmoid(x):
  13. return 1.0 / (1.0 + math.exp(-x))
  14. def libsvm_row_to_features(row):
  15. items = row.strip().split(' ')
  16. label = items[0]
  17. features = {}
  18. for kv in items[1:]:
  19. k, v = kv.split(':')
  20. features[k] = float(v)
  21. return label, features
  22. class LrModel:
  23. def __init__(self, w_json_file):
  24. self.w_dict = load_json(w_json_file)
  25. def predict_h(self, features):
  26. h = sum(map(lambda x: wx(self.w_dict, x), features.items()))
  27. return h
  28. def predict(self, features):
  29. bias = self.w_dict.get('bias', 0.0)
  30. h = self.predict_h(features)
  31. score = sigmoid(h + bias)
  32. return score
  33. def test():
  34. lr_model = LrModel('model/ad_out_v2_model_v1.day.json')
  35. rows = [
  36. (0.279004,'0 u_brand#vivo:1 u_device#V1829A:1 u_system#Android:1 u_system_ver#Android10:1 i_id#17015839:1 i_up_id#24811642:1 i_title_len#5:1 i_play_len#8:1 i_days_since_upload#4:1 ctx_week#3:1 ctx_hour#8:1 ctx_region#山西:1 ctx_city#临汾:1 u_3month_exp_cnt#4:1 u_3month_click_cnt#4:1 u_3month_share_cnt#2:1 u_3month_return_cnt#6:1 i_1day_exp_cnt#16:1 i_1day_click_cnt#15:1 i_1day_share_cnt#12:1 i_1day_return_cnt#15:1 i_3day_exp_cnt#18:1 i_3day_click_cnt#17:1 i_3day_share_cnt#14:1 i_3day_return_cnt#15:1 i_7day_exp_cnt#18:1 i_7day_click_cnt#18:1 i_7day_share_cnt#14:1 i_7day_return_cnt#15:1 i_3month_exp_cnt#19:1 i_3month_click_cnt#18:1 i_3month_share_cnt#14:1 i_3month_return_cnt#16:1 u_ctr_3month:0.066667 u_str_3month:0.0375 u_rov_3month:0.283333 u_ros_3month:1.0 i_ctr_1day:0.070518 i_str_1day:0.007219 i_rov_1day:0.044376 i_ros_1day:0.871681 i_ctr_3day:0.070359 i_str_3day:0.007297 i_rov_3day:0.017242 i_ros_3day:0.335814 i_ctr_7day:0.070245 i_str_7day:0.007044 i_rov_7day:0.012268 i_ros_7day:0.247943 i_ctr_3month:0.06989 i_str_3month:0.007203 i_rov_3month:0.012624 i_ros_3month:0.250784'),
  37. (0.454255,'1 u_brand#vivo:1 u_device#V2230A:1 u_system#Android:1 u_system_ver#Android13:1 i_id#17141266:1 i_up_id#65303321:1 i_title_len#5:1 i_play_len#8:1 i_days_since_upload#3:1 ctx_week#3:1 ctx_hour#17:1 ctx_region#河北:1 ctx_city#邯郸:1 u_3month_exp_cnt#3:1 u_3month_click_cnt#3:1 i_1day_exp_cnt#19:1 i_1day_click_cnt#19:1 i_1day_share_cnt#16:1 i_1day_return_cnt#19:1 i_3day_exp_cnt#21:1 i_3day_click_cnt#21:1 i_3day_share_cnt#18:1 i_3day_return_cnt#20:1 i_7day_exp_cnt#22:1 i_7day_click_cnt#21:1 i_7day_share_cnt#18:1 i_7day_return_cnt#20:1 i_3month_exp_cnt#22:1 i_3month_click_cnt#21:1 i_3month_share_cnt#18:1 i_3month_return_cnt#20:1 u_ctr_3month:0.1 i_ctr_1day:0.078281 i_str_1day:0.009801 i_rov_1day:0.066601 i_ros_1day:0.868055 i_ctr_3day:0.076168 i_str_3day:0.010706 i_rov_3day:0.036384 i_ros_3day:0.446198 i_ctr_7day:0.07423 i_str_7day:0.011886 i_rov_7day:0.022732 i_ros_7day:0.257645 i_ctr_3month:0.07423 i_str_3month:0.011886 i_rov_3month:0.022732 i_ros_3month:0.257645'),
  38. (lr_model.predict({}), '0 dsiaod:1 dsaodadsa:1.2'),
  39. ]
  40. for std_score, row in rows:
  41. label, features = libsvm_row_to_features(row)
  42. score = lr_model.predict(features)
  43. score_diff = std_score - score
  44. print(std_score, score, score_diff)
  45. assert(abs(score_diff) < 10e-6)
  46. if __name__ == '__main__':
  47. test()