Source code for pythainlp.wangchanberta.core
# SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project
# SPDX-FileType: SOURCE
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import re
import warnings
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from transformers import (
CamembertTokenizer,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.pipelines import TokenClassificationPipeline
from pythainlp.tokenize import word_tokenize
_model_name: str = "wangchanberta-base-att-spm-uncased"
_tokenizer: Optional["CamembertTokenizer"] = None
def _get_tokenizer() -> CamembertTokenizer:
"""Get the tokenizer, initializing it if necessary."""
global _tokenizer
if _tokenizer is None:
from transformers import CamembertTokenizer
_tokenizer = CamembertTokenizer.from_pretrained(
f"airesearch/{_model_name}", revision="main"
)
if _model_name == "wangchanberta-base-att-spm-uncased":
_tokenizer.additional_special_tokens = [
"<s>NOTUSED",
"</s>NOTUSED",
"<_>",
]
return _tokenizer
[docs]
class ThaiNameTagger:
dataset_name: str
grouped_entities: bool
classify_tokens: TokenClassificationPipeline
json_ner: list[dict[str, str]]
output: str
sent_ner: list[tuple[str, str]]
[docs]
def __init__(
self, dataset_name: str = "thainer", grouped_entities: bool = True
) -> None:
"""This function tags named entities in text in IOB format.
Powered by wangchanberta from VISTEC-depa\
AI Research Institute of Thailand
:param str dataset_name:
* *thainer* - ThaiNER dataset
:param bool grouped_entities: grouped entities
"""
from transformers import pipeline
self.dataset_name = dataset_name
self.grouped_entities = grouped_entities
self.classify_tokens = pipeline(
task="ner",
tokenizer=_get_tokenizer(),
model=f"airesearch/{_model_name}",
revision=f"finetuned@{self.dataset_name}-ner",
ignore_labels=[],
grouped_entities=self.grouped_entities,
)
def _IOB(self, tag: str) -> str:
if tag != "O":
return "B-" + tag
return "O"
def _clear_tag(self, tag: str) -> str:
return tag.replace("B-", "").replace("I-", "")
[docs]
def get_ner(
self, text: str, pos: bool = False, tag: bool = False
) -> Union[list[tuple[str, str]], str]:
"""This function tags named entities in text in IOB format.
Powered by wangchanberta from VISTEC-depa\
AI Research Institute of Thailand
:param str text: text in Thai to be tagged
:param bool tag: output HTML-like tags.
:return: a list of tuples associated with tokenized word groups,\
NER tags, and output HTML-like tags (if the parameter `tag` is \
specified as `True`). \
Otherwise, return a list of tuples associated with tokenized \
words and NER tags
:rtype: Union[list[tuple[str, str]]], str
"""
if pos:
warnings.warn(
"This model doesn't support output of POS tags and it doesn't output the POS tags.",
stacklevel=2,
)
text = re.sub(" ", "<_>", text)
self.json_ner: list[dict[str, str]] = self.classify_tokens(text)
self.output: str = ""
if self.grouped_entities and self.dataset_name == "thainer":
self.sent_ner: list[tuple[str, str]] = [
(
i["word"].replace("<_>", " ").replace("▁", ""),
self._IOB(i["entity_group"]),
)
for i in self.json_ner
]
elif self.dataset_name == "thainer":
self.sent_ner = [
(i["word"].replace("<_>", " ").replace("▁", ""), i["entity"])
for i in self.json_ner
if i["word"] != "▁"
]
else:
self.sent_ner = [
(
i["word"].replace("<_>", " ").replace("▁", ""),
i["entity"].replace("_", "-").replace("E-", "I-"),
)
for i in self.json_ner
]
if self.sent_ner[0][0] == "" and len(self.sent_ner) > 1:
self.sent_ner = self.sent_ner[1:]
for idx, (word, ner) in enumerate(self.sent_ner):
if idx > 0 and ner.startswith("B-"):
if self._clear_tag(ner) == self._clear_tag(
self.sent_ner[idx - 1][1]
):
self.sent_ner[idx] = (word, ner.replace("B-", "I-"))
if tag:
temp = ""
sent = ""
for idx, (word, ner) in enumerate(self.sent_ner):
if ner.startswith("B-") and temp != "":
sent += "</" + temp + ">"
temp = ner[2:]
sent += "<" + temp + ">"
elif ner.startswith("B-"):
temp = ner[2:]
sent += "<" + temp + ">"
elif ner == "O" and temp != "":
sent += "</" + temp + ">"
temp = ""
sent += word
if idx == len(self.sent_ner) - 1 and temp != "":
sent += "</" + temp + ">"
return sent
else:
return self.sent_ner
[docs]
class NamedEntityRecognition:
tokenizer: PreTrainedTokenizerBase
model: PreTrainedModel
[docs]
def __init__(
self, model: str = "pythainlp/thainer-corpus-v2-base-model"
) -> None:
"""This function tags named entities in text in IOB format.
Powered by wangchanberta from VISTEC-depa\
AI Research Institute of Thailand
:param str model: The model that use wangchanberta pretrained.
"""
from transformers import AutoModelForTokenClassification, AutoTokenizer
self.tokenizer: PreTrainedTokenizerBase = (
AutoTokenizer.from_pretrained(model)
)
self.model: PreTrainedModel = (
AutoModelForTokenClassification.from_pretrained(model)
)
def _fix_span_error(
self, words: list[int], ner: list[str]
) -> list[tuple[str, str]]:
_ner = []
_ner = ner
_new_tag = []
for i, j in zip(words, _ner):
i_decoded = self.tokenizer.decode(i)
if i_decoded.isspace() and j.startswith("B-"):
j = "O"
if i_decoded in ("", "<s>", "</s>"):
continue
if i_decoded == "<_>":
i_decoded = " "
_new_tag.append((i_decoded, j))
return _new_tag
[docs]
def get_ner(
self, text: str, pos: bool = False, tag: bool = False
) -> Union[list[tuple[str, str]], str]:
"""This function tags named entities in text in IOB format.
Powered by wangchanberta from VISTEC-depa\
AI Research Institute of Thailand
:param str text: text in Thai to be tagged
:param bool tag: output HTML-like tags.
:return: a list of tuples associated with tokenized word groups, NER tags, \
and output HTML-like tags (if the parameter `tag` is \
specified as `True`). \
Otherwise, return a list of tuples associated with tokenized \
words and NER tags
:rtype: Union[list[tuple[str, str]]], str
"""
import torch
if pos:
warnings.warn(
"This model doesn't support output postag and It doesn't output the postag.",
stacklevel=2,
)
words_token = word_tokenize(text.replace(" ", "<_>"))
inputs = self.tokenizer(
words_token, is_split_into_words=True, return_tensors="pt"
)
ids = inputs["input_ids"]
mask = inputs["attention_mask"]
# forward pass
outputs = self.model(ids, attention_mask=mask)
logits = outputs[0]
predictions = torch.argmax(logits, dim=2)
predicted_token_class = [
self.model.config.id2label[t.item()] for t in predictions[0]
]
ner_tag = self._fix_span_error(
inputs["input_ids"][0], predicted_token_class
)
if tag:
temp = ""
sent = ""
for idx, (word, ner) in enumerate(ner_tag):
if ner.startswith("B-") and temp != "":
sent += "</" + temp + ">"
temp = ner[2:]
sent += "<" + temp + ">"
elif ner.startswith("B-"):
temp = ner[2:]
sent += "<" + temp + ">"
elif ner == "O" and temp != "":
sent += "</" + temp + ">"
temp = ""
sent += word
if idx == len(ner_tag) - 1 and temp != "":
sent += "</" + temp + ">"
return sent
return ner_tag
[docs]
def segment(text: str) -> list[str]:
"""Subword tokenize. SentencePiece from wangchanberta model.
:param str text: text to be tokenized
:return: list of subwords
:rtype: list[str]
"""
if not text or not isinstance(text, str):
return []
return _get_tokenizer().tokenize(text) # type: ignore[no-any-return]