import requests
import json
from typing import Union, Dict, Generator
import time

class ChatCompletionTester:
    def __init__(self, base_url: str = "http://localhost:8000"):
        self.base_url = base_url
        self.endpoint = f"{base_url}/chat/completions"
        
    def create_test_payload(self, stream: bool = False) -> Dict:
        """Create a sample payload for testing"""
        return {
            "model": "mistralai/Mixtral-8x22B-Instruct-v0.1",
            "messages": [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": "What is the capital of France?"}
            ],
            "temperature": 0.7,
            "max_tokens": 4096,
            "stream": stream
        }

    def test_non_streaming(self) -> Union[Dict, None]:
        """Test non-streaming response"""
        print("\n=== Testing Non-Streaming Response ===")
        try:
            payload = self.create_test_payload(stream=False)
            print("Sending request...")
            
            response = requests.post(
                self.endpoint,
                json=payload,
                headers={"Content-Type": "application/json"}
            )
            
            if response.status_code == 200:
                result = response.json()
                content = result['choices'][0]['message']['content']
                print("\nResponse received successfully!")
                print(f"Content: {content}")
                return result
            else:
                print(f"Error: Status code {response.status_code}")
                print(f"Response: {response.text}")
                return None
                
        except Exception as e:
            print(f"Error during non-streaming test: {str(e)}")
            return None

    def test_streaming(self) -> Union[str, None]:
        """Test streaming response"""
        print("\n=== Testing Streaming Response ===")
        try:
            payload = self.create_test_payload(stream=True)
            print("Sending request...")
            
            response = requests.post(
                self.endpoint,
                json=payload,
                headers={"Content-Type": "application/json"},
                stream=True
            )
            
            if response.status_code == 200:
                print("\nReceiving streaming response:")
                full_response = ""
                for line in response.iter_lines(decode_unicode=True):
                    if line:
                        if line.startswith("data: "):
                            try:
                                data = json.loads(line[6:])
                                if data == "[DONE]":
                                    continue
                                content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
                                if content:
                                    print(content, end="", flush=True)
                                    full_response += content
                            except json.JSONDecodeError:
                                continue
                print("\n\nStreaming completed!")
                return full_response
            else:
                print(f"Error: Status code {response.status_code}")
                print(f"Response: {response.text}")
                return None
                
        except Exception as e:
            print(f"Error during streaming test: {str(e)}")
            return None

    def run_all_tests(self):
        """Run both streaming and non-streaming tests"""
        print("Starting API endpoint tests...")
        
        # Test server connectivity
        try:
            requests.get(self.base_url)
            print("✓ Server is accessible")
        except requests.exceptions.ConnectionError:
            print("✗ Server is not accessible. Please ensure the FastAPI server is running.")
            return

        # Run tests with timing
        start_time = time.time()
        
        # Test non-streaming
        non_streaming_result = self.test_non_streaming()
        if non_streaming_result:
            print("✓ Non-streaming test passed")
        else:
            print("✗ Non-streaming test failed")

        # Test streaming
        streaming_result = self.test_streaming()
        if streaming_result:
            print("✓ Streaming test passed")
        else:
            print("✗ Streaming test failed")

        end_time = time.time()
        print(f"\nAll tests completed in {end_time - start_time:.2f} seconds")

def main():
    # Create tester instance
    tester = ChatCompletionTester()
    
    # Run all tests
    tester.run_all_tests()

if __name__ == "__main__":
    main()