Skip to content

Commit d44897d

Browse files
committed
WebGPU: Transpose Conv kernels in Prepack
PrePack Conv kernels with path-aware transpose decisions, store the transposed kernels for reuse, and add ComputeContextBase helpers for node access and GPU buffer unmapping.
1 parent 4c43c66 commit d44897d

File tree

3 files changed

+183
-37
lines changed

3 files changed

+183
-37
lines changed

onnxruntime/core/providers/webgpu/compute_context.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/providers/webgpu/compute_context.h"
5+
#include "core/framework/tensor.h"
56
#include "core/providers/webgpu/webgpu_execution_provider.h"
67

78
namespace onnxruntime {
@@ -19,6 +20,18 @@ const webgpu::BufferManager& ComputeContextBase::BufferManagerAccessor::Get(cons
1920
return context.ep_.BufferManager();
2021
}
2122

23+
void ComputeContextBase::EnsureGpuBufferUnmapped(Tensor& tensor) const {
24+
void* data = tensor.MutableDataRaw();
25+
if (!data) {
26+
return;
27+
}
28+
29+
auto buffer = reinterpret_cast<WGPUBuffer>(data);
30+
if (buffer != nullptr && wgpuBufferGetMapState(buffer) != WGPUBufferMapState_Unmapped) {
31+
wgpuBufferUnmap(buffer);
32+
}
33+
}
34+
2235
ComputeContext::ComputeContext(WebGpuContext& webgpu_context,
2336
const WebGpuExecutionProvider& ep,
2437
const OpKernel& op_kernel,

onnxruntime/core/providers/webgpu/compute_context.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ class ComputeContextBase {
5656
return op_kernel_.Node().Name();
5757
}
5858

59+
inline const Node& Node() const {
60+
return op_kernel_.Node();
61+
}
62+
63+
// Some read-only WebGPU allocators map buffers for CPU initialization. Before binding a
64+
// tensor to a GPU program ensure the buffer is unmapped.
65+
void EnsureGpuBufferUnmapped(Tensor& tensor) const;
66+
5967
//
6068
// Get the operator type.
6169
//

onnxruntime/core/providers/webgpu/nn/conv.cc

Lines changed: 162 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33
#include "core/providers/webgpu/nn/conv.h"
4+
#include "core/graph/node_arg.h"
45
#include "core/providers/webgpu/nn/conv2d_mm.h"
56
#include "core/providers/webgpu/shader_helper.h"
67
#include "core/providers/webgpu/webgpu_supported_types.h"
@@ -29,10 +30,20 @@ template <bool is_channels_last, bool is_fused>
2930
Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context) const {
3031
bool has_bias = context.InputCount() > 2;
3132
const auto* input = context.Input<Tensor>(0);
32-
const auto* kernel = context.Input<Tensor>(1);
33+
const Tensor* kernel = nullptr;
34+
bool kernel_is_prepacked = false;
35+
if (transposed_kernel_) {
36+
kernel = transposed_kernel_.get();
37+
kernel_is_prepacked = true;
38+
} else {
39+
kernel = context.Input<Tensor>(1);
40+
}
3341
const auto* bias = has_bias ? context.Input<Tensor>(2) : nullptr;
3442
TensorShape input_shape = input->Shape();
35-
TensorShape kernel_shape = kernel->Shape();
43+
ORT_ENFORCE(kernel != nullptr, "Conv kernel tensor is required.");
44+
TensorShape kernel_shape = kernel_is_prepacked
45+
? TensorShape(TensorShapeVector{kernel->Shape()[3], kernel->Shape()[2], kernel->Shape()[0], kernel->Shape()[1]})
46+
: kernel->Shape();
3647
ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end());
3748
TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end());
3849
TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end());
@@ -106,9 +117,13 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
106117
if (conv_attrs_.group > 1) {
107118
Tensor transposed_kernel;
108119
if (is_channels_last) {
109-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
110-
inputs[1] = &transposed_kernel;
111-
modified_input_output_shapes[1] = transposed_kernel.Shape();
120+
const Tensor* grouped_kernel = kernel;
121+
if (!kernel_is_prepacked) {
122+
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
123+
grouped_kernel = &transposed_kernel;
124+
}
125+
inputs[1] = grouped_kernel;
126+
modified_input_output_shapes[1] = grouped_kernel->Shape();
112127
}
113128
auto output_channels_per_group = output_channels / conv_attrs_.group;
114129
auto components = static_cast<int>(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1);
@@ -146,9 +161,12 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
146161
std::vector<TensorShape> matmul_input_reshapes;
147162
if (is_channels_last) {
148163
// Transpose weights
149-
150-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
151-
inputs[1] = &transposed_kernel;
164+
const Tensor* matmul_kernel = kernel;
165+
if (!kernel_is_prepacked) {
166+
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
167+
matmul_kernel = &transposed_kernel;
168+
}
169+
inputs[1] = matmul_kernel;
152170
if (same_size) {
153171
const auto shared_dim = input_height * input_width * input_channels;
154172
input_reshape = TensorShape({1, batch, shared_dim});
@@ -160,7 +178,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
160178
matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels});
161179
}
162180
matmul_inputs.push_back(input);
163-
matmul_inputs.push_back(&transposed_kernel);
181+
matmul_inputs.push_back(matmul_kernel);
164182
matmul_input_reshapes.push_back(input_reshape);
165183
matmul_input_reshapes.push_back(kernel_reshape);
166184
} else {
@@ -203,56 +221,163 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
203221
return ComputeMatMul(&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]);
204222
}
205223
}
206-
// Transpose weights
224+
// Transpose weights when necessary
207225
Tensor transposed_kernel;
208-
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
226+
const Tensor* conv_kernel = kernel;
227+
if (!kernel_is_prepacked) {
228+
ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm));
229+
conv_kernel = &transposed_kernel;
230+
}
209231
auto dim_a_outer = static_cast<uint32_t>(is_channels_last ? output_height * output_width : output_channels);
210232
auto dim_b_outer = static_cast<uint32_t>(is_channels_last ? output_channels : output_height * output_width);
211233
auto dim_inner = static_cast<uint32_t>(kernel_height * kernel_width * input_channels);
212-
inputs[1] = &transposed_kernel;
213-
TensorShape transposed_kernel_shape = transposed_kernel.Shape();
214-
modified_input_output_shapes[1] = transposed_kernel.Shape();
234+
inputs[1] = conv_kernel;
235+
TensorShape transposed_kernel_shape = conv_kernel->Shape();
236+
modified_input_output_shapes[1] = transposed_kernel_shape;
215237
Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
216238
return context.RunProgram(conv2d_mm_program);
217239
}
218240

