| 일 | 월 | 화 | 수 | 목 | 금 | 토 |
|---|---|---|---|---|---|---|
| 1 | 2 | 3 | ||||
| 4 | 5 | 6 | 7 | 8 | 9 | 10 |
| 11 | 12 | 13 | 14 | 15 | 16 | 17 |
| 18 | 19 | 20 | 21 | 22 | 23 | 24 |
| 25 | 26 | 27 | 28 | 29 | 30 | 31 |
- 딥러닝
- 의료정보
- multi gpu
- GaN
- netflix thumbnail
- causal ml
- nccl 업그레이드
- machine learning
- causal forest
- 분산 학습
- causal transformer
- doubleml
- NTMs
- pytorch
- irregularly sampled time series
- 인과추론
- causal inference
- 의료
- 불규칙적 샘플링
- gru-d
- ERD
- 토픽모델링
- causal reasoning
- Time Series
- nccl 업데이트
- 인과추론 의료
- Transformer
- 리뷰
- first pg on this rank that detected no heartbeat of its watchdog.
- causal machine learning
- Today
- Total
데알못정을
[Review] Causal Transformer for Estimating Counterfactual Outcomes 본문
[Review] Causal Transformer for Estimating Counterfactual Outcomes
쩡을이 2025. 4. 23. 15:01
Introduction
의료 의사결정은 서로 다른 치료를 적용한 후 시간에 따라 변화하는 개별 환자의 건강 결과에 대한 정확한 지식을 필요로한다. 이는 궁극적으로 치료 계획의 선택에 정보를 제공하며, 개별 환자에 맞춘 효과적인 치료 제공을 가능하게 한다.
전통적으로, randomized controlled trials (RCT)가 치료 효과 추정의 gold standard이지만, 이는 비현실적이며, 비윤리적이다. 이를 해결하기 위해 EHR 데이터와 같은 observational data로부터 치료 효과를 추정하고자하는 관심이 늘어나고 있다.
통계 모델이나 머신러닝의 경우 시간이 지남에 따라 변하는 교란 요인을 통제하지 못해 종종 예측에 일반화 오류가 있거나 편향이 존재한다. 이러한 지점을 공략하기 위해 최근에는 recurrent marginal structural networks (RMSNs), counterfactual recurrent network (CRN), G-Net 같은 방법론이 개발되었는데, 이러한 방법론의 근간 자체가 LSTM이다 보니, 복잡하고 긴 종속성을 해결하지 못한다는 문제를 지적한다.
따라서 해당 연구는 Transformer를 기본으로 한 반사실적 결과 추정 모델인 Causal Transformer를 제안한다.

Problem Formulation
$i$를 각 환자의 index라 했을 때 health trajectories는 time steps $t = 1, ...,T^(i)$에 걸쳐있다. 각 환자 $i$의 각 시간 $t$에는 $d_x$ 차원을 가진 시간에 따라 변하는 시계열(공변량) $X_t^(i) \in mathbb{R}^d_x$와 $d_a$차원을 가진 treatment 집합 $A_t^(i) \in \left\{ a_1, ..., a_d_{a}\right\}$ 가 있다. (치료가 각 시간 별로 다차원 벡터로 존재 - 원핫인코딩 처럼) $d_y$차원을 가진 outcomes Y_t^(i) \in mathbb{R}^d_y$. $V^(i)$는 성별, 나이 같은 환자의 static 변수 . 이를 기반으로 학습을 위해 다음이 준비되어 있다고 생각


$\tau$가 예측 윈도우이고, t부터 미래 $\tau$까지의 treatment를 주면서 outcome인 Y를 예측하는 것에 관심이 있음
그러나, 구체적인 treatment intervention이 전형적으로 한 환자에 대해 관측되지않으며, 추정되어야함. formally하게 주어진 관측 데이터 D를 바탕으로 portential counterfactual potential outcome이 indentifiable하기 위해 다음의 가정이 만족되어야 함
(1) consistency, (2) sequential ignorability, (3) sequential overlap -> 해당 내용은 appendix A에 있음

