데알못정을

Causal Machine Learning에서 Balanced Representation의 의미 본문

Research

Causal Machine Learning에서 Balanced Representation의 의미

쩡을이 2025. 6. 3. 14:52
728x90

Balanced Representation in Causal Machine Learning

- 개념

Balanced representation인과 추론(causal inference) 또는 counterfactual prediction 분야에서 사용하는 개념으로,
처치(treatment)와 비처치(control) 그룹 간 confounding bias를 줄이기 위해 학습된 표현 공간(latent space)을 의미

쉽게 말하면:

처치 여부와 관계없이 비슷한 환자들이 비슷한 표현(embedding)을 갖도록 만드는 것.

- 배경

인과 추론에서 큰 문제는 confounding
예: 건강한 사람은 치료를 안 받고, 아픈 사람은 치료를 받는다면 → 치료 효과 자체를 구분하기 어려움

Randomized trial이 아니라면 관측 데이터에서 처치 여부 $T$와 공변량 $X$ 사이의 상관이 존재함

이를 해결하기 위해, representation learning을 통해
$T=1$과 $T=0$ 그룹이 공정하게 비교될 수 있는 공간 $\Phi(X)$을 학습합니다.

- 목표

Representation $\Phi(x)$를 학습할 때, 다음 두 가지 조건을 만족시켜야 함.

Treatment-independence (균형)

- $\Phi(x)$ 만으로는 이 샘플이 처치받았는지 알 수 없어야 함

- $\Rightarrow$ 처치와 비처치 샘플들이 $\Phi(x)$ 공간에서 비슷한 분포를 갖도록

Outcome-relevance (예측력)

- $\Phi(x)$는 결과 $Y$를 정확히 예측할 수 있어야 함

- 대표적인 objective 함수는 다음과 같다:

$$
\min_{\Phi, f} \; \mathcal{L}_{\text{pred}}(f(\Phi(x), t), y) + \lambda \cdot \text{Discrepancy}(\Phi(x) \mid t=1, \Phi(x) \mid t=0)
$$

- $\mathcal{L}_{\text{pred}}$: 결과 예측 손실 (예: MSE)

- $\text{Discrepancy}$: 처치 그룹 간 representation 분포 차이
(예: MMD, Wasserstein distance, adversarial loss 등)

- $\lambda$: 균형과 예측 정확도 간 trade-off 조절 하이퍼파라미터

예시)

원본 데이터

- 치료 받은 환자: 중증

- 치료 안 받은 환자: 경증

→ 평균 비교하면 치료 효과가 왜곡

$\Phi(x)$ 공간에서는

- 중증/경증이 균일하게 분포되도록 표현을 바꿈

→ 마치 무작위 실험처럼 공정하게 비교 가능하게 만듦

 

728x90
Comments