Skip to content

Commit 0c51162

Browse files
authored
Merge pull request #12 from stratosphereips/harpo_unsloth_scripts
add utils scripts for hf transformers usage
2 parents 64b5410 + cf6d5ec commit 0c51162

2 files changed

Lines changed: 150 additions & 0 deletions

File tree

unsloth-scripts/chat_model.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import os
2+
os.environ['HF_HOME'] = '/media/data/hf/'
3+
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0'
4+
import argparse
5+
import torch
6+
from transformers import (
7+
AutoTokenizer,
8+
AutoModelForCausalLM,
9+
TextStreamer,
10+
BitsAndBytesConfig,
11+
)
12+
13+
def load_model(model_name, device_str, quantization):
14+
"""
15+
Load the model and tokenizer with optional quantization and device selection.
16+
"""
17+
print(f"Loading model: {model_name}")
18+
19+
# Determine device
20+
if device_str == "auto":
21+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22+
else:
23+
device = torch.device(device_str)
24+
25+
kwargs = {}
26+
27+
# Use safe dtype for CPU
28+
if device.type == "cpu":
29+
kwargs["torch_dtype"] = torch.float32
30+
else:
31+
kwargs["torch_dtype"] = torch.float16
32+
kwargs["device_map"] = "auto"
33+
34+
if quantization == "4bit":
35+
quant_config = BitsAndBytesConfig(
36+
load_in_4bit=True,
37+
bnb_4bit_compute_dtype=torch.float16,
38+
bnb_4bit_use_double_quant=True,
39+
bnb_4bit_quant_type="nf4" # or "fp4" if supported
40+
)
41+
kwargs["quantization_config"] = quant_config
42+
elif quantization == "8bit":
43+
kwargs["load_in_8bit"] = True
44+
45+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
46+
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
47+
48+
if device.type == "cpu":
49+
model.to(device)
50+
51+
print(f"Model loaded on {device} with quantization: {quantization or 'none'}")
52+
return tokenizer, model, device
53+
54+
55+
def chat_stream(tokenizer, model, device, max_length=512):
56+
"""
57+
Interactive chat loop with streaming and optional chat template.
58+
"""
59+
history = []
60+
print("\n>>> Interactive chat started. Type 'exit' to quit.\n")
61+
62+
while True:
63+
user_input = input("You: ")
64+
if user_input.lower() in ["exit", "quit"]:
65+
break
66+
67+
history.append({"role": "user", "content": user_input})
68+
69+
# Format with chat template if available
70+
if hasattr(tokenizer, "apply_chat_template"):
71+
prompt = tokenizer.apply_chat_template(
72+
history,
73+
tokenize=False,
74+
add_generation_prompt=True,
75+
return_tensors=None
76+
)
77+
else:
78+
prompt = "\n".join([f"User: {msg['content']}" if msg["role"] == "user"
79+
else f"Assistant: {msg['content']}" for msg in history])
80+
prompt += "\nAssistant:"
81+
82+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
83+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
84+
85+
print("Assistant:", end=" ", flush=True)
86+
87+
model.generate(
88+
**inputs,
89+
max_length=max_length,
90+
pad_token_id=tokenizer.eos_token_id,
91+
streamer=streamer,
92+
)
93+
print() # newline
94+
95+
full_output = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
96+
assistant_reply = full_output[len(prompt):].strip()
97+
history.append({"role": "assistant", "content": assistant_reply})
98+
99+
100+
def main():
101+
parser = argparse.ArgumentParser(description="Chat with a Hugging Face model with optional quantization and device control.")
102+
parser.add_argument("model_name", type=str, help="Model ID from Hugging Face hub (e.g., mistralai/Mistral-7B-Instruct-v0.2)")
103+
parser.add_argument("--max_length", type=int, default=512, help="Maximum generation length")
104+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "cuda"], help="Device to run the model on")
105+
parser.add_argument("--quant", type=str, choices=["4bit", "8bit"], help="Optional quantization: 4bit or 8bit")
106+
107+
args = parser.parse_args()
108+
109+
tokenizer, model, device = load_model(args.model_name, args.device, args.quant)
110+
chat_stream(tokenizer, model, device, args.max_length)
111+
112+
113+
if __name__ == "__main__":
114+
main()

unsloth-scripts/download_model.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
import argparse
3+
4+
os.environ['HF_HOME'] = '/media/data/hf/'
5+
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '0'
6+
7+
8+
from transformers import AutoTokenizer, AutoModel
9+
10+
def download_model(model_name: str, save_dir: str = "./downloaded_model"):
11+
"""
12+
Downloads a model and its tokenizer from Hugging Face and saves it locally.
13+
14+
Args:
15+
model_name (str): The model name or path from Hugging Face Hub (e.g., "bert-base-uncased").
16+
save_dir (str): The directory where the model and tokenizer will be saved.
17+
"""
18+
os.makedirs(save_dir, exist_ok=True)
19+
20+
print(f"Downloading model and tokenizer: {model_name}")
21+
22+
tokenizer = AutoTokenizer.from_pretrained(model_name)
23+
model = AutoModel.from_pretrained(model_name)
24+
25+
tokenizer.save_pretrained(save_dir)
26+
model.save_pretrained(save_dir)
27+
28+
print(f"Model and tokenizer saved to: {save_dir}")
29+
30+
if __name__ == "__main__":
31+
parser = argparse.ArgumentParser(description="Download a Hugging Face transformer model and tokenizer.")
32+
parser.add_argument("model_name", type=str, help="The model name or path (e.g., 'bert-base-uncased')")
33+
parser.add_argument("--save_dir", type=str, default="./downloaded_model", help="Directory to save the model and tokenizer")
34+
args = parser.parse_args()
35+
36+
download_model(args.model_name, args.save_dir)

0 commit comments

Comments
 (0)