consistency의 경우 환자가 받은 치료와 같은 치료를 모델이 상상했다면 모델이 상상한 결과는 실제 환자의 결과와 같아야한다는 뜻. → 당연히 지켜져야함. (그렇지 않으면 똑같은 치료를 받아도 결과가 다르다는 걸 의미) 예를 들어, 실제로 항생제를 받은 환자가 WBC 수치 12로 나왔는데,
모델은 "이 환자가 항생제를 받았을 때 WBC는 8이었을 것"이라고 한다면 → Consistency가 깨진 것
Sequential overlab의 경우 모든 환자 상태에 대해 어떤 치료든 받을 확률이 0보다 크고 1보다 작아야 한다는 뜻 → 환자기 실제로는 치료 A를 받았지만, 치료 B를 받았더라면 어땠을 까를 추정하고싶은데, 데이터에 치료 B를 받은 기록이없다면 추정 자체가 불가능
Sequential Ignorability의 경우 지금 어떤 치료를 받을지는, 과거 기록만 보면 알 수 있고, 미래 결과와는 무관해야 한다는 뜻 → 지금 의사가 치료를 선택할 때 미래에 어떤 결과가 나올지를 미리 알고있어서 결정하면 안되고, 오직 과거 기록만 보고 결정해야 함. 관찰되지 않은 요인(예: 의사의 직관)이 치료와 결과 모두에 영향을 준다면 confounding이 생기고, 정확한 인과 추론이 불가능
본 연구의 task는 다음과 같이 요약 가능:
환자의 history $\overline{H}_t$가 주어졌을 때 treatment intervention $\overline{a}_{t:t+\tau -1}$을 적용한 counterfactual outcomes $Y_{t_\tau}$를 예측하는 것
그러나 정말 단순하게 $g(\tau, \overline{a}_{t:t+\tau-1}, \overline{H}_t)$를 예측하는 것은 편향을 유발할 수 있음 - (treatment interventions는 outcome 뿐 아니라, future covariate에도 영향을 미치기 때문)
예를들어 vasopressor라는 치료제는 future covariate인 MAP를 상승시킬 수 있음
Causal transformer는 이러한 지점을 고려한 맞춤화된 모델임
Causal Transformer
CT는 3가지 분리된 transformer subnetwork를 결합하는 multi input architecture인데, 각 subnetwork는 input으로서 다양한 sequence를 입력 받는다. (1) past time-varying covariates $\overline{X}_t$ ; (2) past outcomes $\overline{Y}_t$; past treatment before intervention \overline{A}_{t-1}
본 모델의 목표는 "어떤 미래 치료를 했을 때 결과가 어떻게 나올까"를 예측하는 것이기에, 반사실적 예측을 위해 future treatment assignment를 입력으로추가하고, 자기회귀방식으로 예측한 값을 입력으로 넣어줌. (모델이 한 스텝씩 예측한 값을 다시 입력으로 넣어줌)
따라서, 저자는 두 treatment sequence 와 outcome sequecne를 concat한 $\overline{A}_{t-1} \cup \bar{a}_{t:t+\tau-1},\quad \overline{Y}_t \cup \hat{Y}_{t+1:t+\tau-1}$를 input을 만듦 (4) 정적 변수 V를 모든 subnetwork에 입력

