본문 바로가기
논문

[리뷰] Mamba: Linear-Time Sequence Modeling with Selective State Spaces

by mhiiii 2025. 3. 16.
728x90

Mamba란?


기존 Transformer과 RNN이 가진 연산량 문제와 길이 의존성 문제를 해결하는 새로운 모델

 

State Space Model(SSM)에 기반하여 만들어진 딥러닝 모델

 

➡️ Self-Attention 없이 Transformer 수준의 성능을 달성하면서도 연산 속도 및 메모리 사용량을 개선한 것이 특징


 

✅ 기존 Transformer는 어떤 문제를 가지고 있을까?

 

먼저, 트랜스포머는 최고의 성능을 가지는 시퀀스 모델.

 

입력 값이 무엇이든 시퀀스의 이전 토큰을 참고할 수 있어서 그 표현을 도출 가능 

또한, Attention 매커니즘을 활용하여 복잡한 문맥 정보를 효과적으로 학습할 수 있음 

 

하지만, Transformer 구조에는 몇 가지 치명적인 한계점이 존재

  1. 복잡도가 입력 길이에 대해 이차 복잡도 (O(n²))
    Transformer의 Self-Attention은 입력 시퀀스의 모든 토큰 쌍에 대해 attention 계산 필요
  2. 어텐션 행렬 (Attention Matrix)의 메모리 사용량 증가
  3. 장기 의존성(Long-Term Dependency) 문제
  4. 자기회귀 구조로 병렬 생성 불가
  5. 입력 길이 고정

이러한 단점을 극복하기 위해서 attention에 대한 연구가 많이 진행됨.

 

→ 시퀀스 처리에 유망한 SSM이 급부상


✅ SSM이 뭐지?

 

State Space Model은 상태 공간(State Space)를 기반으로 데이터를 모델링하는 구조

 

여기서는 S4에 대해서 다루었습니다.

 

시간에 따라 변화하는 시스템이나 시퀀스 데이터를 모델링하는 데 자주 사용

 

State란?

A minimum set of variables, known as state variables, that fully describe the system and its response to any given set of inputs

 

쉽게 말하자면 시스템을 완전히 설명할 수 있는 최소한의 변수들을 의미 

 

상태 공간이 필요한 이유

  • 동적 시스템의 내부 상태를 나타낼 수 있음
  • 시간에 따라 변하는 시스템의 동작을 이해하고 예측하는 데 필수적
  • 시스템의 내부 상태 변수들을 사용하여 시스템의 현재 상태와 입력으로부터 미래의 출력을 예측할 수 있음
  • 현재와 과거의 상태를 기반으로 시스템의 미래 동작을 예측 가능, 현재의 입력만 고려하는 것보다 더 정확한 예측 가능

 

상태 공간 모델은 다음과 같은 두 개의 방정식으로 정의됨 

 

(1) 상태 업데이트

현재 상태에서 다음 상태를 계산하는 과정

 

$$h_{t+1} = Ah_t + Bx_t $$

 

 

  • = 현재 상태 벡터
  • = 상태 전이 행렬 (State Transition Matrix)
  • = 현재 입력 값
  • = 입력 변환 행렬 (Input Transformation Matrix)

 

 

(2) 출력 계산 

 

현재 상태에서 출력 값을 계산하는 과정

 

$$y_t = Ch_t + Dx_t$$

 

 

  • $y_t$ = 출력 값
  • $C$= 출력 변환 행렬 (Output Transformation Matrix)
  • $D$ = 직접 변환 행렬 (Direct Transformation Matrix)

 

 

➡️ SSM의 핵심

 

  • 상태 벡터 $h_t$는 시스템의 내부 상태를 나타냄
  • 상태 벡터가 시간에 따라 변화하면서 입력 데이터를 반영
  • 출력 벡터 $y_t$는 현재 상태와 입력 값의 선형 조합으로 계산됨

 

딥러닝의 경우 변수가 전부 이산형 변수이기 때문에 이산화 과정을 거쳐야 함

 

 

그러면 다음과 같은 Discrete time state space model을 구현할 수 있음 

 

이 Discrete SSM은 

 

$h_0 = B̅x_0$

$y₀ = Ch₀$

$h_1 = A̅h₀ + B̅x₁ = A̅B̅x₀ + B̅x₁$

