|
@@ -15,6 +15,8 @@
|
|
|
9.调整embedding variable (PAI平台目前还不支持)
|
|
|
"""
|
|
|
|
|
|
+import os
|
|
|
+os.environ['PROCESSOR_TEST'] = "1"
|
|
|
import re
|
|
|
from easy_rec.python.protos.pipeline_pb2 import EasyRecConfig
|
|
|
from easy_rec.python.protos.train_pb2 import TrainConfig
|
|
@@ -61,7 +63,8 @@ sparse_features = [
|
|
|
]
|
|
|
tag_features = [
|
|
|
"user_vid_return_tags_2h", "user_vid_return_tags_1d", "user_vid_return_tags_3d",
|
|
|
- "user_vid_return_tags_7d", "user_vid_return_tags_14d"
|
|
|
+ "user_vid_return_tags_7d", "user_vid_return_tags_14d",
|
|
|
+ "user_conver_ad_class"
|
|
|
]
|
|
|
seq_features = [
|
|
|
"user_cid_click_list", "user_cid_conver_list"
|
|
@@ -74,7 +77,7 @@ input_type_map = {
|
|
|
}
|
|
|
|
|
|
use_ev_features = [
|
|
|
- "cid", "adid", "adverid", "vid"
|
|
|
+ # PAI上运行的TF疑似不支持
|
|
|
]
|
|
|
|
|
|
bucket_size_map = {
|
|
@@ -91,6 +94,7 @@ bucket_size_map = {
|
|
|
'root_source_channel': 1000,
|
|
|
'is_first_layer': 100,
|
|
|
'user_has_conver_1y': 100,
|
|
|
+ 'user_conver_ad_class': 10000
|
|
|
}
|
|
|
|
|
|
def create_config():
|
|
@@ -163,6 +167,8 @@ def create_config():
|
|
|
for name in dense_features:
|
|
|
feature_config = FeatureConfig()
|
|
|
feature_config.input_names.append(name)
|
|
|
+ if name not in input_fields:
|
|
|
+ raise Exception(f"{name} not found in input fields")
|
|
|
feature_config.feature_type = FeatureConfig.RawFeature
|
|
|
feature_config.boundaries.extend(boundaries)
|
|
|
feature_config.embedding_dim = 6
|
|
@@ -172,6 +178,8 @@ def create_config():
|
|
|
for name in sparse_features:
|
|
|
feature_config = FeatureConfig()
|
|
|
feature_config.input_names.append(name)
|
|
|
+ if name not in input_fields:
|
|
|
+ raise Exception(f"{name} not found in input fields")
|
|
|
feature_config.feature_type = FeatureConfig.IdFeature
|
|
|
# 只有INT64类型的特征才能使用embedding variable特性
|
|
|
if name in use_ev_features:
|
|
@@ -187,6 +195,8 @@ def create_config():
|
|
|
for name in tag_features + seq_features:
|
|
|
feature_config = FeatureConfig()
|
|
|
feature_config.input_names.append(name)
|
|
|
+ if name not in input_fields:
|
|
|
+ raise Exception(f"{name} not found in input fields")
|
|
|
feature_config.feature_type = FeatureConfig.TagFeature
|
|
|
feature_config.hash_bucket_size = bucket_size_map.get(name, 1000000)
|
|
|
feature_config.embedding_dim = 6
|