Paper-Weekly08-Learning To Retrieve Prompts for In-Context Learning
in-context-learning的例子(prompt, examples, templates…)对任务的表现影响还是比较大的,如何找到合适的例子是关键。这篇NAACL的文章主要解决的是如何召回好的prompts。提出了Efficient Prompt Retrival方法,使用一个小的LM对召回的例子进行打分排序,构建正负样本对,再做对比学习。值得注意的是如何构建数据集和如何训练。
如何构建数据集
training阶段,对于给定的样本对,首先需要构建一些正负样本对训练召回器。使用sparse或dense的retriver,按照如下打分方式进行排序: $$ s(\overline{e}l)=Prob{\hat{g}}(y|\overline{e}_l,x) $$ 即计算在给定当前例子和input序列x时,得到目标序列y的概率,概率越大说明给的例子对当前的生成更有帮助。得到分数后对例子进行排序,取topk和bottomk,由于之前一步已经是用target y召回过了所以能保证这里的例子是good prompt或者hard difficult prompt。
训练过程
分别使用input encoder和prompt encoder对输入进行编码,使用的是bert的cls embedding。
每个训练样本有2B个cls embedding对,一个x,一个e+,一个e-,B-1个batch内正例,B-1个batch内batch内负例,一共1个正对,2B-1个负对。
Our training instances are of the form ⟨xi, ei+, ei−,1, . . . e− i,2B−1⟩. Where the positive example ei+ is sampled from the set E(i) pos, and our negative examples consist of one hard negative example sampled from E(i) neg, B − 1 positive examples from the other instances in the same mini-batch, and the B − 1 hard negatives from those instances.
训练损失如下: $$ L(x_i,e_i^+,e_{i,1}^-,…,e_{i,2B-1}^-)=-log\frac{e^{sim(x_i,e_i^+)}}{e^{sim(x_i,e_i^+)}+\sum e^{sim(x_i,e_i^-)}} $$
推理过程
使用FAISS提前储存好index to embedding映射,计算好x text的embedding,再使用MIP(maximum inner product)搜索最相似的L个prompt,这个L由LM可以接受的最长输入决定。至于不同的例子如何放、放多少效果最好,文章没有进一步探索。
实验结果
在三个数据集上评估:
- BREAKA:dataset mapping complex natural language questions into a language-based meaning representation, where a question is decomposed into an ordered list of atomic steps. We use the low-level BREAK subset, containing 44K/7K/8K examples in its training/development/test sets.
- MTop: A semantic parsing dataset, focused on task-oriented dialogue, where commands are mapped to complex nested queries across 11 domains. Similar to past work (Pasupat et al., 2021), we use the English subset of MTOP, containing 16K/2K/4K examples in its training/development/test sets.
- SMCALFLOW: A large English-language task-oriented dataset that covers tasks such as calendar, weather, places, and people. The meaning representation is a dataflow program, which includes API calls, function composition and complex constraints. SMCALFLOW includes 15K development set examples and 134K training examples, from which we sample a random set of 44K examples for training.