Parcourir la source

feat:VOV模型

zhaohaipeng il y a 9 mois
Parent
commit
6f862222ad
2 fichiers modifiés avec 11 ajouts et 3 suppressions
  1. 8 3
      XGB/vov_xgboost_train.py
  2. 3 0
      client/ODPSClient.py

+ 8 - 3
XGB/vov_xgboost_train.py

@@ -34,11 +34,16 @@ def get_partition_df(table, dt):
     logger.info(f"开始下载: {table} -- {dt} 的数据")
     df = pd.DataFrame()
     try:
+        table_info = odps_client.get_table(table)
+        col_names = [col.name for col in table_info.table_schema.columns]
         download_session = odps_client.get_download_session(table, dt)
         logger.info(f"表: {table} 中的分区 {dt}, 共有 {download_session.count} 条数据")
-        with download_session.open_arrow_reader(0, download_session.count) as reader:
-            # 将所有数据加载到 DataFrame 中
-            df = pd.concat([batch.to_pandas() for batch in reader])
+        with download_session.open_record_reader(0, download_session.count) as reader:
+            records = []
+            for record in reader:
+                records.append(record.values)  # 获取每一行的值
+            # 使用元数据中的列名
+            df = pd.DataFrame(records, columns=col_names)
     except Exception as e:
         logger.error(f"下载 {table} -- {dt} 的数据异常: ", e)
 

+ 3 - 0
client/ODPSClient.py

@@ -33,6 +33,9 @@ class ODPSClient(object):
                 result.append(record)
         return result
 
+    def get_table(self, table: str):
+        return self.odps.get_table(table)
+
     def get_download_session(self, table: str, dt: str):
         tunnel = TableTunnel(self.odps)