|
@@ -21,22 +21,25 @@ class StaticModel():
|
|
|
|
|
|
def create_feeds(self, is_infer=False):
|
|
def create_feeds(self, is_infer=False):
|
|
# 定义输入数据占位符
|
|
# 定义输入数据占位符
|
|
- # sample_id = paddle.static.data(
|
|
|
|
- # name="sample_id", shape=[-1, 1], dtype='int64')
|
|
|
|
|
|
+
|
|
feeds_list = []
|
|
feeds_list = []
|
|
if not is_infer:
|
|
if not is_infer:
|
|
label = paddle.static.data(
|
|
label = paddle.static.data(
|
|
name="label", shape=[-1, 1], dtype='float32')
|
|
name="label", shape=[-1, 1], dtype='float32')
|
|
- feeds_list.append(label)
|
|
|
|
-
|
|
|
|
- left_features = paddle.static.data(
|
|
|
|
- name="left_features", shape=[-1, self.feature_num], dtype='float32')
|
|
|
|
- feeds_list.append(left_features)
|
|
|
|
- right_features = paddle.static.data(
|
|
|
|
- name="right_features", shape=[-1, self.feature_num], dtype='float32')
|
|
|
|
- feeds_list.append(right_features)
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+ feeds_list.append(label)
|
|
|
|
+ left_features = paddle.static.data(
|
|
|
|
+ name="left_features", shape=[-1, self.feature_num], dtype='float32')
|
|
|
|
+ feeds_list.append(left_features)
|
|
|
|
+ right_features = paddle.static.data(
|
|
|
|
+ name="right_features", shape=[-1, self.feature_num], dtype='float32')
|
|
|
|
+ feeds_list.append(right_features)
|
|
|
|
+ else:
|
|
|
|
+ sample_id = paddle.static.data(
|
|
|
|
+ name="sample_id", shape=[-1, 1], dtype='int64')
|
|
|
|
+ feeds_list.append(label)
|
|
|
|
+ left_features = paddle.static.data(
|
|
|
|
+ name="left_features", shape=[-1, self.feature_num], dtype='float32')
|
|
|
|
+ feeds_list.append(left_features)
|
|
|
|
|
|
|
|
|
|
return feeds_list
|
|
return feeds_list
|
|
@@ -52,10 +55,11 @@ class StaticModel():
|
|
)
|
|
)
|
|
|
|
|
|
if is_infer:
|
|
if is_infer:
|
|
- left_features = input
|
|
|
|
|
|
+ sample_id,left_features = input
|
|
left_vec = dssm_model(left_features,None,is_infer=True)
|
|
left_vec = dssm_model(left_features,None,is_infer=True)
|
|
self.inference_target_var = left_vec
|
|
self.inference_target_var = left_vec
|
|
fetch_dict = {
|
|
fetch_dict = {
|
|
|
|
+ 'sample_id': sample_id
|
|
'left_vector': left_vec
|
|
'left_vector': left_vec
|
|
}
|
|
}
|
|
return fetch_dict
|
|
return fetch_dict
|
|
@@ -68,8 +72,6 @@ class StaticModel():
|
|
self.left_vector = left_vec
|
|
self.left_vector = left_vec
|
|
self.right_vector = right_vec
|
|
self.right_vector = right_vec
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
# 计算损失
|
|
# 计算损失
|
|
# 使用带margin的二元交叉熵损失
|
|
# 使用带margin的二元交叉熵损失
|
|
pos_mask = paddle.cast(label > 0.5, 'float32')
|
|
pos_mask = paddle.cast(label > 0.5, 'float32')
|