Source code for pythainlp.augment.lm.wangchanberta

# -*- coding: utf-8 -*-
from transformers import (
    CamembertTokenizer,
    pipeline,
)
import random
from typing import List

model_name = "airesearch/wangchanberta-base-att-spm-uncased"


[docs]class Thai2transformersAug: def __init__(self): self.model_name = "airesearch/wangchanberta-base-att-spm-uncased" self.target_tokenizer = CamembertTokenizer self.tokenizer = CamembertTokenizer.from_pretrained( self.model_name, revision='main') self.tokenizer.additional_special_tokens = [ '<s>NOTUSED', '</s>NOTUSED', '<_>' ] self.fill_mask = pipeline( task='fill-mask', tokenizer=self.tokenizer, model=f'{self.model_name}', revision='main' ) self.MASK_TOKEN = self.tokenizer.mask_token def generate(self, sentence: str, num_replace_tokens: int = 3): self.sent2 = [] self.input_text = sentence sent = [ i for i in self.tokenizer.tokenize(self.input_text) if i != '▁' ] if len(sent) < num_replace_tokens: num_replace_tokens = len(sent) masked_text = self.input_text for i in range(num_replace_tokens): replace_token = [ sent.pop(random.randrange(len(sent))) for _ in range(1) ][0] masked_text = masked_text+self.MASK_TOKEN self.sent2 += [ str(j['sequence']).replace('<s> ', '').replace('</s>', '') for j in self.fill_mask(masked_text) if j['sequence'] not in self.sent2 ] masked_text = self.input_text return self.sent2
[docs] def augment( self, sentence: str, num_replace_tokens: int = 3 ) -> List[str]: """ Text Augment from wangchanberta :param str sentence: thai sentence :param int num_replace_tokens: number replace tokens :return: list of text augment :rtype: List[str] :Example: :: from pythainlp.augment.lm import Thai2transformersAug aug=Thai2transformersAug() aug.augment("ช้างมีทั้งหมด 50 ตัว บน") # output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้', 'ช้างมีทั้งหมด 50 ตัว บนสุด', 'ช้างมีทั้งหมด 50 ตัว บนบก', 'ช้างมีทั้งหมด 50 ตัว บนนั้น', 'ช้างมีทั้งหมด 50 ตัว บนหัว'] """ self.sent2 = [] self.sent2 = self.generate(sentence, num_replace_tokens) return self.sent2