在Transformers 中使用約束波束搜索引導(dǎo)文本生成

引言
本文假設(shè)讀者已經(jīng)熟悉文本生成領(lǐng)域波束搜索相關(guān)的背景知識,具體可參見博文?如何生成文本: 通過 Transformers 用不同的解碼方法生成文本。
與普通的波束搜索不同,約束?波束搜索允許我們控制所生成的文本。這很有用,因?yàn)橛袝r(shí)我們確切地知道輸出中需要包含什么。例如,在機(jī)器翻譯任務(wù)中,我們可能通過查字典已經(jīng)知道哪些詞必須包含在最終的譯文中; 而在某些特定的場合中,雖然某幾個(gè)詞對于語言模型而言差不多,但對最終用戶而言可能卻相差很大。這兩種情況都可以通過允許用戶告訴模型最終輸出中必須包含哪些詞來解決。
這事兒為什么這么難
然而,這個(gè)事情操作起來并不容易,它要求我們在生成過程中的?某個(gè)時(shí)刻?在輸出文本的?某個(gè)位置?強(qiáng)制生成某些特定子序列。
假設(shè)我們要生成一個(gè)句子?S
,它必須按照先?t1?再?t2?的順序包含短語?p1=t1,t2。以下定義了我們希望生成的句子?S:
S期望=s1,s2,…,sk,t1,t2,sk+1,…,sn
問題是波束搜索是逐詞輸出文本的。我們可以大致將波束搜索視為函數(shù)?B(s0:i)=si+1,它根據(jù)當(dāng)前生成的序列?s0:i?預(yù)測下一時(shí)刻?i+1?的輸出。但是這個(gè)函數(shù)在任意時(shí)刻?i<k?怎么知道,未來的某個(gè)時(shí)刻?k?必須生成某個(gè)指定詞?或者當(dāng)它在時(shí)刻?i=k?時(shí),它如何確定當(dāng)前那個(gè)指定詞的最佳位置,而不是未來的某一時(shí)刻?i>k?

如果你同時(shí)有多個(gè)不同的約束怎么辦?如果你想同時(shí)指定使用短語?p1=t1,t2?和?短語?p2=t3,t4,t5,t6?怎么辦?如果你希望模型在兩個(gè)短語之間?任選一個(gè)?怎么辦?如果你想同時(shí)指定使用短語?p1?以及短語列表?p21,p22,p23?中的任一短語怎么辦?
上述需求在實(shí)際場景中是很合理的需求,下文介紹的新的約束波束搜索功能可以滿足所有這些需求!
我們會先簡要介紹一下新的?約束波束搜索?可以做些什么,然后再深入介紹其原理。
例 1: 指定包含某詞
假設(shè)我們要將?"How old are you?"
?翻譯成德語。它對應(yīng)兩種德語表達(dá),其中?"Wie alt bist du?"
?是非正式場合的表達(dá),而?"Wie alt sind Sie?"
?是正式場合的表達(dá)。
不同的場合,我們可能傾向于不同的表達(dá),但我們?nèi)绾胃嬖V模型呢?
使用傳統(tǒng)波束搜索
我們先看下如何使用?傳統(tǒng)波束搜索?來完成翻譯。
!pip install -q git+https://github.com/huggingface/transformers.git
python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLMtokenizer = AutoTokenizer.from_pretrained("t5-base")model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")encoder_input_str = "translate English to German: How old are you?"input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_idsoutputs = model.generate( ? ?input_ids, ? ?num_beams=10, ? ?num_return_sequences=1, ? ?no_repeat_ngram_size=1, ? ?remove_invalid_values=True,)print("Output:\n" + 100 *'-')print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:----------------------------------------------------------------------------------------------------Wie alt bist du?
使用約束波束搜索
但是如果我們想要一個(gè)正式的表達(dá)而不是非正式的表達(dá)呢?如果我們已經(jīng)先驗(yàn)地知道輸出中必須包含什么,我們該如何?將其?注入到輸出中呢?
我們可以通過?model.generate()
?的?force_words_ids
?參數(shù)來實(shí)現(xiàn)這一功能,代碼如下:
python
tokenizer = AutoTokenizer.from_pretrained("t5-base")model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")encoder_input_str = "translate English to German: How old are you?"force_words = ["Sie"]input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_idsforce_words_ids = tokenizer(force_words, add_special_tokens=False).input_idsoutputs = model.generate( ? ?input_ids, ? ?force_words_ids=force_words_ids, ? ?num_beams=5, ? ?num_return_sequences=1, ? ?no_repeat_ngram_size=1, ? ?remove_invalid_values=True,)print("Output:\n" + 100 *'-')print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:----------------------------------------------------------------------------------------------------Wie alt sind Sie?
如你所見,現(xiàn)在我們能用我們對輸出的先驗(yàn)知識來指導(dǎo)文本的生成。以前我們必須先生成一堆候選輸出,然后手動(dòng)從中挑選出符合我們要求的輸出。現(xiàn)在我們可以直接在生成階段做到這一點(diǎn)。
例 2: 析取式約束
在上面的例子中,我們知道需要在最終輸出中包含哪些單詞。這方面的一個(gè)例子可能是在神經(jīng)機(jī)器翻譯過程中結(jié)合使用字典。
但是,如果我們不知道要使用哪種 _詞形_呢,我們可能希望使用單詞?rain
?但對其不同的詞性沒有偏好,即?["raining", "rained", "rains", ...]
?是等概的。更一般地,很多情況下,我們可能并不刻板地希望?逐字母一致?,此時(shí)我們希望劃定一個(gè)范圍由模型去從中選擇最合適的。
支持這種行為的約束叫?析取式約束 (Disjunctive Constraints)?,其允許用戶輸入一個(gè)單詞列表來引導(dǎo)文本生成,最終輸出中僅須包含該列表中的?至少一個(gè)?詞即可。
下面是一個(gè)混合使用上述兩類約束的例子:
python
from transformers import GPT2LMHeadModel, GPT2Tokenizermodel = GPT2LMHeadModel.from_pretrained("gpt2")tokenizer = GPT2Tokenizer.from_pretrained("gpt2")force_word = "scared"force_flexible = ["scream", "screams", "screaming", "screamed"]force_words_ids = [ ? ?tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids, ? ?tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,]starting_text = ["The soldiers", "The child"]input_ids = tokenizer(starting_text, return_tensors="pt").input_idsoutputs = model.generate( ? ?input_ids, ? ?force_words_ids=force_words_ids, ? ?num_beams=10, ? ?num_return_sequences=1, ? ?no_repeat_ngram_size=1, ? ?remove_invalid_values=True,)print("Output:\n" + 100 *'-')print(tokenizer.decode(outputs[0], skip_special_tokens=True))print(tokenizer.decode(outputs[1], skip_special_tokens=True))
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.Output:----------------------------------------------------------------------------------------------------The soldiers, who were all scared and screaming at each other as they tried to get out of theThe child was taken to a local hospital where she screamed and scared for her life, police said.
如你所見,第一個(gè)輸出里有?"screaming"
?,第二個(gè)輸出里有?"screamed"
?,同時(shí)它們都原原本本地包含了?"scared"
?。注意,其實(shí)?["screaming", "screamed", ...]
?列表中不必一定是同一單詞的不同詞形,它可以是任何單詞。使用這種方式,可以滿足我們只需要從候選單詞列表中選擇一個(gè)單詞的應(yīng)用場景。
傳統(tǒng)波束搜索
以下是傳統(tǒng)?波束搜索?的一個(gè)例子,摘自之前的?博文:

與貪心搜索不同,波束搜索會保留更多的候選詞。上圖中,我們每一步都展示了 3 個(gè)最可能的預(yù)測詞。
在?num_beams=3
?時(shí),我們可以將第 1 步波束搜索表示成下圖:

波束搜索不像貪心搜索那樣只選擇?"The dog"
?,而是允許將?"The nice"
?和?"The car"
?留待進(jìn)一步考慮?。
下一步,我們會為上一步創(chuàng)建的三個(gè)分支分別預(yù)測可能的下一個(gè)詞。

雖然我們?考查?了明顯多于?num_beams
?個(gè)候選詞,但在每步結(jié)束時(shí),我們只會輸出?num_beams
?個(gè)最終候選詞。我們不能一直分叉,那樣的話,?beams
?的數(shù)目將在?n?步后變成?beamsn?個(gè),最終變成指數(shù)級的增長 (當(dāng)波束數(shù)為?10?時(shí),在?10?步之后就會變成?10,000,000,000?個(gè)分支!)。
接著,我們重復(fù)上述步驟,直到滿足中止條件,如生成?<eos>
?標(biāo)記或達(dá)到?max_length
?。整個(gè)過程可以總結(jié)為: 分叉、排序、剪枝,如此往復(fù)。
約束波束搜索
約束波束搜索試圖通過在每一步生成過程中 _注入_所需詞來滿足約束。
假設(shè)我們試圖指定輸出中須包含短語?"is fast"
?。
在傳統(tǒng)波束搜索中,我們在每個(gè)分支中找到?k
?個(gè)概率最高的候選詞,以供下一步使用。在約束波束搜索中,除了執(zhí)行與傳統(tǒng)波束搜索相同的操作外,我們還會試著把約束詞加進(jìn)去,以?看看我們是否能盡量滿足約束。圖示如下:

