Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/compute_context.h"
#include "core/framework/tensor.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"

namespace onnxruntime {
Expand All @@ -19,6 +20,22 @@
return context.ep_.BufferManager();
}

Status ComputeContextBase::CreateUnmappedGPUTensor(AllocatorPtr alloc, MLDataType data_type, const TensorShape& shape, std::unique_ptr<Tensor>& tensor) const {
ORT_RETURN_IF_NOT(alloc != nullptr, "Allocator must not be null when creating GPU tensor.");

tensor = std::make_unique<Tensor>(data_type, shape, alloc);

Check warning on line 26 in onnxruntime/core/providers/webgpu/compute_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/compute_context.cc:26: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
ORT_RETURN_IF_NOT(tensor != nullptr, "Failed to allocate GPU tensor.");

void* data = tensor->MutableDataRaw();
ORT_RETURN_IF_NOT(data != nullptr, "Failed to get GPU tensor buffer.");

auto buffer = reinterpret_cast<WGPUBuffer>(data);
if (wgpuBufferGetMapState(buffer) != WGPUBufferMapState_Unmapped) {
wgpuBufferUnmap(buffer);
}
return Status::OK();
}

ComputeContext::ComputeContext(WebGpuContext& webgpu_context,
const WebGpuExecutionProvider& ep,
const OpKernel& op_kernel,
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/providers/webgpu/webgpu_external_header.h"

#include <memory>

Check warning on line 8 in onnxruntime/core/providers/webgpu/compute_context.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: compute_context.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/webgpu/compute_context.h:8: Found C++ system header after other header. Should be: compute_context.h, c system, c++ system, other. [build/include_order] [4]
#include <utility>

#include "core/framework/data_transfer_manager.h"
Expand Down Expand Up @@ -56,6 +57,12 @@
return op_kernel_.Node().Name();
}

inline const onnxruntime::Node& GetNode() const {
return op_kernel_.Node();
}

Status CreateUnmappedGPUTensor(AllocatorPtr alloc, MLDataType data_type, const TensorShape& shape, std::unique_ptr<Tensor>& tensor) const;

//
// Get the operator type.
//
Expand Down
198 changes: 161 additions & 37 deletions onnxruntime/core/providers/webgpu/nn/conv.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/webgpu/nn/conv.h"
#include "core/graph/node_arg.h"
#include "core/providers/webgpu/nn/conv2d_mm.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
Expand Down Expand Up @@ -29,10 +30,20 @@
Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context) const {
bool has_bias = context.InputCount() > 2;
const auto* input = context.Input<Tensor>(0);
const auto* kernel = context.Input<Tensor>(1);
const Tensor* kernel = nullptr;
bool kernel_is_prepacked = false;
if (transposed_kernel_) {
kernel = transposed_kernel_.get();
kernel_is_prepacked = true;
} else {
kernel = context.Input<Tensor>(1);
}
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
TensorShape input_shape = input->Shape();
TensorShape kernel_shape = kernel->Shape();
ORT_ENFORCE(kernel != nullptr, "Conv kernel tensor is required.");
TensorShape kernel_shape = kernel_is_prepacked
? TensorShape(TensorShapeVector{kernel->Shape()[3], kernel->Shape()[2], kernel->Shape()[0], kernel->Shape()[1]})
: kernel->Shape();
ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end());
TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end());
TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end());
Expand Down Expand Up @@ -106,9 +117,13 @@
if (conv_attrs_.group > 1) {
Tensor transposed_kernel;
if (is_channels_last) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
modified_input_output_shapes[1] = transposed_kernel.Shape();
const Tensor* grouped_kernel = kernel;
if (!kernel_is_prepacked) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
grouped_kernel = &transposed_kernel;
}
inputs[1] = grouped_kernel;
modified_input_output_shapes[1] = grouped_kernel->Shape();
}
auto output_channels_per_group = output_channels / conv_attrs_.group;
auto components = static_cast<int>(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
Expand Down Expand Up @@ -146,9 +161,12 @@
std::vector<TensorShape> matmul_input_reshapes;
if (is_channels_last) {
// Transpose weights

ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
inputs[1] = &transposed_kernel;
const Tensor* matmul_kernel = kernel;
if (!kernel_is_prepacked) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
matmul_kernel = &transposed_kernel;
}
inputs[1] = matmul_kernel;
if (same_size) {
const auto shared_dim = input_height * input_width * input_channels;
input_reshape = TensorShape({1, batch, shared_dim});
Expand All @@ -160,7 +178,7 @@
matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
}
matmul_inputs.push_back(input);
matmul_inputs.push_back(&transposed_kernel);
matmul_inputs.push_back(matmul_kernel);
matmul_input_reshapes.push_back(input_reshape);
matmul_input_reshapes.push_back(kernel_reshape);
} else {
Expand Down Expand Up @@ -203,56 +221,162 @@
return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]);
}
}
// Transpose weights
// Transpose weights when necessary
Tensor transposed_kernel;
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
const Tensor* conv_kernel = kernel;
if (!kernel_is_prepacked) {
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
conv_kernel = &transposed_kernel;
}
auto dim_a_outer = static_cast<uint32_t>(is_channels_last ? output_height * output_width : output_channels);
auto dim_b_outer = static_cast<uint32_t>(is_channels_last ? output_channels : output_height * output_width);
auto dim_inner = static_cast<uint32_t>(kernel_height * kernel_width * input_channels);
inputs[1] = &transposed_kernel;
TensorShape transposed_kernel_shape = transposed_kernel.Shape();
modified_input_output_shapes[1] = transposed_kernel.Shape();
inputs[1] = conv_kernel;
TensorShape transposed_kernel_shape = conv_kernel->Shape();
modified_input_output_shapes[1] = transposed_kernel_shape;
Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
return context.RunProgram(conv2d_mm_program);
}

