-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
retrievethenread.py
173 lines (156 loc) · 7.5 KB
/
retrievethenread.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from typing import Any, Optional
from azure.search.documents.aio import SearchClient
from azure.search.documents.models import VectorQuery
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from openai_messages_token_helper import build_messages, get_token_limit
from approaches.approach import Approach, ThoughtStep
from core.authentication import AuthenticationHelper
class RetrieveThenReadApproach(Approach):
"""
Simple retrieve-then-read implementation, using the AI Search and OpenAI APIs directly. It first retrieves
top documents from search, then constructs a prompt with them, and then uses OpenAI to generate an completion
(answer) with that prompt.
"""
system_chat_template = (
"You are an intelligent assistant helping Contoso Inc employees with their healthcare plan questions and employee handbook questions. "
+ "Use 'you' to refer to the individual asking the questions even if they ask with 'I'. "
+ "Answer the following question using only the data provided in the sources below. "
+ "For tabular information return it as an html table. Do not return markdown format. "
+ "Each source has a name followed by colon and the actual information, always include the source name for each fact you use in the response. "
+ "If you cannot answer using the sources below, say you don't know. Use below example to answer"
)
# shots/sample conversation
question = """
'What is the deductible for the employee plan for a visit to Overlake in Bellevue?'
Sources:
info1.txt: deductibles depend on whether you are in-network or out-of-network. In-network deductibles are $500 for employee and $1000 for family. Out-of-network deductibles are $1000 for employee and $2000 for family.
info2.pdf: Overlake is in-network for the employee plan.
info3.pdf: Overlake is the name of the area that includes a park and ride near Bellevue.
info4.pdf: In-network institutions include Overlake, Swedish and others in the region
"""
answer = "In-network deductibles are $500 for employee and $1000 for family [info1.txt] and Overlake is in-network for the employee plan [info2.pdf][info4.pdf]."
def __init__(
self,
*,
search_client: SearchClient,
auth_helper: AuthenticationHelper,
openai_client: AsyncOpenAI,
chatgpt_model: str,
chatgpt_deployment: Optional[str], # Not needed for non-Azure OpenAI
embedding_model: str,
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_dimensions: int,
sourcepage_field: str,
content_field: str,
query_language: str,
query_speller: str,
):
self.search_client = search_client
self.chatgpt_deployment = chatgpt_deployment
self.openai_client = openai_client
self.auth_helper = auth_helper
self.chatgpt_model = chatgpt_model
self.embedding_model = embedding_model
self.embedding_dimensions = embedding_dimensions
self.chatgpt_deployment = chatgpt_deployment
self.embedding_deployment = embedding_deployment
self.sourcepage_field = sourcepage_field
self.content_field = content_field
self.query_language = query_language
self.query_speller = query_speller
self.chatgpt_token_limit = get_token_limit(chatgpt_model)
async def run(
self,
messages: list[ChatCompletionMessageParam],
session_state: Any = None,
context: dict[str, Any] = {},
) -> dict[str, Any]:
q = messages[-1]["content"]
if not isinstance(q, str):
raise ValueError("The most recent message content must be a string.")
overrides = context.get("overrides", {})
auth_claims = context.get("auth_claims", {})
use_text_search = overrides.get("retrieval_mode") in ["text", "hybrid", None]
use_vector_search = overrides.get("retrieval_mode") in ["vectors", "hybrid", None]
use_semantic_ranker = True if overrides.get("semantic_ranker") else False
use_semantic_captions = True if overrides.get("semantic_captions") else False
top = overrides.get("top", 3)
minimum_search_score = overrides.get("minimum_search_score", 0.0)
minimum_reranker_score = overrides.get("minimum_reranker_score", 0.0)
filter = self.build_filter(overrides, auth_claims)
# If retrieval mode includes vectors, compute an embedding for the query
vectors: list[VectorQuery] = []
if use_vector_search:
vectors.append(await self.compute_text_embedding(q))
results = await self.search(
top,
q,
filter,
vectors,
use_text_search,
use_vector_search,
use_semantic_ranker,
use_semantic_captions,
minimum_search_score,
minimum_reranker_score,
)
# Process results
sources_content = self.get_sources_content(results, use_semantic_captions, use_image_citation=False)
# Append user message
content = "\n".join(sources_content)
user_content = q + "\n" + f"Sources:\n {content}"
response_token_limit = 1024
updated_messages = build_messages(
model=self.chatgpt_model,
system_prompt=overrides.get("prompt_template", self.system_chat_template),
few_shots=[{"role": "user", "content": self.question}, {"role": "assistant", "content": self.answer}],
new_user_content=user_content,
max_tokens=self.chatgpt_token_limit - response_token_limit,
)
chat_completion = (
await self.openai_client.chat.completions.create(
# Azure OpenAI takes the deployment name as the model name
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
messages=updated_messages,
temperature=overrides.get("temperature", 0.3),
max_tokens=response_token_limit,
n=1,
)
).model_dump()
data_points = {"text": sources_content}
extra_info = {
"data_points": data_points,
"thoughts": [
ThoughtStep(
"Search using user query",
q,
{
"use_semantic_captions": use_semantic_captions,
"use_semantic_ranker": use_semantic_ranker,
"top": top,
"filter": filter,
"use_vector_search": use_vector_search,
"use_text_search": use_text_search,
},
),
ThoughtStep(
"Search results",
[result.serialize_for_results() for result in results],
),
ThoughtStep(
"Prompt to generate answer",
[str(message) for message in updated_messages],
(
{"model": self.chatgpt_model, "deployment": self.chatgpt_deployment}
if self.chatgpt_deployment
else {"model": self.chatgpt_model}
),
),
],
}
completion = {}
completion["message"] = chat_completion["choices"][0]["message"]
completion["context"] = extra_info
completion["session_state"] = session_state
return completion