rov_train_paddle.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import pandas as pd
  2. import numpy as np
  3. from paddle.io import Dataset
  4. class RovDataset(Dataset):
  5. def clean_data(df):
  6. #y = df['futre7dayreturn'].apply(lambda x: np.log(df['futre7dayreturn']+1))
  7. y = df['futre7dayreturn']
  8. df_vids = df['videoid']
  9. #drop string
  10. #x = df.drop(['videoid', 'videotags', 'videotitle', 'videodescr', 'videodistribute_title', 'videoallwords', 'words_without_tags'], axis=1)
  11. x = df.drop(['videoid', 'videotags', 'words_without_tags', 'dt'], axis=1)
  12. #drop future
  13. #x = df.drop(['futr5viewcount', 'futr5returncount', 'futre7dayreturn'], axis=1)
  14. x = x.drop(['futre7dayreturn'], axis=1)
  15. features = list(x)
  16. drop_features = [f for f in features if (f.find('day30')!=-1 or f.find('day60')!=-1)]
  17. x = x.drop(drop_features, axis=1)
  18. features = [f for f in features if f not in drop_features]
  19. return x, y , df_vids, features
  20. def pack_result(y_, y, vid, fp):
  21. #y_ = y_.astype(int)
  22. y_.reshape(len(y_),1)
  23. df = pd.DataFrame(data=y_, columns=['score'])
  24. if len(vid) >0:
  25. df['vid'] = vid
  26. df['y'] = y
  27. df = df.sort_values(by=['score'], ascending=False)
  28. df.to_csv(fp, index=False)
  29. if __name__ == '__main__':
  30. with open(r"train_data.pickle", "rb") as input_file:
  31. train_data = cPickle.load(input_file)
  32. with open(r"predict_data.pickle", "rb") as input_file:
  33. predict_data = cPickle.load(input_file)
  34. #train
  35. x,y,_,features = clean_data(train_data)
  36. _, model, _ = train(x, y, features)
  37. with open('model.pickle','wb') as output_file:
  38. cPickle.dump(model, output_file)
  39. '''
  40. with open(r"model.pickle", "rb") as input_file:
  41. model = cPickle.load(input_file)
  42. '''
  43. x,y,vid,_ = clean_data(predict_data)
  44. y_ = model.predict(x, num_iteration=model.best_iteration)
  45. pack_result(y_, y, vid, 'pred.csv')