From d487ef8b04cc7f5ac1491f0638f902fe2abe5ac5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 8 Apr 2024 22:28:47 +0200 Subject: Wip refactor --- rag/llm/cohere_generator.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 rag/llm/cohere_generator.py (limited to 'rag/llm/cohere_generator.py') diff --git a/rag/llm/cohere_generator.py b/rag/llm/cohere_generator.py new file mode 100644 index 0000000..a6feacd --- /dev/null +++ b/rag/llm/cohere_generator.py @@ -0,0 +1,29 @@ +import os +from typing import Any, Generator +import cohere + +from dataclasses import asdict +try: + from rag.llm.ollama_generator import Prompt +except ModuleNotFoundError: + from llm.ollama_generator import Prompt +from loguru import logger as log + + +class CohereGenerator: + def __init__(self) -> None: + self.client = cohere.Client(os.environ["COHERE_API_KEY"]) + + def generate(self, prompt: Prompt) -> Generator[Any, Any, Any]: + log.debug("Generating answer from cohere") + for event in self.client.chat_stream( + message=prompt.query, + documents=[asdict(d) for d in prompt.documents], + prompt_truncation="AUTO", + ): + if event.event_type == "text-generation": + yield event.text + elif event.event_type == "citation-generation": + yield event.citations + elif event.event_type == "stream-end": + yield event.finish_reason -- cgit v1.2.3-70-g09d2