#coding utf-8
import json
import math

def load_json(filename):
    with open(filename, 'r') as fin:
        json_data = json.load(fin)
    return json_data

def wx(w_dict, kv):
    k, v = kv
    w = w_dict.get(k, 0.0)
    return w * v

def sigmoid(x):
  return 1.0 / (1.0 + math.exp(-x))

def libsvm_row_to_features(row):
    items = row.strip().split(' ')
    label = items[0]
    features = {}
    for kv in items[1:]:
        k, v = kv.split(':')
        features[k] = float(v)
    return label, features

class LrModel:
    def __init__(self, w_json_file):
        self.w_dict = load_json(w_json_file)
    
    def predict_h(self, features):
        h = sum(map(lambda x: wx(self.w_dict, x), features.items()))
        return h

    def predict(self, features):
        bias = self.w_dict.get('bias', 0.0)
        h = self.predict_h(features)
        score = sigmoid(h + bias)
        return score

def test():
    lr_model = LrModel('model/ad_out_v2_model_v1.day.json')
    rows = [
        (0.279004,'0 u_brand#vivo:1 u_device#V1829A:1 u_system#Android:1 u_system_ver#Android10:1 i_id#17015839:1 i_up_id#24811642:1 i_title_len#5:1 i_play_len#8:1 i_days_since_upload#4:1 ctx_week#3:1 ctx_hour#8:1 ctx_region#山西:1 ctx_city#临汾:1 u_3month_exp_cnt#4:1 u_3month_click_cnt#4:1 u_3month_share_cnt#2:1 u_3month_return_cnt#6:1 i_1day_exp_cnt#16:1 i_1day_click_cnt#15:1 i_1day_share_cnt#12:1 i_1day_return_cnt#15:1 i_3day_exp_cnt#18:1 i_3day_click_cnt#17:1 i_3day_share_cnt#14:1 i_3day_return_cnt#15:1 i_7day_exp_cnt#18:1 i_7day_click_cnt#18:1 i_7day_share_cnt#14:1 i_7day_return_cnt#15:1 i_3month_exp_cnt#19:1 i_3month_click_cnt#18:1 i_3month_share_cnt#14:1 i_3month_return_cnt#16:1 u_ctr_3month:0.066667 u_str_3month:0.0375 u_rov_3month:0.283333 u_ros_3month:1.0 i_ctr_1day:0.070518 i_str_1day:0.007219 i_rov_1day:0.044376 i_ros_1day:0.871681 i_ctr_3day:0.070359 i_str_3day:0.007297 i_rov_3day:0.017242 i_ros_3day:0.335814 i_ctr_7day:0.070245 i_str_7day:0.007044 i_rov_7day:0.012268 i_ros_7day:0.247943 i_ctr_3month:0.06989 i_str_3month:0.007203 i_rov_3month:0.012624 i_ros_3month:0.250784'),
        (0.454255,'1 u_brand#vivo:1 u_device#V2230A:1 u_system#Android:1 u_system_ver#Android13:1 i_id#17141266:1 i_up_id#65303321:1 i_title_len#5:1 i_play_len#8:1 i_days_since_upload#3:1 ctx_week#3:1 ctx_hour#17:1 ctx_region#河北:1 ctx_city#邯郸:1 u_3month_exp_cnt#3:1 u_3month_click_cnt#3:1 i_1day_exp_cnt#19:1 i_1day_click_cnt#19:1 i_1day_share_cnt#16:1 i_1day_return_cnt#19:1 i_3day_exp_cnt#21:1 i_3day_click_cnt#21:1 i_3day_share_cnt#18:1 i_3day_return_cnt#20:1 i_7day_exp_cnt#22:1 i_7day_click_cnt#21:1 i_7day_share_cnt#18:1 i_7day_return_cnt#20:1 i_3month_exp_cnt#22:1 i_3month_click_cnt#21:1 i_3month_share_cnt#18:1 i_3month_return_cnt#20:1 u_ctr_3month:0.1 i_ctr_1day:0.078281 i_str_1day:0.009801 i_rov_1day:0.066601 i_ros_1day:0.868055 i_ctr_3day:0.076168 i_str_3day:0.010706 i_rov_3day:0.036384 i_ros_3day:0.446198 i_ctr_7day:0.07423 i_str_7day:0.011886 i_rov_7day:0.022732 i_ros_7day:0.257645 i_ctr_3month:0.07423 i_str_3month:0.011886 i_rov_3month:0.022732 i_ros_3month:0.257645'),
        (lr_model.predict({}), '0 dsiaod:1 dsaodadsa:1.2'),
    ]
    for std_score, row in rows:
        label, features = libsvm_row_to_features(row)
        score = lr_model.predict(features)
        score_diff = std_score - score
        print(std_score, score, score_diff)
        assert(abs(score_diff) < 10e-6)

if __name__ == '__main__':
    test()