기본 콘텐츠로 건너뛰기

IMDb 데이터셋의 label 열이 DistilBERT 모델 forward() 함수의 labels 인자로 전달되는 과정

IMDb 데이터셋의 label 열이 DistilBERT 모델 forward() 함수의 labels 인자로 전달되는 과정

1. 개요

허깅 페이스의 Transformers 라이브러리를 사용하여 모델을 훈련할 때 데이터셋의 label 항목이 어떤 과정을 거쳐 모델의 forward(..., labels, ...) 메소드로 전달되는지 설명합니다.

  • IMDb 데이터셋에서 추출한 한 개의 데이터 예시

    {
        "text": "I love sci-fi...",
        "label": 0
    }
    
  • DistilBertForSequenceClassification 클래스의 forward() 메소드

    transformers/models/distilbert/modeling_distilbert.py

    def forward(
        ...,
        labels: Optional[torch.LongTensor] = None,
        ...
    )-> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        ...
    

이 문서에서는 IMDb 데이터셋과 DistilBertForSequenceClassification 모델을 사용하여 설명하지만 특정 데이터셋과 모델에만 해당하는 것은 아닙니다.

이 문서에서 등장하는 주요 API는 다음과 같습니다.

  • datasets.default_data_collator
  • class datasets.Dataset
  • class datasets.DatasetDict
  • class transformers.DataCollatorWithPadding
  • class transformers.DistilBertForSequenceClassification
  • class transformers.Trainer

2. 트레이너 (Trainer)

2.1. Trainer.__init()__

  • 데이터 콜레이터를 인자로 받아들이는 Trainer 객체 생성

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_imdb["train"],
        eval_dataset=tokenized_imdb["test"],
        processing_class=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    
  • Trainer 클래스의 생성자에서 사용자 지정 데이터 콜레이터를 쓸 것인지, 아니면 기본 데이터 콜레이터를 쓸 것인지 결정

    transformers/trainer.py

    class Trainer:
        def __init__(self, ..., data_collator, ...):
            ...
            default_collator = (
                DataCollatorWithPadding(processing_class)
                if processing_class is not None
                and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
                else default_data_collator
            )
            self.data_collator = data_collator if data_collator is not None else default_collator
            ...
    

2.2. Trainer.train()

  • 훈련 과정에서 배치 데이터를 얻기 위하여 __init__() 메소드 내에서 정했던 데이터 콜레이터를 이용

    transformers/trainer.py

    class Trainer:
        ...
        def train(self, ...):
            ...
            find_executable_batch_size(self._inner_training_loop, ...)
            ...
        
        def _inner_training_loop(self, batch_size=None, ...):
            ...
            train_dataloader = self.get_train_dataloader()
            ...
            for epoch in range(epochs_trained, num_train_epochs):
                epoch_dataloader = train_dataloader
                ...
                epoch_iterator = iter(epoch_dataloader)
                ...
                for _ in range(total_updates):
                    batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
                    ...
                    for i, inputs in enumerate(batch_samples):
                        ...
                        tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                        ...
            ...
        
        def get_train_dataloader(self) -> DataLoader:
            ...
            data_collator = self.data_collator
            ...
            
        def training_step(self, model, inputs, ...) -> torch.Tensor:
            ...
            loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
            ...
    
        def compute_loss(self, model, inputs, ...)
            ...
            outputs = model(**inputs)
            ...
            return (loss, outputs) if return_outputs else loss
    

3. 데이터 콜레이터 (Data Collator)

데이터 콜레이터는 데이터셋으로부터 배치 크기의 데이터를 추출하여 반환하는 역할을 수행합니다. 트레이너는 훈련 과정에서 데이터를 공급받기 위하여 응용 프로그램에서 직접 생성하여 지정한 데이터 콜레이터를 사용하거나, 그렇지 않으면 기본 데이터 콜레이터를 사용합니다.

Transformers 라이브러리는 다음 세 종류의 기본 데이터 콜레이터를 구현하고 있습니다.

  • PyTorch - torch_default_data_collator
  • TensorFlow - tf_default_data_collator
  • NumPy - numpy_default_data_collator

3.1. 데이터 콜레이터를 직접 생성하여 지정하는 경우

  • DataCollatorWithPadding 객체를 생성하여 Trainer 객체 생성 시 인자로 전달

    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
    
    trainer = Trainer(
        ...
        data_collator=data_collator,
        ...
    )
    
    trainer.train()
    
  • DataCollatorWithPadding 객체 호출 시 label 항목을 labels로 변경

    transformers/data/data_collator.py

    class DataCollatorWithPadding:
        def __call__(self, ...):
            batch = pad_without_fast_tokenizer_warning(...)
            if "label" in batch:
                batch["labels"] = batch["label"]
    	        del batch["label"]
            if "label_ids" in batch:
                batch["labels"] = batch["label_ids"]
                del batch["label_ids"]
    	    return batch
    

3.2. 데이터 콜레이터를 지정하지 않는 경우

  • Trainer 객체 생성 시 data_collator 파라미터 지정하지 않음

    trainer = Trainer(
        ...
        data_collator=None,
        ...
    )
    
  • PyTorch의 경우 torch_default_data_collator() 함수 호출 시 label 항목을 labels로 변경

    def torch_default_data_collator():
        ...
        if "label" in ...:
    	    ...
            batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
    

