"""
Support code for sending a generation query to a remotely served Llama model.
"""
import torch
import argparse
from transformers import AutoTokenizer
from logit_client import LogitClient

host = "https://nerc.guha-anderson.com"

# Initialize client and tokenizer
print(f"Connecting to server at {host}...")
client = LogitClient(host)

tokenizer = AutoTokenizer.from_pretrained(
    "unsloth/llama-3.2-3b-instruct-bnb-4bit",
    padding_side="left",
    clean_up_tokenization_spaces=False
)
tokenizer.pad_token = tokenizer.eos_token

model_name = "Llama-3.2-1B"

def generate(
        prompt: str,
        max_new_tokens: int = 10) -> str:
    """
    Retrieve generation
    """
    # Step 1: Tokenize input
    input_ids = tokenizer.encode(prompt, return_tensors="pt")[0]
    
    # Step 2: Get generation from model
    logits = client.get_generation(model_name, input_ids,max_new_tokens)

    # Step 3: Decode generated token ids to text
    generated_text = tokenizer.decode(logits[len(input_ids):], skip_special_tokens=True)
    return generated_text

def main():
    parser = argparse.ArgumentParser(description="Vanilla text generation")
    parser.add_argument("--prompt", required=True, help="Model name")
    parser.add_argument("--max_tokens", type=int, default=50, help="Maximum tokens to generate")
    args = parser.parse_args()

    # Generate text
    generated_text = get_generation(
        prompt=args.prompt,
        max_new_tokens=args.max_tokens
    )
    print(generated_text)
    return(generated_text)


if __name__ == "__main__":
    main()