上圖中,我們最終候選詞除了包括像?"dog"
?和?"nice"
?這樣的高概率詞之外,我們還把?"is"
?塞了進(jìn)去,以盡量滿足生成的句子中須含?"is fast"
?的約束。
第二步,每個(gè)分支的候選詞選擇與傳統(tǒng)的波束搜索大部分類似。唯一的不同是,與上面第一步一樣,約束波束搜索會在每個(gè)新分叉上繼續(xù)強(qiáng)加約束,把滿足約束的候選詞強(qiáng)加進(jìn)來,如下圖所示:

組 (Banks)
在討論下一步之前,我們停下來思考一下上述方法的缺陷。
在輸出中野蠻地強(qiáng)制插入約束短語?is fast
?的問題在于,大多數(shù)情況下,你最終會得到像上面的?The is fast
?這樣的無意義輸出。我們需要解決這個(gè)問題。你可以從?huggingface/transformers
?代碼庫中的這個(gè)?問題?中了解更多有關(guān)這個(gè)問題及其復(fù)雜性的深入討論。
組方法通過在滿足約束和產(chǎn)生合理輸出兩者之間取得平衡來解決這個(gè)問題。
我們把所有候選波束按照其?滿足了多少步約束
分到不同的組中,其中組?n?里包含的是?滿足了?n?步約束的波束列表?。然后我們按照順序輪流選擇各組的候選波束。在上圖中,我們先從組 2 (Bank 2) 中選擇概率最大的輸出,然后從組 1 (Bank 1) 中選擇概率最大的輸出,最后從組 0 (Bank 0) 中選擇最大的輸出; 接著我們從組 2 (Bank 2) 中選擇概率次大的輸出,從組 1 (Bank 1) 中選擇概率次大的輸出,依此類推。因?yàn)槲覀兪褂玫氖?num_beams=3
,所以我們只需執(zhí)行上述過程三次,就可以得到?["The is fast", "The dog is", "The dog and"]
。
這樣,即使我們?強(qiáng)制?模型考慮我們手動(dòng)添加的約束詞分支,我們依然會跟蹤其他可能更有意義的高概率序列。盡管?The is fast
?完全滿足約束,但這并不是一個(gè)有意義的短語。幸運(yùn)的是,我們有?"The dog is"
?和?"The dog and"
?可以在未來的步驟中使用,希望在將來這會產(chǎn)生更有意義的輸出。
圖示如下 (以上例的第 3 步為例):

