diff --git a/README.md b/README.md index 8af19606..4f17888e 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ Below shows the generation speed gain by using FastSeq. | Model | W/O FastSeq (in samples/s) | W/ FastSeq (in samples/s) | Speedup | |------------------|:--------------------------:|:-------------------------:|:-----:| | [ProphetNet](examples/prophetnet/README.md) | 2.7 | 10.3 | 3.8x | -| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 12.5 | 4.6x | -| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 8.1 | 2.4x | -| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 8.5 | 2.1x | -| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 7.5 | 1.6x | -| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 122.0 | 1.5x | +| [Bart (`fs`)](examples/bart/README.md) | 2.7 | 13.3 | 5x | +| [Bart (`hf`)](examples/bart/README.md#speedup-bart-huggingface-transformers-version-by-using-fastseq) | 3.4 | 9.9 | 2.9x | +| [DistilBart (`hf`)](examples/distilbart/README.md) | 4.0 | 11.9 | 3x | +| [T5 (`hf`)](examples/t5/README.md) | 4.8 | 11.0 | 2.3x | +| [WMT16 En-De (`fs`)](examples/wmt/README.md) | 84.0 | 124.0 | 1.5x | - All benchmarking experiments run on NVIDIA-V100-16GB with [docker](docker/Dockerfile). Highest speed recorded for each model by tuning batch size. For parameter setting details, click link of corresponding model. - `fs` stands for [Fairseq](https://github.com/pytorch/fairseq) 0.9.0 version, `hf` stands for [Huggingface Transformers](https://github.com/huggingface/transformers) 3.0.2 version. diff --git a/benchmarks/models/fs_bart.sh b/benchmarks/models/fs_bart.sh index b4985cfd..f926e07f 100755 --- a/benchmarks/models/fs_bart.sh +++ b/benchmarks/models/fs_bart.sh @@ -20,9 +20,9 @@ source utils.sh grep "bart.large.cnn cnn_dm.1k/len-1024.bin valid " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 10.4 10.6 # Speed on V100 16GB 250W grep -E "fairseq_v0.9.0 bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 2.3 2.8 -grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.1 100 -grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.9 100 -grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 12.5 100 +grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.3 100 +grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.4 100 +grep -E "fairseq_v0.9.0\+fastseq_v.* bart.large.cnn cnn_dm.1k/len-1024.bin valid 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 13.3 100 ## Accuracy #grep "bart.large.cnn cnn_dm/len-1024.bin valid " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 17.9 18 diff --git a/benchmarks/models/fs_wmt.sh b/benchmarks/models/fs_wmt.sh index aff1274e..067d1add 100755 --- a/benchmarks/models/fs_wmt.sh +++ b/benchmarks/models/fs_wmt.sh @@ -15,6 +15,6 @@ source utils.sh grep " wmt16.en.de.32k wmt16_en_de_bpe32k/bin test " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 0.019 0.021 # Speed on V100 16GB 250W grep -E "fairseq_v0.9.0 wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 82 85 -grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 116.7 1000 -grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 512 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 120 1000 -grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 1024 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 121 1000 +grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 256 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 117 1000 +grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 512 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 123 1000 +grep -E "fairseq_v0.9.0\+fastseq_v.* wmt16.en.de.32k wmt16_en_de_bpe32k/bin test 1024 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 123 1000 diff --git a/benchmarks/models/hf_bart.sh b/benchmarks/models/hf_bart.sh index b06045a6..31836210 100755 --- a/benchmarks/models/hf_bart.sh +++ b/benchmarks/models/hf_bart.sh @@ -20,9 +20,9 @@ source utils.sh grep "facebook/bart-large-cnn cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 34.8 35 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.2 3.4 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 6.2 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.8 100 -grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.0 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 32 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.3 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.6 100 +grep -E "transformers_v3.0.2\+fastseq_v.* facebook/bart-large-cnn cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.9 100 ## Accuracy #grep "facebook/bart-large-cnn cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 44.78 44.82 diff --git a/benchmarks/models/hf_distibart.sh b/benchmarks/models/hf_distibart.sh index dea7d077..35184213 100755 --- a/benchmarks/models/hf_distibart.sh +++ b/benchmarks/models/hf_distibart.sh @@ -20,9 +20,9 @@ source utils.sh grep "sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 35.1 35.3 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 3.9 4.2 -grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.5 100 +grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.5 100 # todo: bigger bs doesn't increase speed -grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 8.5 100 +grep -E "transformers_v3.0.2\+fastseq_v.* sshleifer/distilbart-cnn-12-6 cnn_dm.1k/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 11.9 100 ## Accuracy #grep "sshleifer/distilbart-cnn-12-6 cnn_dm/raw val " perf | awk '{print $9}' | awk -F'|' '{if($1!="NA"){c+=1;s+=$1}}END{print s/c}' | bash range.sh 45 45.1 diff --git a/benchmarks/models/hf_t5.sh b/benchmarks/models/hf_t5.sh index b99fdfde..7b1ae5f3 100755 --- a/benchmarks/models/hf_t5.sh +++ b/benchmarks/models/hf_t5.sh @@ -14,6 +14,5 @@ source utils.sh grep "t5-base wmt_en_ro/raw val " perf | awk '{if($8!="NA"){c+=1;s+=$8}}END{print s/c}' | bash range.sh 27.42 27.44 # Speed on V100 16GB 250W grep -E "transformers_v3.0.2 t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 4.6 5.2 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.0 7.1 -grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 7.5 7.8 - +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 64 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 9.0 9.2 +grep -E "transformers_v3.0.2\+fastseq_v.* t5-base wmt_en_ro/raw val 128 " perf | awk '{s+=$13}END{print s/NR}' | bash range.sh 10.9 11.1 diff --git a/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp b/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp new file mode 100644 index 00000000..4199cd6e --- /dev/null +++ b/fastseq/clib/cuda/ngram_repeat_block_cuda.cpp @@ -0,0 +1,47 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +#include +#include + +/* +CPP Binding for CUDA OP +*/ + +// CUDA forward declarations +torch::Tensor ngram_repeat_block_cuda_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +// Input check and call to CUDA OP +// Backward method not required +torch::Tensor ngram_repeat_block_forward(torch::Tensor tokens, + torch::Tensor lprobs, int bsz, + int step, int beam_size, + int no_repeat_ngram_size) { + CHECK_INPUT(tokens); + CHECK_INPUT(lprobs); + assert(bsz > 0); + assert(step >= 0); + assert(beam_size > 0); + assert(no_repeat_ngram_size > 0); + + return ngram_repeat_block_cuda_forward(tokens, lprobs, bsz, step, beam_size, + no_repeat_ngram_size); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &ngram_repeat_block_forward, + "No Repeat Ngram Block forward (CUDA)"); +} diff --git a/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu b/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu new file mode 100644 index 00000000..b458b091 --- /dev/null +++ b/fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu @@ -0,0 +1,76 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for blocking repeated n-grams. +*/ + +#include +#include +#include +#include +#include + +// Ban repeated ngrams of length = 'no_repeat_ngram_size' +__global__ void banRepeatedTokens(long* __restrict__ tokens, + float* __restrict__ lprobs, + int max_predict_len, int vocab_size, + int no_repeat_ngram_size) { + auto row = blockIdx.x; + auto col = threadIdx.x; + auto start = row * (max_predict_len) + col; + // Each thread compares ngram starting from + // thread index with final ngram starting from + // step - no_repeat_ngram_size +2 + auto check_start_pos = blockDim.x; + auto lprob_start = row * vocab_size; + bool is_banned = true; + extern __shared__ long tokens_shm[]; + tokens_shm[col] = tokens[start]; + if (col == blockDim.x - 1) { + for (int i=1; i(); + auto lprob_ptr = lprobs.data_ptr(); + int blocks = bsz * beam_size; + int shared_mem_size = (step + 1) * sizeof(long); + + // Launching N blocks where N is number of samples in a batch (beams*bsz) + // Launching T threads where T is number of previous ngrams in a sample + // Allocating shared mem per block for fastser access of input tokens since + // each token will be accessed N times to compare with current Ngram where + // N is Ngram size. + banRepeatedTokens<<>>( + token_ptr, lprob_ptr, max_predict_len, vocab_size, no_repeat_ngram_size); + return lprobs; +} diff --git a/fastseq/ops/ngram_repeat_block.py b/fastseq/ops/ngram_repeat_block.py new file mode 100644 index 00000000..cb3aebca --- /dev/null +++ b/fastseq/ops/ngram_repeat_block.py @@ -0,0 +1,58 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" Wrapper for ngram_repeat_block cuda extension """ +from torch import nn +from torch.autograd import Function +import ngram_repeat_block_cuda + +class NGramRepeatBlockFunction(Function): + """ + forward inputs to ngram_repeat_block cuda extension + backward method not needed. + + """ + def forward(self, tokens, lprobs, bsz, + step, beam_size, no_repeat_ngram_size): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + outputs = ngram_repeat_block_cuda.forward(tokens, + lprobs, bsz, step, beam_size, no_repeat_ngram_size) + return outputs + + def backward (*args): + raise NotImplementedError + +class NGramRepeatBlock(nn.Module): + """ Wrapper class for calling ngram_repeat_block cuda extension """ + def __init__(self): + super(NGramRepeatBlock, self).__init__() + + def reset_parameters(self): + pass + + def forward(self, tokens, lprobs, bsz, + step, beam_size, no_repeat_ngram_size): + """ + Args: + tokens(Tensor): Input tokens(Bsz*beam, seq_len) + lprobs(Tensor): likelihood probability, + Expected to be updated in place.(Bsz*beam, vocab_size) + bsz(int): batch size + step(int): current step + beam_size(int): beam size + no_repeat_ngram_size(int): Ngram size + """ + assert tokens.size(0)== bsz*beam_size + assert lprobs.size(0)== bsz*beam_size + + return NGramRepeatBlockFunction.apply(tokens, lprobs, + bsz, step, beam_size, no_repeat_ngram_size) diff --git a/fastseq/optimizer/fairseq/beam_search_optimizer.py b/fastseq/optimizer/fairseq/beam_search_optimizer.py index 00a0a29a..69fc1788 100644 --- a/fastseq/optimizer/fairseq/beam_search_optimizer.py +++ b/fastseq/optimizer/fairseq/beam_search_optimizer.py @@ -14,7 +14,7 @@ from fairseq.models.transformer import TransformerEncoder, TransformerModel from fairseq.modules.multihead_attention import MultiheadAttention from fairseq.sequence_generator import SequenceGenerator - +from fastseq.ops.ngram_repeat_block import NGramRepeatBlock from fastseq.utils.api_decorator import replace @replace(TransformerEncoder) @@ -429,6 +429,7 @@ def _generate(self, bsz = input_size[0] src_len = input_size[1] beam_size = self.beam_size + self.no_repeat_ngram_op = NGramRepeatBlock() if self.match_source_len: max_len = src_lengths.max().item() @@ -640,24 +641,6 @@ def replicate_first_beam(tensor, mask): # minimum length constraint (does not apply if using prefix_tokens) lprobs[:, self.eos] = -math.inf - if self.no_repeat_ngram_size > 0: - # for each beam and batch sentence, generate a list of previous ngrams - banned_list = [[] for bbsz_idx in range(bsz * beam_size)] - cpu_tokens = tokens.cpu()[:, :step + 1].numpy() - check_start_pos = step + 2 - self.no_repeat_ngram_size - for bbsz_idx in range(bsz * beam_size): - for i in range(check_start_pos): - is_banned = True - for k in range(self.no_repeat_ngram_size - 1): - if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[ - bbsz_idx, check_start_pos + k]: - is_banned = False - break - if is_banned: - banned_list[bbsz_idx].append( - cpu_tokens[bbsz_idx, - i + self.no_repeat_ngram_size - 1]) - # Record attention scores if avg_attn_scores is not None: if attn is None: @@ -674,24 +657,9 @@ def replicate_first_beam(tensor, mask): self.search.set_src_lengths(src_lengths) if self.no_repeat_ngram_size > 0: - - def calculate_banned_tokens(bbsz_idx): - # before decoding the next token, prevent decoding of ngrams that have already appeared - banned_tokens_per_sample = [ - (bbsz_idx, t) for t in banned_list[bbsz_idx] - ] - return banned_tokens_per_sample - - banned_tokens = [] - if step + 2 - self.no_repeat_ngram_size >= 0: - for bbsz_idx in range(bsz * beam_size): - banned_tokens.extend(calculate_banned_tokens(bbsz_idx)) - - if banned_tokens: - banned_tokens = torch.LongTensor(banned_tokens) - lprobs.index_put_( - tuple(banned_tokens.t()), - lprobs.new_tensor([-math.inf] * len(banned_tokens))) + #Applying Cuda Op for NGram repeat Blocking + lprobs = self.no_repeat_ngram_op(tokens,lprobs, bsz, step, + beam_size, self.no_repeat_ngram_size) cand_scores, cand_indices, cand_beams = self.search.step( step, diff --git a/fastseq/optimizer/transformers/beam_search_optimizer.py b/fastseq/optimizer/transformers/beam_search_optimizer.py index 99d2d65e..b9102801 100644 --- a/fastseq/optimizer/transformers/beam_search_optimizer.py +++ b/fastseq/optimizer/transformers/beam_search_optimizer.py @@ -17,6 +17,7 @@ from transformers.modeling_bart import BartForConditionalGeneration from transformers.modeling_t5 import T5ForConditionalGeneration +from fastseq.ops.ngram_repeat_block import NGramRepeatBlock from fastseq.logging import get_logger from fastseq.utils.api_decorator import replace @@ -648,17 +649,9 @@ def _update_scores(banned_tokens): cpu_input_ids = input_ids.cpu() if no_repeat_ngram_size > 0: - # calculate a list of banned tokens to prevent repetitively - # generating the same ngrams - num_batch_hypotheses = batch_size * num_beams - # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345 - banned_ngram_tokens = calc_banned_ngram_tokens_v2( - cpu_input_ids, - num_batch_hypotheses, - no_repeat_ngram_size, - cur_len, - self.config.pad_token_id) - _update_scores(banned_ngram_tokens) + #custom op for Ngram repeat blocking + scores = self.no_repeat_ngram_op(input_ids,scores.float(), + batch_size, cur_len-1, num_beams, no_repeat_ngram_size) if bad_words_ids is not None: # calculate a list of banned tokens according to bad words @@ -721,6 +714,9 @@ def _generate_beam_search( # done sentences done = [False for _ in range(batch_size)] + #NGram Repeat block Op + self.no_repeat_ngram_op = NGramRepeatBlock()#.to('cuda', torch.float32) + while cur_len < max_length: model_inputs = self.prepare_inputs_for_generation( input_ids, past=past, attention_mask=attention_mask, diff --git a/setup.py b/setup.py index 98ce050c..6f80569b 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension from fastseq.config import FASTSEQ_VERSION @@ -14,6 +15,13 @@ def get_fastseq_version(): extras["fairseq"] = ["fairseq>=0.9.0"] extras["transformers"] = ["transformers>=3.0.2"] +extensions = [ + CUDAExtension('ngram_repeat_block_cuda', [ + 'fastseq/clib/cuda/ngram_repeat_block_cuda.cpp', + 'fastseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu', + ]), + ] + setup( name="fastseq", version=get_fastseq_version(), @@ -51,6 +59,7 @@ def get_fastseq_version(): "Programming Language :: Python :: 3.7", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], + ext_modules=extensions, entry_points={ 'console_scripts': [ 'fastseq-generate-for-fairseq = fastseq_cli.generate:cli_main', @@ -58,4 +67,7 @@ def get_fastseq_version(): 'fastseq-eval-lm-for-fairseq = fastseq_cli.eval_lm:cli_main', ], }, + cmdclass={ + 'build_ext': BuildExtension + }, ) diff --git a/tests/ops/test_ngram_repeat_block.py b/tests/ops/test_ngram_repeat_block.py new file mode 100644 index 00000000..5bb739f3 --- /dev/null +++ b/tests/ops/test_ngram_repeat_block.py @@ -0,0 +1,173 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" Unit test for Ngram repeat block cuda op """ + +import math +import torch +from fastseq.ops.ngram_repeat_block import NGramRepeatBlock +from fastseq.utils.test_utils import TestCaseBase +from absl.testing import absltest, parameterized + +class NgramRepeatBlockTest(TestCaseBase): + """ check to ensure cuda implementation output + of this op matches with original fairseq + implememntation. + """ + + def apply_no_repeat_ngram(self, tokens,lprobs, bsz,step, + beam_size, no_repeat_ngram_size): + """ Fairseq implementation of blocking + repeated ngrams + """ + banned_list = [[] for bbsz_idx in range(bsz * beam_size)] + cpu_tokens = tokens.cpu()[:, :step + 1].numpy() + check_start_pos = step + 2 - no_repeat_ngram_size + for bbsz_idx in range(bsz * beam_size): + for i in range(check_start_pos): + is_banned = True + for k in range(no_repeat_ngram_size - 1): + if cpu_tokens[bbsz_idx, i + k] != cpu_tokens[ + bbsz_idx, check_start_pos + k]: + is_banned = False + break + if is_banned: + banned_list[bbsz_idx].append( + cpu_tokens[bbsz_idx, + i + no_repeat_ngram_size - 1]) + + def calculate_banned_tokens(bbsz_idx): + """before decoding the next token, prevent decoding + of ngrams that have already appeared + """ + banned_tokens_per_sample = [ + (bbsz_idx, t) for t in banned_list[bbsz_idx] + ] + return banned_tokens_per_sample + + banned_tokens = [] + if step + 2 - no_repeat_ngram_size >= 0: + for bbsz_idx in range(bsz * beam_size): + banned_tokens.extend(calculate_banned_tokens(bbsz_idx)) + + if banned_tokens: + banned_tokens = torch.LongTensor(banned_tokens) + lprobs.index_put_( + tuple(banned_tokens.t()), + lprobs.new_tensor([-math.inf] * len(banned_tokens))) + + return lprobs + + @parameterized.named_parameters({ + 'testcase_name': 'Normal', + 'vocab_size': 10, + 'bsz': 256, + 'beam_size': 1, + 'step': 6, + 'ngram_repeat_block_size': 3, + 'sequence_length':2048, + 'pos1':0, + }, + { + 'testcase_name': 'overlapping_ngrams', + 'vocab_size': 10, + 'bsz': 256, + 'beam_size': 1, + 'step': 4, + 'ngram_repeat_block_size': 3, + 'sequence_length':2048, + 'pos1':0, + }, + { + 'testcase_name': 'min_step', + 'vocab_size': 10, + 'bsz': 256, + 'beam_size': 1, + 'step': 3, + 'ngram_repeat_block_size': 3, + 'sequence_length':2048, + 'pos1':0, + }, + { + 'testcase_name': 'higher_beam_size', + 'vocab_size': 10, + 'bsz': 256, + 'beam_size': 2, + 'step': 6, + 'ngram_repeat_block_size': 3, + 'sequence_length':2048, + 'pos1':0, + }, + { + 'testcase_name': 'higher_ngram_size', + 'vocab_size': 10, + 'bsz': 256, + 'beam_size': 1, + 'step': 12, + 'ngram_repeat_block_size': 5, + 'sequence_length':2048, + 'pos1':0, + }, + { + 'testcase_name': 'higher_vocab_size', + 'vocab_size': 1000, + 'bsz': 256, + 'beam_size': 1, + 'step': 6, + 'ngram_repeat_block_size': 3, + 'sequence_length':2048, + 'pos1':0, + } + ) + def test_ngram_repeat_block_kernel(self, bsz, beam_size, vocab_size, + step, ngram_repeat_block_size, sequence_length, pos1): + + """ Use random input with repeated ngram to check + whether corresponding token in vocabulary is blocked (-Inf score) + + Args: + bsz (int): batch size + beam_size (int): beam size + vocab_size (int): vocab size + step (int): current decoding step + ngram_repeat_block_size (int): size of ngram + sequence_length (int): sequence length + pos1 (int) first position where repeated ngram occurs + within a sentence. + """ + + lprobs_fairseq = torch.zeros(bsz*beam_size, + vocab_size).type(torch.FloatTensor) + lprobs_fastseq = torch.zeros(bsz*beam_size, + vocab_size).type(torch.FloatTensor) + repeated_ngram = torch.randint(0,10, (1,2)) + #second place where ngram is repeated + pos2 = step-ngram_repeat_block_size+2 + #Dummy input with repeated ngram + inp = torch.cat((torch.randint(0,10, (1,pos1)), + repeated_ngram, torch.randint(0,10, + (1,pos2-pos1-ngram_repeat_block_size+1)), + repeated_ngram, torch.randint(0,10, + (1, sequence_length - + pos2-ngram_repeat_block_size+1))), 1) + tokens=inp.repeat( (bsz*beam_size,1)) + #CUDA kernel initialization + rnn = NGramRepeatBlock() + lprobs_fastseq = lprobs_fastseq.cuda() + lprobs_fairseq = lprobs_fairseq.cuda() + tokens = tokens.cuda() + #Cuda opt implementation + lprobs_fastseq = rnn( tokens,lprobs_fastseq, bsz, step, beam_size, + ngram_repeat_block_size) + #Original implementation + lprobs_fairseq = self.apply_no_repeat_ngram(tokens, lprobs_fairseq, bsz, + step, beam_size, ngram_repeat_block_size) + err_msg = ''' + ngram repeat block kernel implementation output + doesn't match with output of original implementation + ''' + assert torch.all(torch.eq(lprobs_fairseq, + lprobs_fastseq)).cpu().numpy(), err_msg + +if __name__ == "__main__": + absltest.main() diff --git a/tests/run_fairseq_tests.sh b/tests/run_fairseq_tests.sh index 74236a18..23b12d88 100755 --- a/tests/run_fairseq_tests.sh +++ b/tests/run_fairseq_tests.sh @@ -7,8 +7,8 @@ pip install gitpython pip install absl-py pip install packaging cd ${FASTSEQ_TEST_PATH}/../ -pip install --editable . pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html +pip install --editable . cd tests python run_fairseq_tests.py deactivate