아키텍처
CT(Causal Transformer)는 치료에 영향을 받지 않는(균형 잡힌) 표현 시퀀스 $\bar{\Phi}_{t+\tau-1} = (\Phi_1, \ldots, \Phi_{t+\tau-1})$를 생성한다. 이를 위해 $B$개의 동일한 transformer 블록을 쌓는다. 첫 번째 transformer 블록은 세 개의 입력 시퀀스를 받는다. $B$번째 transformer 블록은 표현 시퀀스 $\bar{\Phi}_{t+\tau-1}$를 출력한다.
- Transformer block -
$b = 1, \ldots, B$는 서로 다른 transformer 블록 인덱스. 각 transformer 블록은 세 개의 입력 시퀀스 각각에 대해 병렬적인 hidden state 시퀀스를 입력으로 받는다. 시간 $t$에서 각 입력 시퀀스에 대응하는 hidden state는 $A_t^b$ 또는 $a_t^b$, $Y_t^b$ 또는 $\hat{Y}_t^b$, $X_t^b$로 표기한다. hidden state의 차원은 $d_h$로 정의한다. 추가적으로 각 transformer 블록은 정적 공변량의 표현 벡터 $\tilde{V}$를 추가 입력으로 받는다.
첫 번째 transformer 블록($b = 1$)에서는 다음과 같이 선형 변환된 시계열을 입력으로 사용한다:
$$A_t^0, a_t^0 = \text{Linear}_A(A_t, a_t), \quad
X_t^0 = \text{Linear}_X(X_t),$$
$$Y_t^0, \hat{Y}_t^0 = \text{Linear}_Y(Y_t, \hat{Y}_t), \quad
\tilde{V} = \text{Linear}_V(V)$$
fully-connected 선형 계층의 파라미터는 모든 시간 단계에 대해 공유된다. 모든 블록 $b \geq 2$는 이전 블록 $b - 1$의 출력 시퀀스를 입력으로 사용한다.
Transformer 블록 $b$ 이후의 hidden state 시퀀스는 다음의 세 텐서로 표현된다:
$A^b = \left( \bar{A}_{t-1}^b \cup \bar{a}_{t:t+\tau-1}^b \right)^\top$,
$X^b = (\bar{X}_t^b)^\top$,
$Y^b = \left( \bar{Y}_t^b \cup \hat{Y}_{t+1:t+\tau-1}^b \right)^\top$이다.
Dong et al. (2021), Lu et al. (2021)의 방식을 따라, 각 transformer 블록은 다음 세 가지 구성요소를 포함한다:
(i) multi-head self-/cross-attention,
(ii) feed-forward layer,
(iii) layer normalization이다.
(i) Multi-head self-/cross-attention은 여러 개의 병렬 attention head를 사용하는 scaled dot-product attention 기법이다.
각 attention head는 key, query, value로 구성된 세 개의 입력 $K, Q, V \in \mathbb{R}^{T \times d_{qkv}}$를 요구한다.
이들은 hidden state 시퀀스 $H^b = (h_1^b, \ldots, h_T^b)^\top \in \mathbb{R}^{T \times d_h}$에서 생성된다.
($H^b$는 $A^b$, $X^b$, 또는 $Y^b$ 중 하나이며 subnetwork에 따라 다르다.)
공식적으로 attention은 다음과 같이 계산된다:equation (3)
\[
\text{Attn}^{(i)}(Q^{(i)}, K^{(i)}, V^{(i)}) = \text{softmax}\left( \frac{Q^{(i)} (K^{(i)})^\top}{\sqrt{d_{qkv}}} \right) V^{(i)}
\]
key, query, value는 다음과 같이 정의된다:
\[
Q^{(i)} = Q^{(i)}(H^b) = H^b W_Q^{(i)} + \mathbf{1} b_Q^{(i)\top}
\]
\[
K^{(i)} = K^{(i)}(H^b) = H^b W_K^{(i)} + \mathbf{1} b_K^{(i)\top}
\]
\[
V^{(i)} = V^{(i)}(H^b) = H^b W_V^{(i)} + \mathbf{1} b_V^{(i)\top}
\]
여기서 $W_Q^{(i)}, W_K^{(i)}, W_V^{(i)} \in \mathbb{R}^{d_h \times d_{qkv}}$,
$b_Q^{(i)}, b_K^{(i)}, b_V^{(i)} \in \mathbb{R}^{d_{qkv}}$는 attention head $i$에 대한 학습 가능한 파라미터이다.
$\mathbf{1} \in \mathbb{R}^{d_{qkv}}$는 모든 원소가 1인 벡터이다. (bias를 더해주려는 의도) softmax 연산은 행(row) 단위로 독립적으로 적용된다. query와 key의 차원 $d_{qkv}$는 $d_{qkv} = d_h / n_h$이며, 여기서 $n_h$는 attention head의 개수이다.
저자는 Vaswani et al. (2017)의 원래 multi-head attention에서 마지막 출력 프로젝션 계층을 제거함으로써 overfitting의 위험을 줄이고 구조를 단순화하였다.
CT에서는 self-attention이 동일한 transformer subnetwork의 hidden state 시퀀스를 사용하여 key, query, value를 추론한다. 반면 cross-attention은 나머지 두 개의 transformer subnetwork에서 얻은 hidden state 시퀀스를 key 및 value로 사용한다. 저자는 병렬적인 hidden state 사이의 정보를 교환하기 위해 여러 개의 cross-attention을 사용한다. 이러한 cross-attention은 self-attention layer 위에 배치된다 (세부 구조는 아키텍처 참조). 서로 다른 cross-attention 출력들을 pooling할 때, 정적 공변량의 표현 벡터 $\tilde{V}$를 추가한다. 또한 self-attention과 cross-attention에서의 마스킹을 위해 Eq. (3)의 attention logit을 $-\infty$로 설정하여, 정보가 현재 입력에서 미래의 hidden state로만 흐르도록 한다 (반대 방향은 허용되지 않음).




