Source code for quackamollie.model_manager.ollama.ollama_model_manager

# -*- coding: utf-8 -*-
__all__ = ["OllamaQuackamollieModelManager"]
__author__ = "QuacktorAI"
__copyright__ = "Copyright 2024, Forge of Absurd Ducks"
__credits__ = ["QuacktorAI"]

import aiohttp
import json

from quackamollie.core.cli.settings import get_settings_from_context, QuackamollieSettings
from quackamollie.core.database.model import ChatMessage
from quackamollie.core.enum.model_family_icon import ModelFamilyIcon
from quackamollie.core.enum.user_type import UserType
from quackamollie.core.meta.model.meta_quackamollie_model import MetaQuackamollieModel
from quackamollie.core.meta.model_manager.meta_quackamollie_model_manager import MetaQuackamollieModelManager
from typing import AsyncIterable, Dict, List, Optional, Tuple, Type


[docs] class OllamaQuackamollieModelManager(MetaQuackamollieModelManager, MetaQuackamollieModel): """ Model manager managed by the `QuackamollieModelManagerRegistry` and serving models from the Ollama API """ families: List[ModelFamilyIcon] = [ModelFamilyIcon.LLAMA] _raw_model_list: None | List = None
[docs] @classmethod async def get_raw_model_list(cls) -> List: if cls._raw_model_list is None: quackamollie_settings: QuackamollieSettings = get_settings_from_context() url = f"{quackamollie_settings.ollama_base_url}/api/tags" async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: data = await response.json() cls._raw_model_list = data["models"] else: cls._raw_model_list = [] return cls._raw_model_list
[docs] @classmethod async def get_model_list(cls) -> List[str]: """ Discover the models available for the model manager at runtime asynchronously :return: A list of available models for the model manager :rtype: List[str] """ if cls._model_list is None: raw_model_list = await cls.get_raw_model_list() cls._model_list = [model["name"] for model in raw_model_list] return cls._model_list
[docs] @classmethod async def get_model_families(cls) -> Dict[str, List[ModelFamilyIcon]]: """ Discover the models families available for the model manager at runtime asynchronously :return: A dict with values the list of families indexed by model name :rtype: Dict[str, List[ModelFamilyIcon]] """ if cls._model_families is None: raw_model_list = await cls.get_raw_model_list() cls._model_families = {} for model in raw_model_list: model_name = model["name"] model_families: List[ModelFamilyIcon] = [ModelFamilyIcon.DEFAULT] if model["details"]["families"]: model_icon = {"llama": ModelFamilyIcon.LLAMA, "clip": ModelFamilyIcon.MULTIMODAL} try: model_families = [model_icon[family] for family in model['details']['families']] except KeyError: # Use a default value when the key is not found model_families = [ModelFamilyIcon.DEFAULT] cls._model_families[model_name] = model_families return cls._model_families
[docs] @classmethod def parse_chat_history(cls, chat_messages: Optional[List[ChatMessage]]) -> List[Dict]: """ Parse the chat history given as a list of `ChatMessage` from the database model to a list compatible with the model manager's models. :param chat_messages: A list of `ChatMessage` from the database model :param chat_messages: Optional[List[ChatMessage]] :return: A list of messages formatted to be compatible with the model manager's models. :rtype: List[Dict] """ chat_history: List[Dict] = [] # Construct the list of messages in a format supported by Ollama if chat_messages: for past_msg in chat_messages: chat_history.append({ "role": past_msg.user.user_type.value, "content": past_msg.content, "images": [] }) return chat_history
[docs] @classmethod async def get_model_class(cls, model_name: str) -> Optional[Type[MetaQuackamollieModel]]: """ Get the model class from the model name :param model_name: Name of the model as listed by `cls.get_model_list` :type model_name: str :return: A subclass of MetaQuackamollieModel :rtype: Optional[Type[MetaQuackamollieModel]] """ if model_name in await cls.get_model_list(): return cls else: return None
[docs] @classmethod def reset(cls): """ Reset the model manager dynamic fields to force reloading models. Be careful if used asynchronously """ cls._raw_model_list = None cls._model_list = None cls._model_families = None
[docs] @classmethod async def astream_answer(cls, content: str, chat_history: List, model_config: Optional[str] = None, model_name: Optional[str] = None, images_base64: Optional[List[str]] = None, **kwargs) -> AsyncIterable[Tuple[str, bool]]: """ Asynchronous iterator to stream the answer from a LLM model :param content: The new message content :type content: str :param chat_history: A list of past messages formatted accordingly by model manager :type chat_history: List :param model_config: Additional configuration given as a string through CLI or Telegram `App Settings` and retrieved from the database :type model_config: Optional[str] :param model_name: Name of the model as listed by `cls.get_model_list` :type model_name: Optional[str] :param images_base64: A list of images formatted as base64 strings :type images_base64: Optional[List[str]] :param kwargs: Additional streaming arguments :type kwargs: kwargs :return: An asynchronous iterator giving a tuple containing the new chunk and a boolean indicating if the model is done or not :rtype: AsyncIterable[Tuple[str, bool]] """ # Add the current message to the list of past messages and finish building the payload chat_history.append(dict(role=UserType.user.value, content=content, images=[] if images_base64 is None else images_base64)) # Build the payload for Ollama aiohttp request payload = { "model": model_name, "messages": chat_history, "stream": True, } # Build the Ollama API URL from the config quackamollie_settings: QuackamollieSettings = get_settings_from_context() ollama_api_url = f"{quackamollie_settings.ollama_base_url}/api/chat" # Iterate through the stream of answers given by Ollama from API async with aiohttp.ClientSession() as session: async with session.post(ollama_api_url, json=payload) as response: async for chunk in response.content: if chunk: decoded_chunk = chunk.decode() if decoded_chunk.strip(): stream_data = json.loads(decoded_chunk) # Parse the JSON of the request to get chunk and is_done msg = stream_data.get("message", None) if msg is None: chunk = "" else: chunk = msg.get("content", "") is_done = stream_data.get("done", False) yield chunk, is_done