YOU CAN REPLICATE THIS PROJECT -> https://github.com/Nagharjun17/Flash-Attention-Triton
- Implemented custom Flash Attention forward & backward kernels in Triton with causal mask support.
- Integrated with PyTorch using
autograd.Functionfor seamless gradient computation. - Benchmarked kernels vs plain PyTorch math baseline with ~2.5× speedup (forward and backward pass runtime: 2.757 ms to 1.116 ms on RTX GPU).
- Added handwritten notes with forward/backward derivations for better theoretical understanding.
- Tested with PyTorch 2.3 and Triton 2.1 on RTX hardware.