(ii) Feed-forward layer (FF)는 ReLU 활성화 함수를 갖는 완전 연결 층으로 구성되며, hidden state 시퀀스에 대해 시간 축 기준으로 독립적으로 적용된다. 이는 다음과 같이 표현된다:
\[
\text{FF}(h_t) = \text{Linear} \left( \text{ReLU}(\text{Linear}(h_t)) \right),
\]
여기서 선형 계층은 dropout을 뒤따른다.
(iii) Layer normalization (LN)은 Ba et al. (2016)의 방법을 따른다. 각 self-attention과 cross-attention 이후에 residual connection이 더해지고 layer normalization이 적용된다. layer normalization은 다음과 같이 계산된다:
\[
\text{LN}(h_t) = \frac{\gamma}{\sigma} \odot (h_t - \mu) + \beta,
\]
\[
\mu = \frac{1}{d_h} \sum_{j=1}^{d_h} (h_t)_j, \quad
\sigma = \sqrt{ \frac{1}{d_h} \sum_{j=1}^{d_h} \left( (h_t)_j - \mu \right)^2 },
\]
여기서 $\gamma, \beta \in \mathbb{R}^{d_h}$는 scale 및 shift 파라미터이고, $\odot$는 element-wise 곱을 의미한다.
여기서 질문.... chatgpt 도와줘,,,!!
Q. 왜 treatment, outcome, covariate을 따로 처리하고 cross attention으로 연결하나?




