Skip to content

Commit 3e7f6b0

Browse files
authored
[Fix Playground] Fix errors of filepath and import path (#389)
* [Fix Playground] Fix errors of filepath and import path * [Fix Bug] fix bug of prompt template and import error * [Fix Bug] Fix parameters name in run() * [Fix Bug] Fix parameters name in run() * [Fix Bug] Fix parameters name in run()
1 parent c520b7b commit 3e7f6b0

File tree

5 files changed

+60
-52
lines changed

5 files changed

+60
-52
lines changed

dataflow/operators/knowledge_cleaning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .generate.file_or_url_to_markdown_converter_batch import FileOrURLToMarkdownConverterBatch
99
from .generate.kbc_text_cleaner import KBCTextCleaner
1010
from .generate.kbc_text_cleaner_batch import KBCTextCleanerBatch
11-
# from .generate.mathbook_question_extract import MathBookQuestionExtract
11+
from .generate.mathbook_question_extract import MathBookQuestionExtract
1212
# from .generate.kbc_multihop_qa_generator import KBCMultiHopQAGenerator
1313
from .generate.kbc_multihop_qa_generator_batch import KBCMultiHopQAGeneratorBatch
1414
from .generate.qa_extract import QAExtractor

dataflow/operators/knowledge_cleaning/generate/mathbook_question_extract.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,32 @@
1111
import re
1212
from openai import OpenAI
1313
import base64
14-
from typing import List, Literal
15-
from concurrent.futures import ThreadPoolExecutor, as_completed
14+
from typing import Literal, Union
1615
from dataflow.core import LLMServingABC
1716
from dataflow.serving import APIVLMServing_openai
18-
19-
17+
from dataflow.core.prompt import DIYPromptABC
18+
from dataflow.utils.storage import DataFlowStorage
2019

2120
@OPERATOR_REGISTRY.register()
2221
class MathBookQuestionExtract(OperatorABC):
23-
def __init__(self, llm_serving: APIVLMServing_openai, prompt_template = None):
22+
def __init__(self,
23+
llm_serving: APIVLMServing_openai,
24+
prompt_template: Union[MathbookQuestionExtractPrompt, DIYPromptABC] = MathbookQuestionExtractPrompt(),
25+
mineru_backend: str = "vlm-vllm-engine",
26+
dpi: int = 300,
27+
key_name_of_api_key: str = "DF_API_KEY",
28+
model_name: str = "o4-mini",
29+
max_workers: int = 20
30+
):
2431
self.logger = get_logger()
2532
self.llm_serving = llm_serving
26-
if prompt_template:
27-
self.prompt_template = prompt_template
28-
else:
29-
self.prompt_template = MathbookQuestionExtractPrompt()
33+
self.prompt_template = prompt_template
34+
35+
self.mineru_backend = mineru_backend
36+
self.dpi = dpi
37+
self.key_name_of_api_key = key_name_of_api_key
38+
self.model_name = model_name
39+
self.max_workers = max_workers # 注意:这个参数在原逻辑中并未被使用,但仍按要求移入init
3040

3141
@staticmethod
3242
def get_desc(lang: str = "zh"):
@@ -80,7 +90,6 @@ def mineru2_runner(self,
8090
# pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client
8191
mineru_backend: Literal["pipeline", "vlm-transformers", "vlm-vllm-engine", "vlm-http-client"] = "pipeline"
8292
):
83-
8493
try:
8594
import mineru
8695
except ImportError:
@@ -227,6 +236,7 @@ def process_input(self,
227236
return full_input_image_list,full_input_label_list
228237

229238
def analyze_and_save(self,result_list,save_folder,img_folder,output_file_name):
239+
# ... (analyze_and_save 方法保持不变)
230240
# make save_folder if not exist
231241
if not os.path.exists(save_folder):
232242
os.makedirs(save_folder)
@@ -281,42 +291,38 @@ def analyze_and_save(self,result_list,save_folder,img_folder,output_file_name):
281291

282292
def run(
283293
self,
284-
pdf_file_path: str,
294+
storage: DataFlowStorage,
295+
input_pdf_file_path: str,
285296
output_file_name: str,
286297
output_folder: str,
287-
MinerU_Backend: str = "vlm-sglang-engine",
288-
dpi: int = 300,
289-
api_url: str = "http://123.129.219.111:3000/v1",
290-
key_name_of_api_key: str = "DF_API_KEY",
291-
model_name: str = "o4-mini",
292-
max_workers: int = 20
293298
):
294-
api_key = os.environ.get(key_name_of_api_key)
299+
# get the configuration parameters from self
300+
api_key = os.environ.get(self.key_name_of_api_key)
295301
if not api_key:
296-
raise ValueError(f"API key not found in environment variable {key_name_of_api_key}")
302+
raise ValueError(f"API key not found in environment variable {self.key_name_of_api_key}")
297303

298304
# 1. convert pdf to images
299-
pdf2images_folder_name = output_folder+"/pdfimages"
300-
self.pdf2images(pdf_file_path, pdf2images_folder_name, dpi)
305+
pdf2images_folder_name = os.path.join(output_folder, "pdfimages")
306+
self.pdf2images(input_pdf_file_path, pdf2images_folder_name, self.dpi)
301307

302308
# 2. use mineru to extract content and pics
303-
json_content_file, pic_folder = self.mineru2_runner(pdf_file_path, output_folder, MinerU_Backend)
309+
json_content_file, pic_folder = self.mineru2_runner(input_pdf_file_path, output_folder, self.mineru_backend)
304310

305311
# 3. organize_pics
306-
output_image_folder = output_folder+"/organized_images"
307-
output_json_file = output_folder+"/organized_images/organized_info.json"
308-
self.organize_pics(json_content_file, pic_folder,output_json_file, output_image_folder)
312+
output_image_folder = os.path.join(output_folder, "organized_images")
313+
output_json_file = os.path.join(output_image_folder, "organized_info.json")
314+
self.organize_pics(json_content_file, pic_folder, output_json_file, output_image_folder)
309315

310316
# 4. process input
311-
full_input_image_list,full_input_label_list = self.process_input(pdf2images_folder_name, output_json_file)
317+
full_input_image_list, full_input_label_list = self.process_input(pdf2images_folder_name, output_json_file)
312318

313319
# 5. init server and generate
314320
system_prompt = self.prompt_template.build_prompt()
315321
result_text_list = self.llm_serving.generate_from_input_multi_images(
316322
list_of_image_paths=full_input_image_list,
317323
list_of_image_labels=full_input_label_list,
318324
system_prompt=system_prompt,
319-
model=model_name,
325+
model=self.model_name,
320326
timeout=1800
321327
)
322328

dataflow/prompts/kbcleaning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataflow.utils.registry import PROMPT_REGISTRY
22
from dataflow.core.prompt import PromptABC
3-
3+
import re
44
@PROMPT_REGISTRY.register()
55
class KnowledgeCleanerPrompt(PromptABC):
66
'''
@@ -220,7 +220,7 @@ def _post_process(self, cleaned_text: str) -> str:
220220

221221

222222
@PROMPT_REGISTRY.register()
223-
class MathbookQuestionExtractPrompt:
223+
class MathbookQuestionExtractPrompt(PromptABC):
224224
def __init__(self):
225225
pass
226226

dataflow/statics/playground/playground/mathbook_extract.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,43 @@
1-
from dataflow.operators.knowledge_cleaning.generate.mathbook_question_extract import MathBookQuestionExtract
2-
from dataflow.serving.api_vlm_serving_openai import APIVLMServing_openai
1+
from dataflow.operators.knowledge_cleaning import MathBookQuestionExtract
2+
from dataflow.serving import APIVLMServing_openai
33

44
class QuestionExtractPipeline:
5-
def __init__(self, llm_serving: APIVLMServing_openai):
6-
self.extractor = MathBookQuestionExtract(llm_serving)
7-
self.test_pdf = "../example_data/KBCleaningPipeline/questionextract_test.pdf"
5+
def __init__(self,
6+
llm_serving: APIVLMServing_openai,
7+
api_url: str = "https://api.openai.com/v1", # end with /v1
8+
key_name_of_api_key: str = "DF_API_KEY", # set in environment first: export DF_API_KEY="your_openai_api_key"
9+
model_name: str = "o4-mini",
10+
max_workers: int = 20
11+
):
12+
self.extractor = MathBookQuestionExtract(
13+
llm_serving=llm_serving,
14+
key_name_of_api_key=key_name_of_api_key,
15+
model_name=model_name,
16+
max_workers=max_workers
17+
)
18+
self.test_pdf = "../example_data/PDF2VQAPipeline/questionextract_test.pdf"
819

920
def forward(
1021
self,
1122
pdf_path: str,
1223
output_name: str,
1324
output_dir: str,
14-
api_url: str = "https://api.openai.com/v1/chat/completions",
15-
key_name_of_api_key: str = "DF_API_KEY",
16-
model_name: str = "o4-mini",
17-
max_workers: int = 20
1825
):
1926
self.extractor.run(
20-
pdf_file_path=pdf_path,
27+
storage=None,
28+
input_pdf_file_path=pdf_path,
2129
output_file_name=output_name,
22-
output_folder=output_dir,
23-
api_url=api_url,
24-
key_name_of_api_key=key_name_of_api_key,
25-
model_name=model_name,
26-
max_workers=max_workers
30+
output_folder=output_dir
2731
)
2832

2933
if __name__ == "__main__":
30-
# 1. initialize LLM Serving
3134
llm_serving = APIVLMServing_openai(
32-
api_url="https://api.openai.com/v1", # end with /v1, DO NOT add /chat/completions
33-
model_name="o4-mini", # recommend using strong reasoning model
34-
max_workers=20 # number of concurrent requests
35+
api_url="https://api.openai.com/v1",
36+
model_name="o4-mini",
37+
max_workers=20
3538
)
3639

37-
# 2. construct and run pipeline
38-
pipeline = QuestionExtractPipeline(llm_serving)
40+
pipeline = QuestionExtractPipeline(llm_serving=llm_serving)
3941
pipeline.forward(
4042
pdf_path=pipeline.test_pdf,
4143
output_name="test_question_extract",

dataflow/statics/playground/playground/vqa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataflow.operators.core_vision import PromptedVQAGenerator
2-
from dataflow.serving.APIVLMServing_openai import APIVLMServing_openai
2+
from dataflow.serving import APIVLMServing_openai
33
from dataflow.utils.storage import FileStorage
44

55
class VQA_generator():

0 commit comments

Comments
 (0)