ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Toolformer: Language Models Can Teach Themselves to Use Tools 논문 리뷰
    LLM papers 2024. 4. 2. 15:50
    728x90

    [2302.04761] Toolformer: Language Models Can Teach Themselves to Use Tools (arxiv.org)

     

    Toolformer: Language Models Can Teach Themselves to Use Tools

    Language models (LMs) exhibit remarkable abilities to solve new tasks from just a few examples or textual instructions, especially at scale. They also, paradoxically, struggle with basic functionality, such as arithmetic or factual lookup, where much simpl

    arxiv.org

     

    Abstract

    언어 모델이 몇 가지 예시나 텍스트의 지시사항만으로 많은 일들을 해나가고 있지만, 산술연산, 사실 조회에서는 고전하고 있다. 본 논문은 언어 모델이 간단한 API를 통해서 외부도구를 사용할 수 있는 방법에 대해서 다룬다.

     

    Toolformer : 어떤 API를 호출할지, 언제 호출할지, 어떤 arguments를 전달할지, 결과를 미래 토큰 예측에 어떻게 최선으로 통합할지를 결정하는 모델

     

    이 모델을 통해서, calculator, Q&A system, search engine, 번역, 달력 등을 포함한 도구들을 통합한다. 본 논문에서 제시한 Toolformer는 다양한 downstream task에서 획기적으로 개선된 zero-shot 성능을 보여주며, 더 큰 언어모델과의 경쟁에서도 핵심 언어 모델링 능력을 보유한다.

     

    1. Introduction

    언어 모델의 산술연산, 사실 조회와 같은 한계점을 극복하기 위해 본 논문에서는 Toolformer를 제안한다.

    Toolformer는 LM의 강점과 외부 도구의 기능을 결합하는 것을 목표로 API를 통해서 외부 도구를 사용하도록 학습한 모델이다.

     

    본 논문의 introduction에 소개된 toolformer의 주요 목표를 살펴보자.

    1. 도구의 사용은 인간의 annotations가 아닌 self-supervised 방식으로 학습되어야한다.

    -> 주석에 대한 cost 뿐 아니라, 인간이 유용하다고 여기는 것이 모델이 유용하다고 여기는 것과 다를 수 있기 때문에 중요하다는 것이다.

    2. LM은 generality를 잃지 않고, 스스로 언제 어떤 도구를 사용할지 결정해야 한다.

     

    이러한 목표들을 통해서 calculator, Q&A system, search engine, 번역, 달력 모두를 포함하게 되며, 이러한 도구들을 통합함으로써, 핵심 언어 모델링 기능을 손상시키지 않으면서 다양한 downstream task에서 zero-shot 성능을 크게 향상키기게 된다.

     

    GPT-J 모델을 사전학습시키고, 6.78B parameters를 가진 것이 더 큰 GPT-3 model의 성능을 뛰어넘는 것을 보여준다.

     

    2. Approach

    본 논문에서 LM이 간단한 API를 통해서 외부 도구를 사용하도록 교육하는 것이 중점이라고 하는데, 어떤식으로 접근하는지를 살펴보자.

     

     

    API call은 tuple c = (a_c, i_c)로 표현된다.

    a_c는 API의 이름, i_c 는 input과 일치하며, API가 c를 call 했을 때, r 이라는 결과를 얻을 수 있다.

     

     

    선형적인 API call의 예시들을 이 시각자료에서 볼 수 있다. 텍스트 시퀀스에 같은 형태가 포함되는 것을 확인해볼 수 있다.

     

    이제 C = {x_1, ..., x_|c|} 라는 dataset이 주어졌을 때를 살펴보자.

    먼저 API call을 위해 C* 이라고 하는 증강 dataset으로 만들고 세 가지 step을 따른다.

     

    시각자료를 통해서 단계를 보다 쉽게 확인해보자.

    1. M의 in-context learning 능력을 활용해서 다양한 potential API call을 샘플링한다.

    2. API call을 실행하고, 얻은 응답이 미래 토큰 예측에 도움이 되는지 확인. -> 필터링 기준

    3. 필터링 후에 다른 도구들의 API 호출을 병합해서 확장 데이터 셋 C*를 만들고 M을 fine-tuning 한다.

     

    이 과정을 자세하게 살펴보자.

     

    Sampling API Calls

    각 API에서 LM이 x = x_1, ..., x_n을 API 호출로 주석을 달도록 만든 prompt P(x)를 작성한다.

    PM(z_n+1 | z_1, ..., z_n)을 M이 시퀀스 z_1, ..., z_n에 대해 연속으로 토큰 z_n+1에 할당하는 확률이라고 하면, 각 i에 대해서 p_i = PM(<API> | P(x), x1; i-1) 식이 위치 i에서 API 호출을 시작하는 데 M이 할당하는 확률을 계산해서 API 호출을 위한 최대 k개의 후보 위치를 샘플링한다.

    [P(x), x1, ..., x_i-1 < API>]를 접두사로 하고, </API>를 시퀀스 종료 토큰으로 사용해서 M에서 최대 m개의 API 호출인 ci_1, ..., ci_m을 얻게 된다.

     

    generate API calls를 QA에서 사용한 prompt P(x)의 예시

     

    Executing API Calls

    M이 생성한 API call을 실행해서 결과를 얻는다.

    API call c_i에 대한 응답은 단일 텍스트 시퀀스 r_i 여야한다.

     

    Filtering API Calls

    x = x_1, ..., x_n 시퀀스에 대해 c_i를 API call 하면 r_i가 API에 대한 응답이 된다.

    M에 대한 weighted cross entropy loss에 대한 식이다. 이 식은 모델이 z 접두사로 이루어진 경우이다.

     

     

    Li_+ : API call과  그 결과가 M에 접두사로 주어졌을 때, 모든 토큰 x_i, ..., x_n에 대한 weighted loss

    Li_- : API call을 전혀 하지 않거나 API call은 하지만, 응답을 제공하지 않는 경우 loss의 최솟값

     

     

    filtering을 위한 식으로, 필터링 임계값 r_f가 주어졌을 때, 이 식을 만족하는 API call만 유지한다. API call과 결과를 추가해서 어떤 API 호출도 하지 않거나, 결과를 얻지 못한 경우에 비해 loss를 적어도 f 만큼은 줄이는 것이다.

     

    Model Finetuning

    샘플링과 필터링이 끝난 후에 남은 API call 들을 원본 입력과 결합한다. 

    input text x = x_1, ..., x_n에 해당하는 API call의 결과를 (c_i, r_i)라고 가지고 있을 때, 새로운 시퀀스 x = x_1; i-1, e(c_i, r_i), x_i;를 구성하게 된다.

    이 작업을 모든 x에 대해 수행함으로써, API call이 추가된 새로운 데이터 셋 C*을 얻게 된다. 추가된 API call을 제외하고는 C*과 C가 같은 내용을 가지게 되므로 C*로 M을 파인튜닝한다.

    이렇게 파인튜닝을 시키면 API call이 정확히 M이 미래 토큰을 예측하는데 도움이 되는 위치와 입력에 삽입될 수 있으며, 언어모델이 자체적인 피드백 기반으로 언제 어떻게 도구를 사용할지를 결정할 수 있게 된다.

     

    Inference

    파인튜닝 후에 Inference 과정을 실시한다. M이 "->" 토큰을 생성할 때까지 정규 디코딩을 수행하는데, 이 토큰은 API call에 대한 응답을 기대한다는 것을 의미한다. 이 시점에서 디코딩 과정을 잠시 중단하고, 적절한 API call을 통해 응답을 얻은 후, 응답과 </API> 토큰을 모두 삽입한 후 디코딩 과정을 지속한다.

     

    3. Tools

    1. 입출력이 텍스트 시퀀스로 표현되는가

    2. 의도된 사용의 몇 가지 증명을 얻을 수 있는가

    오로지 두 가지의 제약만 가지고 QA, Calculator, Wikipedia Search, Machine Translation System, Calendar 라는 다섯가지 Tool을 평가한다.

     

    4. Experiments

    - 한 가지 고려된 Tool이 유용하다고 가정한 다양한 downstream 작업 선택

    - zero-shot에서 성능 평가

     

    4.1 Experimental Setup

    Dataset Generation

    - 데이터셋 C로 CCNet subset 사용

    - 언어모델 M으로 GPT-J 사용

    - C 주석 처리 cost 줄이기 위해, A에 대한 휴리스틱 정의

     

    Model Finetuning

    - C*을 batch size : 128, learning rate : 1* 10(-5) 사용해서 M finetuning

     

    Baseline Models

     

    4.2 Downstream Tasks

    4.2.1 LAMA

     

    4.2.2 Math Datasets

     

    4.2.3 Question Answering

     

     

    4.2.4 Multilingual QA

     

    De, Zh, Ar 언어에서 Toolformer는 GPT-J를 넘어서지 못하는 것을 볼 수 있는데, 본 논문에서는 이 점을 CCNet에서의 파이튜닝이 성능을 저하시키기 때문이며, GPT-J의 원래 사전 훈련 데이터와 비교했을 때의 분포 이동 때문일 수 있다고 한다.

    -> 분포 이동은 모델이 학습한 데이터의 기본 분포가 테스트 or 배포 중에 발생하는 데이터와 비교해서 달라지는 것을 말한다. CCNet의 파인튜닝에 사용되는 훈련 데이터가 GPT-J의 원래 사전 학습 데이터와 크게 다른 경우에 분포이동이 발생하게 되는데, 이러한 데이터 분포 차이로 인해 모델을 일반화하고 새로운 작업이나 데이터셋을 수행하는 능력에 영향을 미칠 수 있게 되는 것이다.

     

    4.2.5 Temporal Datasets

    달력을 실험하기 위해서 TEMPLAMA 라는 Temporal Datasets를 구축

     

    4.3 Language Modeling

     

    CCNet 파인튜닝을 통해서 혼란도가 낮아지는 성능을 보이지만, WikiText와 비교했을 때는 떨어지는 성능을 보인다. 이걸 보면 CCNet 데이터셋이 무작위로 10,000개 문서를 뽑았다는데, 확실히 데이터셋이 WikiText와 비교해서 좋지 않은 것 같다.

     

    4.4 Scaling Laws

     

     

    시각자료를 살펴보면, Tool을 제대로 사용할 수 있는 능력은 775M parameter부터 보이기 시작한다.

    QA 벤치마크의 Wikipedia 검색엔진에서 예외가 발생한느데, 이는 API 사용이 상대적으로 쉽기 때문이라고 한다.

    기본적으로 모델이 크기가 커질수록 API call 없이도 성능이 뛰어나지만, API를 활용하는 능력도 향상된다. 이를 통해서 API call을 사용하거나 사용하지 않는 상황의 격차는 크게 나타나게 된다.

     

    5. Analysis

    Decoding Strategy

    <API> 토큰이 k개의 가장 가능성이 높은 토큰 중 하나일 경우 이를 생성한다. 시각자료를 살펴보자.

     

    시각자료에서는 LAMA의 T-Rex 부분과 WebQS 부분의 성능을 다양한 k값에 대해서 보여주고 있다. k를 증가시키면 모델이 더 많은 예제에서 API call을 수행하게 되는데, k=1 일 때와 k=10 일 때의 차이를 살펴보면 엄청난 증가를 보여주는 것을 볼 수 있다. WebQS에서는 k를 1에서 3으로 살짝 증가시키더라도 엄청나게 많은 APi call이 이루어지는 것을 볼 수 있다.

     

    Data Quality

     

    Li_(-) - Li_(+) 점수를 통해서 API call의 질적인 부분을 살펴보자. 앞서 API call을 필터링하는 부분에서 나왔던 식인데, 이 점수가 높을수록 유용한 API call이고, 낮을수록 미래 토큰 예측에 유용한 정보를 제공하지 않는 API call이라는 것을 볼 수 있다. 

     

    6. Related Work

    - LM Pretraining

    - Tool Use

    - Bootstrapping

     

    7. Limitations

    - Toolformer가 Tool을 연쇄적으로 사용할 능력이 없다는 것.

    연쇄적으로 사용한다는 것은 하나의 도구 출력을 다른 도구의 입력으로 사용하는 것을 말하는데, 현재까지는 API call 자체가 독립적으로 생성되기 때문에 파인튜닝 데이터셋에서도 연쇄적으로 도구를 사용할 수는 없다.

    - LM이 Tool을 상호작용적으로 사용할 수 없다는 것.

    - API call을 결정할 때 input의 표현에 민감하다. (프롬프트에 LM이 민감하기 때문)

    - API call을 결정할 때, 현재 API call로 인해 발생하는 Tool 별 cost를 고려하지 않는다.

     

    8. Conclusion

    본 논문은 어떻게 검색 엔진, 달력, 번역 시스템들을 API call 만으로 간단하게 해결할 수 있는 이러한 도구들을 사용하는지를 self-supervised 방식으로 LM이 학습하도록 하는 Toolformer에 대해서 소개했다. API call을 단순히 파인튜닝함으로써 해결할 수 있었으며, 6.78B GPT-J 모델의 zero-shot 성능을 향상시켜서 다양한 downstream task에서 훨씬 큰 GPT-3 모델의 성능을 능가할 수 있게 되었다.

     

     

    논문을 읽으면서 API call에 대한 GPT finetuning의 위대함을 다시 볼 수 있었다. 최근 GPT api를 사용한 파인튜닝으로 수많은 LLM 플랫폼과 앱들이 출시되고 있는 만큼 이 논문에서 제시한 downstream task 5개 뿐만 아니라, 더 많은 곳에서 파인튜닝만으로 다양한 도구들을 활용할 수 있는 모델의 발전을 희망한다.

     

    728x90
Designed by Tistory.