Paper-Weekly08-Learning To Retrieve Prompts for In-Context Learning

in-context-learning的例子(prompt, examples, templates…)对任务的表现影响还是比较大的,如何找到合适的例子是关键。这篇NAACL的文章主要解决的是如何召回好的prompts。提出了Efficient Prompt Retrival方法,使用一个小的LM对召回的例子进行打分排序,构建正负样本对,再做对比学习。值得注意的是如何构建数据集和如何训练。

如何构建数据集

training阶段,对于给定的样本对,首先需要构建一些正负样本对训练召回器。使用sparse或dense的retriver,按照如下打分方式进行排序: $$ s(\overline{e}l)=Prob{\hat{g}}(y|\overline{e}_l,x) $$ 即计算在给定当前例子和input序列x时,得到目标序列y的概率,概率越大说明给的例子对当前的生成更有帮助。得到分数后对例子进行排序,取topk和bottomk,由于之前一步已经是用target y召回过了所以能保证这里的例子是good prompt或者hard difficult prompt。

训练过程

分别使用input encoder和prompt encoder对输入进行编码,使用的是bert的cls embedding。

每个训练样本有2B个cls embedding对,一个x,一个e+,一个e-,B-1个batch内正例,B-1个batch内batch内负例,一共1个正对,2B-1个负对。

Our training instances are of the form ⟨xi, ei+, ei−,1, . . . e− i,2B−1⟩. Where the positive example ei+ is sampled from the set E(i) pos, and our negative examples consist of one hard negative example sampled from E(i) neg, B − 1 positive examples from the other instances in the same mini-batch, and the B − 1 hard negatives from those instances.

训练损失如下: $$ L(x_i,e_i^+,e_{i,1}^-,…,e_{i,2B-1}^-)=-log\frac{e^{sim(x_i,e_i^+)}}{e^{sim(x_i,e_i^+)}+\sum e^{sim(x_i,e_i^-)}} $$

推理过程

使用FAISS提前储存好index to embedding映射,计算好x text的embedding,再使用MIP(maximum inner product)搜索最相似的L个prompt,这个L由LM可以接受的最长输入决定。至于不同的例子如何放、放多少效果最好,文章没有进一步探索。

实验结果

在三个数据集上评估:

  • BREAKA:dataset mapping complex natural language questions into a language-based meaning representation, where a question is decomposed into an ordered list of atomic steps. We use the low-level BREAK subset, containing 44K/7K/8K examples in its training/development/test sets.
  • MTop: A semantic parsing dataset, focused on task-oriented dialogue, where commands are mapped to complex nested queries across 11 domains. Similar to past work (Pasupat et al., 2021), we use the English subset of MTOP, containing 16K/2K/4K examples in its training/development/test sets.
  • SMCALFLOW: A large English-language task-oriented dataset that covers tasks such as calendar, weather, places, and people. The meaning representation is a dataflow program, which includes API calls, function composition and complex constraints. SMCALFLOW includes 15K development set examples and 134K training examples, from which we sample a random set of 44K examples for training.
陈沁宇
陈沁宇
Master Student@PKU

My research interests include natural language processing, machine learning and recommender systems.