From aac821b148c6c0d35b940609dc9b0ddcb053b28e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 19 Jun 2024 02:07:06 +0200 Subject: Still wip on rewrite --- rag/generator/cohere.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'rag/generator/cohere.py') diff --git a/rag/generator/cohere.py b/rag/generator/cohere.py index f30fe69..575452f 100644 --- a/rag/generator/cohere.py +++ b/rag/generator/cohere.py @@ -5,10 +5,10 @@ from typing import Any, Generator, List import cohere from loguru import logger as log -from rag.rag import Message +from rag.message import Message +from rag.retriever.vector import Document from .abstract import AbstractGenerator -from .prompt import Prompt class Cohere(metaclass=AbstractGenerator): @@ -16,13 +16,13 @@ class Cohere(metaclass=AbstractGenerator): self.client = cohere.Client(os.environ["COHERE_API_KEY"]) def generate( - self, prompt: Prompt, messages: List[Message] + self, messages: List[Message], documents: List[Document] ) -> Generator[Any, Any, Any]: log.debug("Generating answer from cohere...") for event in self.client.chat_stream( - message=prompt.to_str(), - documents=[asdict(d) for d in prompt.documents], - chat_history=[m.as_dict() for m in messages], + message=messages[-1].content, + documents=[asdict(d) for d in documents], + chat_history=[m.as_dict() for m in messages[:-1]], prompt_truncation="AUTO", ): if event.event_type == "text-generation": -- cgit v1.2.3-70-g09d2