|
11 | 11 | import re |
12 | 12 | from openai import OpenAI |
13 | 13 | import base64 |
14 | | -from typing import List, Literal |
15 | | -from concurrent.futures import ThreadPoolExecutor, as_completed |
| 14 | +from typing import Literal, Union |
16 | 15 | from dataflow.core import LLMServingABC |
17 | 16 | from dataflow.serving import APIVLMServing_openai |
18 | | - |
19 | | - |
| 17 | +from dataflow.core.prompt import DIYPromptABC |
| 18 | +from dataflow.utils.storage import DataFlowStorage |
20 | 19 |
|
21 | 20 | @OPERATOR_REGISTRY.register() |
22 | 21 | class MathBookQuestionExtract(OperatorABC): |
23 | | - def __init__(self, llm_serving: APIVLMServing_openai, prompt_template = None): |
| 22 | + def __init__(self, |
| 23 | + llm_serving: APIVLMServing_openai, |
| 24 | + prompt_template: Union[MathbookQuestionExtractPrompt, DIYPromptABC] = MathbookQuestionExtractPrompt(), |
| 25 | + mineru_backend: str = "vlm-vllm-engine", |
| 26 | + dpi: int = 300, |
| 27 | + key_name_of_api_key: str = "DF_API_KEY", |
| 28 | + model_name: str = "o4-mini", |
| 29 | + max_workers: int = 20 |
| 30 | + ): |
24 | 31 | self.logger = get_logger() |
25 | 32 | self.llm_serving = llm_serving |
26 | | - if prompt_template: |
27 | | - self.prompt_template = prompt_template |
28 | | - else: |
29 | | - self.prompt_template = MathbookQuestionExtractPrompt() |
| 33 | + self.prompt_template = prompt_template |
| 34 | + |
| 35 | + self.mineru_backend = mineru_backend |
| 36 | + self.dpi = dpi |
| 37 | + self.key_name_of_api_key = key_name_of_api_key |
| 38 | + self.model_name = model_name |
| 39 | + self.max_workers = max_workers # 注意:这个参数在原逻辑中并未被使用,但仍按要求移入init |
30 | 40 |
|
31 | 41 | @staticmethod |
32 | 42 | def get_desc(lang: str = "zh"): |
@@ -80,7 +90,6 @@ def mineru2_runner(self, |
80 | 90 | # pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client |
81 | 91 | mineru_backend: Literal["pipeline", "vlm-transformers", "vlm-vllm-engine", "vlm-http-client"] = "pipeline" |
82 | 92 | ): |
83 | | - |
84 | 93 | try: |
85 | 94 | import mineru |
86 | 95 | except ImportError: |
@@ -227,6 +236,7 @@ def process_input(self, |
227 | 236 | return full_input_image_list,full_input_label_list |
228 | 237 |
|
229 | 238 | def analyze_and_save(self,result_list,save_folder,img_folder,output_file_name): |
| 239 | + # ... (analyze_and_save 方法保持不变) |
230 | 240 | # make save_folder if not exist |
231 | 241 | if not os.path.exists(save_folder): |
232 | 242 | os.makedirs(save_folder) |
@@ -281,42 +291,38 @@ def analyze_and_save(self,result_list,save_folder,img_folder,output_file_name): |
281 | 291 |
|
282 | 292 | def run( |
283 | 293 | self, |
284 | | - pdf_file_path: str, |
| 294 | + storage: DataFlowStorage, |
| 295 | + input_pdf_file_path: str, |
285 | 296 | output_file_name: str, |
286 | 297 | output_folder: str, |
287 | | - MinerU_Backend: str = "vlm-sglang-engine", |
288 | | - dpi: int = 300, |
289 | | - api_url: str = "http://123.129.219.111:3000/v1", |
290 | | - key_name_of_api_key: str = "DF_API_KEY", |
291 | | - model_name: str = "o4-mini", |
292 | | - max_workers: int = 20 |
293 | 298 | ): |
294 | | - api_key = os.environ.get(key_name_of_api_key) |
| 299 | + # get the configuration parameters from self |
| 300 | + api_key = os.environ.get(self.key_name_of_api_key) |
295 | 301 | if not api_key: |
296 | | - raise ValueError(f"API key not found in environment variable {key_name_of_api_key}") |
| 302 | + raise ValueError(f"API key not found in environment variable {self.key_name_of_api_key}") |
297 | 303 |
|
298 | 304 | # 1. convert pdf to images |
299 | | - pdf2images_folder_name = output_folder+"/pdfimages" |
300 | | - self.pdf2images(pdf_file_path, pdf2images_folder_name, dpi) |
| 305 | + pdf2images_folder_name = os.path.join(output_folder, "pdfimages") |
| 306 | + self.pdf2images(input_pdf_file_path, pdf2images_folder_name, self.dpi) |
301 | 307 |
|
302 | 308 | # 2. use mineru to extract content and pics |
303 | | - json_content_file, pic_folder = self.mineru2_runner(pdf_file_path, output_folder, MinerU_Backend) |
| 309 | + json_content_file, pic_folder = self.mineru2_runner(input_pdf_file_path, output_folder, self.mineru_backend) |
304 | 310 |
|
305 | 311 | # 3. organize_pics |
306 | | - output_image_folder = output_folder+"/organized_images" |
307 | | - output_json_file = output_folder+"/organized_images/organized_info.json" |
308 | | - self.organize_pics(json_content_file, pic_folder,output_json_file, output_image_folder) |
| 312 | + output_image_folder = os.path.join(output_folder, "organized_images") |
| 313 | + output_json_file = os.path.join(output_image_folder, "organized_info.json") |
| 314 | + self.organize_pics(json_content_file, pic_folder, output_json_file, output_image_folder) |
309 | 315 |
|
310 | 316 | # 4. process input |
311 | | - full_input_image_list,full_input_label_list = self.process_input(pdf2images_folder_name, output_json_file) |
| 317 | + full_input_image_list, full_input_label_list = self.process_input(pdf2images_folder_name, output_json_file) |
312 | 318 |
|
313 | 319 | # 5. init server and generate |
314 | 320 | system_prompt = self.prompt_template.build_prompt() |
315 | 321 | result_text_list = self.llm_serving.generate_from_input_multi_images( |
316 | 322 | list_of_image_paths=full_input_image_list, |
317 | 323 | list_of_image_labels=full_input_label_list, |
318 | 324 | system_prompt=system_prompt, |
319 | | - model=model_name, |
| 325 | + model=self.model_name, |
320 | 326 | timeout=1800 |
321 | 327 | ) |
322 | 328 |
|
|
0 commit comments