이어서 ..
(balanced) 표현은 B번째 transformer 블록의 두 개 또는 세 개의 병렬 hidden state에 대해 평균 풀링(average pooling)을 수행하여 구성된다.
이때 fully-connected linear layer와 exponential linear unit (ELU) 비선형 함수가 사용된다. 즉,
\[
\tilde{\Phi}_i =
\begin{cases}
\frac{1}{3} (A^B_{i-1} + X^B_i + Y^B_i), & i \in \{1, \ldots, t\}, \\
\frac{1}{2} (a^B_{i-1} + \hat{Y}^B_i), & i \in \{t+1, \ldots, t+\tau-1\},
\end{cases}
\]
\[
\Phi_t = \text{ELU}(\text{Linear}(\tilde{\Phi}_t))
\]
여기서 fully-connected linear layer는 dropout과 함께 사용된다. 최종 balanced representation인 $\Phi_t$는 $\mathbb{R}^{d_r}$ 차원의 벡터이며, $d_r$은 balanced representation의 차원을 나타낸다.
- Positional encoding -
hidden state 들의 순서에 대한 정보를 보존하기위해 postion encoding을 사용했다. 이는 우리가 Treatment A -> side effect S -> Treatment B를 Treatment A -> Treatment B -> side effect S와 구분할수 있게 해주기 때문에 임상 현장과 특히 관련이 있다.
Transformer는 기본저으로 위 두 시퀀스를 구분하지 못하기 때문에 시점 간 상대적인 순서를 넣어줌
해당 논문에서는 Relative Positional Enoding을 사용한다. -> "지금 i번째 시간인데, j번째 과거 정보와의 거리 = 몇 칸 차이냐?"
이런 정보를 직접 attention score 계산에 더해준다. (시간 i에서 attention을 계산할 때, 과거 j와의 거리 (j - i) 정보를 포함시키자.)
\[
a_{ij}^V = w^V_{\text{clip}(j - i, l_{\text{max}})}, \quad
a_{ij}^K = w^K_{\text{clip}(j - i, l_{\text{max}})}
\]
$j$: 과거 정보 인덱스 (과거 치료, 결과 등)
$i$: 현재 시점
$j-i$: 시점차 차이
clip은 너무 먼과거는 무시하기 위해 아래처럼 자른다
\[
\text{clip}(x, l_{\text{max}}) = \max(-l_{\text{max}}, \min(l_{\text{max}}, x))
\]
즉, 너무 멀리 떨어진 과거는 "그냥 먼 과거"로 통일
이 인코딩 벡터는 attention score에 더해진다. 기존 attention socre는 $\alpha_{ij} = \frac{Q_i^\top K_j}{\sqrt{d_{qkv}}}$인데, 여기에 relative encoding을 추가하면 $\alpha_{ij} = \frac{Q_i^\top (K_j + a^K_{ij})}{\sqrt{d_{qkv}}}$이고, attention socre도 다음과 같이 보정한다.
\[
\text{Attn}_i = \sum_{j=1}^t \alpha_{ij} \cdot (V_j + a^V_{ij})
\]
이렇게만 보니까 이해가 잘 안되서 chatgpt에 예제를 부탁했다.
| 시간 t | A(treatment) | X(covariate: 맥박) | Y (Outcome: 백혈구 수) |
| 1 | 0 (미투여) | 95 | 8.1 |
| 2 | 1 (투여) | 97 | 8.3 |
| 3 | 1 | 100 | 8.8 |
| 4 | 0 | 102 | 9.1 |
| 5 | 1 | 105 | 9.5 |
| 6 | ? | ? | ? |
-> 현재 시점 t = 6에서 $a_6$을 했을 때 $Y_6$, $Y_7$, ...를 예측하고 싶은 상황
Transformer 내부에서는 self attention을 계산하는데, i = 6, j = 1,...5 인 상황이고, 문제는 $Y_5$와 $Y_1$을 보고 둘다 그냥 과거 outcome으로 인식하는 것임. 따라서얼마나 최근의 값인가에 대한 정보가 없음. -> $Y_1$의 영향을 $Y_5$보다 더 크게 볼 수도 있음 (임상적으로 말이 안됨)
핵심 아이디어는 i = 6 시점에서 j = 1~5 까지의 과거 시점에 대해 거리 j-i를 계산하고, 이 거리를 인덱스로 한 서로 다른 학습 가능한 고정 벡터를 만든다. ($w_{-5}$, ..., $w_{-1}$)
그래서 이걸 어떻게 활용하냐면
기존 "$Q_i \cdot K_j$ :i 시점 쿼리와 j 시점 키의 내적" 을 "$Q_i \cdot (K_j+w_{j-i})$ :i 시점 쿼리와 j 시점 키의 내적에 i 번째시점에서 j번째 시점이 한칸 앞이다를 추가" 하는 느낌이다.
| j | 원래 $K_j$ | 상대거리 j - i | 더해주는 벡터 $w_{j-i}^K$ | 최종 key |
| 5 | [0.1,−0.2,...][0.1, -0.2, ...] | -1 | [+0.05,−0.01,...][+0.05, -0.01, ...] | [0.15,−0.21,...][0.15, -0.21, ...] |
| 4 | [0.3,+0.0,...][0.3, +0.0, ...] | -2 | [−0.02,+0.01,...][-0.02, +0.01, ...] | [0.28,+0.01,...][0.28, +0.01, ...] |
| 3 | [0.0,+0.4,...][0.0, +0.4, ...] | -3 | [+0.10,−0.03,...][+0.10, -0.03, ...] | [0.10,+0.37,...][0.10, +0.37, ...] |
이런식으로, '아, 얘가1시간 전이면 이 정도 가중치를더 줘야지', '얘는 너무 오래 전이라 좀 무시하자' 가 학습이 잘 되는 쪽으로 수행됨
정리하자면, 여기서 말하는 position encoding이란, key 벡터에 더해지는 일종의 상대 거리를 반영한 bias임. (j-i)를 전체 시퀀스 길이에서 얼만큼 통제할 것인지가 $l_{max}$임
이 범위를 넘어가면 그냥 무시 -> 너무 먼 시점은 고려하지 않겠다. (최대 weight 생성 인덱스: 2*l_max +1)
| $w_{j-i}의 범위 | $−l_{max}<=j−i<=l_{max} |
| 벡터 개수 | $2l_{max}+1개 |
| clip 이유 | 파라미터 수 제한 + 먼 과거는 비슷하게 취급 |
- Training of our Causal Transformer -
본 연구에서 학습하려는 balanced representation은 두 가지 조건을 만족해야 한다.
(1) 미래 결과를 잘 예측해야 한다.
(2) 현재 치료가 무엇이었는지는 예측 못하게 해야한다.
-> 즉, representation만 보고서는 항생제를 줬는지, 안줬는지 구분이 안되도록 만들어야하는데, 왜냐하면 treatment를 예측할 수 있으면 confounding bias가 여전히 들어 있는 것이기 때문
(2)번은 이게 뭔소리여 대체...
ChatGPT와의 수 차례 대화를 통해 아래 결론을 얻었다.
우리는 치료의 효과를 보고 싶음 - 어떤 치료를 했을 때 어떻게 Y가 변하는지?
그런데, 종종 치료는 X를 보고서 의사의 판단에 의해 수행됨 (예: MAP가 65mmHg 이하니까, 승압제를 투여한다.)
인과 추론에서 치료 효과를 추정하기 위한 gold standard가 RCT인 것을 감안하면, 치료 효과를 공정하게 추정하기 위해서는 의사는 랜덤하게 치료를 정해야하고, 이에 따른 결과로 효과를 추정하는 것이 공정하다. 하지만 현재 MAP를 보고 승압제를 주는 이런 경우 말 그대로 위에서 언급했던 representation 만 보고서 치료를 줬는지 안줬는지 알 수 있는 상황이다.
이때의 인과 그래프
X ─► A ─► Y
│ ▲
└───────┘
즉 A->Y 인과 영항을 보고 싶은데, A는 X의 영향을 받으므로, confounding bias가 생긴 것임 ( 이때 인과 그래프는 X -> A -> Y)
("모델이 A 때문에 Y가 변했다고 착각하게 되지만, 실제로는 X가 A도 결정하고, Y도 결정하는 상황" -> "치료 A가 회복 Y에 진짜 영향을 줬는가?" 이걸 알고싶으면 X가 A에게 주는 영향을 제거하면서 동시에 X가 Y에 주는 정보는 남겨야 함)
그래서 본 논문에서 유도하고 싶은 바람직한 인과 그래프는 A -> Y <- X 이다.
따라서 Representation에서 현재 치료가 무엇이었는지는 예측 못하게 한다는 의미는 X -> A로의 인과 정보를 지워버리겠다는 의미이다.
| 환자 | X(나이) | 치료 여부 A | 회복 여부 Y |
| 1 | 90 | 1 (치료함) | 0 (미회복) |
| 2 | 25 | 0 (치료 안 함) | 1 (회복) |
| 3 | 85 | 1 | 0 |
| 4 | 30 | 0 | 1 |
| 5 | 70 | 1 | 0 |
| 6 | 28 | ? | ? |
이 경우 X만 보고도 A를 예측할 수 있음 (나이가 많은 환자는 항생제를 투여받는다)
이 상태로 representation을 만들면 모델은 A가 무엇인지 대충 알고 있게되고, Y 예측 시 A의 영향이 개입된다. 이렇게 되면 치료를 해서 회복이 안되는 건지, 나이가 많아서 회복이 안되는 건지 구분을 못한다.
그래서 하는 것이 adversarial loss의 도입이고, 이놈이 하는 것이 representation은 Y를 예측하는데 유용한 정보인 X와 과거 Y를 남기고, A를 예측하는 데 필요한 신호인 X -> A의 경로를 지워버린다.
- Adversarial balanced representations. -
목적을 다시 요약하자면
- Y 예측기 : representation $tilde{\Phi}_i$로부터 결과 $Y_{t+\tau}$를 잘 맞추게 만들기
- A 예측기 : representation $tilde{\Phi}_i$로부터 결과 $A_{t}$를 맞추지 못하게 만들기
논문에서는 이를 위해 두 개의 feed-forward network (MLP)를 사용했다.

