Ver código fonte

Rename widedeep_v13_3 to widedeep_v13_4

StrayWarrior 6 dias atrás
pai
commit
29c39056cf
1 arquivos alterados com 12 adições e 2 exclusões
  1. 12 2
      widedeep_v13_4.py

+ 12 - 2
widedeep_v13_3.py → widedeep_v13_4.py

@@ -15,6 +15,8 @@
 9.调整embedding variable (PAI平台目前还不支持)
 """
 
+import os
+os.environ['PROCESSOR_TEST'] = "1"
 import re
 from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
 from easy_rec.python.protos.train_pb2 import TrainConfig
@@ -61,7 +63,8 @@ sparse_features = [
 ]
 tag_features = [
     "user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d",
-    "user_vid_return_tags_7d", "user_vid_return_tags_14d"
+    "user_vid_return_tags_7d", "user_vid_return_tags_14d",
+    "user_conver_ad_class"
 ]
 seq_features = [
     "user_cid_click_list", "user_cid_conver_list"
@@ -74,7 +77,7 @@ input_type_map = {
 }
 
 use_ev_features = [
-    "cid", "adid", "adverid", "vid"
+    # PAI上运行的TF疑似不支持
 ]
 
 bucket_size_map = {
@@ -91,6 +94,7 @@ bucket_size_map = {
     'root_source_channel': 1000,
     'is_first_layer': 100,
     'user_has_conver_1y': 100,
+    'user_conver_ad_class': 10000
 }
 
 def create_config():
@@ -163,6 +167,8 @@ def create_config():
     for name in dense_features:
         feature_config = FeatureConfig()
         feature_config.input_names.append(name)
+        if name not in input_fields:
+            raise Exception(f"{name} not found in input fields")
         feature_config.feature_type = FeatureConfig.RawFeature
         feature_config.boundaries.extend(boundaries)
         feature_config.embedding_dim = 6
@@ -172,6 +178,8 @@ def create_config():
     for name in sparse_features:
         feature_config = FeatureConfig()
         feature_config.input_names.append(name)
+        if name not in input_fields:
+            raise Exception(f"{name} not found in input fields")
         feature_config.feature_type = FeatureConfig.IdFeature
         # 只有INT64类型的特征才能使用embedding variable特性
         if name in use_ev_features:
@@ -187,6 +195,8 @@ def create_config():
     for name in tag_features + seq_features:
         feature_config = FeatureConfig()
         feature_config.input_names.append(name)
+        if name not in input_fields:
+            raise Exception(f"{name} not found in input fields")
         feature_config.feature_type = FeatureConfig.TagFeature
         feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
         feature_config.embedding_dim = 6