|
@@ -126,9 +126,18 @@ public class PAIModelV1 {
|
|
|
public List<Float> score(final List<AdRankItem> items,
|
|
public List<Float> score(final List<AdRankItem> items,
|
|
|
final Map<String, String> userFeatureMap,
|
|
final Map<String, String> userFeatureMap,
|
|
|
final Map<String, String> sceneFeatureMap) {
|
|
final Map<String, String> sceneFeatureMap) {
|
|
|
|
|
+ long totalStart = System.currentTimeMillis();
|
|
|
|
|
+ long buildUserFeatureTime = 0;
|
|
|
|
|
+ long buildItemFeatureTime = 0;
|
|
|
|
|
+ long buildRequestTime = 0;
|
|
|
|
|
+ long predictTime = 0;
|
|
|
|
|
+ long parseResponseTime = 0;
|
|
|
|
|
+
|
|
|
try {
|
|
try {
|
|
|
TFRequest request = new TFRequest();
|
|
TFRequest request = new TFRequest();
|
|
|
|
|
|
|
|
|
|
+ // 阶段1: 构建用户特征
|
|
|
|
|
+ long stageStart = System.currentTimeMillis();
|
|
|
for (String feature : sparseUserStrFeatures) {
|
|
for (String feature : sparseUserStrFeatures) {
|
|
|
String v = userFeatureMap.getOrDefault(feature, "");
|
|
String v = userFeatureMap.getOrDefault(feature, "");
|
|
|
request.addFeed(feature, TFDataType.DT_STRING, new long[]{1}, new String[]{v});
|
|
request.addFeed(feature, TFDataType.DT_STRING, new long[]{1}, new String[]{v});
|
|
@@ -144,11 +153,14 @@ public class PAIModelV1 {
|
|
|
request.addFeed(feature, TFDataType.DT_INT64, new long[]{1}, new long[]{v});
|
|
request.addFeed(feature, TFDataType.DT_INT64, new long[]{1}, new long[]{v});
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
for (String feature : userFeatures) {
|
|
for (String feature : userFeatures) {
|
|
|
double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(feature, "0.0"), 0.0);
|
|
double v = NumberUtils.toDouble(userFeatureMap.getOrDefault(feature, "0.0"), 0.0);
|
|
|
request.addFeed(feature.toLowerCase(), TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
|
|
request.addFeed(feature.toLowerCase(), TFDataType.DT_DOUBLE, new long[]{1}, new double[]{v});
|
|
|
}
|
|
}
|
|
|
|
|
+ buildUserFeatureTime = System.currentTimeMillis() - stageStart;
|
|
|
|
|
+
|
|
|
|
|
+ // 阶段2: 构建广告Item特征
|
|
|
|
|
+ stageStart = System.currentTimeMillis();
|
|
|
Map<String, double[]> doubleFeed = new HashMap<>();
|
|
Map<String, double[]> doubleFeed = new HashMap<>();
|
|
|
Map<String, long[]> longFeed = new HashMap<>();
|
|
Map<String, long[]> longFeed = new HashMap<>();
|
|
|
Map<String, String[]> strFeed = new HashMap<>();
|
|
Map<String, String[]> strFeed = new HashMap<>();
|
|
@@ -184,6 +196,10 @@ public class PAIModelV1 {
|
|
|
strs[i] = v;
|
|
strs[i] = v;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ buildItemFeatureTime = System.currentTimeMillis() - stageStart;
|
|
|
|
|
+
|
|
|
|
|
+ // 阶段3: 构建请求体
|
|
|
|
|
+ stageStart = System.currentTimeMillis();
|
|
|
for (Map.Entry<String, double[]> entry : doubleFeed.entrySet()) {
|
|
for (Map.Entry<String, double[]> entry : doubleFeed.entrySet()) {
|
|
|
request.addFeed(entry.getKey(), TFDataType.DT_DOUBLE, new long[]{items.size()}, entry.getValue());
|
|
request.addFeed(entry.getKey(), TFDataType.DT_DOUBLE, new long[]{items.size()}, entry.getValue());
|
|
|
}
|
|
}
|
|
@@ -196,13 +212,35 @@ public class PAIModelV1 {
|
|
|
request.addFeed(entry.getKey(), TFDataType.DT_STRING, new long[]{items.size()}, entry.getValue());
|
|
request.addFeed(entry.getKey(), TFDataType.DT_STRING, new long[]{items.size()}, entry.getValue());
|
|
|
}
|
|
}
|
|
|
request.addFetch("probs");
|
|
request.addFetch("probs");
|
|
|
|
|
+ buildRequestTime = System.currentTimeMillis() - stageStart;
|
|
|
|
|
+
|
|
|
|
|
+ // 阶段4: PAI-EAS 远程调用
|
|
|
|
|
+ stageStart = System.currentTimeMillis();
|
|
|
TFResponse response = client.predict(request);
|
|
TFResponse response = client.predict(request);
|
|
|
|
|
+ predictTime = System.currentTimeMillis() - stageStart;
|
|
|
|
|
+
|
|
|
|
|
+ // 阶段5: 解析响应
|
|
|
|
|
+ stageStart = System.currentTimeMillis();
|
|
|
List<Float> result = response.getFloatVals("probs");
|
|
List<Float> result = response.getFloatVals("probs");
|
|
|
|
|
+ parseResponseTime = System.currentTimeMillis() - stageStart;
|
|
|
|
|
+
|
|
|
|
|
+ long totalTime = System.currentTimeMillis() - totalStart;
|
|
|
|
|
+
|
|
|
|
|
+ // 输出耗时分析日志
|
|
|
|
|
+ LOGGER.info("PAIModelV1.score cost: total={}ms, itemSize={}, buildUserFeature={}ms, " +
|
|
|
|
|
+ "buildItemFeature={}ms, buildRequest={}ms, predict={}ms, parseResponse={}ms",
|
|
|
|
|
+ totalTime, items.size(), buildUserFeatureTime, buildItemFeatureTime,
|
|
|
|
|
+ buildRequestTime, predictTime, parseResponseTime);
|
|
|
|
|
+
|
|
|
if (!CollectionUtils.isEmpty(result)) {
|
|
if (!CollectionUtils.isEmpty(result)) {
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
} catch (Exception e) {
|
|
} catch (Exception e) {
|
|
|
- LOGGER.error("PAIModel score error", e);
|
|
|
|
|
|
|
+ long totalTime = System.currentTimeMillis() - totalStart;
|
|
|
|
|
+ LOGGER.error("PAIModelV1.score error: total={}ms, itemSize={}, buildUserFeature={}ms, " +
|
|
|
|
|
+ "buildItemFeature={}ms, buildRequest={}ms, predict={}ms, parseResponse={}ms",
|
|
|
|
|
+ totalTime, items.size(), buildUserFeatureTime, buildItemFeatureTime,
|
|
|
|
|
+ buildRequestTime, predictTime, parseResponseTime, e);
|
|
|
}
|
|
}
|
|
|
return new ArrayList<>(Collections.nCopies(items.size(), 0.0f));
|
|
return new ArrayList<>(Collections.nCopies(items.size(), 0.0f));
|
|
|
}
|
|
}
|