# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: 2016-2024 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0
import torch
[docs]class ChatBotModel:
[docs] def __init__(self):
"""
Chat using AI generation
"""
self.history = []
[docs] def reset_chat(self):
"""
Reset chat by cleaning history
"""
self.history = []
[docs] def load_model(
self,
model_name: str = "wangchanglm",
return_dict: bool = True,
load_in_8bit: bool = False,
device: str = "cuda",
torch_dtype=torch.float16,
offload_folder: str = "./",
low_cpu_mem_usage: bool = True,
):
"""
Load model
:param str model_name: Model name (Now, we support wangchanglm only)
:param bool return_dict: return_dict
:param bool load_in_8bit: load model in 8bit
:param str device: device (cpu, cuda or other)
:param torch_dtype torch_dtype: torch_dtype
:param str offload_folder: offload folder
:param bool low_cpu_mem_usage: low cpu mem usage
"""
if model_name == "wangchanglm":
from pythainlp.generate.wangchanglm import WangChanGLM
self.model = WangChanGLM()
self.model.load_model(
model_path="pythainlp/wangchanglm-7.5B-sft-en-sharded",
return_dict=return_dict,
load_in_8bit=load_in_8bit,
offload_folder=offload_folder,
device=device,
torch_dtype=torch_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
else:
raise NotImplementedError(f"We doesn't support {model_name}.")
[docs] def chat(self, text: str) -> str:
"""
Chatbot
:param str text: text for asking chatbot with.
:return: answer from chatbot.
:rtype: str
:Example:
::
from pythainlp.chat import ChatBotModel
import torch
chatbot = ChatBotModel()
chatbot.load_model(device="cpu",torch_dtype=torch.bfloat16)
print(chatbot.chat("สวัสดี"))
# output: ยินดีที่ได้รู้จัก
print(chatbot.history)
# output: [('สวัสดี', 'ยินดีที่ได้รู้จัก')]
"""
_temp = ""
if self.history:
for h, b in self.history:
_temp += (
self.model.PROMPT_DICT["prompt_chatbot"].format_map(
{"human": h, "bot": b}
)
+ self.model.stop_token
)
_temp += self.model.PROMPT_DICT["prompt_chatbot"].format_map(
{"human": text, "bot": ""}
)
_bot = self.model.gen_instruct(_temp)
self.history.append((text, _bot))
return _bot