Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,9 @@
--*/
{
// Override
if(GetMlasPlatform().MlasConvOverride != nullptr &&
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
if(SMEInfo::IsSMEAvailable &&

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'SMEInfo': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'IsSMEAvailable': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'SMEInfo': is not a class or namespace name
GetMlasPlatform().MlasConvOverride != nullptr &&
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)) {
return;
}

Expand Down Expand Up @@ -1201,7 +1202,8 @@
--*/
{
// Override
if (GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
if (SMEInfo::IsSMEAvailable &&

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'SMEInfo': is not a class or namespace name

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'IsSMEAvailable': undeclared identifier

Check failure on line 1205 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'SMEInfo': is not a class or namespace name
GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
Activation, WorkingBufferSize, Beta, ThreadPool)){
Expand Down
24 changes: 12 additions & 12 deletions onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci
const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data,
const float* in_data,
const float* pad_ptr) {
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// Minimize the kernel call count for the number of available threads
auto RequiredTiles = MlasDivRoundup(m, m_step);
Expand Down Expand Up @@ -383,7 +383,7 @@ static std::shared_ptr<const void*[]> LhsPtrFill(const size_t ci, const size_t i

const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw);

const auto m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
const auto m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

const auto lhs_ptrs_k = kh * kw;
Expand Down Expand Up @@ -518,10 +518,10 @@ static void ConvolveSme(const size_t co, //channels out
const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) *
ComputeConvOutSize(iw, d_kw, padding, sw);

size_t n_step = ArmKleidiAI::UseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t m_step = ArmKleidiAI::UseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t n_step = SMEInfo::CanUseSME2 ? kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa()
: kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// tile iteration dimensions
std::array<size_t,3> dim;
Expand Down Expand Up @@ -566,16 +566,16 @@ static void ConvolveSme(const size_t co, //channels out
ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2];

// Get rhs tile, B
const size_t rhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);
const size_t rhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx * n_step, d_kh * d_kw, ci)
: kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, d_kh * d_kw, ci);

auto BTile = reinterpret_cast<const void*>(
reinterpret_cast<const std::byte*>(rhs.get()) + rhs_packed_offset
);

// Get lhs tile, A
const size_t lhs_packed_offset = ArmKleidiAI::UseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);
const size_t lhs_packed_offset = SMEInfo::CanUseSME2 ? kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx * m_step, d_kh * d_kw, ci)
: kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, d_kh * d_kw, ci);

auto ATile = reinterpret_cast<const float*>(
reinterpret_cast<const std::byte*>(lhs.get()) + lhs_packed_offset
Expand All @@ -589,7 +589,7 @@ static void ConvolveSme(const size_t co, //channels out
MIdx * m_step * co * sizeof(float) +
NIdx * n_step * sizeof(float)];

if (ArmKleidiAI::UseSME2) {
if (SMEInfo::CanUseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci);
kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(
TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float),
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@

namespace ArmKleidiAI {

// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();

// Buffer packing routines.
//
size_t
Expand Down
26 changes: 13 additions & 13 deletions onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ Return Value:
}

if (TransA == CblasNoTrans) {
const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t nr = SMEInfo::CanUseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t kr = SMEInfo::CanUseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t sr = SMEInfo::CanUseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

// Ensure size and zero the used span.
Expand Down Expand Up @@ -226,16 +226,16 @@ Return Value:
return true;
}

const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t mr = SMEInfo::CanUseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t kr = SMEInfo::CanUseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
const size_t sr = SMEInfo::CanUseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
size_t m_step = SMEInfo::CanUseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();
size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
size_t n_step = SMEInfo::CanUseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa()
: kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa();

if ((M < m_step || N < n_step) && !Data->BIsPacked) {
Expand Down Expand Up @@ -336,8 +336,8 @@ Return Value:

// Get rhs tile, B
const size_t rhs_packed_offset =
UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K);
SMEInfo::CanUseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K)
: kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K);

const std::byte* B_base = Data[0].BIsPacked
? reinterpret_cast<const std::byte*>(Data[BIdx].B)
Expand All @@ -346,8 +346,8 @@ Return Value:

// Get lhs tile, A
const size_t lhs_packed_offset =
UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K);
SMEInfo::CanUseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K)
: kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K);

const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx;
auto ATile = reinterpret_cast<const float*>(A_base + lhs_packed_offset);
Expand All @@ -370,7 +370,7 @@ Return Value:
float* temp_tile = g_kai_tls.output_tile.data();
std::fill_n(temp_tile, tile_elems, 0.0f);

if (UseSME2) {
if (SMEInfo::CanUseSME2) {
KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa"
<< " M=" << TileSizeM << " << N=" << TileSizeN << " K=" << K);
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class MLASCPUIDInfo

bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; }

bool HasArm_SME() const { return has_arm_sme_; }

bool HasArm_SME2() const { return has_arm_sme2_; }

private:
MLASCPUIDInfo();

Expand All @@ -210,6 +214,8 @@ class MLASCPUIDInfo
bool has_arm_sve_{false};
bool has_arm_sve_i8mm_{false};
bool has_arm_neon_bf16_{false};
bool has_arm_sme_{false};
bool has_arm_sme2_{false};
};
using MLAS_CPUIDINFO = MLASCPUIDInfo;

Expand Down Expand Up @@ -311,6 +317,24 @@ operator!=(const MLFloat16& left, const MLFloat16& right)

#endif // BUILD_MLAS_NO_ONNXRUNTIME

#if defined(MLAS_TARGET_ARM64)

struct SMEInfo {
static const bool CanUseSME2;
static const bool CanUseSME;
static const bool IsSMEAvailable;
};

// Boolean condition to determine if we can use SME2
// By default we should try for SME2 first before falling back to SME.
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPUIDInfo::HasArm_SME2() already returns a cached bool. is it worth having another constant for the same value? admittedly, SMEInfo::CanUseSME2 is less to type, but it may be simpler to just have a single way of getting this information.

// Boolean condition to determine if we can use SME
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
// Boolean condition to tell us if SME is enabled on this system
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;

#endif // MLAS_TARGET_ARM64

static_assert(sizeof(MLAS_FP16) == FP16_SIZE);


Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ Return Value:
}

#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){
if (SMEInfo::IsSMEAvailable) {
this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch;
this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize;
this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB;
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ MlasDynamicQGemmBatch (
) {
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback and putting in guards. This implementation is SME2 specific.
if(ArmKleidiAI::UseSME2){
if(SMEInfo::CanUseSME2){
ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool);
}
#endif
Expand Down Expand Up @@ -336,7 +336,7 @@ MlasDynamicQgemmPackBSize(
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback available
//TODO: Insert Override
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()){//Still require this since no override
bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K);
}
#endif
Expand Down Expand Up @@ -407,7 +407,7 @@ Return Value:
~(BufferAlignment - 1);
// If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for
// this use case. Concat both packed representations for later decision. This allows for cases later
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
// where we still have the prepack at the cost of some memory otherwise we can use the qgemm quantization
// for better performance
return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K);
}
Expand All @@ -425,7 +425,7 @@ MlasDynamicQgemmPackB(
{
#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
//No fallback
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override
if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()){//Still require this since no override
ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB);
}
#endif
Expand Down
Loading