Source code for pythainlp.chat.core

# SPDX-FileCopyrightText: 2016-2026 PyThaiNLP Project
# SPDX-FileType: SOURCE
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
    import torch

    from pythainlp.generate.wangchanglm import WangChanGLM

[docs] class ChatBotModel: history: list[tuple[str, str]] model: "WangChanGLM"
[docs] def __init__(self) -> None: """Chat using AI generation""" self.history = []
[docs] def reset_chat(self) -> None: """Reset chat by cleaning history""" self.history = []
[docs] def load_model( self, model_name: str = "wangchanglm", return_dict: bool = True, load_in_8bit: bool = False, device: str = "cuda", torch_dtype: Optional["torch.dtype"] = None, offload_folder: str = "./", low_cpu_mem_usage: bool = True, ) -> None: """Load model :param str model_name: Model name (Now, we support wangchanglm only) :param bool return_dict: return_dict :param bool load_in_8bit: load model in 8bit :param str device: device (cpu, cuda or other) :param Optional[torch.dtype] torch_dtype: torch_dtype :param str offload_folder: offload folder :param bool low_cpu_mem_usage: low cpu mem usage """ import torch if torch_dtype is None: torch_dtype = torch.float16 if model_name == "wangchanglm": from pythainlp.generate.wangchanglm import WangChanGLM self.model = WangChanGLM() self.model.load_model( model_path="pythainlp/wangchanglm-7.5B-sft-en-sharded", return_dict=return_dict, load_in_8bit=load_in_8bit, offload_folder=offload_folder, device=device, torch_dtype=torch_dtype, low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise NotImplementedError(f"We doesn't support {model_name}.")
[docs] def chat(self, text: str) -> str: """Chatbot :param str text: text for asking chatbot with. :return: answer from chatbot. :rtype: str :Example: :: from pythainlp.chat import ChatBotModel import torch chatbot = ChatBotModel() chatbot.load_model(device="cpu", torch_dtype=torch.bfloat16) print(chatbot.chat("สวัสดี")) # output: ยินดีที่ได้รู้จัก print(chatbot.history) # output: [('สวัสดี', 'ยินดีที่ได้รู้จัก')] """ _temp = "" if self.history: for h, b in self.history: _temp += ( self.model.PROMPT_DICT["prompt_chatbot"].format_map( {"human": h, "bot": b} ) + self.model.stop_token ) _temp += self.model.PROMPT_DICT["prompt_chatbot"].format_map( {"human": text, "bot": ""} ) _bot = self.model.gen_instruct(_temp) self.history.append((text, _bot)) return _bot