attention is all your need 代码调试

    xiaoxiao2025-07-09  17

    attention is all your need是谷歌首次提出的,其摒弃了RNN与CNN,改用transformer模型,网路结构如下所示: 本次代码调试使用pytorch,要求python3环境,python文件共有三个,具体步骤为:下载数据、数据预处理、训练模型、测试模型。 数据预处理阶段,需要加载数据,构建词汇的索引,词转化等。分别定义read_instances_from_file、build_vocab_index、convert_instance_to_idx_seq、main等函数。部分代码如下所示:

    def read_instances_from_file(inst_file, max_sent_len, keep_case): ''' Convert file into word seq lists and vocab ''' word_insts = [] trimmed_sent_count = 0 with open(inst_file) as f: for sent in f: if not keep_case: sent = sent.lower() words = sent.split() if len(words) > max_sent_len: trimmed_sent_count += 1 word_inst = words[:max_sent_len] if word_inst: word_insts += [[Constants.BOS_WORD] + word_inst + [Constants.EOS_WORD]] else: word_insts += [None] print('[Info] Get {} instances from {}'.format(len(word_insts), inst_file)) if trimmed_sent_count > 0: print('[Warning] {} instances are trimmed to the max sentence length {}.' .format(trimmed_sent_count, max_sent_len)) return word_insts

    模型训练阶段首先要进行网络的构建,按照论文所述,encoder和decoder分别有6层、每一层还有attention,同时需要定义损失函数、评估函数,同其他神经网络一样,在模型训练时需要将数据分为很多ecophs,训练阶段代码如下:

    import argparse import math import time from tqdm import tqdm import torch import torch.nn.functional as F import torch.optim as optim import torch.utils.data import transformer.Constants as Constants from dataset import TranslationDataset, paired_collate_fn from transformer.Models import Transformer from transformer.Optim import ScheduledOptim def cal_performance(pred, gold, smoothing=False): ''' Apply label smoothing if needed ''' loss = cal_loss(pred, gold, smoothing) pred = pred.max(1)[1] gold = gold.contiguous().view(-1) non_pad_mask = gold.ne(Constants.PAD) n_correct = pred.eq(gold) n_correct = n_correct.masked_select(non_pad_mask).sum().item() return loss, n_correct def cal_loss(pred, gold, smoothing): ''' Calculate cross entropy loss, apply label smoothing if needed. ''' gold = gold.contiguous().view(-1) if smoothing: eps = 0.1 n_class = pred.size(1) one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) log_prb = F.log_softmax(pred, dim=1) non_pad_mask = gold.ne(Constants.PAD) loss = -(one_hot * log_prb).sum(dim=1) loss = loss.masked_select(non_pad_mask).sum() # average later else: loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum') return loss def train_epoch(model, training_data, optimizer, device, smoothing): ''' Epoch operation in training phase''' model.train() total_loss = 0 n_word_total = 0 n_word_correct = 0 for batch in tqdm( training_data, mininterval=2, desc=' - (Training) ', leave=False): # prepare data src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) gold = tgt_seq[:, 1:] # forward optimizer.zero_grad() pred = model(src_seq, src_pos, tgt_seq, tgt_pos) # backward loss, n_correct = cal_performance(pred, gold, smoothing=smoothing) loss.backward() # update parameters optimizer.step_and_update_lr() # note keeping total_loss += loss.item() non_pad_mask = gold.ne(Constants.PAD) n_word = non_pad_mask.sum().item() n_word_total += n_word n_word_correct += n_correct loss_per_word = total_loss/n_word_total accuracy = n_word_correct/n_word_total return loss_per_word, accuracy def eval_epoch(model, validation_data, device): ''' Epoch operation in evaluation phase ''' model.eval() total_loss = 0 n_word_total = 0 n_word_correct = 0 with torch.no_grad(): for batch in tqdm( validation_data, mininterval=2, desc=' - (Validation) ', leave=False): # prepare data src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch) gold = tgt_seq[:, 1:] # forward pred = model(src_seq, src_pos, tgt_seq, tgt_pos) loss, n_correct = cal_performance(pred, gold, smoothing=False) # note keeping total_loss += loss.item() non_pad_mask = gold.ne(Constants.PAD) n_word = non_pad_mask.sum().item() n_word_total += n_word n_word_correct += n_correct loss_per_word = total_loss/n_word_total accuracy = n_word_correct/n_word_total return loss_per_word, accuracy def train(model, training_data, validation_data, optimizer, device, opt): ''' Start training ''' log_train_file = None log_valid_file = None if opt.log: log_train_file = opt.log + '.train.log' log_valid_file = opt.log + '.valid.log' print('[Info] Training performance will be written to file: {} and {}'.format( log_train_file, log_valid_file)) with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf: log_tf.write('epoch,loss,ppl,accuracy\n') log_vf.write('epoch,loss,ppl,accuracy\n') valid_accus = [] for epoch_i in range(opt.epoch): print('[ Epoch', epoch_i, ']') start = time.time() train_loss, train_accu = train_epoch( model, training_data, optimizer, device, smoothing=opt.label_smoothing) print(' - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 'elapse: {elapse:3.3f} min'.format( ppl=math.exp(min(train_loss, 100)), accu=100*train_accu, elapse=(time.time()-start)/60)) start = time.time() valid_loss, valid_accu = eval_epoch(model, validation_data, device) print(' - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\ 'elapse: {elapse:3.3f} min'.format( ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu, elapse=(time.time()-start)/60)) valid_accus += [valid_accu] model_state_dict = model.state_dict() checkpoint = { 'model': model_state_dict, 'settings': opt, 'epoch': epoch_i} if opt.save_model: if opt.save_mode == 'all': model_name = opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu) torch.save(checkpoint, model_name) elif opt.save_mode == 'best': model_name = opt.save_model + '.chkpt' if valid_accu >= max(valid_accus): torch.save(checkpoint, model_name) print(' - [Info] The checkpoint file has been updated.') if log_train_file and log_valid_file: with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf: log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( epoch=epoch_i, loss=train_loss, ppl=math.exp(min(train_loss, 100)), accu=100*train_accu)) log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format( epoch=epoch_i, loss=valid_loss, ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu)) def main(): ''' Main function ''' parser = argparse.ArgumentParser() parser.add_argument('-data', required=True) parser.add_argument('-epoch', type=int, default=10) parser.add_argument('-batch_size', type=int, default=64) #parser.add_argument('-d_word_vec', type=int, default=512) parser.add_argument('-d_model', type=int, default=512) parser.add_argument('-d_inner_hid', type=int, default=2048) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_layers', type=int, default=6) parser.add_argument('-n_warmup_steps', type=int, default=4000) parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true') parser.add_argument('-proj_share_weight', action='store_true') parser.add_argument('-log', default=None) parser.add_argument('-save_model', default=None) parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-label_smoothing', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model #========= Loading Dataset =========# data = torch.load(opt.data) opt.max_token_seq_len = data['settings'].max_token_seq_len training_data, validation_data = prepare_dataloaders(data, opt) opt.src_vocab_size = training_data.dataset.src_vocab_size opt.tgt_vocab_size = training_data.dataset.tgt_vocab_size #========= Preparing Model =========# if opt.embs_share_weight: assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \ 'The src/tgt word2idx table are different but asked to share word embedding.' print(opt) device = torch.device('cuda' if opt.cuda else 'cpu') transformer = Transformer( opt.src_vocab_size, opt.tgt_vocab_size, opt.max_token_seq_len, tgt_emb_prj_weight_sharing=opt.proj_share_weight, emb_src_tgt_weight_sharing=opt.embs_share_weight, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner=opt.d_inner_hid, n_layers=opt.n_layers, n_head=opt.n_head, dropout=opt.dropout).to(device) optimizer = ScheduledOptim( optim.Adam( filter(lambda x: x.requires_grad, transformer.parameters()), betas=(0.9, 0.98), eps=1e-09), opt.d_model, opt.n_warmup_steps) train(transformer, training_data, validation_data, optimizer, device ,opt) def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader if __name__ == '__main__': main()

    模型训练好后,需要对其测试,测试阶段代码与训练大致相同,如下所示:

    import torch import torch.utils.data import argparse from tqdm import tqdm from dataset import collate_fn, TranslationDataset from transformer.Translator import Translator from preprocess import read_instances_from_file, convert_instance_to_idx_seq def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument('-src', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-vocab', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) f.write(pred_line + '\n') print('[Info] Finished.') if __name__ == "__main__": main()

    首次运行该代码,效果并不太理想,可能和随机种子设置有关,也可能是代码未完全按照论文逻辑书写,需要进行后续调试。

    最新回复(0)