November 18, 202511 min

Maîtriser l'Inline Assembly en Rust : Quand et Comment Optimiser en Sécurité

m
mayo

L'inline assembly en Rust, via la macro asm! ou les intrinsèques core::arch, est un outil puissant mais rare pour optimiser du code critique en performance quand le compilateur ou les bibliothèques standard tombent à court. Je vais exposer quand l'utiliser, fournir un exemple d'implémentation, et détailler les stratégies pour assurer sécurité et portabilité à travers les architectures.

Scénarios pour l'Inline Assembly

L'inline assembly est justifié dans ces cas :

  • Instructions CPU Uniques : Quand une tâche nécessite des instructions que Rust ne peut pas générer (ex : popcnt de x86 pour compter les bits, si on n'utilise pas count_ones()).
  • Optimisation Extrême : Quand l'usage de registres ajusté à la main ou le cycle shaving dans une boucle chaude surpasse les optimisations de LLVM.
  • Intégration Legacy : Quand on interface avec des routines matérielles assembly-only (ex : gestionnaires d'interruption custom).

Scénario d'Exemple : Boucle de Comptage de Bits

Considère optimiser une fonction de cryptographie qui compte les bits mis dans un tableau d'entiers 64-bit pour la distance de Hamming dans un système temps réel. Le u64::count_ones() de Rust utilise popcnt sur x86_64 si disponible, mais j'ai besoin d'une boucle custom avec unrolling manuel et pipelining pour un CPU spécifique (ex : Skylake avec AVX2 désactivé), où le profiling montre un goulot d'étranglement.

Implémentation avec asm!

Voici une boucle de comptage de bits pour x86_64 :

#[cfg(target_arch = "x86_64")]
unsafe fn count_bits(data: &[u64]) -> u64 {
    let mut total: u64 = 0;
    for chunk in data.chunks(4) { // Traite 4 éléments à la fois
        let mut sum: u64;
        asm!(
            "xor {sum}, {sum}         \n\t", // Met sum à zéro
            "popcnt {tmp}, {x0}       \n\t", // Compte les bits du premier élément
            "add {sum}, {tmp}         \n\t",
            "popcnt {tmp}, {x1}       \n\t", // Deuxième élément
            "add {sum}, {tmp}         \n\t",
            "popcnt {tmp}, {x2}       \n\t", // Troisième
            "add {sum}, {tmp}         \n\t",
            "popcnt {tmp}, {x3}       \n\t", // Quatrième
            "add {sum}, {tmp}         \n\t",
            sum = out(reg) sum,          // Sortie : total bits
            x0 = in(reg) chunk.get(0).copied().unwrap_or(0), // Entrées : 4 éléments
            x1 = in(reg) chunk.get(1).copied().unwrap_or(0),
            x2 = in(reg) chunk.get(2).copied().unwrap_or(0),
            x3 = in(reg) chunk.get(3).copied().unwrap_or(0),
            tmp = out(reg) _,            // Registre temp pour popcnt
            options(nostack, pure)       // Pas de stack, déterministe
        );
        total += sum;
    }
    total
}

Pourquoi asm! ? : L'unrolling manuel et le contrôle de registres maximisent l'efficacité du pipeline CPU, potentiellement surpassant count_ones() en évitant l'overhead d'appel de fonction et exploitant le parallélisme au niveau instruction.

Abstraction Sûre :

pub fn total_bits(data: &[u64]) -> u64 {
    if cfg!(target_arch = "x86_64") && is_x86_feature_detected!("popcnt") {
        unsafe { count_bits(data) }
    } else {
        data.iter().map(|x| x.count_ones() as u64).sum() // Fallback
    }
}

Assurer la Sécurité

  • Scope Unsafe : Le bloc asm! est confiné à une fonction unsafe, signalant clairement le risque. Je documenterais les invariants (ex : "data doit être de la mémoire valide").
  • Gestion de Registres : Utilise in(reg) pour les entrées, out(reg) pour les sorties, et clobber tmp pour éviter de corrompre l'état de l'appelant. options(nostack) empêche l'interférence de stack.
  • Pas de Comportement Indéfini : Évite l'accès mémoire en assembly ; s'appuie sur Rust pour les loads vérifiés en bounds. Teste les cas limites (ex : chunks vides ou courts).
  • Validation : Tests unitaires avec entrées connues (ex : 0xFFFF_FFFF_FFFF_FFFF → 64 bits) assurent la justesse contre la version scalaire.

Techniques Avancées d'Optimisation Assembly

Optimisation Multi-Architecture

// Support multi-architecture avec fallback intelligent
#[cfg(target_arch = "x86_64")]
mod x86_assembly {
    use std::arch::x86_64::*;
    
    #[target_feature(enable = "popcnt,bmi2")]
    pub unsafe fn count_bits_advanced(data: &[u64]) -> u64 {
        let mut total = 0u64;
        
        // Traitement vectorisé pour chunks larges
        for chunk in data.chunks_exact(8) {
            let mut sum: u64;
            asm!(
                // Unroll complet pour 8 éléments avec parallélisme maximal
                "xor {sum}, {sum}",
                "popcnt {tmp1}, {x0}",
                "popcnt {tmp2}, {x1}",
                "add {sum}, {tmp1}",
                "popcnt {tmp3}, {x2}",
                "add {sum}, {tmp2}",
                "popcnt {tmp4}, {x3}",
                "add {sum}, {tmp3}",
                "popcnt {tmp1}, {x4}",  // Réutilise tmp1
                "add {sum}, {tmp4}",
                "popcnt {tmp2}, {x5}",
                "add {sum}, {tmp1}",
                "popcnt {tmp3}, {x6}",
                "add {sum}, {tmp2}",
                "popcnt {tmp4}, {x7}",
                "add {sum}, {tmp3}",
                "add {sum}, {tmp4}",
                
                sum = out(reg) sum,
                x0 = in(reg) chunk[0],
                x1 = in(reg) chunk[1],
                x2 = in(reg) chunk[2],
                x3 = in(reg) chunk[3],
                x4 = in(reg) chunk[4],
                x5 = in(reg) chunk[5],
                x6 = in(reg) chunk[6],
                x7 = in(reg) chunk[7],
                tmp1 = out(reg) _,
                tmp2 = out(reg) _,
                tmp3 = out(reg) _,
                tmp4 = out(reg) _,
                options(nostack, pure)
            );
            total += sum;
        }
        
        // Traite les éléments restants
        for &value in data.chunks_exact(8).remainder() {
            total += value.count_ones() as u64;
        }
        
        total
    }
    
    // Version avec prefetching pour grandes données
    #[target_feature(enable = "popcnt")]
    pub unsafe fn count_bits_prefetch(data: &[u64]) -> u64 {
        let mut total = 0u64;
        const PREFETCH_DISTANCE: usize = 64; // Cache lines en avance
        
        for (i, chunk) in data.chunks_exact(4).enumerate() {
            // Prefetch les données futures
            if i * 4 + PREFETCH_DISTANCE < data.len() {
                let prefetch_addr = &data[i * 4 + PREFETCH_DISTANCE] as *const u64;
                asm!(
                    "prefetcht0 ({addr})",
                    addr = in(reg) prefetch_addr,
                    options(nostack, readonly)
                );
            }
            
            let mut sum: u64;
            asm!(
                "xor {sum}, {sum}",
                "popcnt {tmp}, {x0}",
                "add {sum}, {tmp}",
                "popcnt {tmp}, {x1}",
                "add {sum}, {tmp}",
                "popcnt {tmp}, {x2}",
                "add {sum}, {tmp}",
                "popcnt {tmp}, {x3}",
                "add {sum}, {tmp}",
                
                sum = out(reg) sum,
                x0 = in(reg) chunk[0],
                x1 = in(reg) chunk[1],
                x2 = in(reg) chunk[2],
                x3 = in(reg) chunk[3],
                tmp = out(reg) _,
                options(nostack, pure)
            );
            total += sum;
        }
        
        total
    }
}

#[cfg(target_arch = "aarch64")]
mod arm_assembly {
    use std::arch::aarch64::*;
    
    pub unsafe fn count_bits_neon(data: &[u64]) -> u64 {
        let mut total = 0u64;
        
        // NEON vectorisation pour ARM
        for chunk in data.chunks_exact(2) {
            let v = vld1q_u64(chunk.as_ptr());
            let count = vcnt_u8(vreinterpret_u8_u64(vget_low_u64(v)));
            let sum = vaddv_u8(count);
            total += sum as u64;
            
            let count_high = vcnt_u8(vreinterpret_u8_u64(vget_high_u64(v)));
            let sum_high = vaddv_u8(count_high);
            total += sum_high as u64;
        }
        
        // Fallback pour remainder
        for &value in data.chunks_exact(2).remainder() {
            total += value.count_ones() as u64;
        }
        
        total
    }
}

Optimisations Spécialisées par Domaine

// Crypto : Hamming distance optimisée
#[cfg(target_arch = "x86_64")]
pub unsafe fn hamming_distance_asm(a: &[u64], b: &[u64]) -> u64 {
    assert_eq!(a.len(), b.len());
    let mut total = 0u64;
    
    for (chunk_a, chunk_b) in a.chunks_exact(4).zip(b.chunks_exact(4)) {
        let mut sum: u64;
        asm!(
            "xor {sum}, {sum}",
            
            // XOR et POPCNT en pipeline
            "xor {tmp}, {a0}, {b0}",
            "popcnt {tmp}, {tmp}",
            "add {sum}, {tmp}",
            
            "xor {tmp}, {a1}, {b1}",
            "popcnt {tmp}, {tmp}",
            "add {sum}, {tmp}",
            
            "xor {tmp}, {a2}, {b2}",
            "popcnt {tmp}, {tmp}",
            "add {sum}, {tmp}",
            
            "xor {tmp}, {a3}, {b3}",
            "popcnt {tmp}, {tmp}",
            "add {sum}, {tmp}",
            
            sum = out(reg) sum,
            a0 = in(reg) chunk_a[0],
            a1 = in(reg) chunk_a[1],
            a2 = in(reg) chunk_a[2],
            a3 = in(reg) chunk_a[3],
            b0 = in(reg) chunk_b[0],
            b1 = in(reg) chunk_b[1],
            b2 = in(reg) chunk_b[2],
            b3 = in(reg) chunk_b[3],
            tmp = out(reg) _,
            options(nostack, pure)
        );
        total += sum;
    }
    
    total
}

// Traitement d'images : Seuillage optimisé
#[cfg(target_arch = "x86_64")]
pub unsafe fn threshold_asm(data: &[u8], threshold: u8, output: &mut [u8]) {
    assert_eq!(data.len(), output.len());
    
    for (chunk_in, chunk_out) in data.chunks_exact(8).zip(output.chunks_exact_mut(8)) {
        asm!(
            // Charge 8 octets en registre 64-bit
            "mov {input}, qword ptr [{input_ptr}]",
            "mov {thresh_expanded}, {thresh}",
            
            // Réplique threshold sur 8 octets
            "mov {tmp}, 0x0101010101010101",
            "imul {thresh_expanded}, {tmp}",
            
            // Compare et génère masque
            "xor {result}, {result}",
            "cmp {input}, {thresh_expanded}",
            "setae {result:l}",       // Set si above or equal
            "neg {result}",           // Étend le bit à tout l'octet
            
            // Stocke le résultat
            "mov qword ptr [{output_ptr}], {result}",
            
            input = out(reg) _,
            thresh_expanded = out(reg) _,
            tmp = out(reg) _,
            result = out(reg) _,
            input_ptr = in(reg) chunk_in.as_ptr(),
            output_ptr = in(reg) chunk_out.as_mut_ptr(),
            thresh = in(reg) threshold as u64,
            options(nostack)
        );
    }
}

Gestion d'Erreurs et Sécurité Robuste

use std::fmt;

#[derive(Debug)]
pub enum AssemblyError {
    UnsupportedArchitecture,
    MissingCpuFeature(String),
    InvalidInput(String),
    RuntimeError(String),
}

impl fmt::Display for AssemblyError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            AssemblyError::UnsupportedArchitecture => {
                write!(f, "Assembly optimizations not available on this architecture")
            }
            AssemblyError::MissingCpuFeature(feature) => {
                write!(f, "Required CPU feature not available: {}", feature)
            }
            AssemblyError::InvalidInput(msg) => {
                write!(f, "Invalid input: {}", msg)
            }
            AssemblyError::RuntimeError(msg) => {
                write!(f, "Runtime error: {}", msg)
            }
        }
    }
}

