import streamlit as st
from huggingface_hub import InferenceClient

# Hugging Face API 토큰을 Hugging Face Secrets에서 불러옴
HF_TOKEN = st.secrets["HF_TOKEN"]

# Inference Client 설정 (GRIN-MoE 모델 사용)
client = InferenceClient(token=HF_TOKEN)

# Streamlit 페이지 설정
st.set_page_config(page_title="GRIN-MoE AI Chat", page_icon="🤖")
st.title("GRIN-MoE 모델과 대화해보세요!")

# 채팅 기록을 세션에 저장
if 'messages' not in st.session_state:
    st.session_state.messages = []

# 사용자 입력 받기
user_input = st.text_input("질문을 입력하세요:")

# 스트리밍 응답 함수
def generate_streaming_response(prompt):
    response_text = ""
    for message in client.chat_completion(
        model="microsoft/GRIN-MoE",  # 모델 이름을 명시적으로 전달
        messages=[{"role": "user", "content": prompt}],
        max_tokens=500,
        stream=True
    ):
        delta = message.choices[0].delta.content
        response_text += delta
        yield delta

# 대화 처리
if user_input:
    st.session_state.messages.append({"role": "user", "content": user_input})

    # AI 응답을 스트리밍 방식으로 보여줌
    with st.spinner('AI가 응답하는 중...'):
        response_text = ""
        for delta in generate_streaming_response(user_input):
            response_text += delta
            st.write(response_text)

        st.session_state.messages.append({"role": "assistant", "content": response_text})

# 이전 대화 기록 출력
if st.session_state.messages:
    for msg in st.session_state.messages:
        if msg["role"] == "user":
            st.write(f"**사용자:** {msg['content']}")
        else:
            st.write(f"**AI:** {msg['content']}")