SIMD en Rust : Optimiser la Multiplication de Matrices
Table des matières
Les capacités SIMD (Single Instruction, Multiple Data) de Rust permettent le traitement parallèle de plusieurs éléments de données en une seule instruction CPU, idéal pour des tâches computationnellement intensives comme la multiplication de matrices. Je vais expliquer comment exploiter std::arch pour un débit maximum, adresser la portabilité à travers les architectures (ex : x86_64 avec SSE/AVX vs ARM avec NEON), et souligner les défis et solutions pour assurer justesse et performance.
Vectoriser la Multiplication de Matrices avec SIMD
La multiplication de matrices (ex : ( C = A \times B ), où ( A ) est ( m \times n ), ( B ) est ( n \times p ), et ( C ) est ( m \times p )) implique de calculer des produits scalaires de lignes et colonnes. Une implémentation scalaire naïve pour une matrice 4x4 est :
fn matrix_mult_scalar(a: &[[f32; 4]; 4], b: &[[f32; 4]; 4], c: &mut [[f32; 4]; 4]) {
for i in 0..4 {
for j in 0..4 {
c[i][j] = 0.0;
for k in 0..4 {
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
Cela traite un f32 à la fois, ce qui est lent. SIMD peut calculer plusieurs éléments simultanément (ex : 8 f32 avec AVX sur x86_64). Voici comment le vectoriser en utilisant std::arch :
Sélectionner les Instructions SIMD
Sur x86_64 avec AVX (registres 256-bit), utilise :
_mm256_loadu_ps: Charge 8f32dans un registre 256-bit._mm256_mul_ps: Multiplie deux vecteurs 256-bit._mm256_add_ps: Additionne deux vecteurs 256-bit._mm256_storeu_ps: Stocke les résultats en mémoire.
Implémentation Vectorisée
En supposant que ( p ) est un multiple de 8 (padding si nécessaire), vectorise la boucle interne :
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
unsafe fn matrix_mult_simd(a: &[[f32; 8]; 8], b: &[[f32; 8]; 8], c: &mut [[f32; 8]; 8]) {
for i in 0..8 {
for j in (0..8).step_by(8) { // Traite 8 éléments de C[i][j..j+8]
let mut sum = _mm256_setzero_ps(); // Registre 256-bit à zéro
for k in 0..8 {
let a_vec = _mm256_set1_ps(a[i][k]); // Broadcast a[i][k]
let b_ptr = b[k][j..].as_ptr();
let b_vec = _mm256_loadu_ps(b_ptr); // Charge 8 éléments de B
let prod = _mm256_mul_ps(a_vec, b_vec);
sum = _mm256_add_ps(sum, prod); // Accumule
}
_mm256_storeu_ps(c[i][j..].as_mut_ptr(), sum); // Stocke 8 résultats
}
}
}
Cela calcule 8 termes de produit scalaire par itération, réduisant les itérations de boucle par 8x. Enveloppe cela dans des boucles externes, optionnellement avec unrolling ou tiling (ex : blocs 8x8) pour un meilleur usage cache.
Utiliser les Outils SIMD de Rust
std::arch: Fournit des intrinsiques brutes, nécessitantunsafeet un ciblage d'architecture manuel (ex :#[cfg(target_arch = "x86_64")]). Active AVX avec--features avx2dansCargo.toml.- Crates comme
packed_simd: Offre des abstractions portables :
Cela cache les spécificités d'architecture, se rabattant sur du code scalaire si SIMD n'est pas disponible.use packed_simd::f32x8; fn matrix_mult_simd_portable(a: &[[f32; 8]; 8], b: &[[f32; 8]; 8], c: &mut [[f32; 8]; 8]) { for i in 0..8 { for j in (0..8).step_by(8) { let mut sum = f32x8::splat(0.0); for k in 0..8 { let a_vec = f32x8::splat(a[i][k]); let b_vec = f32x8::from_slice_unaligned(&b[k][j..]); let prod = a_vec * b_vec; sum = sum + prod; } sum.write_unaligned(&mut c[i][j..]); } } }
Défis à Travers les Architectures
- Disponibilité des Jeux d'Instructions : AVX est spécifique à x86_64 ; ARM utilise NEON (128-bit, 4x
f32). Le code AVX échoue sur ARM ou des CPUs x86 plus anciens sans AVX.- Solution : Utilise
#[cfg]pour la compilation conditionnelle ou la détection de fonctionnalités à l'exécution avecstd::is_x86_feature_detected!("avx2"). Fallback vers scalaire ou SIMD plus étroit (ex : SSE2).
- Solution : Utilise
- Alignement : AVX préfère la mémoire alignée sur 32 octets. Les chargements non-alignés (
_mm256_loadu_ps) sont plus lents.- Solution : Aligne les données avec
#[repr(align(32))]ou pad les tableaux, échangeant mémoire contre vitesse.
- Solution : Aligne les données avec
- Portabilité : Hardcoder AVX te verrouille sur x86_64.
packed_simdaide, mais les performances varient (ex : NEON 4-wide vs AVX 8-wide).- Solution : Abstrais avec des crates ou écris plusieurs implémentations, sélectionnant à l'exécution.
- Justesse : L'associativité des nombres flottants change avec l'ordre de sommation SIMD, risquant une dérive numérique.
- Solution : Teste contre les résultats scalaires avec des entrées connues ; utilise
fsumou réduction par paires pour la précision.
- Solution : Teste contre les résultats scalaires avec des entrées connues ; utilise
Optimisations Avancées
Implémentation Multi-Architecture
// Support multi-architecture avec détection runtime
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
mod x86_simd {
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn matrix_mult_avx2(
a: &[[f32; 8]; 8],
b: &[[f32; 8]; 8],
c: &mut [[f32; 8]; 8]
) {
// Implémentation AVX2 optimisée
for i in 0..8 {
for j in (0..8).step_by(8) {
let mut sum = _mm256_setzero_ps();
for k in 0..8 {
let a_vec = _mm256_set1_ps(a[i][k]);
let b_vec = _mm256_loadu_ps(b[k][j..].as_ptr());
sum = _mm256_fmadd_ps(a_vec, b_vec, sum); // Fused multiply-add
}
_mm256_storeu_ps(c[i][j..].as_mut_ptr(), sum);
}
}
}
#[target_feature(enable = "sse2")]
pub unsafe fn matrix_mult_sse2(
a: &[[f32; 4]; 4],
b: &[[f32; 4]; 4],
c: &mut [[f32; 4]; 4]
) {
// Implémentation SSE2 pour CPUs plus anciens
for i in 0..4 {
let mut sum = _mm_setzero_ps();
for k in 0..4 {
let a_vec = _mm_set1_ps(a[i][k]);
let b_vec = _mm_loadu_ps(b[k].as_ptr());
let prod = _mm_mul_ps(a_vec, b_vec);
sum = _mm_add_ps(sum, prod);
}
_mm_storeu_ps(c[i].as_mut_ptr(), sum);
}
}
}
#[cfg(target_arch = "aarch64")]
mod arm_simd {
use std::arch::aarch64::*;
pub unsafe fn matrix_mult_neon(
a: &[[f32; 4]; 4],
b: &[[f32; 4]; 4],
c: &mut [[f32; 4]; 4]
) {
// Implémentation NEON pour ARM
for i in 0..4 {
let mut sum = vdupq_n_f32(0.0);
for k in 0..4 {
let a_vec = vdupq_n_f32(a[i][k]);
let b_vec = vld1q_f32(b[k].as_ptr());
sum = vfmaq_f32(sum, a_vec, b_vec); // Fused multiply-add
}
vst1q_f32(c[i].as_mut_ptr(), sum);
}
}
}
// Dispatcher runtime
pub fn matrix_mult_optimized(
a: &[[f32; 8]; 8],
b: &[[f32; 8]; 8],
c: &mut [[f32; 8]; 8]
) {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
unsafe { x86_simd::matrix_mult_avx2(a, b, c) }
} else if is_x86_feature_detected!("sse2") {
// Conversion vers format 4x4 pour SSE2
matrix_mult_sse2_fallback(a, b, c)
} else {
matrix_mult_scalar_fallback(a, b, c)
}
}
#[cfg(target_arch = "aarch64")]
{
// Conversion vers format 4x4 pour NEON
matrix_mult_neon_fallback(a, b, c)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
matrix_mult_scalar_fallback(a, b, c)
}
}
Techniques d'Optimisation Cache
// Blocking/Tiling pour optimiser les accès cache
const BLOCK_SIZE: usize = 64; // Optimisé pour L1 cache
pub fn matrix_mult_blocked<const N: usize>(
a: &[[f32; N]; N],
b: &[[f32; N]; N],
c: &mut [[f32; N]; N]
) {
// Initialiser C à zéro
for row in c.iter_mut() {
row.fill(0.0);
}
// Blocking triple-nested
for ii in (0..N).step_by(BLOCK_SIZE) {
for jj in (0..N).step_by(BLOCK_SIZE) {
for kk in (0..N).step_by(BLOCK_SIZE) {
// Bloc interne SIMD-optimisé
let i_end = (ii + BLOCK_SIZE).min(N);
let j_end = (jj + BLOCK_SIZE).min(N);
let k_end = (kk + BLOCK_SIZE).min(N);
matrix_mult_block_simd(
a, b, c,
ii..i_end,
jj..j_end,
kk..k_end
);
}
}
}
}
fn matrix_mult_block_simd<const N: usize>(
a: &[[f32; N]; N],
b: &[[f32; N]; N],
c: &mut [[f32; N]; N],
i_range: std::ops::Range<usize>,
j_range: std::ops::Range<usize>,
k_range: std::ops::Range<usize>
) {
#[cfg(target_arch = "x86_64")]
unsafe {
if is_x86_feature_detected!("avx2") {
for i in i_range {
for j in j_range.clone().step_by(8) {
if j + 8 <= j_range.end {
let mut sum = _mm256_loadu_ps(c[i][j..].as_ptr());
for k in k_range.clone() {
let a_vec = _mm256_set1_ps(a[i][k]);
let b_vec = _mm256_loadu_ps(b[k][j..].as_ptr());
sum = _mm256_fmadd_ps(a_vec, b_vec, sum);
}
_mm256_storeu_ps(c[i][j..].as_mut_ptr(), sum);
} else {
// Fallback scalaire pour les éléments restants
for jj in j..j_range.end {
for k in k_range.clone() {
c[i][jj] += a[i][k] * b[k][jj];
}
}
}
}
}
return;
}
}
// Fallback scalaire
for i in i_range {
for j in j_range.clone() {
for k in k_range.clone() {
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
Benchmarking Complet
use criterion::{BenchmarkId, Criterion, Throughput, black_box};
fn comprehensive_matrix_bench(c: &mut Criterion) {
let sizes = [64, 128, 256, 512];
let mut group = c.benchmark_group("matrix_multiplication");
for size in sizes {
group.throughput(Throughput::Elements((size * size * size) as u64));
// Matrices alignées pour SIMD
let mut a = vec![vec![1.0f32; size]; size];
let mut b = vec![vec![2.0f32; size]; size];
let mut c = vec![vec![0.0f32; size]; size];
// Initialisation avec données aléatoires
for i in 0..size {
for j in 0..size {
a[i][j] = (i + j) as f32;
b[i][j] = (i * j) as f32;
}
}
group.bench_with_input(
BenchmarkId::new("scalar", size),
&size,
|bench, _| {
bench.iter(|| {
matrix_mult_scalar_generic(
black_box(&a),
black_box(&b),
black_box(&mut c)
)
})
}
);
#[cfg(target_arch = "x86_64")]
group.bench_with_input(
BenchmarkId::new("simd_avx2", size),
&size,
|bench, _| {
bench.iter(|| {
matrix_mult_optimized_generic(
black_box(&a),
black_box(&b),
black_box(&mut c)
)
})
}
);
group.bench_with_input(
BenchmarkId::new("blocked", size),
&size,
|bench, _| {
bench.iter(|| {
matrix_mult_blocked_generic(
black_box(&a),
black_box(&b),
black_box(&mut c)
)
})
}
);
}
group.finish();
}
// Fonctions génériques pour vecteurs de taille dynamique
fn matrix_mult_scalar_generic(
a: &[Vec<f32>],
b: &[Vec<f32>],
c: &mut [Vec<f32>]
) {
let n = a.len();
for i in 0..n {
for j in 0..n {
c[i][j] = 0.0;
for k in 0..n {
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
Vérification
- Benchmarking : Utilise
criterionpour comparer SIMD vs scalaire :
Attends-toi à ce que SIMD soit 4-8x plus rapide pour de grandes matrices.use criterion::{black_box, Criterion}; fn bench(c: &mut Criterion) { let a = [[1.0_f32; 8]; 8]; let b = [[2.0_f32; 8]; 8]; let mut c = [[0.0_f32; 8]; 8]; c.bench_function("simd", |b| b.iter(|| unsafe { matrix_mult_simd(black_box(&a), black_box(&b), black_box(&mut c)) })); c.bench_function("scalar", |b| b.iter(|| matrix_mult_scalar(black_box(&a), black_box(&b), black_box(&mut c)))); } - Profiling : Utilise
perfsur Linux (perf stat -e cycles,instructions) pour confirmer la réduction d'instructions (ex : 8x moins de multiplications). - Inspection d'Assembleur : Lance
cargo rustc --release -- --emit asmou utilisegodbolt.orgpour vérifier des boucles serrées avec instructions SIMD (ex :vmulps,vaddps).
Tests de Justesse
#[cfg(test)]
mod correctness_tests {
use super::*;
#[test]
fn test_simd_vs_scalar() {
let a = [[1.0, 2.0, 3.0, 4.0]; 4];
let b = [[5.0, 6.0, 7.0, 8.0]; 4];
let mut c_scalar = [[0.0; 4]; 4];
let mut c_simd = [[0.0; 4]; 4];
matrix_mult_scalar(&a, &b, &mut c_scalar);
unsafe { matrix_mult_simd_4x4(&a, &b, &mut c_simd); }
// Comparaison avec tolérance pour erreurs d'arrondi
for i in 0..4 {
for j in 0..4 {
assert!(
(c_scalar[i][j] - c_simd[i][j]).abs() < 1e-6,
"Mismatch at [{}, {}]: scalar={}, simd={}",
i, j, c_scalar[i][j], c_simd[i][j]
);
}
}
}
#[test]
fn test_numerical_stability() {
// Test avec valeurs extrêmes
let mut a = [[0.0; 4]; 4];
let mut b = [[0.0; 4]; 4];
// Matrice avec grandes valeurs
for i in 0..4 {
for j in 0..4 {
a[i][j] = 1e6;
b[i][j] = 1e-6;
}
}
let mut c_scalar = [[0.0; 4]; 4];
let mut c_simd = [[0.0; 4]; 4];
matrix_mult_scalar(&a, &b, &mut c_scalar);
unsafe { matrix_mult_simd_4x4(&a, &b, &mut c_simd); }
// Vérifier que les résultats sont raisonnables
for i in 0..4 {
for j in 0..4 {
assert!(c_scalar[i][j].is_finite());
assert!(c_simd[i][j].is_finite());
assert!((c_scalar[i][j] - c_simd[i][j]).abs() / c_scalar[i][j] < 1e-3);
}
}
}
}
Exemple de Résultat Pratique
Pour une matrice 1024x1024, AVX pourrait réduire le runtime de secondes à millisecondes sur un CPU moderne, en supposant une bonne localité de données. Le profiling devrait montrer une réduction d'instructions 8x dans la boucle interne, avec des benchmarks confirmant des speedups significatifs.
Considérations Pratiques
Gestion Mémoire Optimisée
// Allocation alignée pour performance SIMD optimale
use std::alloc::{alloc, dealloc, Layout};
struct AlignedMatrix<const N: usize> {
data: *mut f32,
layout: Layout,
}
impl<const N: usize> AlignedMatrix<N> {
fn new() -> Self {
let layout = Layout::from_size_align(
N * N * std::mem::size_of::<f32>(),
32 // Alignement AVX
).unwrap();
let data = unsafe { alloc(layout) as *mut f32 };
Self { data, layout }
}
fn as_slice(&self) -> &[f32] {
unsafe { std::slice::from_raw_parts(self.data, N * N) }
}
fn as_mut_slice(&mut self) -> &mut [f32] {
unsafe { std::slice::from_raw_parts_mut(self.data, N * N) }
}
}
impl<const N: usize> Drop for AlignedMatrix<N> {
fn drop(&mut self) {
unsafe { dealloc(self.data as *mut u8, self.layout) }
}
}
Conclusion
Pour un débit maximum sur une architecture connue (ex : x86_64 avec AVX), utilise std::arch pour vectoriser la boucle interne de multiplication de matrices, avec tiling pour l'efficacité cache. Pour la portabilité, passe à packed_simd, acceptant un certain overhead. Adresse les défis comme l'alignement et la détection de fonctionnalités avec compilation conditionnelle et vérifications runtime, assurant à la fois vitesse et justesse dans un système de production.