template <bool is_channels_last, bool is_fused>
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& /* context */,
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& context,
const Tensor& tensor,
int input_idx,
AllocatorPtr /* alloc */,
AllocatorPtr alloc,
/*out*/ bool& is_packed) {
is_packed = false;

if constexpr (is_channels_last) {
if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) {
// only deal with 4D NHWC weights
// Only prepack kernel weights (input_idx == 1)
if (input_idx != 1) {
return Status::OK();
}

// TODO: implement weight transpose for pre-pack here
// Conv::ComputeInternal() should be updated to reflect the change:
// - if the initializer is packed, `context.Input<Tensor>(1)` will be nullptr.
// - in this case, use `transposed_kernel_` instead.
const auto& kernel_shape = tensor.Shape();
const auto& dims = kernel_shape.GetDims();

// // Step.1 - calculate transposed weight shape
// TensorShape transposed_kernel_shape{tensor.Shape()[2],
// tensor.Shape()[3],
// tensor.Shape()[1],
// tensor.Shape()[0]};
// Conv kernels must be 4D: [O, I, H, W]
if (dims.size() != 4) {
return Status::OK();
}

// // Step.2 - create transposed weight tensor
// transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);
// Get group attribute
int64_t group = conv_attrs_.group;

// // Step.3 - do transpose
// size_t perm[] = {2, 3, 1, 0};
// ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
// perm,
// tensor,
// *transposed_kernel_));
// Get kernel spatial dimensions
const int64_t kernel_height = dims[2];
const int64_t kernel_width = dims[3];

// is_packed = true; // set this flag to true so that ORT will release the initializer tensor
// Get input shape to check same_size condition
const auto& input_defs = context.GetNode().InputDefs();
const auto* input_arg = input_defs[0];
int64_t input_height = -1;
int64_t input_width = -1;
if (input_arg && input_arg->Exists()) {
const auto* input_shape_proto = input_arg->Shape();
if (input_shape_proto && input_shape_proto->dim_size() >= 3) {
if constexpr (is_channels_last) {
// For channels_last: [N, H, W, C] or [N, W, C] for Conv1D
if (input_shape_proto->dim_size() == 4) {
if (input_shape_proto->dim(1).has_dim_value()) {
input_height = input_shape_proto->dim(1).dim_value();
}
if (input_shape_proto->dim(2).has_dim_value()) {
input_width = input_shape_proto->dim(2).dim_value();
}
}
} else {
// For channels_first: [N, C, H, W] or [N, C, W] for Conv1D
if (input_shape_proto->dim_size() == 4) {
if (input_shape_proto->dim(2).has_dim_value()) {
input_height = input_shape_proto->dim(2).dim_value();
}
if (input_shape_proto->dim(3).has_dim_value()) {
input_width = input_shape_proto->dim(3).dim_value();
}
}
}
}
}

// Get pads and strides
const auto& pads_vec = conv_attrs_.pads;
const auto& strides_vec = conv_attrs_.strides;

std::vector<int64_t> pads(pads_vec.begin(), pads_vec.end());
std::vector<int64_t> strides(strides_vec.begin(), strides_vec.end());

Check warning on line 306 in onnxruntime/core/providers/webgpu/nn/conv.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/nn/conv.cc:306: Add #include <vector> for vector<> [build/include_what_you_use] [4]

// Default pads and strides if not specified
if (pads.empty()) {
pads.resize(4, 0);
}
if (strides.empty()) {
strides.resize(2, 1);
}

// Analyze execution paths to determine if kernel needs pre-transformation:

// Path 1: Grouped convolution (group > 1)
// - Only transposes when is_channels_last
// - channels_first: no transpose
if (group > 1) {
if constexpr (!is_channels_last) {
// channels_first grouped conv doesn't transpose
return Status::OK();
}
// is_channels_last grouped conv transposes - proceed to transpose below
} else {
// Path 2: MatMul optimization (same_size or 1x1 conv conditions)
// - channels_last: same_size OR 1x1 -> transposes
// - channels_first: 1x1 only (same_size requires is_channels_last) -> does NOT transpose

// Note: same_size in ComputeInternal has `is_channels_last &&` prefix,
// so for channels_first it's always false regardless of dimensions.
const bool same_size = is_channels_last && (input_height > 0 && input_width > 0 &&
input_height == kernel_height && input_width == kernel_width &&
pads[0] == 0 && pads[1] == 0);

const bool is_1x1_conv =
(kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides.size() > 0 &&
strides[0] == 1 && (strides.size() == 1 || strides[1] == 1));

if (same_size || is_1x1_conv) {
if constexpr (!is_channels_last) {
// MatMul optimization for channels_first (1x1 only) does NOT transpose
return Status::OK();
}
// is_channels_last MatMul optimization transposes - proceed to transpose below
}

// Path 3: General convolution (fallback path)
// - ALWAYS transposes regardless of is_channels_last
// - For channels_first with dynamic input shapes, we still need to transpose
// because if we don't hit the 1x1 optimization, we'll hit this general path
// which always transposes. The only risk is if runtime dimensions turn out
// to match 1x1 conditions, but we can't know that at PrePack time.
}

// Perform the transpose using same logic as TransposeKernel
// For 4D: perm = {2, 3, 1, 0} transforms [O, I, H, W] -> [H, W, I, O]
const InlinedVector<size_t> perm = InlinedVector<size_t>{2, 3, 1, 0};
auto rank = kernel_shape.NumDimensions();

TensorShapeVector transposed_kernel_shape_vector(rank);
for (size_t i = 0; i < rank; ++i) {
transposed_kernel_shape_vector[i] = kernel_shape[perm[i]];
}
TensorShape transposed_kernel_shape(transposed_kernel_shape_vector);

ORT_ENFORCE(alloc != nullptr, "Allocator must be provided for WebGPU pre-pack.");

// Create the transposed kernel tensor using the WebGPU allocator.
// Both input tensor and output tensor are GPU tensors, ready for GPU operations.
ORT_RETURN_IF_ERROR(context.CreateUnmappedGPUTensor(alloc, tensor.DataType(), transposed_kernel_shape, transposed_kernel_));

// Perform GPU-based transpose directly from the input GPU tensor
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, tensor, *transposed_kernel_));

is_packed = true; // set this flag to true so that ORT will release the initializer tensor

return Status::OK();
}

Expand Down
Loading