219241
template <bool is_channels_last, bool is_fused>
220-
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& /* context */,
242+
Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& context,
221243
const Tensor& tensor,
222244
int input_idx,
223-
AllocatorPtr /* alloc */,
245+
AllocatorPtr alloc,
224246
/*out*/ bool& is_packed) {
225247
is_packed = false;
226248

227-
if constexpr (is_channels_last) {
228-
if (input_idx == 1 && tensor.Shape().NumDimensions() == 4) {
229-
// only deal with 4D NHWC weights
249+
// Only prepack kernel weights (input_idx == 1)
250+
if (input_idx != 1) {
251+
return Status::OK();
252+
}
230253

231-
// TODO: implement weight transpose for pre-pack here
232-
// Conv::ComputeInternal() should be updated to reflect the change:
233-
// - if the initializer is packed, `context.Input<Tensor>(1)` will be nullptr.
234-
// - in this case, use `transposed_kernel_` instead.
254+
const auto& kernel_shape = tensor.Shape();
255+
const auto& dims = kernel_shape.GetDims();
235256

236-
// // Step.1 - calculate transposed weight shape
237-
// TensorShape transposed_kernel_shape{tensor.Shape()[2],
238-
// tensor.Shape()[3],
239-
// tensor.Shape()[1],
240-
// tensor.Shape()[0]};
257+
// Conv kernels must be 4D: [O, I, H, W]
258+
if (dims.size() != 4) {
259+
return Status::OK();
260+
}
241261

242-
// // Step.2 - create transposed weight tensor
243-
// transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);
262+
// Get group attribute
263+
int64_t group = conv_attrs_.group;
244264

245-
// // Step.3 - do transpose
246-
// size_t perm[] = {2, 3, 1, 0};
247-
// ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
248-
// perm,
249-
// tensor,
250-
// *transposed_kernel_));
265+
// Get kernel spatial dimensions
266+
const int64_t kernel_height = dims[2];
267+
const int64_t kernel_width = dims[3];
251268

252-
// is_packed = true; // set this flag to true so that ORT will release the initializer tensor
269+
// Get input shape to check same_size condition
270+
const auto& input_defs = context.Node().InputDefs();
271+
const auto* input_arg = input_defs[0];
272+
int64_t input_height = -1;
273+
int64_t input_width = -1;
274+
if (input_arg && input_arg->Exists()) {
275+
const auto* input_shape_proto = input_arg->Shape();
276+
if (input_shape_proto && input_shape_proto->dim_size() >= 3) {
277+
if constexpr (is_channels_last) {
278+
// For channels_last: [N, H, W, C] or [N, W, C] for Conv1D
279+
if (input_shape_proto->dim_size() == 4) {
280+
if (input_shape_proto->dim(1).has_dim_value()) {
281+
input_height = input_shape_proto->dim(1).dim_value();
282+
}
283+
if (input_shape_proto->dim(2).has_dim_value()) {
284+
input_width = input_shape_proto->dim(2).dim_value();
285+
}
286+
}
287+
} else {
288+
// For channels_first: [N, C, H, W] or [N, C, W] for Conv1D
289+
if (input_shape_proto->dim_size() == 4) {
290+
if (input_shape_proto->dim(2).has_dim_value()) {
291+
input_height = input_shape_proto->dim(2).dim_value();
292+
}
293+
if (input_shape_proto->dim(3).has_dim_value()) {
294+
input_width = input_shape_proto->dim(3).dim_value();
295+
}
296+
}
297+
}
253298
}
254299
}
255300

301+
// Get pads and strides
302+
const auto& pads_vec = conv_attrs_.pads;
303+
const auto& strides_vec = conv_attrs_.strides;
304+
305+
std::vector<int64_t> pads(pads_vec.begin(), pads_vec.end());
306+
std::vector<int64_t> strides(strides_vec.begin(), strides_vec.end());
307+
308+
// Default pads and strides if not specified
309+
if (pads.empty()) {
310+
pads.resize(4, 0);
311+
}
312+
if (strides.empty()) {
313+
strides.resize(2, 1);
314+
}
315+
316+
// Analyze execution paths to determine if kernel needs pre-transformation:
317+
318+
// Path 1: Grouped convolution (group > 1)
319+
// - Only transposes when is_channels_last
320+
// - channels_first: no transpose
321+
if (group > 1) {
322+
if constexpr (!is_channels_last) {
323+
// channels_first grouped conv doesn't transpose
324+
return Status::OK();
325+
}
326+
// is_channels_last grouped conv transposes - proceed to transpose below
327+
} else {
328+
// Path 2: MatMul optimization (same_size or 1x1 conv conditions)
329+
// - channels_last: same_size OR 1x1 -> transposes
330+
// - channels_first: 1x1 only (same_size requires is_channels_last) -> does NOT transpose
331+
332+
// Note: same_size in ComputeInternal has `is_channels_last &&` prefix,
333+
// so for channels_first it's always false regardless of dimensions.
334+
const bool same_size = is_channels_last && (input_height > 0 && input_width > 0 &&
335+
input_height == kernel_height && input_width == kernel_width &&
336+
pads[0] == 0 && pads[1] == 0);
337+
338+
const bool is_1x1_conv =
339+
(kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides.size() > 0 &&
340+
strides[0] == 1 && (strides.size() == 1 || strides[1] == 1));
341+
342+
if (same_size || is_1x1_conv) {
343+
if constexpr (!is_channels_last) {
344+
// MatMul optimization for channels_first (1x1 only) does NOT transpose
345+
return Status::OK();
346+
}
347+
// is_channels_last MatMul optimization transposes - proceed to transpose below
348+
}
349+
350+
// Path 3: General convolution (fallback path)
351+
// - ALWAYS transposes regardless of is_channels_last
352+
// - For channels_first with dynamic input shapes, we still need to transpose
353+
// because if we don't hit the 1x1 optimization, we'll hit this general path
354+
// which always transposes. The only risk is if runtime dimensions turn out
355+
// to match 1x1 conditions, but we can't know that at PrePack time.
356+
}
357+
358+
// Perform the transpose using same logic as TransposeKernel
359+
// For 4D: perm = {2, 3, 1, 0} transforms [O, I, H, W] -> [H, W, I, O]
360+
const InlinedVector<size_t> perm = InlinedVector<size_t>{2, 3, 1, 0};
361+
auto rank = kernel_shape.NumDimensions();
362+
363+
TensorShapeVector transposed_kernel_shape_vector(rank);
364+
for (size_t i = 0; i < rank; ++i) {
365+
transposed_kernel_shape_vector[i] = kernel_shape[perm[i]];
366+
}
367+
TensorShape transposed_kernel_shape(transposed_kernel_shape_vector);
368+
369+
ORT_ENFORCE(alloc != nullptr, "Allocator must be provided for WebGPU pre-pack.");
370+
371+
// Create the transposed kernel tensor using the WebGPU allocator.
372+
// Both input tensor and output tensor are GPU tensors, ready for GPU operations.
373+
transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc);
374+
context.EnsureGpuBufferUnmapped(*transposed_kernel_);
375+
376+
// Perform GPU-based transpose directly from the input GPU tensor
377+
ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context, perm, tensor, *transposed_kernel_));
378+
379+
is_packed = true; // set this flag to true so that ORT will release the initializer tensor
380+
256381
return Status::OK();
257382
}
258383

0 commit comments

Comments
 (0)