outcome prediction network $G_Y$가 representation $tilde{\Phi}_i$와 현재 treatment $A_t$를 추가로 입력 받음
$G_A$는 representation 만 입력 받음 -> A 예측기
Y를 잘 맞추게 하기 위한 objective function:

“Representation에서는 A를 예측 못 하게 만든다면서, 왜 Y 예측할 때는 A를 집어넣는가?”
-> 미래 시점 Y예측에 도움이 되니까. A가 실제로 Y에 영향을 주는 treatment이기 때문에
representation에서 A를 없애고, treatment를 독립적으로 입력함으로써, Y예측에 confounding 정보로부터 독립적으로 들어와야 하게 만든 것임
전체 수식은 다음의 의미를 가지고 있음 : "representation $tilde{\Phi}_i$가 과거 정보만 압축해 담고 있고, 현재 시점에 어떤 치료 $A_t$를 했는지 명시적으로 따로 알려준 상태에서 미래 Y_{t+1}을 예측하라"
학습구조:
X ───┐
▼
Φ_t (rep) ────► G_Y ───► Y_{t+1}
▲
│
A_t
A를 못 맞추게 하기 위한 objective function
-> Representation이 Y는 예측 가능하지만, A는 예측하지 못하게 만드는 것이 목표

이 형태는 entropy 형태로 이를 minimize 하는 것은 A를 얼마나 잘 맞추는가를 의미함. 하지만 여기서는 잘 못맞추도록 하는 것이 목표
그래서 domain confusion loss를 도입. 여기서 domain은 치료 종류 (class), confusion은 치료 클래스를 구별하지 못하게 만드는 것