請注意,上圖中不需要強(qiáng)制添加?"The is fast"
,因?yàn)樗呀?jīng)被包含在概率排序中了。另外,請注意像?"The dog is slow"
?或?"The dog is mad"
?這樣的波束實(shí)際上是屬于組 0 (Bank 0) 的,為什么呢?因?yàn)楸M管它包含詞?"is"
?,但它不可用于生成?"is fast"
?,因?yàn)?fast
?的位子已經(jīng)被?slow
?或?mad
?占掉了,也就杜絕了后續(xù)能生成?"is fast"
?的可能性。從另一個(gè)角度講,因?yàn)?slow
?這樣的詞的加入,該分支?滿足約束的進(jìn)度?被重置成了 0。
最后請注意,我們最終生成了包含約束短語的合理輸出:?"The dog is fast"
?!
起初我們很擔(dān)心,因?yàn)槊つ康靥砑蛹s束詞會導(dǎo)致出現(xiàn)諸如?"The is fast"
?之類的無意義短語。然而,使用基于組的輪流選擇方法,我們最終隱式地?cái)[脫了無意義的輸出,優(yōu)先選擇了更合理的輸出。
關(guān)于?Constraint
?類的更多信息及自定義約束
我們總結(jié)下要點(diǎn)。每一步,我們都不斷地糾纏模型,強(qiáng)制添加約束詞,同時(shí)也跟蹤不滿足約束的分支,直到最終生成包含所需短語的合理的高概率序列。
在實(shí)現(xiàn)時(shí),我們的主要方法是將每個(gè)約束表示為一個(gè)?Constraint
?對象,其目的是跟蹤滿足約束的進(jìn)度并告訴波束搜索接下來要生成哪些詞。盡管我們可以使用?model.generate()
?的關(guān)鍵字參數(shù)?force_words_ids
?,但使用該參數(shù)時(shí)后端實(shí)際發(fā)生的情況如下:
python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PhrasalConstrainttokenizer = AutoTokenizer.from_pretrained("t5-base")model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")encoder_input_str = "translate English to German: How old are you?"constraints = [ ? ?PhrasalConstraint( ? ? ? ?tokenizer("Sie", add_special_tokens=False).input_ids ? ?)]input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_idsoutputs = model.generate( ? ?input_ids, ? ?constraints=constraints, ? ?num_beams=10, ? ?num_return_sequences=1, ? ?no_repeat_ngram_size=1, ? ?remove_invalid_values=True,)print("Output:\n" + 100 *'-')print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Output:----------------------------------------------------------------------------------------------------Wie alt sind Sie?
你甚至可以定義一個(gè)自己的約束并將其通過?constraints
?參數(shù)輸入給?model.generate()
?。此時(shí),你只需要?jiǎng)?chuàng)建?Constraint
?抽象接口類的子類并遵循其要求即可。你可以在?此處?的?Constraint
?定義中找到更多信息。
我們還可以嘗試其他一些有意思的約束 (尚未實(shí)現(xiàn),也許你可以試一試!) 如?OrderedConstraints
?、?TemplateConstraints
?等。目前,在最終輸出中約束短語間是無序的。例如,前面的例子一個(gè)輸出中的約束短語順序?yàn)?scared -> screaming
?,而另一個(gè)輸出中的約束短語順序?yàn)?screamed -> scared
?。 如果有了?OrderedConstraints
, 我們就可以允許用戶指定約束短語的順序。?TemplateConstraints
?的功能更小眾,其約束可以像這樣:
python
starting_text = "The woman"template = ["the", "", "School of", "", "in"]possible_outputs == [ ? "The woman attended the Ross School of Business in Michigan.", ? "The woman was the administrator for the Harvard School of Business in MA."]
或是這樣:
python
starting_text = "The woman"template = ["the", "", "", "University", "", "in"]possible_outputs == [ ? "The woman attended the Carnegie Mellon University in Pittsburgh.",]impossible_outputs == [ ?"The woman attended the Harvard University in MA."]
或者,如果用戶不關(guān)心兩個(gè)詞之間應(yīng)該隔多少個(gè)詞,那僅用?OrderedConstraint
?就可以了。
總結(jié)
約束波束搜索為我們提供了一種將外部知識和需求注入文本生成過程的靈活方法。以前,沒有一個(gè)簡單的方法可用于告訴模型 1. 輸出中需要包含某列表中的詞或短語,其中 2. 其中有一些是可選的,有些必須包含的,這樣 3. 它們可以最終生成至在合理的位置。現(xiàn)在,我們可以通過綜合使用?Constraint
?的不同子類來完全控制我們的生成!
該新特性主要基于以下論文:
Guided Open Vocabulary Image Captioning with Constrained Beam Search
Fast Lexically Constrained Decoding with Dynamic Beam Allocation for Neural Machine Translation
Improved Lexically Constrained Decoding for Translation and Monolingual Rewriting
Guided Generation of Cause and Effect
與上述這些工作一樣,還有許多新的研究正在探索如何使用外部知識 (例如 KG (Knowledge Graph) 、KB (Knowledge Base) ) 來指導(dǎo)大型深度學(xué)習(xí)模型輸出。我們希望約束波束搜索功能成為實(shí)現(xiàn)此目的的有效方法之一。
感謝所有為此功能提供指導(dǎo)的人: Patrick von Platen 參與了從?初始問題?討論到?最終 PR?的全過程,還有 Narsil Patry,他們二位對代碼進(jìn)行了詳細(xì)的反饋。
本文使用的圖標(biāo)來自于?Freepik - Flaticon。
英文原文:?https://hf.co/blog/constrained-beam-search
原文作者: Chan Woo Kim
譯者: Matrix Yao (姚偉峰),英特爾深度學(xué)習(xí)工程師,工作方向?yàn)?transformer-family 模型在各模態(tài)數(shù)據(jù)上的應(yīng)用及大規(guī)模模型的訓(xùn)練推理。
審校/排版: zhongdongy (阿東)