4. 정리

  • 데이터 콜레이터가 데이터셋의 label 항목을 labels 항목으로 변경하고 트레이너가 labels 항목을 모델의 forward(..., labels, ...) 메소드 인자로 전달합니다.

Written with StackEdit.

댓글

이 블로그의 인기 게시물

Windows에 AMP와 MediaWiki 설치하기

1. 들어가기     AMP는 Apache + MySQL +  Perl/PHP/Python에 대한 줄임말이다. LAMP (Linux + AMP)라고 하여 Linux에 설치하는 것으로 많이 소개하고 있지만 Windows에서도 간편하게 설치하여 사용할 수 있다.       이 글은 Windows 7에 Apache + MySQL + PHP를 설치하고 그 기반에서 MediaWiki를 설치하여 실행하는 과정을 간략히 정리한 것이다. 2. MySQL     * 버전 5.6.12     1) 다운로드         http://dev.mysql.com/downloads/installer/         MySQL Installer 5.6.12         Windows (x86, 32-bit), MSI Installer         (mysql-installer-web-community-5.6.12.0.msi)     2) 다운로드한 MSI 파일을 더블클릭하여 설치를 진행한다.           설치 위치:                   C:\Program Files\MySQL               선택 사항:                       Install MySQL Products             Choosing a Se...

MATLAB Rutime 설치하기

MATLAB Rutime 설치하기 미설치시 에러 MATLAB Runtime 을 설치하지 않은 환경에서 MATLAB 응용프로그램이나 공유 라이브러리를 사용하려고 하면 아래와 같은 에러 메시지가 표시될 것입니다. 처리되지 않은 예외: System.TypeInitializationException: 'MathWorks.MATLAB.NET.Utility.MWMCR'의 형식 이니셜라이저에서 예 외를 Throw했습니다. ---> System.TypeInitializationException: 'MathWorks.MATLAB.NET.Arrays.MWArray'의 형식 이니셜라이저에서 예외를 Throw했습니다. ---> System.DllNotFoundException: DLL 'mclmcrrt9_3.dll'을(를) 로드할 수 없습니다. 지정된 모듈을 찾을 수 없습니다. (예외가 발생한 HRESULT: 0x8007007E) 위치: MathWorks.MATLAB.NET.Arrays.MWArray.mclmcrInitialize2(Int32 primaryMode) 위치: MathWorks.MATLAB.NET.Arrays.MWArray..cctor() --- 내부 예외 스택 추적의 끝 --- 위치: MathWorks.MATLAB.NET.Utility.MWMCR..cctor() --- 내부 예외 스택 추적의 끝 --- 위치: MathWorks.MATLAB.NET.Utility.MWMCR.processExiting(Exception exception) 해결 방법 이 문제를 해결하기 위해서는 MATLAB Runtime 을 설치해야 합니다. 여러 가지 방법으로 MATLAB Runtime 을 설치할 수 있습니다. MATLAB 이 설치되어 있는 경우에는 MATLAB 설치 폴더 아래에 있는 MATLAB Runtime 설치 프로그램을 실행하여 설치합니다. ...

Wi-Fi 카드 2.4GHz로만 동작시키기

Wi-Fi 카드 2.4GHz로만 동작시키기 별도의 Wi-Fi AP 장치를 두지 않고 아래와 같은 기기들로만 Wi-Fi 네트워크를 구성하고자 할 때 주변 기기들이 2.4GHz만 지원하기 때문에 PC에서 실행하는 AP가 항상 2.4GHz를 사용하도록 Wi-Fi 카드를 설정해 주어야 합니다. 기기 Wi-Fi 카드 주파수 대역 Wi-Fi Direct 지원 PC (Windows 10) 2.4GHz, 5GHz O 주변 기기들 2.4GHz X Wi-Fi 카드별 주파수 대역 선택 방법 Windows 시작 메뉴에서 설정 을 클릭합니다. Windows 설정 화면에서 네트워크 및 인터넷 을 클릭합니다. 설정 화면의 왼쪽 메뉴바에서 Wi-Fi 를 클릭합니다. 화면 오른쪽 관련 설정 구역에 있는 어댑터 옵션 변경 을 클릭합니다. 설정을 바꾸고자 하는 Wi-Fi 카드 항목을 선택하고 마우스 오른쪽을 누른 다음 속성 메뉴를 클릭합니다. 대화상자의 네트워킹 탭 화면에 있는 구성 버튼을 클릭합니다. 장치 속성 대화상자의 고급 탭 화면으로 이동합니다. 제시되는 속성 항목들은 제품별로 다르며 자세한 사항은 아래의 제품별 설명을 참고하여 값을 설정하시기 바랍니다. Intel Dual Band Wireless-AC 7265 기술 사양 주파수 대역: 2.4GHz, 5GHz 무선 표준: 802.11ac 주파수 대역 선택 장치 속성 대화상자에서 아래와 같이 선택합니다. Wireless Mode 1. 802.11a => 5GHz 4. 802.11b/g => 2.4GHz (이 항목 선택) 6. 802.11a/b/g => 2.4GHz, 5GHz Intel Dual Band Wireless-AC 8265 기술 사양 주파수 대역: 2.4GHz, 5GHz 무선 표준: 802.11ac 주파수 대역 선택 장치 속성 대화상자에서 아래와 같이 ...