이 꼴은 사실 KL divergence (uniform | G_a) 임

Q “어차피 식 (17) 로 A 예측기 $G_A를 망가뜨릴 건데, 식 (16)처럼 A를 잘 예측하게 만들 필요가 있나?”
-> 일단 잘 맞추도록 만들어야 G_A가 높은 정확도를 가진 모델이되는데, 그래야 $tilde{\Phi}_i$가 식 17에 따라 G_A가 잘 못맞추도록 자기 자식의 weight $\theta$를 업데이트할거임. 애초에 G_A가 잘 못맞추는 모델이 되면, $tilde{\Phi}_i$ 굳이 자기 안에 있는 confounding을 제거할 필요가 없음
Overall adversarial objective

Experiments
counterfacture outcome 추론을 위한 benchmark task가 있는데, 이 절차에 따라 CT를 평가했다고 함. 실험에서는 (semi-)syntheric dataset을 기본적으로 사용한다. (real data는 counterfactual outcome의 ground truth가 없음)
Baseline model 선정
해당 연구 분야에서 제안된 state of the art model의 paper에서 사용한것과 동일한 baseline들을 사용 (MSMs, RMSNs, CRN, G-Net)
Experiments with fully-synthetic data
- data -
본 연구에서는 폐암 치료의 효과를 시간에 따라 시뮬레이션 할 수 있는 수학적 모델을 평가 데이터로 사용했다. 이는 실제 환자의 데이터를 쓰는 것이 아니라, 환자가 어떤 치료를 받았을 때 종양이 어떻게 반응할지를 수학적으로 예측해주는 도구이다. ( Prediction of Treatment Response for Combined Chemo- and Radiation Therapy for Non-Small Cell Lung Cancer Patients Using a Bio-Mathematical Model 에서 제안됨)
이 모델은 약물의 체내 흡수와 제거 속도, 방사선의 영향, 종양의 세포 성장률 등을 수학적으로 모델링한 것. 본 연구에서는 실험에서 가짜 환자 데이터를 만들기 위해 사용한다.
본 연구에서는 이 모델을 바탕으로 다음 두 가지 실험을 수행한다.:
1. Single sliding treatment: 한 환자가 한 가지 치료만 계속해서 받는 상황
2. Random trajectories: 환자가 다양한 치료를 섞어서 받는 상황
- Result -

가상의 데이터이기 때문에 공변량이 치료 A 선택에 얼마나 영향을 줄지를 시뮬레이션 할 수 있음 (실제 데이터로는 불가능, 이미 데이터에 내제되어 있음)
1. Z 생성: 환자의 나이, 체온, 산소포화도 등
2. T 결정: T = P(약 B | Z) = sigmoid( $\gamma$ * wᵀZ)
→ $\gamma$로 confounding 정도 조절
3. $\gamma$ 생성: bio-math model 사용 (T와 Z 입력)
→ 예: 약물이 박테리아를 줄이는 ODE 기반 모델
CT가 기존 최신 모델보다 성능이 좋았는데, 특히 confounding $\gamma$ 값이 클 수록, 예측해야 할 시간 간격이 클수록 그 차이가 더 크게 나타났다. CDC loss를 활성화하지 않은 파란색의 경우, 단기 예측엔 큰 영향이 없지만, 장기 예측에는 CDC loss가 중요했다.

이 plot은 balanced representation이 실제로 treatment를 예측하지 않고, outcome 예측만 잘 할 수 있는지를 보기 위한 plot. 좌측에 treatment 별 색상을 입혀보면, 무작위적인 모습을 볼 수 있으나, 우측 outcome에 대한 결과를 보면, 나름 패턴이 있음을 알 수 있음 -> 의도한대로 학습이 되었다
Experiments with semi-synthetic data
해당 실험은 실제 환자 데이터를 기반으로 하되, 결과(Y)는 연구자가 알고 있는 방식으로 생성해서 정답 (counterfactual outcome)을 가진 데이터를 만들고자 함. 이러한 데이터를 semi-synthetic data라 부르며, 이 실험은 CT 모델이 실제 환경에서도 잘 작동하는지를 검증하기 위한 중간 단계 실험이다.
Details on Experiments with Semi-Synthetic Data
사용된 데이터
- MIMIC-extract (MIMIC-III): 데이터 전처리 파이프라인. MIMIC-III을 시간 단위로 정규화 및 보간처리 한 정제된 데이터셋
- 선택된 변수: vital sign 25개, static 3개 (one hot encoding를 사용하여 벡터화) 등 총 44개 변수를 사용
1. Untreated outcome $Z_{j, t}^(i)$ 생성
z는 아직 치료가 적용되지 않았을 때의 결과 값임. 이 값은 아래 3개의 요소가 합쳐져서 만들어진다. (데이터 내에 치료를 수행하지 않은 환자는 outcome을 그대로 사용)

