Browse Source

add infer code

often 5 months ago
parent
commit
c94fe349e7

+ 15 - 1
recommend-model-produce/src/main/python/tools/static_ps_infer_v2.py

@@ -48,7 +48,6 @@ logger = logging.getLogger(__name__)
 
 
 import json
 import json
 
 
-# 创建推理结果处理器类
 class InferenceFetchHandler(object):
 class InferenceFetchHandler(object):
     def __init__(self, output_file, batch_size=1000):
     def __init__(self, output_file, batch_size=1000):
         self.output_file = output_file
         self.output_file = output_file
@@ -56,6 +55,11 @@ class InferenceFetchHandler(object):
         self.current_batch = []
         self.current_batch = []
         self.total_samples = 0
         self.total_samples = 0
         
         
+        # 添加 Paddle 需要的属性
+        self.period_secs = 60  # 设置默认的周期时间(秒)
+        self.done_event = threading.Event()  # 添加完成事件
+        self.terminal_event = threading.Event()  # 添加终止事件
+        
         # 创建输出目录(如果不存在)
         # 创建输出目录(如果不存在)
         output_dir = os.path.dirname(output_file)
         output_dir = os.path.dirname(output_file)
         if not os.path.exists(output_dir):
         if not os.path.exists(output_dir):
@@ -96,6 +100,16 @@ class InferenceFetchHandler(object):
         if self.current_batch:
         if self.current_batch:
             self._write_batch()
             self._write_batch()
             logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
             logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
+        self.done_event.set()
+    
+    def stop(self):
+        """停止处理"""
+        self.terminal_event.set()
+        self.finish()
+
+    def wait(self):
+        """等待处理完成"""
+        self.done_event.wait()
 
 
 
 
 def parse_args():
 def parse_args():