본문 바로가기
  • 데이터에 가치를 더하다, 서영석입니다.
카테고리 없음

[논문 리뷰] L-VAE – Learnable β를 가진 Variational Autoencoder

by 꿀먹은데이터 2025. 11. 30.

1. 배경: 왜 아직도 VAE & β-VAE인가

1-1. 기본 VAE 복습

VAE(Variational Autoencoder)는 입력 \(x\)를 잠재변수 \(z\)로 인코딩했다가 다시 복원된 \(\hat{x}\)로 디코딩하는 구조의 오토인코더다. 단순 Autoencoder와 다른 점은, 잠재공간(latent space)에 확률 분포를 얹는다는 것이다.

학습 목표는 ELBO(Evidence Lower Bound)를 최대화하는 것이고 보통 “reconstruction + KL” 구조의 loss로 쓴다.

  • 첫 번째 항: 재구성 오차(reconstruction loss)
  • 두 번째 항: 잠재 분포를 prior (p(z)) (보통 (N(0, I))에 가깝게 만드는 KL divergence

1-2. β-VAE: KL에 가중치를 주면 disentanglement가 생긴다

β-VAE는 여기서 KL 항에 스칼라 β를 곱한다.

 

  • β>1 : : KL을 강하게 → 잠재공간을 더 “정돈/압축” → disentanglement 증가
  • 하지만 reconstruction은 망가질 수 있다 (너무 강한 규제)

문제는 여기서 β를 사람이 손으로 튜닝해야 한다는 것

  • 데이터/아키텍처/학습 스케줄마다 최적 β가 다름
  • grid search 하려면 비용이 꽤 크다
  • β를 키우면 재구성이 나빠지고, 줄이면 disentanglement가 안 생김

L-VAE 논문은 바로 이 “β 튜닝 지옥”에 문제의식을 둔다.

 

2. β-VAE에 대한 비판: β 하나로는 부족하다

논문은 먼저 β-VAE의 학습 동작을 네 가지 관찰로 정리한다.

  1. Observation 1 – scale mismatch
    vanilla VAE(β=1)에서는 reconstruction loss가 KL보다 수치적으로 훨씬 커서, 전체 loss를 사실상 reconstruction이 지배한다.
    학습이 진행되어도 KL 항이 충분히 줄어들지 않고, 오히려 증가하기도 한다.
  2. Observation 2 – β를 키우면 재구성 품질이 떨어짐
    β를 키우면 KL은 잘 제어되고 disentanglement는 좋아지지만, 그만큼 reconstruction loss가 커져서 이미지/샘플 품질이 나빠진다.
  3. Observation 3 – “좋은 β”는 다른 하이퍼파라미터에도 의존
    같은 β라도 learning rate, batch size, iteration 수에 따라 disentanglement vs reconstruction 트레이드오프가 크게 달라진다.
  4. Observation 4 – 항상 β > 1이 좋은 것도 아님
    일부 데이터셋(MPI3D, Isaac3D)에서는 오히려 β < 1일 때 더 나은 disentanglement + reconstruction 조합이 나왔다.

이 네 가지 관찰로부터 논문은 다음과 같은 결론에 도달한다.

“그냥 β 하나를 고정해두고 쓰는 방식 자체가 한계다.
→ loss 항 사이의 trade-off를 모델이 직접 학습하도록 만들자.”

3. L-VAE 아이디어: β를 ‘배우는’ VAE

3-1. Multi-task learning에서 가져온 아이디어

저자들은 multi-task learning에서 쓰이는 loss weight 학습 기법을 가져온다.

대표적인 것이 Kendall et al.(2018)의 uncertainty weighting이다.

multi-task setting에서 전체 loss가 L1,L2L_1, L_2 두 항으로 나뉜다고 하면, 다음처럼 각 task의 “불확실성 σi\sigma_i”를 학습 가능한 파라미터로 두고,

이 전체를 같이 최적화한다.
log⁡ σi 항σi→0으로 붕괴하는 것을 막는 regularizer 역할을 한다.

3-2. 이걸 VAE loss에 그대로 꽂으면?

VAE에서는

  • L0=L_0 = reconstruction loss ( MSE 등)
  • L1=L_1 = KL divergence DKL

라고 두고, 위 구조를 그대로 사용한다. 그러면 L-VAE loss는 다음처럼 쓸 수 있다.

여기서 σ0,σ1\sigma_0, \sigma_1네트워크 파라미터와 함께 gradient descent로 학습되는 스칼라다.

유효한 “동적 β”는 다음 비율로 해석할 수 있다.

즉, L-VAE는 β를 고정하지 않고, 학습 과정에서 σ0,σ1\sigma_0, \sigma_1을 조정하면서
reconstruction vs disentanglement의 trade-off를 자동으로 탐색
한다.

3-3. 구조 자체는 β-VAE와 거의 동일

흥미로운 점은 네트워크 아키텍처는 β-VAE와 완전히 동일하다는 것.
변하는 건 loss의 상대 가중치뿐이다.

  • 인코더/디코더는 MLP 또는 CNN (dSprites에는 MLP, 나머지는 CNN)
  • latent dimension도 각 데이터셋의 factor 개수에 맞춰 5/7/9 등으로 설정
  • Optimizer는 Adam + OneCycleLR 스케줄 사용

즉, “β를 언제, 어디까지 줄지/늘릴지”를 사람이 아니라 모델이 결정하게 만든 버전의 β-VAE라고 보면 된다.

4. 실험 세팅: 어떤 데이터와 비교군을 썼나

논문은 총 다섯 개의 데이터셋을 사용한다.

  • dSprites 흑백 2D 심볼 (square, heart, ellipse 등) / factor: shape, scale, orientation, x/y 위치 (총 5개)
  • MPI3D-complex : 로봇 팔이 다양한 물체를 움직이는 3D 이미지 / factor: 색/형태/크기/카메라 높이/배경색/가로/세로 위치 (7개)
  • Falcor3D : 거실 씬, 조명/카메라 위치/강도 등이 바뀌는 3D 렌더 이미지 (7개 factor)
  • Isaac3D : 주방에서 로봇 팔이 물체를 잡고 있는 씬 (9개 factor)
  • CelebA (정량 평가는 아니고, latent traversal 시각화 위주)

비교하는 모델은 다음과 같다.

  • VAE (β=1)
  • β-VAE (고정 β, 여러 값 grid search)
  • ControlVAE (PID 컨트롤러로 β를 동적으로 조정)
  • DynamicVAE (ControlVAE 변형)
  • σ-VAE (recon 쪽 weight만 학습하는 변형)
  • L-VAE (recon + KL weight 모두 학습)

4-1. 평가 지표: disentanglement를 여러 각도에서 본다

논문은 disentanglement를 세 가지 속성으로 정리한다.

  • Explicitness: representation에서 원래 factor를 얼마나 쉽게 복원할 수 있는가
  • Compactness: 한 factor를 적은 차원(이상적으로는 한 차원)에 담을 수 있는가
  • Modularity: 한 factor를 바꿔도 representation의 다른 차원은 덜 영향을 받는가

이를 위해 다음 지표들을 사용:

  • Explicitness score (explicitness)
  • SAP, MIG (compactness)
  • β-VAE score, FactorVAE score (modularity)
  • IRS (Interventional Robustness Score, holistic: modularity + explicitness)

각 모델에 대해 reconstruction loss + 위 disentanglement 지표 6개를 측정해서 비교한다.


5. 결과: L-VAE가 실제로 더 좋은가?

테이블 전체를 가져오진 않고 핵심만 요약하면:

5-1. 성능 정리

  • dSprites
    • L-VAE가 β-VAE, ControlVAE, DynamicVAE, σ-VAE보다 대부분의 disentanglement 지표에서 best 혹은 second best를 기록
    • reconstruction loss도 β-VAE보다는 낮고, VAE보다는 다소 높지만 양호한 수준
  • MPI3D / Falcor3D / Isaac3D
    • 데이터셋에 따라 약간 차이는 있지만, L-VAE는 거의 항상 상위권(1~2등) 성능을 유지하면서 reconstruction과 disentanglement의 균형이 잘 잡혀 있다.

요약하면:

“β를 따로 튜닝하지 않아도, β-VAE 계열 중에서 최상위권 성능을 내는 안전한 선택지

라는 느낌이다.

5-2. L-VAE가 학습한 β는 얼마인가?

논문에서는 L-VAE가 학습한 weight 비율(=유효 β)을 분석한다. 

  • dSprites, Falcor3D에서는 β^>1\hat{\beta} > 1 근처 값이 나오고,
  • Isaac3D에서는 β^≈0.95\hat{\beta} \approx 0.95 같이 1보다 작은 값이 가장 좋은 disentanglement를 만든다.
  • 즉, 데이터셋에 따라 “좋은 β”가 제각각이라는 것을 L-VAE가 스스로 학습한다.

또한, L-VAE가 최종적으로 학습한 β^\hat{\beta}β-VAE의 고정 β로 그대로 넣어서 다시 학습시켜 보면,
동일한 β값인데도 동적으로 학습된 L-VAE가 더 좋은 성능을 보이는 경우가 많다.

  • 단순히 “최종 β 값”이 중요한 게 아니라,
  • 학습 과정 전체에서 β가 어떻게 변해왔는지가 disentanglement에 영향을 준다는 의미다.

6. L-VAE를 수식 말고 코드로 한 줄 요약하면

논문에 나오는 수식 구조를 PyTorch 스타일로 써보면 대략 이런 느낌이다
(실제 구현과 1:1로 같진 않고, 개념적인 skeleton):

# recon_loss: MSE or BCE (batch mean)
# kl_loss: KL divergence (batch mean)

# log_sigma0, log_sigma1: learnable scalar parameters
# (초기값 0 -> sigma ~ 1)
sigma0 = torch.exp(log_sigma0)
sigma1 = torch.exp(log_sigma1)

loss = (1.0 / (sigma0 ** 2)) * recon_loss \
     + (1.0 / (sigma1 ** 2)) * kl_loss \
     + torch.log(sigma0) + torch.log(sigma1)

loss.backward()
optimizer.step()


-->>>  beta_hat = (sigma0 ** 2) / (sigma1 ** 2)

7. 개인적인 정리: 이 논문에서 얻은 포인트

연구/아이디어 관점에서 인상적이었던 포인트만 정리해본다.

  1. “β를 고정할 필요가 없다”는 것을 아주 정공법으로 보여줌
    • β-VAE의 핵심 가정(β>1이 disentanglement에 좋다)을 실제로 여러 데이터셋에서 다시 검증하고
      그 가정이 항상 맞는 게 아님을 데이터로 설득한다.
  2. Loss weighting을 별도 장치 없이 “그 자체”로 학습하는 구조
    • Kendall의 uncertainty weighting을 VAE loss에 그대로 적용했을 뿐인데
      기존 ControlVAE/DynamicVAE 같은 복잡한 제어 이론 기반 방식보다
      깔끔하면서도 잘 동작한다는 점이 꽤 매력적이다.
  3. Disentanglement 연구에서 “하이퍼파라미터 선택”이 얼마나 큰 변수인지 다시 확인
    • Locatello et al.의 대규모 실험(12k 모델)처럼
      이 논문도 “하이퍼와 β 튜닝을 어떻게 하느냐에 따라 결과가 확 뒤집힌다”는 걸 다시 강조한다.

8. 마무리

요약하면, L-VAE는

  • 구조는 β-VAE와 거의 동일하지만
  • KL과 reconstruction의 상대 weight를 학습 가능한 파라미터로 승격시키고
  • 이 덕분에 β 튜닝을 사람이 안 해도, disentanglement 관점에서 항상 상위권 성능을 내는 모델이다.

 

ps. 요즘 기술적인 공부를 많이 안하고 있어서, 기술적 공부도 할겸 찾으면서 작성했다. 앞으로는 더더 기술적이면서도 기초가 되는 내용으로 블로그를 채우고 싶다..