| B-spline(t) | 시간의 흐름에 따라 부드럽게 변하는 베이지안 기저함수. 어떤 변수는 시간이 지나며 자연적으로 변화할 수 있기 때문에 도입함 |
| $g_j^{(i)}(t)$ | 각 환자마다 고유하게 가지는 무작위적인 변화. Gaussian Process(GP) 기반으로 생성됨. → 개인차 반영 |
| $f_j^Z(X_t^{(i)})$ | 시계열 공변량 $Xt(i)X_t^{(i)}에 비선형 함수를 적용해 만든 외생적 영향. RFF 기반 GP 사용 |
| $ε_t$ | 잡음 (Gaussian noise, ε ~ N(0, 0.005²)) |
즉, 이 Z는 환자가 치료를 안했다면 원래 이렇게 진행되었을거야를 시뮬레이션한 베이스라인 trajectory
2. Treatment $A_{t}^(l)$ 생성
환자는 n가지 binary treatment를 받을 수 있다. 각 treatment는 시계열 변수에 따라 다르게 할당되며, confounding을 의도적으로 삽입한다.
아래 수식은 시점t에 대해 treatment l을 받을지 여부를 결정하는 확률 함

| $\bar{A}_{T_{l}} | 지금까지 받은 해당 treatment의 이력 (e.g., 그동안 몇 번이나 약을 맞았는가) |
| $f_{Y}^l(X_t) | 현재 시점의 시계열 변수 X를 기반으로 한 비선형 함수 (treatment 선택 결정기) |
| $\gamma_A, \gamma_X$ | 각각 과거 치료 이력과 공변량 X가 treatment 결정에 미치는 영향력 |
| $b_l$ | 각 treatment별 편향값 (bias) |
| $\sigma | 시그모이드 함수로 확률로 변환 |
treatment 할당은 공변량에 따라, 과거 치료 이력에 따라 다르게 주어지므로, confounding을 의도적으로 삽입한 구성 (논문에서는 총 3가지 치료를 고려함 -> treatment 조합 총 8개)
3. Treatment 효과 모델링 $E_{j}(t)$
치료를 받은 후에는 종속 변수 Y가 변화하게 된다. 여기서는 각 treatment가 특정 결과에 영향을 주며, 그 영향이 시간에 따라 감소한다고 가정한다.

| $\beta_{l,j}$ | treatment l이 outcome j에 미치는 영향 정도 (양수 또는 음수 가능) |
| $w_l$ | treatment 효과가 남아있는 시간 윈도우 |
| $\mathbf{1}_{A_i^{l} = 1}$ | 그 시점에 해당 treatment를 받았는지 여부 |
| $p(A_i^l)$ | 앞서 계산된 treatment 확률 |
| $(w^{l} - i)^2$ | 시간이 지날수록 효과가 역제곱법칙으로 빠르게 감소하는 것 구현 |
4. 최종 Outcome
이제, 치료가주어진 후의 진짜 결과는 아래 처럼 계산됨

즉, Z는 원래 결과, E는 치료효과, Y는 실제 관측된 결과
실험은 ICU 체류 시간이 20시간 이상인 1000명에 대해 20~100시간의 시퀀스를 가진 시계열 데이터를 Trn 60%, vld 20%, test 20%로 분할하여 수행했고, 예측 윈도우 $\tau$는 1부터 10까지 변화시키면서 수행했다.
또한, 각 시점마다 treatment 조합을 바꾸어 8개의 counterfactual를 생성했다.
아무튼 이렇게 함으로, 각 환자에 대해 다양한 counterfacutal outcome을 만들어낼 수 있다. 현실 의료 데이터를 기반으로 결과만 시뮬레이션 했기 때문에 semi-synthetic data가 됨.

