11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
33#include " core/providers/webgpu/nn/conv.h"
4+ #include " core/graph/node_arg.h"
45#include " core/providers/webgpu/nn/conv2d_mm.h"
56#include " core/providers/webgpu/shader_helper.h"
67#include " core/providers/webgpu/webgpu_supported_types.h"
@@ -29,10 +30,20 @@ template <bool is_channels_last, bool is_fused>
2930Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context) const {
3031 bool has_bias = context.InputCount () > 2 ;
3132 const auto * input = context.Input <Tensor>(0 );
32- const auto * kernel = context.Input <Tensor>(1 );
33+ const Tensor* kernel = nullptr ;
34+ bool kernel_is_prepacked = false ;
35+ if (transposed_kernel_) {
36+ kernel = transposed_kernel_.get ();
37+ kernel_is_prepacked = true ;
38+ } else {
39+ kernel = context.Input <Tensor>(1 );
40+ }
3341 const auto * bias = has_bias ? context.Input <Tensor>(2 ) : nullptr ;
3442 TensorShape input_shape = input->Shape ();
35- TensorShape kernel_shape = kernel->Shape ();
43+ ORT_ENFORCE (kernel != nullptr , " Conv kernel tensor is required." );
44+ TensorShape kernel_shape = kernel_is_prepacked
45+ ? TensorShape (TensorShapeVector{kernel->Shape ()[3 ], kernel->Shape ()[2 ], kernel->Shape ()[0 ], kernel->Shape ()[1 ]})
46+ : kernel->Shape ();
3647 ConvAttributes::ConvPadVector local_pads (conv_attrs_.pads .begin (), conv_attrs_.pads .end ());
3748 TensorShapeVector local_dilations (conv_attrs_.dilations .begin (), conv_attrs_.dilations .end ());
3849 TensorShapeVector local_strides (conv_attrs_.strides .begin (), conv_attrs_.strides .end ());
@@ -106,9 +117,13 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
106117 if (conv_attrs_.group > 1 ) {
107118 Tensor transposed_kernel;
108119 if (is_channels_last) {
109- ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
110- inputs[1 ] = &transposed_kernel;
111- modified_input_output_shapes[1 ] = transposed_kernel.Shape ();
120+ const Tensor* grouped_kernel = kernel;
121+ if (!kernel_is_prepacked) {
122+ ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
123+ grouped_kernel = &transposed_kernel;
124+ }
125+ inputs[1 ] = grouped_kernel;
126+ modified_input_output_shapes[1 ] = grouped_kernel->Shape ();
112127 }
113128 auto output_channels_per_group = output_channels / conv_attrs_.group ;
114129 auto components = static_cast <int >(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents (output_channels) : 1 );
@@ -146,9 +161,12 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
146161 std::vector<TensorShape> matmul_input_reshapes;
147162 if (is_channels_last) {
148163 // Transpose weights
149-
150- ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
151- inputs[1 ] = &transposed_kernel;
164+ const Tensor* matmul_kernel = kernel;
165+ if (!kernel_is_prepacked) {
166+ ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
167+ matmul_kernel = &transposed_kernel;
168+ }
169+ inputs[1 ] = matmul_kernel;
152170 if (same_size) {
153171 const auto shared_dim = input_height * input_width * input_channels;
154172 input_reshape = TensorShape ({1 , batch, shared_dim});
@@ -160,7 +178,7 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
160178 matmul_output_shape = TensorShape ({batch, output_height * output_width, output_channels});
161179 }
162180 matmul_inputs.push_back (input);
163- matmul_inputs.push_back (&transposed_kernel );
181+ matmul_inputs.push_back (matmul_kernel );
164182 matmul_input_reshapes.push_back (input_reshape);
165183 matmul_input_reshapes.push_back (kernel_reshape);
166184 } else {
@@ -203,56 +221,163 @@ Status Conv<is_channels_last, is_fused>::ComputeInternal(ComputeContext& context
203221 return ComputeMatMul (&context, activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0 ], matmul_input_reshapes[1 ]);
204222 }
205223 }
206- // Transpose weights
224+ // Transpose weights when necessary
207225 Tensor transposed_kernel;
208- ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
226+ const Tensor* conv_kernel = kernel;
227+ if (!kernel_is_prepacked) {
228+ ORT_RETURN_IF_ERROR (TransposeKernel (context, kernel, kernel_shape, &transposed_kernel, perm));
229+ conv_kernel = &transposed_kernel;
230+ }
209231 auto dim_a_outer = static_cast <uint32_t >(is_channels_last ? output_height * output_width : output_channels);
210232 auto dim_b_outer = static_cast <uint32_t >(is_channels_last ? output_channels : output_height * output_width);
211233 auto dim_inner = static_cast <uint32_t >(kernel_height * kernel_width * input_channels);
212- inputs[1 ] = &transposed_kernel ;
213- TensorShape transposed_kernel_shape = transposed_kernel. Shape ();
214- modified_input_output_shapes[1 ] = transposed_kernel. Shape () ;
234+ inputs[1 ] = conv_kernel ;
235+ TensorShape transposed_kernel_shape = conv_kernel-> Shape ();
236+ modified_input_output_shapes[1 ] = transposed_kernel_shape ;
215237 Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram (activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes);
216238 return context.RunProgram (conv2d_mm_program);
217239}
218240
219241template <bool is_channels_last, bool is_fused>
220- Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& /* context */ ,
242+ Status Conv<is_channels_last, is_fused>::PrePackInternal(ComputeContextBase& context,
221243 const Tensor& tensor,
222244 int input_idx,
223- AllocatorPtr /* alloc */ ,
245+ AllocatorPtr alloc,
224246 /* out*/ bool & is_packed) {
225247 is_packed = false ;
226248
227- if constexpr (is_channels_last) {
228- if (input_idx == 1 && tensor.Shape ().NumDimensions () == 4 ) {
229- // only deal with 4D NHWC weights
249+ // Only prepack kernel weights (input_idx == 1)
250+ if (input_idx != 1 ) {
251+ return Status::OK ();
252+ }
230253
231- // TODO: implement weight transpose for pre-pack here
232- // Conv::ComputeInternal() should be updated to reflect the change:
233- // - if the initializer is packed, `context.Input<Tensor>(1)` will be nullptr.
234- // - in this case, use `transposed_kernel_` instead.
254+ const auto & kernel_shape = tensor.Shape ();
255+ const auto & dims = kernel_shape.GetDims ();
235256
236- // // Step.1 - calculate transposed weight shape
237- // TensorShape transposed_kernel_shape{tensor.Shape()[2],
238- // tensor.Shape()[3],
239- // tensor.Shape()[1],
240- // tensor.Shape()[0]};
257+ // Conv kernels must be 4D: [O, I, H, W]
258+ if (dims.size () != 4 ) {
259+ return Status::OK ();
260+ }
241261
242- // // Step.2 - create transposed weight tensor
243- // transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType(), transposed_kernel_shape, alloc) ;
262+ // Get group attribute
263+ int64_t group = conv_attrs_. group ;
244264
245- // // Step.3 - do transpose
246- // size_t perm[] = {2, 3, 1, 0};
247- // ORT_RETURN_IF_ERROR(Transpose::DoTranspose(context,
248- // perm,
249- // tensor,
250- // *transposed_kernel_));
265+ // Get kernel spatial dimensions
266+ const int64_t kernel_height = dims[2 ];
267+ const int64_t kernel_width = dims[3 ];
251268
252- // is_packed = true; // set this flag to true so that ORT will release the initializer tensor
269+ // Get input shape to check same_size condition
270+ const auto & input_defs = context.Node ().InputDefs ();
271+ const auto * input_arg = input_defs[0 ];
272+ int64_t input_height = -1 ;
273+ int64_t input_width = -1 ;
274+ if (input_arg && input_arg->Exists ()) {
275+ const auto * input_shape_proto = input_arg->Shape ();
276+ if (input_shape_proto && input_shape_proto->dim_size () >= 3 ) {
277+ if constexpr (is_channels_last) {
278+ // For channels_last: [N, H, W, C] or [N, W, C] for Conv1D
279+ if (input_shape_proto->dim_size () == 4 ) {
280+ if (input_shape_proto->dim (1 ).has_dim_value ()) {
281+ input_height = input_shape_proto->dim (1 ).dim_value ();
282+ }
283+ if (input_shape_proto->dim (2 ).has_dim_value ()) {
284+ input_width = input_shape_proto->dim (2 ).dim_value ();
285+ }
286+ }
287+ } else {
288+ // For channels_first: [N, C, H, W] or [N, C, W] for Conv1D
289+ if (input_shape_proto->dim_size () == 4 ) {
290+ if (input_shape_proto->dim (2 ).has_dim_value ()) {
291+ input_height = input_shape_proto->dim (2 ).dim_value ();
292+ }
293+ if (input_shape_proto->dim (3 ).has_dim_value ()) {
294+ input_width = input_shape_proto->dim (3 ).dim_value ();
295+ }
296+ }
297+ }
253298 }
254299 }
255300
301+ // Get pads and strides
302+ const auto & pads_vec = conv_attrs_.pads ;
303+ const auto & strides_vec = conv_attrs_.strides ;
304+
305+ std::vector<int64_t > pads (pads_vec.begin (), pads_vec.end ());
306+ std::vector<int64_t > strides (strides_vec.begin (), strides_vec.end ());
307+
308+ // Default pads and strides if not specified
309+ if (pads.empty ()) {
310+ pads.resize (4 , 0 );
311+ }
312+ if (strides.empty ()) {
313+ strides.resize (2 , 1 );
314+ }
315+
316+ // Analyze execution paths to determine if kernel needs pre-transformation:
317+
318+ // Path 1: Grouped convolution (group > 1)
319+ // - Only transposes when is_channels_last
320+ // - channels_first: no transpose
321+ if (group > 1 ) {
322+ if constexpr (!is_channels_last) {
323+ // channels_first grouped conv doesn't transpose
324+ return Status::OK ();
325+ }
326+ // is_channels_last grouped conv transposes - proceed to transpose below
327+ } else {
328+ // Path 2: MatMul optimization (same_size or 1x1 conv conditions)
329+ // - channels_last: same_size OR 1x1 -> transposes
330+ // - channels_first: 1x1 only (same_size requires is_channels_last) -> does NOT transpose
331+
332+ // Note: same_size in ComputeInternal has `is_channels_last &&` prefix,
333+ // so for channels_first it's always false regardless of dimensions.
334+ const bool same_size = is_channels_last && (input_height > 0 && input_width > 0 &&
335+ input_height == kernel_height && input_width == kernel_width &&
336+ pads[0 ] == 0 && pads[1 ] == 0 );
337+
338+ const bool is_1x1_conv =
339+ (kernel_height == 1 && kernel_width == 1 && pads[0 ] == 0 && pads[1 ] == 0 && strides.size () > 0 &&
340+ strides[0 ] == 1 && (strides.size () == 1 || strides[1 ] == 1 ));
341+
342+ if (same_size || is_1x1_conv) {
343+ if constexpr (!is_channels_last) {
344+ // MatMul optimization for channels_first (1x1 only) does NOT transpose
345+ return Status::OK ();
346+ }
347+ // is_channels_last MatMul optimization transposes - proceed to transpose below
348+ }
349+
350+ // Path 3: General convolution (fallback path)
351+ // - ALWAYS transposes regardless of is_channels_last
352+ // - For channels_first with dynamic input shapes, we still need to transpose
353+ // because if we don't hit the 1x1 optimization, we'll hit this general path
354+ // which always transposes. The only risk is if runtime dimensions turn out
355+ // to match 1x1 conditions, but we can't know that at PrePack time.
356+ }
357+
358+ // Perform the transpose using same logic as TransposeKernel
359+ // For 4D: perm = {2, 3, 1, 0} transforms [O, I, H, W] -> [H, W, I, O]
360+ const InlinedVector<size_t > perm = InlinedVector<size_t >{2 , 3 , 1 , 0 };
361+ auto rank = kernel_shape.NumDimensions ();
362+
363+ TensorShapeVector transposed_kernel_shape_vector (rank);
364+ for (size_t i = 0 ; i < rank; ++i) {
365+ transposed_kernel_shape_vector[i] = kernel_shape[perm[i]];
366+ }
367+ TensorShape transposed_kernel_shape (transposed_kernel_shape_vector);
368+
369+ ORT_ENFORCE (alloc != nullptr , " Allocator must be provided for WebGPU pre-pack." );
370+
371+ // Create the transposed kernel tensor using the WebGPU allocator.
372+ // Both input tensor and output tensor are GPU tensors, ready for GPU operations.
373+ transposed_kernel_ = std::make_unique<Tensor>(tensor.DataType (), transposed_kernel_shape, alloc);
374+ context.EnsureGpuBufferUnmapped (*transposed_kernel_);
375+
376+ // Perform GPU-based transpose directly from the input GPU tensor
377+ ORT_RETURN_IF_ERROR (Transpose::DoTranspose (context, perm, tensor, *transposed_kernel_));
378+
379+ is_packed = true ; // set this flag to true so that ORT will release the initializer tensor
380+
256381 return Status::OK ();
257382}
258383
0 commit comments