impl std::error::Error for AssemblyError {}

// API sûre avec validation complète
pub fn count_bits_safe(data: &[u64]) -> Result<u64, AssemblyError> {
    if data.is_empty() {
        return Ok(0);
    }
    
    if data.len() > 1_000_000 {
        return Err(AssemblyError::InvalidInput(
            "Data too large (max 1M elements)".to_string()
        ));
    }
    
    #[cfg(target_arch = "x86_64")]
    {
        if is_x86_feature_detected!("popcnt") {
            // Validation additionnelle pour l'alignement
            let data_ptr = data.as_ptr() as usize;
            if data_ptr % 8 != 0 {
                return Err(AssemblyError::InvalidInput(
                    "Data must be 8-byte aligned".to_string()
                ));
            }
            
            let result = unsafe { 
                std::panic::catch_unwind(|| count_bits(data))
                    .map_err(|_| AssemblyError::RuntimeError(
                        "Assembly routine panicked".to_string()
                    ))?
            };
            
            Ok(result)
        } else {
            Err(AssemblyError::MissingCpuFeature("popcnt".to_string()))
        }
    }
    
    #[cfg(not(target_arch = "x86_64"))]
    {
        Err(AssemblyError::UnsupportedArchitecture)
    }
}

// Fallback automatique avec métriques
pub fn count_bits_with_fallback(data: &[u64]) -> (u64, &'static str) {
    match count_bits_safe(data) {
        Ok(result) => (result, "assembly"),
        Err(_) => {
            let result = data.iter().map(|x| x.count_ones() as u64).sum();
            (result, "fallback")
        }
    }
}

