Temporal Transformer for Longitudinal Data

Self-Attention으로 시점 간 관계를 직접 학습

Attention 메커니즘의 직관부터 Scaled Dot-Product Attention, Multi-Head Attention, Positional Encoding, Causal Mask까지 상세히 다룬다. Temporal Transformer의 PyTorch 구현과 Attention Weight 시각화를 통한 해석 가능성을 실무 예시와 함께 설명한다.

Statistics
Deep Learning
저자

Kwangmin Kim

공개

2026년 03월 08일

1 Temporal Transformer for Longitudinal Data

1.1 Attention 메커니즘 직관

1.1.1 LSTM/TCN의 공통 한계

LSTM과 TCN은 모두 시점 간 관계를 간접적으로 학습한다:

  • LSTM: \(h_1 \to h_2 \to ... \to h_T\) — 먼 과거는 여러 단계를 거쳐 압축됨
  • TCN: dilation으로 넓히지만 여전히 로컬 패턴의 계층적 조합

만약 Week 2의 만족도가 Week 8의 이탈에 직접 영향을 준다면, LSTM은 6단계를 거쳐야 그 관계를 학습한다.

1.1.2 Attention의 핵심

Attention은 모든 시점 쌍 \((t_i, t_j)\)의 관계를 직접 계산한다.

LSTM:   t₁ → t₂ → t₃ → t₄ → t₅ → t₆ → t₇ → t₈
        (t₂와 t₈ 사이에 6단계)

Attention:
        t₁ ←→ t₂ ←→ t₃ ←→ t₄ ←→ t₅ ←→ t₆ ←→ t₇ ←→ t₈
        (모든 쌍이 직접 연결, 1단계)

직관: 시험 기간에 특정 과목 점수가 떨어졌다면, Attention은 “어느 시기의 무엇이 지금에 영향을 주는가”를 직접 계산한다.


1.2 Scaled Dot-Product Attention

1.2.1 Query, Key, Value

입력 시퀀스 \(\mathbf{X} \in \mathbb{R}^{T \times d}\)에서 세 행렬을 생성한다:

\[Q = XW_Q, \quad K = XW_K, \quad V = XW_V\]

  • Query (\(Q\)): “나는 어떤 정보를 찾고 있는가”
  • Key (\(K\)): “나는 어떤 정보를 가지고 있는가”
  • Value (\(V\)): “실제로 전달할 정보”

1.2.2 Attention 계산

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V\]

단계별:

  1. 유사도 계산: \(S = QK^T \in \mathbb{R}^{T \times T}\) — 모든 시점 쌍의 유사도
  2. 스케일링: \(S / \sqrt{d_k}\) — 차원이 클수록 내적 값이 커지는 것을 보정
  3. 정규화: \(\text{softmax}(S)\) — 각 Query에 대해 Key들의 가중치를 확률로 변환
  4. 가중 합산: 정규화된 가중치로 Value를 합산

1.2.3 스케일링의 필요성

\(d_k\) 차원의 벡터 내적은 평균 0, 분산 \(d_k\)를 가진다. \(d_k\)가 크면 내적 값이 커지고 softmax가 거의 one-hot이 되어 기울기가 소실된다.

\[\text{Var}(q \cdot k) = d_k \quad \Rightarrow \quad \text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = 1\]

1.2.4 예시: T=4 시점

Q·K^T (스케일 후):
         Week1  Week2  Week3  Week4
Week1  [  1.0   0.3    0.1    0.0 ]
Week2  [  0.4   1.0    0.5    0.1 ]
Week3  [  0.1   0.6    1.0    0.7 ]
Week4  [  0.0   0.2    0.8    1.0 ]

softmax 후 (각 행의 합 = 1):
         Week1  Week2  Week3  Week4
Week1  [  0.50  0.25   0.15   0.10 ]
Week2  [  0.20  0.35   0.30   0.15 ]
Week3  [  0.05  0.20   0.35   0.40 ]
Week4  [  0.03  0.10   0.37   0.50 ]

→ Week 4의 표현을 계산할 때:
  V₄ = 0.03·V₁ + 0.10·V₂ + 0.37·V₃ + 0.50·V₄
  → Week 3과 Week 4의 정보에 가장 집중

1.3 Multi-Head Attention

1.3.1 왜 필요한가

단일 Attention은 하나의 “관점”에서만 시점 간 관계를 본다. 종단 데이터에서는 여러 관점이 필요하다:

  • 만족도 트렌드: 어느 시점에서 하락이 시작되었나?
  • 행동 패턴: 대화 횟수가 급변한 시점은?
  • 개인화 효과: 개인화 적용 전후 차이는?

1.3.2 구조