$y₁ = C(A̅B̅x₀ + B̅x₁) = CA̅B̅x₀ + CB̅x₁$

$yₖ = C(A^kB̅x₀ + A^{k-1}B̅x₁ + ... + C̅B̅xₖ)$

 

위와 같은 연산이 이루어지게 될 것임

 

그렇게 되면, 입력 $x$에 곱해진 $ABC$ 행렬은 결국 컨볼루션의 커널과 같은 역할을 하게됨

 

컨볼루션 커널

$K = (CB, CAB, ..., CA^KB)$

$y = x * K$

 

 

 

이 커널은 병렬 연산이 가능하게 하여 효율적으로 계산할 수 있음

 

기존 SSM은 이전 값이 계산되어야지 현재 값을 계산할 수 있었지만,

컨볼루션 커널과 같은 방식을 사용해 각 입력 데이터에 대해 동시에 계산이 가능하게 됨

 

 

📌 그러나, SSM도 두가지 문제점이 존재


SSM은 시간 불변성(Linear Time In-variance, LTI)을 가짐 


즉, 모든 시간 step에 대해서 A,B,C와 같은 Matrix가 고정되어 있음
→ 따라서, 어떠한 입력이 들어가도 transformer와 달리 LTI 특성으로 인해 입력 값에 따른 연산이 유동적으로 변화하기 힘듦.

  1. Selective Copying
    기존 SSM은 특정 입력을 무시하거나 집중하는 능력이 없음

    바닐라 SSM은 시간 불변(Time invariant)하기 때문에 내용 인식 추론(content-aware reasoning)을 할 수 없음
    → 시간 불변이란 시간에 따라 시스템의 특성이 변하지 않는 것. 따라서 시간에 의존하지 않음 
         생성하는 모든 토큰에 대해 매개변수 A,B,C가 동일하다는 것을 의미

    SSM의 고정된 A, B, C 행렬로는 입력 데이터에서 중요한 정보를 선택적으로 추출하기 어려움

    시간 불변 특성으로 인해 과거 정보 중 어떤 부분을 더 중요하게 기억해야 하는지를 조정할 수 없음.

  2. Induction heads
    입력에서 발견된 패턴을 재현하는 것이 불가


SSM이 이러한 작업에서 성능이 떨어지는 것은 시간 불변성 SSM의 근본적인 문제,
즉 A, B, C 행렬의 정적인 특성으로 인한 문제

 

 

그래서 S4(Selective SSM)에서는 선택 메커니즘을 추가함 

→ 시간에 따라 변하는 특성을 가지게 됨 (Time-variant)

→ 컨볼루션 사용 불가 (고정된 커널이 아니게 되었기 때문) 

→ 병렬화 불가 (컨볼루션 커널 사용 불가로) 

 

따라서, 이제 SSM은 전체적인 기록을 압축하는 효율적인 작은 상태를 생성할 수 있게 됨

그러나, 어텐션 행렬을 사용하는 Transformer와 비교하면 훨씬 안 좋은 상태

 

목표 : Transformer만큼의 성능을 내면서 작은 크기를 유지하게 하자!!


728x90

 

Mamba

단순히 과거 모든 정보를 기억하는 것이 아니라, 입력 데이터에 따라 중요한 정보를 선택적으로 추출하는 방식을 도입

 

아래 그림 1에서 기존 SSM과 발전된 S6(Mamba)의 차이를 확인할 수 있음

 

Algorithm 2에서 확인할 수 있듯이, A는 기존과 같이 고정으로 두고 B, C  행렬은 입력에 영향을 받도록 함을 알 수 있음 

그림 1 : S4에서 S6(Mamba)으로 어떻게 변화했는지 의사코드로 표현

 

그림 1의 설명

 

맘바는 행렬 B와 C, 스텝 크기 ∆를 입력에 의존하도록 하여 입력의 시퀀스 길이와 배치 크기를 포함

→ 모든 입력 토큰에 대해 서로 다른 B와 C 행렬을 가지고 있다는 것을 의미

 

 

Δ(step size) : 입력의 이산화(discretization) 매개변수로서의 해상도,

이산화된 상태 공간 모델에서 입력 데이터를 이산적으로 처리하는 단위

 

  • 작은 Step size는 특정 단어를 무시하고 대신 이전 맥락을 더 사용하는 데 초점
  • 큰 Step size는 맥락보다 입력 단어에 더욱 집중하도록 학습

 

