본문 바로가기

AI/LLM

LLM Tuning - DPO (Direct Peference Optimization)

LLM Tuning

최근의 언어 모델은 이미 학습이 된 Pre-trained model을 사용자의 취향과 목적에 맞춰 튜닝하는 쪽으로 진행되고 있다. 거기에 필요한 기술들은 굉장히 다양하므로 언제 한 번 정리할 생각인데, 일단 업무상 필요한 부분만 정리해 보려고 한다.

LLM 튜닝의 목적에 따라서 Instruction Tuning과 AI Alignment로 분류할 수 있다. 전자는 모델의 기능을 확보하는 것으로, 어떠한 작업을 수행할 수 있도록 그 예시에 대한 데이터를 학습시키는 것이다. 후자는 AI 가 사용자가 선호하는 방식으로 결과를 출력하게 하는 것이다. 사실 그 두 가지는 경계가 모호한 면이 있고, 실제 데이터셋에서도 이 두가지가 명확하게 구분되지는 않는 것 같다.

다만 아래 설명할 DPO의 경우 AI Alignment 에 맞춘 튜닝 방법으로 데이터셋에서부터 '선호 답변' 과 '비선호 답변'으로 구분하여 선호 답변 쪽으로 모델이 출력하게 유도하는 방식으로 학습된다.

DPO (Direct Peference Optimization)

참고 : Direct Preference Optimization: Your Language Model is Secretly a Reward Model


출처 : Direct Preference Optimization: Your Language Model is Secretly a Reward Model

DPO는 RLHF (Reinforcement Learning from Human Feedback) 에서부터 나온 LLM 튜닝 방식이다. RLHF 는 사람의 피드백을 통해 답변의 선호 여부를 받아 학습시키는 방법이다. 이 과정에서 사람의 선호 여부를 판단하는 Reward 모델이라는 것을 학습하여 이를 평가자로 삼고, 모델이 이런저런 답변들을 내면서 Reward 모델이 평가하도록 하여 학습을 진행한다. 하지만 DPO는 이러한 과정을 간략화하여 더 직접적으로 선호도 (Preference)를 학습할 수 있다고 한다.

사실 이 내용은 나로서는 잘 이해가 가지 않았다. 애초에 강화학습 (RL : Reinforcement Learning) 자체가 직접적으로 학습될 수 없는 문제에 대해서 학습하기 위해 나온 방법론으로, 직접적으로 학습이 가능하다면 강화학습 자체가 필요 없다. 따라서 내 기준으론 DPO가 나온 시점에서 애초에 강화학습이란 개념에 얽매일 필요가 없어보이는 것이다. 그런데 DPO에서는 굳이 강화학습의 개념을 가져와 설명하는데 이것이 어떤 의미를 가지는지 솔직히 잘 모르겠다. 오히려 이해하기 더 복잡해지는 느낌.

그래서 나는 DPO를 단순하게 선호/비선호의 pair data 를 가지고 모델을 Fine-tuning하는 개념으로 접근하여 보고 있다.

Loss

DPO의 Loss 함수는 아래와 같다.


출처 : Direct Preference Optimization: Your Language Model is Secretly a Reward Model

실제 loss 함수에 대한 코드는 trl 라이브러리 에서 확인할 수 있다. 해당 코드 내용은 아래와 같다. (군더더기는 정리했다.)

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:

        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps
        logits = pi_logratios - ref_logratios

        losses = (
            - F.logsigmoid( self.beta * logits) * (1 - self.label_smoothing)
            - F.logsigmoid(-self.beta * logits) * self.label_smoothing
        )

        chosen_rewards   = self.beta * ( policy_chosen_logps - reference_chosen_logps )
        rejected_rewards = self.beta * ( policy_rejected_logps - reference_rejected_logps )

        return losses, chosen_rewards, rejected_rewards

trl 라이브러리 : trl/trainer/dpo_trainer.py

실제로는 후속 연구들로 추가된 여러 loss 함수들이 구현되어 있으나 default 로 sigmoid 기반 loss만 남겨두었다. 또한 torch 관련 함수들 (데이터 값의 디바이스 설정 등) 은 가독성을 위해 제외하였다.

위 식에는 2가지 Hyper Parameter 가 들어간다.

이름 기본 값 값 범위 설명
Beta 0.1 0.1 ~ 0.5 temperature parameter로 확률이 더 치우쳐지거나 완만하게 조절 (참고)
label
smoothing
0.0 0.0 ~ 0.5 라벨이 부정확할 때 보정하는 변수로 모델이 데이터에 과도한 확신을 가지지 않도록 함

입력값에 대한 설명은 아래와 같다.

이름 설명
logps 해당 데이터를 모델이 생성한다고 할 때 생성 확률에 log를 적용한 것
chosen/rejected 데이터에서 선호/비선호 항목에 대한 값
policy/reference 학습 중인 모델/참조 모델 (학습 모델의 초기값) 으로부터 계산한 값

logps 는 간단히 말하자면 모델이 해당 데이터를 생성할 확률이다. 이 부분을 이해하기 위해서는 Perplexity 개념을 살펴보면 좋다.
참고 : [NLP] 언어모델의 평가지표 'Perplexity' 개념 및 계산방법

모델로부터 어떤 시퀀스가 나올 확률은, 시퀀스의 각 원소들이 생성될 확률의 곱과 같다. 따라서 어떤 문장이 생성될 확률은 각 단어 (token)들이 생성될 확률들의 곱이 된다. 그리고 일반적으로 이렇게 확률에 대해 곱연산이 반복적으로 계산될 때 한없이 작은 값이 되므로 log를 취해주기 마련이다.

기본적으로 DPO에서는 학습 시 2개의 모델을 로드한다. 그 중 하나에 대해서 학습이 진행되며, 다른 모델을 참조용(reference)으로 사용한다. 이는 DPO가 기존에 학습 된 언어의 자연스러움에서 과하게 벗어나지 않도록 보정하기 위한 장치이다. 실제로 loss 함수에서 보면 참조 모델에서의 생성 확률로 normalize하는 것을 볼 수 있다.



또한 DPO에서는 선호/비선호 쌍으로 된 데이터를 받아 학습한다. 전반적으로 생성 확률에 있어서 선호 데이터 형태가 비선호 데이터 형태보다 높게 나오도록 설정하는 것이라고 볼 수 있다. 이 부분은 loss 계산시의 logsigmoid 함수의 형태를 보면 더 확실하게 알 수 있다.


출처 : PyTorch:LOGSIGMOID

logsigmoid 함수는 sigmoid 함수에 log를 취한 것으로, 음수 값에 대해서는 선형에 가깝게 동작하지만 양수 값에 대해서는 0이 되며, 0 인근에서 연속적이 되도록 smoothing 되는 함수다. 위 loss 함수를 간략하게 만들면 log({선호 생성확률}) - log({비선호 생성확률}) 이므로 비선호 생성확률이 선호 생성확률보다 클 경우 음수가 된다. 따라서 선호 생성확률이 비선호 생성확률보다 충분히 클 경우 loss 는 0이 되고, 반대의 경우만 loss가 발생하게 된다.

한편 여기서 계산되는 chosen/rejected reward는 코드상에서 학습 진행 정도를 보기 위한 metric으로 사용되는 것으로 보인다.

'AI > LLM' 카테고리의 다른 글

LLM Fine Tuning & Catastrophic Forgetting  (0) 2024.08.10
OpenAI API 사용법 정리 (Python)  (0) 2024.03.13
LLM Text Generation 및 Generation Parameters  (2) 2024.02.17
LLM Evaluation  (0) 2024.02.02