\(h\)개의 헤드가 각각 독립적으로 Attention을 계산한다:

\[\text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i)\]

\[\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W_O\]

  • 각 헤드의 차원: \(d_k = d_v = d_{\text{model}} / h\)
  • 출력 투영 \(W_O\)로 원래 차원으로 복원

파라미터 수: 단일 Attention과 거의 동일 (차원을 나눠 쓰므로).

d_model=64, n_heads=4:
  head 1: d_k=16 — 만족도 단기 변화에 집중
  head 2: d_k=16 — 행동 패턴 장기 추세에 집중
  head 3: d_k=16 — 개인화 전환점에 집중
  head 4: d_k=16 — 감정 변동에 집중

1.4 Positional Encoding

1.4.1 문제: Attention은 순서를 모른다

Attention은 집합(set) 연산이다. \(\{x_1, x_2, x_3\}\)\(\{x_3, x_1, x_2\}\)에 같은 결과를 낸다. 시계열에서 시간 순서는 핵심 정보이므로, 별도로 주입해야 한다.

1.4.2 Sinusoidal Positional Encoding

위치 \(\text{pos}\), 차원 \(i\)에 대해:

\[PE_{(\text{pos}, 2i)} = \sin\left(\frac{\text{pos}}{10000^{2i/d_{\text{model}}}}\right)\]

\[PE_{(\text{pos}, 2i+1)} = \cos\left(\frac{\text{pos}}{10000^{2i/d_{\text{model}}}}\right)\]

특성:

  • 고유성: 각 위치가 고유한 인코딩을 가짐
  • 상대적 관계: \(PE_{\text{pos}+k}\)\(PE_{\text{pos}}\)의 선형 변환으로 표현 가능
  • 학습 불필요: 고정된 함수로 계산 (학습 가능한 PE도 사용 가능)
import torch
import numpy as np

def get_positional_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(max_len).unsqueeze(1).float()
    div_term = torch.exp(
        torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
    )
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe  # (max_len, d_model)

# 시각화
pe = get_positional_encoding(max_len=20, d_model=64)
# pe[0] ≠ pe[1] ≠ ... ≠ pe[19] — 각 위치가 고유

1.4.3 종단 데이터에서의 시간 인코딩

주간 데이터인 경우 Week 1~8이 위치 0~7에 대응한다. 불규칙 시점인 경우 실제 시간값을 위치로 사용하거나, 학습 가능한 Temporal Encoding을 적용한다.


1.5 Causal Mask: 미래 정보 차단

1.5.1 왜 필요한가

자기회귀(autoregressive) 예측에서는 시점 \(t\)의 출력이 \(t+1\) 이후의 정보를 참조하면 안 된다.

1.5.2 하삼각 마스크

def generate_causal_mask(T):
    """하삼각 마스크 — True인 위치는 -inf로 마스킹"""
    mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
    return mask

# T=5 예시
mask = generate_causal_mask(5)
# tensor([[False,  True,  True,  True,  True],
#         [False, False,  True,  True,  True],
#         [False, False, False,  True,  True],
#         [False, False, False, False,  True],
#         [False, False, False, False, False]])

Attention 스코어에 마스크를 적용:

\[\text{score}_{ij} = \begin{cases} \frac{q_i \cdot k_j}{\sqrt{d_k}} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}\]

\(-\infty\)는 softmax 후 0이 되어 미래 정보가 완전히 차단된다.

1.5.3 분류 vs 자기회귀

태스크 Causal Mask 이유
시퀀스 분류 (이탈 예측) 선택적 전체 시퀀스를 보고 판단해도 됨
다음 시점 예측 필수 미래 정보 누수 방지
실시간 모니터링 필수 현재까지의 정보만 사용

1.6 Temporal Transformer: PyTorch 구현

import torch
import torch.nn as nn
import numpy as np


class TemporalTransformer(nn.Module):
    """종단 데이터 분류/예측용 Temporal Transformer"""

    def __init__(self, input_size, d_model=64, n_heads=4,
                 n_layers=2, dim_feedforward=128, dropout=0.1,
                 max_len=50, use_causal_mask=False):
        """
        Args:
            input_size: 입력 피처 수
            d_model: Transformer 내부 차원
            n_heads: Multi-head attention 헤드 수
            n_layers: Transformer encoder 층 수
            dim_feedforward: FFN 히든 차원
            dropout: 드롭아웃 비율
            max_len: 최대 시퀀스 길이
            use_causal_mask: 미래 마스킹 여부
        """
        super().__init__()
        self.d_model = d_model
        self.use_causal_mask = use_causal_mask

        # 입력 투영
        self.input_proj = nn.Sequential(
            nn.Linear(input_size, d_model),
            nn.LayerNorm(d_model),
            nn.Dropout(dropout)
        )

        # Positional Encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
            activation="gelu"
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layers
        )

        # 분류 헤드
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        """
        Args:
            x: (batch, T, features)
            mask: (batch, T) — 1=유효, 0=패딩
        Returns:
            (batch, 1)
        """
        batch_size, T, _ = x.shape

        # 입력 투영 + Positional Encoding
        x = self.input_proj(x) + self.pe[:, :T, :]

        # 패딩 마스크 (True = 무시할 위치)
        key_padding_mask = (mask == 0)

        # Causal 마스크 (선택적)
        attn_mask = None
        if self.use_causal_mask:
            attn_mask = torch.triu(
                torch.ones(T, T, device=x.device), diagonal=1
            ).bool()

        # Transformer Encoder
        out = self.transformer(
            x,
            mask=attn_mask,
            src_key_padding_mask=key_padding_mask
        )

        # 마지막 유효 시점의 representation 추출
        seq_lengths = mask.sum(dim=1).long()
        last = out[torch.arange(batch_size), (seq_lengths - 1).clamp(min=0), :]

        return self.classifier(last)


# 사용 예
model = TemporalTransformer(
    input_size=4,       # satisfaction, turn_count, emotion_score, personalized
    d_model=64,
    n_heads=4,
    n_layers=3,
    dim_feedforward=128,
    dropout=0.1,
    max_len=50,
    use_causal_mask=False  # 분류 태스크 → causal 불필요
)

1.7 Attention Weight 시각화

1.7.1 Attention 추출

def extract_attention_weights(model, seq, mask):
    """Transformer의 attention weight를 추출"""
    model.eval()
    attention_weights = []
    hooks = []

    def hook_fn(module, input, output):
        """MultiheadAttention의 attention weight 캡처"""
        # need_weights=True일 때 output = (attn_output, attn_weights)
        if isinstance(output, tuple) and len(output) == 2:
            attention_weights.append(output[1].detach())

    # Attention 층에 훅 등록
    for layer in model.transformer.layers:
        hook = layer.self_attn.register_forward_hook(hook_fn)
        hooks.append(hook)

    # Forward (need_weights=True 설정 필요)
    # PyTorch TransformerEncoderLayer는 기본적으로 average_attn_weights=True
    with torch.no_grad():
        _ = model(seq.unsqueeze(0), mask.unsqueeze(0))

    # 훅 제거
    for hook in hooks:
        hook.remove()

    return attention_weights  # list of (1, T, T) per layer

1.7.2 시각화 코드

import matplotlib.pyplot as plt
import seaborn as sns

def plot_attention_heatmap(attention_weights, layer_idx=-1, title="Attention Weights"):
    """Attention heatmap 시각화"""
    attn = attention_weights[layer_idx][0].numpy()  # (T, T)
    T = attn.shape[0]

    fig, ax = plt.subplots(figsize=(8, 6))
    sns.heatmap(
        attn,
        xticklabels=[f"W{i+1}" for i in range(T)],
        yticklabels=[f"W{i+1}" for i in range(T)],
        cmap="YlOrRd",
        annot=True,
        fmt=".2f",
        ax=ax
    )
    ax.set_xlabel("Key (참조하는 시점)")
    ax.set_ylabel("Query (현재 시점)")
    ax.set_title(title)
    plt.tight_layout()
    plt.show()


def plot_importance_bar(attention_weights, target_step=-1, layer_idx=-1):
    """특정 시점의 attention 분포 — 어느 시점이 중요한가?"""
    attn = attention_weights[layer_idx][0].numpy()  # (T, T)
    T = attn.shape[0]

    # target_step이 -1이면 마지막 유효 시점
    if target_step == -1:
        target_step = T - 1

    weights = attn[target_step, :]

    fig, ax = plt.subplots(figsize=(8, 3))
    bars = ax.bar(range(1, T + 1), weights, color="steelblue")
    # 가장 중요한 시점 강조
    max_idx = weights.argmax()
    bars[max_idx].set_color("crimson")

    ax.set_xlabel("Week")
    ax.set_ylabel("Attention Weight")
    ax.set_title(f"Week {target_step+1} 예측에 기여하는 각 시점의 중요도")
    plt.tight_layout()
    plt.show()


# 사용
attn_weights = extract_attention_weights(model, test_seq, test_mask)
plot_attention_heatmap(attn_weights, layer_idx=-1, title="마지막 Encoder Layer의 Attention")
plot_importance_bar(attn_weights, target_step=-1, layer_idx=-1)

1.7.3 해석 예시

AI Agent 이탈 예측 — 8주 시퀀스:

마지막 시점(Week 8)의 Attention 분포:
  Week 1: 0.05  (초기 온보딩 — 영향 적음)
  Week 2: 0.08
  Week 3: 0.10
  Week 4: 0.22  ← 개인화 시작 시점 (높은 가중치)
  Week 5: 0.25  ← 개인화 적응 시점 (최고 가중치)
  Week 6: 0.15
  Week 7: 0.10
  Week 8: 0.05

→ 비즈니스 인사이트:
  "Week 4~5의 사용자 경험이 장기 리텐션을 결정한다."
  "개인화 초기 2주가 핵심 전환 기간(critical window)이다."
  → 이 시기에 만족도가 떨어지면 집중 개입 필요

1.8 해석 가능성: Attention Heatmap의 가치

1.8.1 통계 모델 vs Transformer 해석

접근법 해석 방법 시점 간 관계
LMM 고정/랜덤 효과 계수 시점 간 직접 비교 어려움
LSTM SHAP (사후 해석) 비용 높음, 근사적
Transformer Attention weight 모델 내장, 시점 쌍별 관계

1.8.2 주의사항

Attention weight가 항상 “인과적 중요도”를 의미하지는 않는다:

  • Attention은 예측에 유용한 정보의 흐름을 보여줌
  • 인과적 효과와는 다를 수 있음 (예: confounding)
  • 해석 시 도메인 지식과 결합해야 함

1.9 실무 예시: AI Agent — 8주 만족도 시퀀스에서 핵심 시점 발견

# 전체 파이프라인

# 1. 데이터 준비
features = ["satisfaction", "turn_count", "emotion_score", "personalized"]
train_dataset = LongitudinalDataset(train_df, features, "will_churn", max_len=12)
val_dataset = LongitudinalDataset(val_df, features, "will_churn", max_len=12)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)

# 2. 모델 학습
model = TemporalTransformer(
    input_size=4, d_model=64, n_heads=4, n_layers=3, max_len=12
)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.01)
criterion = nn.BCELoss()

best_val_auc = 0
for epoch in range(100):
    # 학습
    model.train()
    for seq, mask, target, lengths in train_loader:
        pred = model(seq, mask)
        loss = criterion(pred, target)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # 검증
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for seq, mask, target, lengths in val_loader:
            pred = model(seq, mask)
            preds.extend(pred.numpy().flatten())
            labels.extend(target.numpy().flatten())

    from sklearn.metrics import roc_auc_score
    val_auc = roc_auc_score(labels, preds)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), "best_transformer.pt")

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: val_AUC={val_auc:.4f}")

# 3. Attention 분석
model.load_state_dict(torch.load("best_transformer.pt"))
attn = extract_attention_weights(model, test_seq, test_mask)
plot_attention_heatmap(attn, layer_idx=-1)
plot_importance_bar(attn, target_step=-1)

1.10 LSTM vs TCN vs Transformer 비교

항목 LSTM TCN Transformer
시퀀스 처리 순차적 병렬 병렬
장기 의존성 보통 (게이트) 우수 (dilated) 매우 우수 (직접 참조)
수용 범위 이론적 무한 설계로 결정 전체 시퀀스
병렬 연산 불가 가능 가능
메모리 \(O(T)\) \(O(T)\) \(O(T^2)\) (Attention)
해석 가능성 낮음 낮음 Attention heatmap
가변 길이 자연스러움 패딩 필요 패딩 + 마스크
데이터 요구량 적음 중간 많음
구현 복잡도 낮음 중간 중간~높음
불규칙 시점 불가 불가 부분적 가능

1.10.1 선택 가이드

시퀀스 길이 < 20, 데이터 소규모    → LSTM/GRU
시퀀스 길이 > 50, 속도 중요        → TCN
시점 간 관계 해석 필요             → Transformer
불규칙 측정 시점                   → Neural ODE (다음 파일)
데이터 < 500명                    → 통계 모델 (LMM/GAMM) 우선

1.11 요약

항목 내용
Scaled Dot-Product Attention \(\text{softmax}(QK^T / \sqrt{d_k}) V\) — 모든 시점 쌍의 관계 직접 계산
Multi-Head Attention \(h\)개 헤드가 각각 다른 관점에서 관계 학습
Positional Encoding Sinusoidal 함수로 시간 순서 주입
Causal Mask 하삼각 마스크로 미래 정보 차단 (자기회귀 시)
핵심 장점 장기 의존성 + 해석 가능성 (Attention heatmap)
핵심 한계 \(O(T^2)\) 메모리, 데이터 요구량 높음

다음: 28 — Neural ODE: 연속 시간 역학으로 불규칙 측정 데이터 모델링

Subscribe

Enjoy this blog? Get notified of new posts by email: