Source code for app

"""
Functions for Streamlit applications
"""

import os
import yaml
from typing import Generator

import google.generativeai as genai
import streamlit as st

from store import VectorStore


[docs] def make_rag_prompt(user_query: str, passages: list[str]) -> str: """Generate the RAG prompt Parameters ---------- user_query : str The user query. passages : list[str] List of relevant passages obtained from the index. Returns ------- prompt : str The prompt to pass to the LLM API """ escaped_passages = "\n\n".join( passage.replace("\r", " ").replace("\n", " ") for passage in passages ) prompt = ( "You are a helpful and informative bot that answers questions using text from " "the reference passages included below. Be sure to respond in a complete sentence, " "being comprehensive, including all relevant background information. However, you " "are talking to a non-technical audience, so be sure to break down complicated " "concepts and strike a friendly and conversational tone. If the passages are " "irrelevant to the answer, you may ignore them.\n\n" f"QUESTION: {user_query}\n\n" f"PASSAGES: {escaped_passages}\n\n" "ANSWER:\n" ) return prompt
[docs] def call_llm(user_query: str) -> Generator[str, None, None]: """From the user query this retrieves the context, setup the prompt, and call the LLM API to generate the response. Parameters ---------- user_query : str The user query. Returns ------- llm response : Generator[str] Yields a string generator containing the live LLM response. """ passages = st.session_state["vs"].query(user_query) prompt = make_rag_prompt(user_query, passages) response = st.session_state["llm"].generate_content(prompt, stream=True) for chunk in response: yield chunk.text
# UI is heavily inspired from # https://docs.streamlit.io/develop/tutorials/llms/build-conversational-apps st.session_state["prompt_deactivated"] = False if not (GEMINI_API_TOKEN := os.getenv("GEMINI_API_TOKEN")): st.error( "GEMINI_API_TOKEN environment variable must be provided when running the container.", icon="🚨", ) st.session_state["prompt_deactivated"] = True genai.configure(api_key=GEMINI_API_TOKEN) hide_decoration_bar_style = """ <style> header {visibility: hidden;} </style> """ st.markdown(hide_decoration_bar_style, unsafe_allow_html=True) st.title("Alice RAG") # setup the session state DEFAULT_MODEL = "Gemini 1.5 Flash" if "messages" not in st.session_state: st.session_state.messages = [] if "list_of_models" not in st.session_state: with open(os.path.join(os.path.dirname(__file__), "config.yaml"), "r") as file: st.session_state["list_of_models"] = yaml.safe_load(file) if "model" not in st.session_state: st.session_state["model"] = DEFAULT_MODEL if "vs" not in st.session_state: st.session_state["vs"] = VectorStore() st.session_state["vs"].load() with st.chat_message("assistant"): st.markdown( "Hello there! I recently read *Alice's Adventures in Wonderland*, what a great book! " "I can answer any questions you might have about it. " "Please enter your question in the prompt area below." ) for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input( "Prompt", disabled=st.session_state["prompt_deactivated"], ): with st.chat_message("user"): st.markdown(prompt) st.session_state["messages"].append({"role": "user", "content": prompt}) with st.chat_message("assistant"): stream = call_llm(prompt) response = st.write_stream(stream) st.session_state["messages"].append({"role": "assistant", "content": response}) with st.sidebar: models = [k for k in st.session_state["list_of_models"]] st.session_state["model"] = st.radio( "Available LLMs:", models, index=models.index(DEFAULT_MODEL), captions=[ m["description"] for m in st.session_state["list_of_models"].values() ], ) # Initialize the selected model name = st.session_state["list_of_models"][st.session_state["model"]]["model_name"] st.session_state["llm"] = genai.GenerativeModel(model_name=name) if st.button("Clear Chat"): st.session_state["messages"] = [] st.rerun()