"""
Note:
    Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
"""
import argparse
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import torch
from threading import Thread

MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'


def stress_test(token_len, n, num_gpu):
    device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left")
    model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device).eval()

    # Use INT4 weight infer
    # model = AutoModelForCausalLM.from_pretrained(
    #     MODEL_PATH,
    #     trust_remote_code=True,
    #     quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    #     low_cpu_mem_usage=True,
    # ).eval()

    times = []
    decode_times = []

    print("Warming up...")
    vocab_size = tokenizer.vocab_size
    warmup_token_len = 20
    random_token_ids = torch.randint(3, vocab_size - 200, (warmup_token_len - 5,), dtype=torch.long)
    start_tokens = [151331, 151333, 151336, 198]
    end_tokens = [151337]
    input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(0).to(
        device)
    attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
    position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
    warmup_inputs = {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'position_ids': position_ids
    }
    with torch.no_grad():
        _ = model.generate(
            input_ids=warmup_inputs['input_ids'],
            attention_mask=warmup_inputs['attention_mask'],
            max_new_tokens=2048,
            do_sample=False,
            repetition_penalty=1.0,
            eos_token_id=[151329, 151336, 151338]
        )
    print("Warming up complete. Starting stress test...")

    for i in range(n):
        random_token_ids = torch.randint(3, vocab_size - 200, (token_len - 5,), dtype=torch.long)
        input_ids = torch.tensor(start_tokens + random_token_ids.tolist() + end_tokens, dtype=torch.long).unsqueeze(
            0).to(device)
        attention_mask = torch.ones_like(input_ids, dtype=torch.bfloat16).to(device)
        position_ids = torch.arange(len(input_ids[0]), dtype=torch.bfloat16).unsqueeze(0).to(device)
        test_inputs = {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'position_ids': position_ids
        }

        streamer = TextIteratorStreamer(
            tokenizer=tokenizer,
            timeout=36000,
            skip_prompt=True,
            skip_special_tokens=True
        )

        generate_kwargs = {
            "input_ids": test_inputs['input_ids'],
            "attention_mask": test_inputs['attention_mask'],
            "max_new_tokens": 512,
            "do_sample": False,
            "repetition_penalty": 1.0,
            "eos_token_id": [151329, 151336, 151338],
            "streamer": streamer
        }

        start_time = time.time()
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        first_token_time = None
        all_token_times = []

        for token in streamer:
            current_time = time.time()
            if first_token_time is None:
                first_token_time = current_time
                times.append(first_token_time - start_time)
            all_token_times.append(current_time)

        t.join()
        end_time = time.time()

        avg_decode_time_per_token = len(all_token_times) / (end_time - first_token_time) if all_token_times else 0
        decode_times.append(avg_decode_time_per_token)
        print(
            f"Iteration {i + 1}/{n} - Prefilling Time: {times[-1]:.4f} seconds - Average Decode Time: {avg_decode_time_per_token:.4f} tokens/second")

        torch.cuda.empty_cache()

    avg_first_token_time = sum(times) / n
    avg_decode_time = sum(decode_times) / n
    print(f"\nAverage First Token Time over {n} iterations: {avg_first_token_time:.4f} seconds")
    print(f"Average Decode Time per Token over {n} iterations: {avg_decode_time:.4f} tokens/second")
    return times, avg_first_token_time, decode_times, avg_decode_time


def main():
    parser = argparse.ArgumentParser(description="Stress test for model inference")
    parser.add_argument('--token_len', type=int, default=1000, help='Number of tokens for each test')
    parser.add_argument('--n', type=int, default=3, help='Number of iterations for the stress test')
    parser.add_argument('--num_gpu', type=int, default=1, help='Number of GPUs to use for inference')
    args = parser.parse_args()

    token_len = args.token_len
    n = args.n
    num_gpu = args.num_gpu

    stress_test(token_len, n, num_gpu)


if __name__ == "__main__":
    main()