Skip to content

Add MPS (Apple Silicon) device support#7

Open
aki916 wants to merge 7 commits into
FujitsuResearch:mainfrom
aki916:fix/mps-device-support
Open

Add MPS (Apple Silicon) device support#7
aki916 wants to merge 7 commits into
FujitsuResearch:mainfrom
aki916:fix/mps-device-support

Conversation

@aki916
Copy link
Copy Markdown

@aki916 aki916 commented Apr 5, 2026

Summary

  • Add a device utility module ( onecomp/utils/device.py ) that auto-detects the best available device (CUDA > MPS > CPU) and provides a cross-platform empty_cache() helper.
  • Replace all hardcoded torch.cuda.empty_cache() calls with the device-agnostic empty_cache() .
  • Add safe wrappers for Cholesky operations (cholesky, cholesky_inverse, cholesky_solve) that fall back to CPU on MPS, since MPS does not support these ops.
  • Change the default device parameter from cuda:0 to auto-detection.

====================================================

GPTQ 4bit

  • MPS (a) — Full CPU fallback for run_gptq : Run the entire run_gptq on CPU; only data loading and pre/post-processing use MPS.
  • MPS (b) — Minimal CPU fallback (Cholesky only) : Run run_gptq on MPS but fall back to CPU solely for torch.linalg.cholesky / torch.cholesky_solve , which are not yet implemented on the MPS backend.

TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T

env \ calibration data 8 32 128
MPS (a) Apple M3 99.4 135.1 301.6
MPS (b) Apple M3 291.6 318.5 494.3
CPU Apple M3 207.1 668.9 2527.8
RTX A6000 151.1 159.3 156.6
A100 98.3 99.2 106.8
H100 102.4 102.6 106.9
RTX PRO 6000 Blackwell 111.2 113.0 126.6

meta-llama/Llama-3.1-8B

env \ calibration data 8 32 128
MPS (a) Apple M3 1126.4 1396.4 2643.4
CPU 2188.4 6433.6 > 7200
RTX A6000 506.1 494.2 584.0
A100 295.5 307.4 345.6
H100 291.5 300.3 345.9
RTX PRO 6000 Blackwell 319.3 325.0 388.9

@FKKimura FKKimura requested a review from aki916f April 23, 2026 13:45
@FKKimura FKKimura closed this Apr 23, 2026
@FKKimura FKKimura reopened this Apr 23, 2026
@aki916 aki916 force-pushed the fix/mps-device-support branch from bd35e24 to d7edfbe Compare April 29, 2026 08:06
@aki916
Copy link
Copy Markdown
Author

aki916 commented May 13, 2026

I checked my modification with following scripts;
We should specify the environment e.g. python xxx.py --device mps

import argparse
import traceback

import torch
from onecomp import GPTQ, CalibrationConfig, ModelConfig, QEPConfig, Runner, setup_logger
from onecomp.utils import empty_cache

MODEL_ID = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
MAX_LENGTH = 128
NUM_SAMPLES = 8
PROMPT = "Fujitsu is"
MAX_NEW_TOKENS = 32


def detect_device():
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def generate_text(model, tokenizer, device, prompt=PROMPT, max_new_tokens=MAX_NEW_TOKENS) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


def generate_original(model_config: ModelConfig) -> str:
    model = model_config.load_model()
    tokenizer = model_config.load_tokenizer()
    device = next(model.parameters()).device
    text = generate_text(model, tokenizer, device)
    del model, tokenizer
    empty_cache()
    return text


def generate_quantized(runner: Runner) -> str:
    model, tokenizer = runner.create_quantized_model()
    device = runner.model_config.device
    model.to(device)
    text = generate_text(model, tokenizer, device)
    model.to("cpu")
    del model, tokenizer
    empty_cache()
    return text


def print_generation_comparison(before: str, after: str) -> bool:
    print(f"  Prompt : {PROMPT}")
    print(f"  Before : {before}")
    print(f"  After  : {after}")

    if not before.startswith(PROMPT):
        print("  [WARN] Before output does not start with the prompt")
    if not after.startswith(PROMPT):
        print("  [WARN] After output does not start with the prompt")
        return False
    if len(after.strip()) <= len(PROMPT.strip()):
        print("  [WARN] After output has no generated continuation")
        return False
    if before == after:
        print("  [INFO] Before and after outputs are identical")
    else:
        print("  [INFO] Outputs differ (expected after quantization)")
    return True


def run_with_generation_check(test_name: str, device: str, run_fn) -> bool:
    print(f"[{test_name}] (device={device})")
    model_config = ModelConfig(model_id=MODEL_ID, device=device)

    print("  Generating with original model...")
    before = generate_original(model_config)

    runner = run_fn(model_config)
    print("  Generating with quantized model...")
    after = generate_quantized(runner)

    ok = print_generation_comparison(before, after)
    status = "PASSED" if ok else "FAILED (generation check)"
    print(f"[{test_name}] {status}")
    return ok


def test_gptq_only(device: str) -> bool:
    def run(model_config):
        calibration_config = CalibrationConfig(
            max_length=MAX_LENGTH,
            num_calibration_samples=NUM_SAMPLES,
        )
        runner = Runner(
            model_config=model_config,
            quantizer=GPTQ(wbits=4),
            calibration_config=calibration_config,
            qep=False,
        )
        runner.run()
        return runner

    return run_with_generation_check("Test 1: GPTQ only", device, run)


def test_gptq_with_qep(device: str) -> bool:
    def run(model_config):
        qep_config = QEPConfig(device=device)
        calibration_config = CalibrationConfig(
            max_length=MAX_LENGTH,
            num_calibration_samples=NUM_SAMPLES,
        )
        runner = Runner(
            model_config=model_config,
            quantizer=GPTQ(wbits=4),
            calibration_config=calibration_config,
            qep=True,
            qep_config=qep_config,
        )
        runner.run()
        return runner

    return run_with_generation_check("Test 2: GPTQ + QEP", device, run)


def test_auto_run(device: str) -> bool:
    def run(model_config):
        return Runner.auto_run(
            model_id=MODEL_ID,
            wbits=4.0,
            device=device,
            qep=True,
            evaluate=False,
            save_dir=None,
        )

    return run_with_generation_check("Test 3: auto_run", device, run)


def main():
    parser = argparse.ArgumentParser(description="Verify MPS device support fixes")
    parser.add_argument("--device", default=None, help="Device to test (cuda/mps/cpu)")
    args = parser.parse_args()

    device = args.device or detect_device()
    print(f"Testing device: {device}")
    print(f"PyTorch version: {torch.__version__}")
    print(f"CUDA available: {torch.cuda.is_available()}")
    print(f"MPS available: {hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()}")

    setup_logger()

    tests = [
        ("GPTQ only", test_gptq_only),
        ("GPTQ + QEP", test_gptq_with_qep),
        ("auto_run", test_auto_run),
    ]

    results = {}
    for name, test_fn in tests:
        try:
            results[name] = test_fn(device)
        except Exception:
            traceback.print_exc()
            results[name] = False
            print(f"[{name}] FAILED")

    print("\nSummary:")
    for name, passed in results.items():
        status = "PASS" if passed else "FAIL"
        print(f"  {name:20s} : {status}")


if __name__ == "__main__":
    main()

@aki916 aki916 marked this pull request as ready for review May 14, 2026 02:13
@aki916f
Copy link
Copy Markdown
Contributor

aki916f commented May 15, 2026

Thanks for your work, we're reviewing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants