Przeglądaj źródła

Upload files to 'Supportgpt'

SadhulaSaiKumar 1 rok temu
rodzic
commit
9fa948a598
1 zmienionych plików z 132 dodań i 0 usunięć
  1. 132
    0
      Supportgpt/supportgpt_mistral.py

+ 132
- 0
Supportgpt/supportgpt_mistral.py Wyświetl plik

@@ -0,0 +1,132 @@
1
+import langchain
2
+from langchain.document_loaders import PyPDFLoader, DirectoryLoader
3
+from langchain.prompts import PromptTemplate
4
+from langchain.embeddings import HuggingFaceEmbeddings
5
+from langchain.vectorstores import FAISS
6
+from langchain.llms import CTransformers
7
+from langchain.chains import RetrievalQA
8
+from flask import Flask, request, render_template
9
+from datetime import datetime
10
+from flask import Flask, render_template, request, jsonify, session
11
+
12
+
13
+app = Flask(__name__)
14
+
15
+
16
+@app.route('/')
17
+def home():
18
+
19
+    return render_template('index.html')
20
+
21
+
22
+
23
+DB_FAISS_PATH = 'vectorstore/db_faiss'
24
+
25
+custom_prompt_template = """Given the following context and a question, generate an answer based on this context only.
26
+    In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
27
+    If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
28
+
29
+    CONTEXT: {context}
30
+
31
+    QUESTION: {question}
32
+
33
+
34
+    """
35
+
36
+def set_custom_prompt():
37
+    """
38
+    Prompt template for QA retrieval for each vectorstore
39
+    """
40
+    prompt = PromptTemplate(template=custom_prompt_template,
41
+                            input_variables=['context', 'question'])
42
+    return prompt
43
+
44
+#Retrieval QA Chain
45
+def retrieval_qa_chain(llm, prompt, db):
46
+    qa_chain = RetrievalQA.from_chain_type(llm=llm,
47
+                                       chain_type='stuff',
48
+                                       retriever=db.as_retriever(search_kwargs={'k': 2}),
49
+                                       return_source_documents=True,
50
+                                       chain_type_kwargs={'prompt': prompt}
51
+                                       )
52
+    return qa_chain
53
+  
54
+#Loading the model
55
+def load_llm():
56
+    # Load the locally downloaded model here
57
+    # llm = CTransformers(
58
+    #     model = r"C:\Aiproject\Llama-2-7B-Chat-GGML\Llama-2-7B-Chat-GGML\llama-30b.ggmlv3.q8_0.bin",
59
+    #     model_type="llama",
60
+    #     max_new_tokens = 512,
61
+    #     temperature = 0.5
62
+    # )
63
+    llm = CTransformers(model=r"D:\Aiproject\models\mistral-7b-instruct-v0.1.Q4_K_M.gguf",gpu_layers=100,config={'max_new_tokens': 128, 'temperature': 0.01})
64
+    return llm
65
+
66
+#QA Model Function
67
+def qa_bot():
68
+    embeddings = HuggingFaceEmbeddings(model_name=r"D:\Aiproject\models\sentence_tranformer\all-MiniLM-L6-v2",
69
+                                       model_kwargs={'device': 'cpu'})
70
+    db = FAISS.load_local(DB_FAISS_PATH, embeddings)
71
+    llm = load_llm()
72
+    qa_prompt = set_custom_prompt()
73
+    qa = retrieval_qa_chain(llm, qa_prompt, db)
74
+
75
+    return qa
76
+
77
+#output function
78
+def final_result(query):
79
+    qa_result = qa_bot()
80
+    response = qa_result({'query': query})
81
+    return response
82
+
83
+# Streamlit application
84
+# st.title("Medical Bot")
85
+
86
+# if "query" not in st.session_state:
87
+#     st.text("Hi, Welcome to Medical Bot. What is your query?")
88
+#     query = st.text_input("Enter your question")
89
+#     st.button("Ask")
90
+
91
+#     if query:
92
+#         st.text("Searching for relevant documents...")
93
+#         response = final_result(query)
94
+#         st.text(f"Sources: {response['source_documents']}")
95
+#         st.text(f"Answer: {response['result']}")
96
+# else:
97
+#     st.text("No query")
98
+
99
+
100
+
101
+
102
+@app.route('/user_input', methods=['GET', 'POST'])
103
+def index():
104
+    if request.method == 'POST':
105
+        query = request.form.get('user_input')
106
+        session_id= request.form.get('session_id')
107
+        #query = request.form['query']
108
+        start=datetime.now()
109
+        frmtstart=start.strftime('%Y-%m-%d %H:%M:%S')
110
+        print('started',frmtstart)
111
+        user_input=query.lower()
112
+        print(user_input)
113
+        if query:
114
+            #result = chain(question)['result']
115
+            response = final_result(user_input)
116
+            result=response['result']
117
+            print(result)
118
+            end=datetime.now()
119
+            endtime=end.strftime('%Y-%m-%d %H:%M:%S')
120
+            print('Ended',endtime)
121
+
122
+      
123
+        else:
124
+            result = None
125
+        #return render_template('index.html', output=result)
126
+        return jsonify({'output': result})
127
+    return render_template('index.html', result=None)
128
+
129
+
130
+
131
+if __name__ == '__main__':
132
+    app.run(host="0.0.0.0",port=8000,debug=False)

Ładowanie…
Anuluj
Zapisz