Benchmarking et Validation Complets

use criterion::{BenchmarkId, Criterion, Throughput, black_box};

fn comprehensive_assembly_bench(c: &mut Criterion) {
    let sizes = [100, 1_000, 10_000, 100_000];
    let mut group = c.benchmark_group("bit_counting");
    
    for size in sizes {
        let data: Vec<u64> = (0..size).map(|i| {
            // Pattern qui challenge les optimisations
            if i % 3 == 0 { 0 } else { !0u64 >> (i % 64) }
        }).collect();
        
        group.throughput(Throughput::Elements(size as u64));
        
        // Rust standard
        group.bench_with_input(
            BenchmarkId::new("rust_std", size),
            &data,
            |b, data| {
                b.iter(|| {
                    black_box(data.iter().map(|x| x.count_ones() as u64).sum::<u64>())
                })
            }
        );
        
        // Assembly optimisé
        #[cfg(target_arch = "x86_64")]
        if is_x86_feature_detected!("popcnt") {
            group.bench_with_input(
                BenchmarkId::new("assembly_basic", size),
                &data,
                |b, data| {
                    b.iter(|| {
                        black_box(unsafe { count_bits(data) })
                    })
                }
            );
            
            if is_x86_feature_detected!("bmi2") {
                group.bench_with_input(
                    BenchmarkId::new("assembly_advanced", size),
                    &data,
                    |b, data| {
                        b.iter(|| {
                            black_box(unsafe { x86_assembly::count_bits_advanced(data) })
                        })
                    }
                );
            }
        }
        
        // API sûre avec fallback
        group.bench_with_input(
            BenchmarkId::new("safe_api", size),
            &data,
            |b, data| {
                b.iter(|| {
                    black_box(count_bits_with_fallback(data).0)
                })
            }
        );
    }
    
    group.finish();
}

#[cfg(test)]
mod tests {
    use super::*;
    
    #[test]
    fn test_assembly_correctness() {
        let test_cases = vec![
            vec![0u64],
            vec![!0u64],
            vec![0x5555555555555555u64],
            vec![1, 2, 4, 8, 16, 32, 64, 128],
            (0..100).map(|i| i * 0x0123456789ABCDEFu64).collect(),
        ];
        
        for case in test_cases {
            let expected: u64 = case.iter().map(|x| x.count_ones() as u64).sum();
            
            #[cfg(target_arch = "x86_64")]
            if is_x86_feature_detected!("popcnt") {
                let result = unsafe { count_bits(&case) };
                assert_eq!(result, expected, "Assembly result mismatch for {:?}", case);
            }
            
            let (safe_result, method) = count_bits_with_fallback(&case);
            assert_eq!(safe_result, expected, "Safe API result mismatch for {:?} (method: {})", case, method);
        }
    }
    
    #[test]
    fn test_edge_cases() {
        // Test vecteur vide
        assert_eq!(count_bits_with_fallback(&[]).0, 0);
        
        // Test grands vecteurs
        let large_data: Vec<u64> = (0..10000).map(|i| i as u64).collect();
        let expected: u64 = large_data.iter().map(|x| x.count_ones() as u64).sum();
        let (result, _) = count_bits_with_fallback(&large_data);
        assert_eq!(result, expected);
    }
    
    #[test]
    fn test_alignment_requirements() {
        // Test avec données non-alignées
        let mut data = vec![0u64; 100];
        let ptr = data.as_mut_ptr();
        
        // Force un décalage pour tester l'alignement
        let misaligned_data = unsafe {
            std::slice::from_raw_parts(
                (ptr as *const u8).add(4) as *const u64,
                50
            )
        };
        
        #[cfg(target_arch = "x86_64")]
        if is_x86_feature_detected!("popcnt") {
            let result = count_bits_safe(misaligned_data);
            // Devrait échouer sur les données non-alignées
            assert!(result.is_err());
        }
    }
}

Assurer la Portabilité

Abstraction Multi-Architecture

// Trait unifié pour toutes les implémentations
pub trait BitCounter {
    fn count_bits(&self, data: &[u64]) -> u64;
    fn architecture(&self) -> &'static str;
    fn features_required(&self) -> &[&'static str];
}

pub struct OptimalBitCounter;

impl BitCounter for OptimalBitCounter {
    fn count_bits(&self, data: &[u64]) -> u64 {
        #[cfg(target_arch = "x86_64")]
        {
            if is_x86_feature_detected!("popcnt") && is_x86_feature_detected!("bmi2") {
                return unsafe { x86_assembly::count_bits_advanced(data) };
            } else if is_x86_feature_detected!("popcnt") {
                return unsafe { count_bits(data) };
            }
        }
        
        #[cfg(target_arch = "aarch64")]
        {
            return unsafe { arm_assembly::count_bits_neon(data) };
        }
        
        // Fallback universel
        data.iter().map(|x| x.count_ones() as u64).sum()
    }
    
    fn architecture(&self) -> &'static str {
        cfg_if::cfg_if! {
            if #[cfg(target_arch = "x86_64")] {
                "x86_64"
            } else if #[cfg(target_arch = "aarch64")] {
                "aarch64"
            } else {
                "generic"
            }
        }
    }
    
    fn features_required(&self) -> &[&'static str] {
        cfg_if::cfg_if! {
            if #[cfg(target_arch = "x86_64")] {
                &["popcnt"]
            } else if #[cfg(target_arch = "aarch64")] {
                &["neon"]
            } else {
                &[]
            }
        }
    }
}

// Factory pour créer l'implémentation optimale
pub fn create_optimal_counter() -> Box<dyn BitCounter> {
    Box::new(OptimalBitCounter)
}

Conclusion

L'inline assembly en Rust est justifié pour l'optimisation extrême ou les instructions CPU uniques, comme montré avec le comptage de bits optimisé. J'assurerais la sécurité en confinant l'unsafe, gérant les registres soigneusement, et validant avec des tests exhaustifs. Pour la portabilité, j'utiliserais la compilation conditionnelle avec fallbacks, créant des abstractions qui cachent les détails d'architecture tout en livrant des performances maximales sur le matériel cible. L'assembly doit être le dernier recours après avoir épuisé les optimisations de haut niveau.

Retour au blog
Partager ::