Skip to content

[Bug] Segfault in TVM when building TIR module with pragma_unroll_explicit annotations #18393

@Cookiee235

Description

@Cookiee235

When attempting to build a TIR module containing pragma_unroll_explicit annotations with None values, TVM encounters a segmentation fault during the FlattenBuffer pass execution.
pragma_unroll_explicit=None should represent using the compiler's default unrolling strategy. That is, no forced unrolling, nor prohibited unrolling.

Actual behavior

!!!!!!! Segfault encountered !!!!!!!
  File "/build/glibc-LcI20x/glibc-2.31/signal/../sysdeps/unix/sysv/linux/x86_64/sigaction.c", line 0, in 0x00007f03f92df08f
  File "<unknown>", line 0, in tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
  File "<unknown>", line 0, in tvm::tir::StmtMutator::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::AttrStmtNode const*)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)#2}::_FUN(tvm::ffi::ObjectRef const&, tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>*)
  File "<unknown>", line 0, in tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  File "<unknown>", line 0, in tvm::tir::FlattenBuffer(tvm::tir::PrimFunc)
  File "<unknown>", line 0, in std::_Function_handler<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext), tvm::tir::transform::FlattenBuffer()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&, tvm::tir::PrimFunc&&, tvm::IRModule&&, tvm::transform::PassContext&&)
  File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
  File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
  File "/usr/local/src/conda/python-3.12.3/Include/internal/pycore_ceval.h", line 89, in _PyEval_EvalFrame
  File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 1683, in _PyEval_Vector
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 419, in _PyFunction_Vectorcall
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 133, in _PyObject_FastCallDictTstate
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 508, in _PyObject_Call_Prepend
  File "/usr/local/src/conda/python-3.12.3/Objects/typeobject.c", line 8770, in slot_tp_call
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
  File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
  File "<unknown>", line 0, in tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo) const::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}::operator()(tvm::IRModule, tvm::transform::PassContext) const
  File "<unknown>", line 0, in std::_Function_handler<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext), tvm::transform::__TVMFFIStaticInitFunc4()::{lambda(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo)#1}::operator()(tvm::ffi::TypedFunction<tvm::IRModule (tvm::ffi::RValueRef<tvm::IRModule, void>, tvm::transform::PassContext)>, tvm::transform::PassInfo) const::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>::_M_invoke(std::_Any_data const&, tvm::IRModule&&, tvm::transform::PassContext&&)
  File "<unknown>", line 0, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
  File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
  File "/usr/local/src/conda/python-3.12.3/Include/internal/pycore_ceval.h", line 89, in _PyEval_EvalFrame
  File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 1683, in _PyEval_Vector
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 419, in _PyFunction_Vectorcall
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 133, in _PyObject_FastCallDictTstate
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 508, in _PyObject_Call_Prepend
  File "/usr/local/src/conda/python-3.12.3/Objects/typeobject.c", line 8770, in slot_tp_call
  File "/usr/local/src/conda/python-3.12.3/Objects/call.c", line 240, in _PyObject_MakeTpCall
  File "Python/bytecodes.c", line 2706, in _PyEval_EvalFrameDefault
  File "/usr/local/src/conda/python-3.12.3/Python/ceval.c", line 578, in PyEval_EvalCode
  File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1722, in run_eval_code_obj
  File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1743, in run_mod
  File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 1643, in pyrun_file
  File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 433, in _PyRun_SimpleFileObject
  File "/usr/local/src/conda/python-3.12.3/Python/pythonrun.c", line 78, in _PyRun_AnyFileObject
  File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 360, in pymain_run_file_obj
  File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 379, in pymain_run_file
  File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 629, in pymain_run_python
  File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 709, in Py_RunMain
  File "/usr/local/src/conda/python-3.12.3/Modules/main.c", line 763, in Py_BytesMain
  File "<unknown>", line 0, in 0xffffffffffffffff

Environment

tvm-latest(today)

Steps to reproduce


import tvm

tir_str = """# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(lhs: T.Buffer((4, 5, 6), "int16"), rhs: T.Buffer((1,), "int16"), T_add: T.Buffer((4, 5, 6), "int16")):
        T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "llvm", "mtriple": "x86_64-unknown-linux-gnu", "tag": ""}), "tir.noalias": True})
        # with T.block("root"):
        for ax0 in T.serial(4, annotations={"pragma_unroll_explicit": None}):
            for ax1 in T.serial(5):
                for ax2 in T.serial(6):
                    with T.block("T_add"):
                        v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
                        T.reads(lhs[v_ax0, v_ax1, v_ax2], rhs[0])
                        T.writes(T_add[v_ax0, v_ax1, v_ax2])
                        T_add[v_ax0, v_ax1, v_ax2] = lhs[v_ax0, v_ax1, v_ax2] + rhs[0]
"""


tir_mod = tvm.script.from_source(tir_str)
tir_mod.show()
tvm.build(tir_mod)

Triage

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions