Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dataflow_agent/agentroles/cores/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,4 +1535,4 @@ async def _execute_vlm(self, state: MainState, **kwargs) -> Dict[str, Any]:
log.info(f"parsed 内容: {parsed}")
log.info(f"additional_kwargs 内容: {response.additional_kwargs}")

return parsed
return parsed
64 changes: 54 additions & 10 deletions dataflow_agent/promptstemplates/prompts_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,31 +546,73 @@ class PipelineRefinePrompts:

# 3) Refiner: align input names and allow op_context per step_id
PipelineRefinePrompts.system_prompt_for_json_pipeline_refiner = """
You are a JSON pipeline refiner. Modify the given pipeline JSON according to the modification_plan and optional operator contexts.
Only output the full updated pipeline JSON object with keys {"nodes","edges"}. No comments.
Rules:
You are a JSON pipeline refiner with access to operator search tools. Modify the given pipeline JSON according to the modification_plan.

**CRITICAL RULES FOR ADDING NEW OPERATORS:**
1. **MUST USE TOOL**: Before adding ANY new operator, you MUST call the `search_operator_by_description` tool to find real operators.
2. **ONLY USE RETURNED OPERATORS**: You can ONLY use operator names returned by the tool. NEVER invent or guess operator names.
3. **VERIFY OPERATOR EXISTS**: If the tool returns no suitable operators, report this issue instead of making up names.
4. **CHECK MATCH QUALITY**: The search tool returns a `match_quality` field indicating how well the results match your query:
- "high" (similarity >= 0.5): Good match, safe to use
- "medium" (similarity 0.3-0.5): Moderate match, verify the operator description matches your needs
- "low" (similarity < 0.3): Poor match, the operators may NOT satisfy the requirement. You should report "未能找到满足XXX需求的算子" in this case.

**JSON Modification Rules:**
- For remove: delete the node and its edges; then connect all predecessors to all successors to keep connectivity (DAG, no cycles).
- For insert_between(a,b): replace edge a→b with a→new and new→b.
- For insert_before/after/start/end: adjust edges accordingly and keep graph connected.
- For add without explicit position: append at end and wire all terminal nodes to the new node using provided ports.
- Edge fields: {"source","target","source_port","target_port"}.
- Node fields: {"id","name","type","config":{"run":{...},"init":{...}}}.
- Always apply ALL steps in modification_plan sequentially. Do not skip steps.
- When removing a node, reconnect every predecessor to every successor using the correct ports.
- Ensure newly created node ids are unique.
- Always apply ALL steps in modification_plan sequentially. Do not skip steps.
- When removing a node, reconnect every predecessor to every successor using the correct ports.
- Ensure newly created node ids are unique.

**OUTPUT FORMAT:**
- If all operators are found with acceptable match quality: Output the full updated pipeline JSON object with keys {"nodes","edges"}.
- If any required operator has low match quality and cannot satisfy the requirement: Output a JSON object with:
{
"status": "partial_failure",
"message": "未能找到满足「XXX」需求的算子。当前算子库中最相似的是 YYY(功能:ZZZ),但其功能与需求不匹配。",
"matched_operators_info": [...], // 搜索到的算子信息
"pipeline": {...} // 尽可能完成其他修改后的 pipeline,或原始 pipeline
}

No comments in output.
"""
PipelineRefinePrompts.task_prompt_for_json_pipeline_refiner = """
[TASK]
1. 理解当前 pipeline_json 与 modification_plan。
2. 如 op_context 提供了针对某个 step_id 的 operator 代码/端口/配置提示,请据此填写新节点的 type、config.run(input_key/output_key) 与必要的 init。
3. 严格保持 JSON 结构、DAG 连通性与有向无环属性,禁止输出注释或解释性文字。
2. **重要**:在添加新算子之前,必须先调用 `search_operator_by_description` 工具搜索真实存在的算子。
3. **禁止**使用工具返回结果之外的算子名称。如果需要"情感分析"功能,先搜索"情感分析",然后从返回的算子列表中选择最合适的。
4. **关键**:检查工具返回的 `match_quality` 字段:
- 如果是 "high":可以放心使用该算子
- 如果是 "medium":仔细阅读算子描述,确认功能是否匹配
- 如果是 "low":说明没有找到合适的算子!此时应该在输出中明确说明"未能找到满足「XXX」需求的算子",并给出搜索到的最相似算子及其功能描述,让用户了解当前算子库的能力边界。
5. 如需了解算子的详细参数,可调用 `get_operator_code_by_name` 工具获取算子源代码。
6. 根据工具返回的算子信息,填写新节点的 name、type、config.run(input_key/output_key) 与必要的 init。
7. 严格保持 JSON 结构、DAG 连通性与有向无环属性,禁止输出注释或解释性文字。

[WORKFLOW]
1. 分析 modification_plan 中需要添加的算子
2. 对每个需要添加的算子,调用 search_operator_by_description 工具搜索
3. **检查返回结果的 match_quality 字段**:
- 如果 match_quality 为 "high" 或 "medium"(且描述匹配):从 matched_operators 中选择最合适的算子
- 如果 match_quality 为 "low":记录下来,准备在最终输出中报告此问题
4. 如需要,调用 get_operator_code_by_name 获取算子详细参数
5. 生成最终输出:
- 如果所有需要的算子都找到了:输出完整的 pipeline JSON
- 如果有算子未找到(match_quality 为 low):输出包含 status, message, pipeline 的 JSON,明确说明哪些需求无法满足

[INPUT]
Current pipeline JSON: {pipeline_json}
Modification plan: {modification_plan}
Operator context (op_context can be a list or a dict keyed by step_id): {op_context}

Output the UPDATED pipeline JSON ONLY.
[OUTPUT]
根据搜索结果的 match_quality 决定输出格式:
- 全部找到:直接输出更新后的 pipeline JSON(包含 nodes 和 edges)
- 部分未找到:输出 {{"status": "partial_failure", "message": "...", "pipeline": {{...}}}}
"""


