|
@@ -25,6 +25,7 @@ import sys
|
|
|
import paddle.distributed.fleet as fleet
|
|
|
import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
import paddle
|
|
|
+from paddle.base.executor import FetchHandler
|
|
|
import threading
|
|
|
|
|
|
import warnings
|
|
@@ -49,7 +50,7 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
import json
|
|
|
|
|
|
-class InferenceFetchHandler(object):
|
|
|
+class InferenceFetchHandler(FetchHandler):
|
|
|
def __init__(self, output_file, batch_size=1000):
|
|
|
self.output_file = output_file
|
|
|
self.batch_size = batch_size
|
|
@@ -327,7 +328,7 @@ class Main(object):
|
|
|
debug=debug,
|
|
|
fetch_handler=fetch_handler)
|
|
|
fetch_handler.finish()
|
|
|
-
|
|
|
+
|
|
|
|
|
|
|
|
|
def heter_train_loop(self, epoch):
|