설명 없음
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

supportgpt_mistral.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import langchain
  2. from langchain.document_loaders import PyPDFLoader, DirectoryLoader
  3. from langchain.prompts import PromptTemplate
  4. from langchain.embeddings import HuggingFaceEmbeddings
  5. from langchain.vectorstores import FAISS
  6. from langchain.llms import CTransformers
  7. from langchain.chains import RetrievalQA
  8. from flask import Flask, request, render_template
  9. from datetime import datetime
  10. from flask import Flask, render_template, request, jsonify, session
  11. app = Flask(__name__)
  12. @app.route('/')
  13. def home():
  14. return render_template('index.html')
  15. DB_FAISS_PATH = 'vectorstore/db_faiss'
  16. custom_prompt_template = """Given the following context and a question, generate an answer based on this context only.
  17. In the answer try to provide as much text as possible from "response" section in the source document context without making much changes.
  18. If the answer is not found in the context, kindly state "I don't know." Don't try to make up an answer.
  19. CONTEXT: {context}
  20. QUESTION: {question}
  21. """
  22. def set_custom_prompt():
  23. """
  24. Prompt template for QA retrieval for each vectorstore
  25. """
  26. prompt = PromptTemplate(template=custom_prompt_template,
  27. input_variables=['context', 'question'])
  28. return prompt
  29. #Retrieval QA Chain
  30. def retrieval_qa_chain(llm, prompt, db):
  31. qa_chain = RetrievalQA.from_chain_type(llm=llm,
  32. chain_type='stuff',
  33. retriever=db.as_retriever(search_kwargs={'k': 2}),
  34. return_source_documents=True,
  35. chain_type_kwargs={'prompt': prompt}
  36. )
  37. return qa_chain
  38. #Loading the model
  39. def load_llm():
  40. # Load the locally downloaded model here
  41. # llm = CTransformers(
  42. # model = r"C:\Aiproject\Llama-2-7B-Chat-GGML\Llama-2-7B-Chat-GGML\llama-30b.ggmlv3.q8_0.bin",
  43. # model_type="llama",
  44. # max_new_tokens = 512,
  45. # temperature = 0.5
  46. # )
  47. llm = CTransformers(model=r"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",gpu_layers=100,config={'max_new_tokens': 128, 'temperature': 0.01})
  48. return llm
  49. #QA Model Function
  50. def qa_bot():
  51. embeddings = HuggingFaceEmbeddings(model_name=r"sentence-transformers/all-MiniLM-L6-v2",
  52. model_kwargs={'device': 'cpu'})
  53. db = FAISS.load_local(DB_FAISS_PATH, embeddings)
  54. llm = load_llm()
  55. qa_prompt = set_custom_prompt()
  56. qa = retrieval_qa_chain(llm, qa_prompt, db)
  57. return qa
  58. #output function
  59. def final_result(query):
  60. qa_result = qa_bot()
  61. response = qa_result({'query': query})
  62. return response
  63. # Streamlit application
  64. # st.title("Medical Bot")
  65. # if "query" not in st.session_state:
  66. # st.text("Hi, Welcome to Medical Bot. What is your query?")
  67. # query = st.text_input("Enter your question")
  68. # st.button("Ask")
  69. # if query:
  70. # st.text("Searching for relevant documents...")
  71. # response = final_result(query)
  72. # st.text(f"Sources: {response['source_documents']}")
  73. # st.text(f"Answer: {response['result']}")
  74. # else:
  75. # st.text("No query")
  76. @app.route('/user_input', methods=['GET', 'POST'])
  77. def index():
  78. if request.method == 'POST':
  79. query = request.form.get('user_input')
  80. session_id= request.form.get('session_id')
  81. #query = request.form['query']
  82. start=datetime.now()
  83. frmtstart=start.strftime('%Y-%m-%d %H:%M:%S')
  84. print('started',frmtstart)
  85. user_input=query.lower()
  86. print(user_input)
  87. if query:
  88. #result = chain(question)['result']
  89. response = final_result(user_input)
  90. result=response['result']
  91. print(result)
  92. end=datetime.now()
  93. endtime=end.strftime('%Y-%m-%d %H:%M:%S')
  94. print('Ended',endtime)
  95. else:
  96. result = None
  97. #return render_template('index.html', output=result)
  98. return jsonify({'output': result})
  99. return render_template('index.html', result=None)
  100. if __name__ == '__main__':
  101. app.run(host="0.0.0.0",port=8000,debug=False)