Expand Down Expand Up @@ -723,7 +765,9 @@ class WriteOperator:
1. Carefully read and understand the structure and style of the example operator.
2. Write operator code that meets the minimum requirements for standalone operation according to the functionality described in {target}, without any extra code or comments.
3. Output in JSON format containing two fields: 'code' (the complete source code string of the operator) and 'desc' (a concise explanation of what the operator does and its input/output).
4. If the operator requires using an LLM, the llm_serving field must be included in the __init__ method.
4. If the operator requires using an LLM, do NOT initialize llm_serving in __init__. Instead, accept llm_serving as a parameter: def __init__(self, llm_serving=None) and assign self.llm_serving = llm_serving. The llm_serving will be injected externally.
5. IMPORTANT: Do NOT import 'LLMServing' from dataflow.serving (it does not exist). Only use 'APILLMServing_request' or 'LocalModelLLMServing_vllm'. Correct import: from dataflow.serving import APILLMServing_request
6. APILLMServing_request API usage: Call self.llm_serving.generate_from_input(list_of_strings) which takes a list of input strings and returns a list of output strings. Do NOT use .request() or .call() methods - they do not exist.
"""

# --------------------------------------------------------------------------- #
Expand Down
Binary file removed dataflow_agent/resources/faiss_cache/all_ops.index
Binary file not shown.
Binary file modified dataflow_agent/resources/faiss_cache/all_ops.index.meta
Binary file not shown.
231 changes: 223 additions & 8 deletions dataflow_agent/toolkits/optool/op_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from typing import Any, Dict, List, Tuple

from dataflow.utils.registry import OPERATOR_REGISTRY
from langchain_core.tools import tool
from dataflow_agent.logger import get_logger

log = get_logger(__name__)
Expand Down Expand Up @@ -562,18 +563,24 @@ def _load_or_build_index(self):
def search(
self,
queries: Union[str, List[str]],
top_k: int = 5
) -> Union[List[str], List[List[str]]]:
top_k: int = 5,
return_scores: bool = False
) -> Union[List[str], List[List[str]], List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
检索最相关的算子

Args:
queries: 单个查询字符串或查询列表
top_k: 返回top-k个结果
return_scores: 是否返回相似度分数

Returns:
如果输入是字符串,返回List[str]
如果输入是列表,返回List[List[str]]
如果 return_scores=False:
如果输入是字符串,返回List[str]
如果输入是列表,返回List[List[str]]
如果 return_scores=True:
如果输入是字符串,返回List[Dict],每个Dict包含 name, description, similarity_score
如果输入是列表,返回List[List[Dict]]
"""
# 统一处理为列表
is_single = isinstance(queries, str)
Expand All @@ -593,10 +600,24 @@ def search(

# 组织结果
results = []
for i, indices in enumerate(I):
matched_ops = [self.ops_list[idx]["name"] for idx in indices]
results.append(matched_ops)
log.info(f"Query {i+1}: '{queries[i][:50]}...' -> {matched_ops}")
for i, (indices, scores) in enumerate(zip(I, D)):
if return_scores:
# 返回包含分数的详细信息
matched_ops = []
for idx, score in zip(indices, scores):
op_info = self.ops_list[idx]
matched_ops.append({
"name": op_info["name"],
"description": op_info.get("description", ""),
"similarity_score": float(score) # FAISS cosine similarity score
})
results.append(matched_ops)
log.info(f"Query {i+1}: '{queries[i][:50]}...' -> {[(op['name'], round(op['similarity_score'], 3)) for op in matched_ops]}")
else:
# 原有逻辑,只返回名称列表
matched_ops = [self.ops_list[idx]["name"] for idx in indices]
results.append(matched_ops)
log.info(f"Query {i+1}: '{queries[i][:50]}...' -> {matched_ops}")

# 如果是单查询,返回单个列表
return results[0] if is_single else results
Expand Down Expand Up @@ -685,6 +706,200 @@ def local_tool_for_get_match_operator_code(pre_task_result):
return "\n\n".join(blocks)


# =================================================================== LangChain Tool 封装的 RAG 工具:

# 匹配质量阈值定义
MATCH_QUALITY_THRESHOLDS = {
"high": 0.5, # >= 0.5 为高度匹配
"medium": 0.3, # >= 0.3 为中等匹配
# < 0.3 为低匹配
}


# 默认 FAISS 索引缓存路径
DEFAULT_FAISS_INDEX_PATH = str(utils.get_project_root() / "dataflow_agent/toolkits/resources/faiss_ops.index")


def _get_operators_by_rag_with_scores(
search_query: str,
top_k: int = 4,
ops_json_path: str = None,
faiss_index_path: str = None,
model_name: str = "text-embedding-3-small",
base_url: str = "http://123.129.219.111:3000/v1/embeddings",
api_key: str = None,
) -> List[Dict[str, Any]]:
"""
通过RAG检索算子,返回包含相似度分数的详细结果

Args:
search_query: 搜索查询
top_k: 返回top-k结果
ops_json_path: 算子JSON文件路径
faiss_index_path: FAISS索引文件路径,如果存在则复用,否则生成并保存
model_name: embedding模型
base_url: API地址
api_key: API密钥

Returns:
List[Dict],每个Dict包含 name, description, similarity_score
"""
if ops_json_path is None:
ops_json_path = utils.get_project_root() / "dataflow_agent/toolkits/resources/ops.json"
if faiss_index_path is None:
faiss_index_path = DEFAULT_FAISS_INDEX_PATH
if api_key is None:
api_key = os.getenv("DF_API_KEY")

searcher = RAGOperatorSearch(
ops_json_path=str(ops_json_path),
category=None,
faiss_index_path=faiss_index_path,
model_name=model_name,
base_url=base_url,
api_key=api_key,
)

return searcher.search(search_query, top_k=top_k, return_scores=True)


def _determine_match_quality(max_score: float) -> str:
"""根据最高相似度分数判断匹配质量"""
if max_score >= MATCH_QUALITY_THRESHOLDS["high"]:
return "high"
elif max_score >= MATCH_QUALITY_THRESHOLDS["medium"]:
return "medium"
else:
return "low"


def _generate_match_warning(query: str, max_score: float, match_quality: str) -> Optional[str]:
"""根据匹配质量生成警告信息"""
if match_quality == "high":
return None
elif match_quality == "medium":
return (
f"提示:与'{query}'相关的算子匹配度为中等(最高相似度: {max_score:.3f})。"
f"请仔细阅读算子描述,确认是否满足您的需求。"
)
else: # low
return (
f"警告:未找到与'{query}'高度匹配的算子。最高相似度仅为{max_score:.3f},"
f"低于推荐阈值{MATCH_QUALITY_THRESHOLDS['medium']}。"
f"当前返回的算子可能无法满足您的需求。如果没有合适的算子,"
f"请在回复中说明'未能找到满足{query}需求的算子'。"
)


@tool
def search_operator_by_description(query: str, top_k: int = 4) -> str:
"""
根据功能描述搜索最匹配的数据处理算子。

当你需要在 pipeline 中添加新算子时,必须先调用此工具搜索真实存在的算子。
禁止使用此工具返回结果之外的算子名称。

**重要**:该工具会返回匹配质量评估(match_quality):
- "high": 高度匹配(相似度>=0.5),可以放心使用
- "medium": 中等匹配(相似度0.3-0.5),请仔细确认是否满足需求
- "low": 低匹配(相似度<0.3),可能无法满足需求,请考虑说明"未能找到满足需求的算子"

Args:
query: 算子功能描述,例如 "情感分析"、"数据清洗"、"文本分类"、"去重"、"数据增强" 等
top_k: 返回的候选算子数量,默认为4

Returns:
JSON 格式的搜索结果,包含匹配的算子名称、描述、相似度分数和匹配质量评估

Examples:
>>> search_operator_by_description("情感分析")
>>> search_operator_by_description("数据去重", top_k=3)
"""
try:
# 调用 RAG 检索(返回包含分数的详细结果)
matched_operators = _get_operators_by_rag_with_scores(query, top_k=top_k)

# 计算最高相似度分数
max_score = 0.0
if matched_operators:
max_score = max(op.get("similarity_score", 0.0) for op in matched_operators)

# 判断匹配质量
match_quality = _determine_match_quality(max_score)

# 生成警告信息
warning = _generate_match_warning(query, max_score, match_quality)

# 构建返回结果
result = {
"query": query,
"matched_operators": matched_operators,
"max_similarity_score": round(max_score, 4),
"match_quality": match_quality,
}

# 添加警告信息(如果有)
if warning:
result["warning"] = warning

# 根据匹配质量生成不同的指导说明
if match_quality == "high":
result["instruction"] = (
"请从 matched_operators 中选择最合适的算子名称(name字段)。"
"匹配质量高,可以放心使用。"
)
elif match_quality == "medium":
result["instruction"] = (
"请从 matched_operators 中选择最合适的算子名称(name字段)。"
"注意:匹配质量为中等,请仔细阅读算子描述(description)确认是否满足需求。"
)
else: # low
result["instruction"] = (
"注意:当前匹配质量较低!请仔细评估 matched_operators 中的算子是否能满足需求。"
f"如果没有合适的算子,请在回复中明确说明'未能找到满足「{query}」需求的算子',"
"并给出建议(如:建议用户自定义算子,或使用其他方式实现该功能)。"
)

log.info(
f"[search_operator_by_description] 查询: '{query}' -> "
f"匹配到 {len(matched_operators)} 个算子, "
f"最高相似度: {max_score:.3f}, 匹配质量: {match_quality}"
)
return json.dumps(result, ensure_ascii=False, indent=2)

except Exception as e:
log.error(f"[search_operator_by_description] 搜索失败: {e}")
return json.dumps({
"error": str(e),
"query": query,
"matched_operators": [],
"match_quality": "error"
}, ensure_ascii=False)


@tool
def get_operator_code_by_name(operator_name: str) -> str:
"""
根据算子名称获取算子的源代码。

在选择了要使用的算子后,可以调用此工具获取算子的源代码,
以便了解算子的 init 参数和 run 参数的具体用法。

Args:
operator_name: 算子名称,必须是 search_operator_by_description 返回的算子名称

Returns:
算子的源代码字符串
"""
try:
code = get_operator_source_by_name(operator_name)
log.info(f"[get_operator_code_by_name] 获取算子 '{operator_name}' 的源代码成功")
return code
except Exception as e:
log.error(f"[get_operator_code_by_name] 获取失败: {e}")
return f"# 获取算子 '{operator_name}' 源代码失败: {e}"


if __name__ == "__main__":
# ============ 示例1: 单个查询 + 指定category + 持久化索引 ============
# log.info("\n" + "="*70)
Expand Down
5 changes: 4 additions & 1 deletion dataflow_agent/toolkits/pipetool/pipe_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,8 +825,11 @@ def render_operator_blocks_with_full_params(
run_args = []
if run_has_storage:
run_args.append("storage=self.storage.step()")

for k, v in custom_run_params.items():
# 跳过 storage 参数,因为已经在上面自动添加了
if k == "storage":
continue
if isinstance(v, str):
run_args.append(f"{k}={repr(v)}")
else:
Expand Down
4 changes: 2 additions & 2 deletions dataflow_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def _remove_leading_json_word(src: str) -> str:
def _strip_json_comments(src: str) -> str:
# /* ... */ 块注释
src = re.sub(r'/\*[\s\S]*?\*/', '', src)
# // ... 行注释
src = re.sub(r'//.*', '', src)
# // ... 行注释,排除 URL 中的 :// 和字符串内的 //)
src = re.sub(r'(?<![:\"\'])//.*', '', src)
# 尾逗号 ,}
src = re.sub(r',\s*([}\]])', r'\1', src)
return src.strip()
Expand Down
Loading