📌 행렬 A는 왜 고정일까?

더보기

A는 상태 전이 행렬

 

  • A행렬은 이전 상태의 정보를 포착할 수 있음

상태 간의 전이 패턴은 입력과 무관해야 함

  • 상태 전이 행렬 는 시스템의 기본 동역학(dynamics)을 정의
  • 즉, 상태 간의 관계는 시스템의 내부적인 특성에 의해 결정됨
  • 입력마다 상태 전이 행렬이 바뀌면 시스템 자체가 불안정해짐

👉 상태 간 전이는 시스템의 물리적, 구조적 특성을 반영하므로 입력에 의해 바뀌면 안 됨

SSM에서는 시간이 지남에 따라 상태가 변화하는데,
A는 이 전이 과정에서 핵심적인 역할을 하므로, 이전 상태를 포착하는 것이 중요

 

그러나, 행렬의 동적 변화(Time-varying)로 인해 컨볼루션을 사용할 수 없게 되어 병렬화가 불가능해짐 

 

Mamba에서는 이를 scan 방식으로 해결하고자 함 

 

Scan은 

 

하드웨어의 메모리 계층을 최적화하여 선택적 상태 공간 모델의 성능을 극대화하는 방법

 

 

GPU의 단점 중 하나는 작지만 매우 효율적인 SRAM과 크지만 약간 덜 효율적인 DRAM 사이의 전송(IO) 속도가 제한되어 있음

 

Kernal Fusion

여러 개의 CUDA 커널을 하나의 커스텀 CUDA 커널로 융합하여 중간 결과를 HBM에 복사하지 않고 연산을 수행

→ 속도가 빠른 SRAM에서만 연산을 진행하고 최종 결과만 용량이 큰 HBM에 저장

 

 

또한, 저자는 Scan 계산 병렬로 처리할 수 있도록 설계

 

입력 시퀀스를 동시에 처리하면서, 선택적으로 필요한 정보를 처리하고 나머지는 건너뛰는 방식

 

기본적인 상태 업데이트는 다음과 같은 순차적 의존성을 가짐:

$H_t = A H_{t-1} + B x_t$

👉 즉, $H_t$는 반드시 $H_{t-1}$이 필요하기 때문에 병렬화가 어려운 구조

 

상태 업데이트를 다음과 같이 Sweep-downSweep-up으로 나눠서 병렬화가 가능하게 함 

 

✅ 1. Sweep-down (상향 전파)

  • 먼저 입력 $x_t$에 대해 병렬로 B$x_t$를 계산
  • 상태 업데이트 항목을 먼저 계산 → 병렬 처리 가능

$$Bx = Bx_0,Bx_1,Bx_2,Bx_3$$

  • 이후 상태 항목끼리의 의존성을 없애기 위해 상향 전파 수행

 

✅ 2. Sweep-up (하향 전파),  Prefix Sum 적용

  • 상향 전파 결과를 다시 병렬적으로 처리
  • 이전 상태에서 얻어진 값을 기반으로 동시에 상태 업데이트

 

기본 상태 업데이트:

 

이걸 prefix sum으로 변환하면:

 

즉, prefix sum 형태가 됨:

$H = [Bx_0, Bx_1+Bx_0, Bx_2 +Bx_1+Bx_0, Bx_3+Bx_2 + Bx_1+Bx_0 ]$

 

👉 prefix sum은 병렬로 계산 가능

 

 

마지막 성능에 대한 부분은 이 블로그를 참고해주세요

(잘 되어있음) 

 

 

 

 

References

https://velog.io/@euisuk-chung/Paper-Review-Mamba-Linear-Time-Sequence-Modeling-with-Selective-State-Spaces

 

[Paper Review] Mamba: Linear-Time Sequence Modeling with Selective State Spaces

https://arxiv.org/pdf/2312.00752https://youtu.be/JjxBNBzDbNk

velog.io

https://minyoungxi.tistory.com/118

 

[논문리뷰] - ⭐️Mamba: Linear-Time Sequence Modeling with Selective State Spaces⭐️ - 맘바 ! Transformer의 대체

1. Interesting Point"We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addr

minyoungxi.tistory.com

 

728x90