|
@@ -0,0 +1,339 @@
|
|
|
+package com.tzld.piaoquan.recommend.server.framework.recaller;
|
|
|
+
|
|
|
+
|
|
|
+import com.google.common.base.Function;
|
|
|
+import com.google.common.base.Optional;
|
|
|
+import com.google.common.base.Predicate;
|
|
|
+import com.google.common.collect.FluentIterable;
|
|
|
+import com.google.common.collect.Lists;
|
|
|
+import com.google.common.collect.Maps;
|
|
|
+
|
|
|
+
|
|
|
+import com.tzld.piaoquan.recommend.server.common.base.RankItem;
|
|
|
+import com.tzld.piaoquan.recommend.server.framework.candidiate.*;
|
|
|
+import com.tzld.piaoquan.recommend.server.framework.common.User;
|
|
|
+import com.tzld.piaoquan.recommend.server.framework.recaller.provider.ItemProvider;
|
|
|
+import com.tzld.piaoquan.recommend.server.framework.recaller.provider.QueueProvider;
|
|
|
+import com.tzld.piaoquan.recommend.server.gen.recommend.RecommendRequest;
|
|
|
+import org.apache.commons.collections4.CollectionUtils;
|
|
|
+import org.apache.commons.lang.exception.ExceptionUtils;
|
|
|
+import org.slf4j.Logger;
|
|
|
+import org.slf4j.LoggerFactory;
|
|
|
+
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.Callable;
|
|
|
+import java.util.concurrent.ExecutionException;
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.Future;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+
|
|
|
+public class BaseRecaller<InMemoryItem> {
|
|
|
+
|
|
|
+ private static final Logger LOGGER = LoggerFactory.getLogger(BaseRecaller.class);
|
|
|
+ private static final long DEFAULT_QUEUE_LOAD_TIMEOUT = 150; // ms
|
|
|
+ private static final long DEFAULT_PARALLEL_FILTER_TIMEOUT = 200; // ms
|
|
|
+ private static final ExecutorService filterExecutorService = Executors.newFixedThreadPool(128);
|
|
|
+ private static final ExecutorService fetchQueueExecutorService = Executors.newFixedThreadPool(128);
|
|
|
+
|
|
|
+ protected final ItemProvider<InMemoryItem> itemProvider;
|
|
|
+ private final QueueProvider<InMemoryItem> queueProvider;
|
|
|
+ private final FilterConfig filterConfig ;
|
|
|
+ private final long QUEUE_LOAD_TIMEOUT;
|
|
|
+
|
|
|
+ public BaseRecaller(ItemProvider<InMemoryItem> itemProvider, QueueProvider<InMemoryItem> queueProvider, FilterConfig filterConfig) {
|
|
|
+ this(itemProvider, queueProvider, filterConfig, DEFAULT_QUEUE_LOAD_TIMEOUT);
|
|
|
+ }
|
|
|
+
|
|
|
+ public BaseRecaller(ItemProvider<InMemoryItem> itemProvider, QueueProvider<InMemoryItem> queueProvider,
|
|
|
+ FilterConfig filterConfig, long queueLoadTimeout) {
|
|
|
+ this.itemProvider = itemProvider;
|
|
|
+ this.queueProvider = queueProvider;
|
|
|
+ this.filterConfig = filterConfig;
|
|
|
+ this.QUEUE_LOAD_TIMEOUT = queueLoadTimeout;
|
|
|
+ }
|
|
|
+
|
|
|
+ public String extractItemId(Entry<InMemoryItem> entry) {
|
|
|
+ return entry.id;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ public boolean isValidItem(InMemoryItem item){
|
|
|
+ return item!= null;
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ public ItemProvider<InMemoryItem> getItemProvider() {
|
|
|
+ return itemProvider;
|
|
|
+ }
|
|
|
+
|
|
|
+ public QueueProvider<InMemoryItem> getQueueProvider() {
|
|
|
+ return queueProvider;
|
|
|
+ }
|
|
|
+
|
|
|
+ public Map<Candidate, Queue<InMemoryItem>> getQueue(String queueStr) throws Exception {
|
|
|
+ Candidate candidate = new Candidate();
|
|
|
+ candidate.setCandidateKey(queueStr);
|
|
|
+ return loadQueues(Arrays.asList(candidate));
|
|
|
+ }
|
|
|
+
|
|
|
+ public Optional<InMemoryItem> getItem(String itemId) throws Exception {
|
|
|
+ return itemProvider.get(itemId);
|
|
|
+ }
|
|
|
+
|
|
|
+ public long getItemDBSize() {
|
|
|
+ return this.itemProvider.dbSize();
|
|
|
+ }
|
|
|
+
|
|
|
+ public long getIndexDBSize() {
|
|
|
+ return this.queueProvider.dbSize();
|
|
|
+ }
|
|
|
+
|
|
|
+ public long getQueueTTL(String queueNameStr, Map<String, Long> cacheRules) {
|
|
|
+ QueueName name = QueueName.fromString(queueNameStr);
|
|
|
+
|
|
|
+ for (String match : name.getMatches()) {
|
|
|
+ if (cacheRules.containsKey(match)) {
|
|
|
+ return cacheRules.get(match);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return name.getTTL();
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Double> listToMap(List<String> list) {
|
|
|
+ Map<String, Double> map = new HashMap<String, Double>();
|
|
|
+ if (list != null) {
|
|
|
+ for (String elem : list) {
|
|
|
+ map.put(elem, 1.0);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return map;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 把queue中的entry 放入RankItem中
|
|
|
+ *
|
|
|
+ * @param entries
|
|
|
+ * @param candidate
|
|
|
+ * @param user
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ private List<RankItem> toHits(final Iterable<Entry<InMemoryItem>> entries, final Candidate candidate, final User user) {
|
|
|
+
|
|
|
+ List<RankItem> result = new ArrayList<RankItem>();
|
|
|
+ for (Entry entry : entries) {
|
|
|
+ RankItem item = new RankItem();
|
|
|
+ item.setId(extractItemId(entry));
|
|
|
+ item.putRankerScore("L3Score", (Double) entry.scores.get("ordering"));
|
|
|
+ item.setQueue(candidate.getMergeQueueName()); // merge queue
|
|
|
+
|
|
|
+ CandidateInfo candidateInfo = new CandidateInfo();
|
|
|
+ candidateInfo.setCandidateQueueName(candidate.getCandidateKey());
|
|
|
+ candidateInfo.setCandidate(candidate);
|
|
|
+
|
|
|
+ result.add(item);
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ // 读取redis中的数据放入queue中
|
|
|
+ public Map<Candidate, Queue<InMemoryItem>> loadQueues(List<Candidate> candidates) {
|
|
|
+
|
|
|
+ // update queueName
|
|
|
+ // final Map<String, Long> cacheRules = getCacheRulesConfig();
|
|
|
+ Iterable<Candidate> updateCandidates = FluentIterable.from(candidates).transform(new Function<Candidate, Candidate>() {
|
|
|
+ @Override
|
|
|
+ public Candidate apply(Candidate candidate) {
|
|
|
+ try {
|
|
|
+ long ttl = QueueName.DEFAULT_LOCAL_CACHE_TTL;
|
|
|
+ candidate.setCandidateQueueName(QueueName.fromString(candidate.getCandidateKey(), ttl));
|
|
|
+ } catch (Exception e) {
|
|
|
+ candidate.setCandidateQueueName(null);
|
|
|
+ LOGGER.error("error parse QueueName [{}]", candidate.getCandidateKey());
|
|
|
+ }
|
|
|
+ return candidate;
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ // parse queues
|
|
|
+ Iterable<QueueName> queueNames = FluentIterable.from(updateCandidates).transform(new Function<Candidate, QueueName>() {
|
|
|
+ @Override
|
|
|
+ public QueueName apply(Candidate candidate) {
|
|
|
+ return candidate.getCandidateQueueName();
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+ // parallel load queues
|
|
|
+ // redis 或者缓存获取index
|
|
|
+ Map<QueueName, Queue<InMemoryItem>> queues = Maps.newConcurrentMap();
|
|
|
+ try {
|
|
|
+ queues = queueProvider.loads(Lists.newArrayList(queueNames), QUEUE_LOAD_TIMEOUT, TimeUnit.MILLISECONDS);
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("load queue occur error [{}]", ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ // parse candidate map
|
|
|
+ Map<Candidate, Queue<InMemoryItem>> candidateQueueMap = Maps.newConcurrentMap();
|
|
|
+ for (Candidate candidate : updateCandidates) {
|
|
|
+ QueueName name = candidate.getCandidateQueueName();
|
|
|
+ if (queues.containsKey(name) && queues.get(name) != null) {
|
|
|
+ candidateQueueMap.put(candidate, queues.get(name));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return candidateQueueMap;
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ /**
|
|
|
+ * recall
|
|
|
+ * 1. construct recall filter
|
|
|
+ * 2. Redis并行召回
|
|
|
+ * 3. do filter
|
|
|
+ *
|
|
|
+ * @param requestData
|
|
|
+ * @param user
|
|
|
+ * @param requestIndex
|
|
|
+ * @param recallCandidates
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ public List<RankItem> recalling(final RecommendRequest requestData, final User user, int requestIndex, List<Candidate> recallCandidates) {
|
|
|
+
|
|
|
+ long startTime = System.currentTimeMillis();
|
|
|
+ final RecallFilter<InMemoryItem> recallFilter = new RecallFilter<InMemoryItem>(this.filterConfig, requestData, user, requestIndex);
|
|
|
+
|
|
|
+ // load queue
|
|
|
+ long queueLoadStartTime = System.currentTimeMillis();
|
|
|
+
|
|
|
+ // load from redis
|
|
|
+ List<Callable<Map<Candidate, Queue<InMemoryItem>>>> fetchQueueCalls = Lists.newArrayList();
|
|
|
+ fetchQueueCalls.add(new Callable<Map<Candidate, Queue<InMemoryItem>>>() {
|
|
|
+ @Override
|
|
|
+ public Map<Candidate, Queue<InMemoryItem>> call() throws Exception {
|
|
|
+ boolean isFromRedis = true;
|
|
|
+ return obtainQueue(recallCandidates, requestData, user, isFromRedis);
|
|
|
+ }
|
|
|
+ });
|
|
|
+
|
|
|
+
|
|
|
+ List<Future<Map<Candidate, Queue<InMemoryItem>>>> fetchQueueFutures = null;
|
|
|
+ try {
|
|
|
+ fetchQueueFutures = fetchQueueExecutorService.invokeAll(fetchQueueCalls, DEFAULT_QUEUE_LOAD_TIMEOUT,
|
|
|
+ TimeUnit.MILLISECONDS);
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ LOGGER.error("[fetch queue error] inter fail: {}", ExceptionUtils.getStackTrace(e));
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("[fetch queue error] ex", ExceptionUtils.getStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ Map<Candidate, Queue<InMemoryItem>> candidateQueueMap = Maps.newHashMap();
|
|
|
+ if (CollectionUtils.isNotEmpty(fetchQueueFutures)) {
|
|
|
+ for (Future<Map<Candidate, Queue<InMemoryItem>>> future : fetchQueueFutures) {
|
|
|
+ if (future.isDone() && !future.isCancelled()) {
|
|
|
+ Map<Candidate, Queue<InMemoryItem>> result = null;
|
|
|
+ try {
|
|
|
+ result = future.get();
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ LOGGER.error("[fetch queue error] InterruptedException {}", ExceptionUtils.getStackTrace(e));
|
|
|
+ } catch (ExecutionException e) {
|
|
|
+ LOGGER.error("[fetch queue error] ex {}", ExceptionUtils.getStackTrace(e));
|
|
|
+ }
|
|
|
+ if (result != null) {
|
|
|
+ candidateQueueMap.putAll(result);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // do filter
|
|
|
+ // 执行 recall filter配置文件中的方法
|
|
|
+ long filterStartTime = System.currentTimeMillis();
|
|
|
+
|
|
|
+ List<Map.Entry<Candidate, Queue<InMemoryItem>>> batch = new ArrayList<Map.Entry<Candidate, Queue<InMemoryItem>>>();
|
|
|
+ final List<Callable<List<RankItem>>> callables = new ArrayList<Callable<List<RankItem>>>();
|
|
|
+ int expectedRecallSum = 0;
|
|
|
+ for (final Map.Entry<Candidate, Queue<InMemoryItem>> entry : candidateQueueMap.entrySet()) {
|
|
|
+ callables.add(new Callable<List<RankItem>>() {
|
|
|
+ @Override
|
|
|
+ public List<RankItem> call() throws Exception {
|
|
|
+ List<RankItem> candidateHits = new ArrayList<RankItem>();
|
|
|
+ final Candidate candidate = entry.getKey();
|
|
|
+ try {
|
|
|
+ // 1. filter
|
|
|
+ Iterable<Entry<InMemoryItem>> entries = FluentIterable.from(entry.getValue()).filter(new Predicate<Entry<InMemoryItem>>() {
|
|
|
+ @Override
|
|
|
+ public boolean apply(Entry<InMemoryItem> entry) {
|
|
|
+ return isValidItem(entry.item) &&
|
|
|
+ recallFilter.predicate(candidate, entry.item);
|
|
|
+ }
|
|
|
+ }).limit(candidate.getCandidateNum());
|
|
|
+
|
|
|
+ // 2. toHits
|
|
|
+ candidateHits.addAll(toHits(entries, candidate, user));
|
|
|
+
|
|
|
+ // debug log for tracing
|
|
|
+ LOGGER.debug("recalled candidate [{}], queue length [{}], expected [{}], hit [{}]",
|
|
|
+ new Object[]{candidate.getCandidateKey(), entry.getValue().size(), candidate.getCandidateNum(), candidateHits.size()});
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("recall filter queue occur error, queue [{}], error: [{}]", candidate.toString(), ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ return candidateHits;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ Map<String, RankItem> hits = new HashMap<String, RankItem>();
|
|
|
+ try {
|
|
|
+ List<Future<List<RankItem>>> futures = filterExecutorService.invokeAll(callables, DEFAULT_PARALLEL_FILTER_TIMEOUT, TimeUnit.MILLISECONDS);
|
|
|
+
|
|
|
+ for (Future<List<RankItem>> future : futures) {
|
|
|
+ try {
|
|
|
+ if (future.isDone() && !future.isCancelled() && future.get() != null) {
|
|
|
+ List<RankItem> part = future.get();
|
|
|
+ // add to result
|
|
|
+ for (RankItem item : part) {
|
|
|
+ // merge candidate Info
|
|
|
+ if (hits.containsKey(item.getId())) {
|
|
|
+ hits.get(item.getId()).addToCandidateInfoList(item.getCandidateInfo());
|
|
|
+ } else {
|
|
|
+ hits.put(item.getId(), item);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ LOGGER.error("parallel recall filter Canceled {} ", requestData.getRequestId());
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("parallel recall filter occur error, uid: [{}], Exception [{}]",
|
|
|
+ requestData.getRequestId(), ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } catch (Exception e) {
|
|
|
+ LOGGER.error("parallel recall filter occur error, uid: [{}], Exception [{}]",
|
|
|
+ requestData.getRequestId(), ExceptionUtils.getFullStackTrace(e));
|
|
|
+ }
|
|
|
+
|
|
|
+ List<RankItem> result = new ArrayList<RankItem>(hits.values());
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<Candidate, Queue<InMemoryItem>> obtainQueue(List<Candidate> refactorCandidates, RecommendRequest requestData, User user, boolean isFromRedis) {
|
|
|
+ if (isFromRedis) {
|
|
|
+ return loadQueues(refactorCandidates);
|
|
|
+ } else {
|
|
|
+ return fetchQueues(refactorCandidates, requestData, user);
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ protected Map<Candidate, Queue<InMemoryItem>> fetchQueues(List<Candidate> candidates, RecommendRequest requestData, User user) {
|
|
|
+
|
|
|
+ return null;
|
|
|
+ }
|
|
|
+}
|