一声棒喝,本不立文字
偏要著録,已是二义

huggingface-blog

MedQA:在 AMD ROCm 上微调临床 AI,无需 CUDA

MedQA: Fine-Tuning a Clinical AI on AMD ROCm — No CUDA Required

二〇二六年五月八日 · 英文原文

Harikrishna Sivanand Iyer和Srijan Sivaram A为lablab.ai AMD Developer Hackathon构建MedQA,在AMD MI300X/ROCm 6.1上用LoRA fine-tuning Qwen3-1.7B。项目使用MedMCQA 2,000样本、fp16、PEFT/Transformers,训练约2.2M参数,耗时约5分钟,并发布HuggingFace模型、Spaces demo和GitHub代码。

](https://huggingface.co/HK2184)

一篇完整 walkthrough,介绍如何在 AMD MI300X 上使用 LoRA 对 MedMCQA 上的 Qwen3-1.7B 进行 fine-tuning,为 lablab.ai 上的 AMD Developer Hackathon 构建。


想法

医学问答是那类风险真正很高的任务。一个在临床 MCQ 上自信地选错答案的模型,不只是错误,而是危险。与此同时,大多数开源医学 AI 工作都默认你有 NVIDIA GPU。CUDA 是默认选项,其他方案往往只是事后补充。

这个项目挑战了这一假设。

MedQA 是一个完全在 AMD hardware 上使用 ROCm 构建的 LoRA fine-tuned 临床问答模型。它接收一道医学多选题,并返回正确答案字母以及对推理过程的临床解释。整个训练 pipeline——从数据加载到 adapter 导出——都在 AMD Instinct MI300X 上运行,没有任何 CUDA 依赖。


为什么选择 AMD ROCm?

AMD Instinct MI300X 是一块很强的 hardware:单设备内置 192 GB HBM3 memory。对 LLM fine-tuning 来说,VRAM 往往是限制条件——它决定了 batch size、sequence length,以及你是否必须进行 quantize。有了 192 GB 可用显存,我们以完整 fp16 训练了带 LoRA 的 Qwen3-1.7B,没有使用任何 4-bit 或 8-bit quantization 技巧。

更重要的是,这个目标是证明 HuggingFace 生态——Transformers、PEFT、TRL、Accelerate——可以在 ROCm 上顺畅工作。事实确实如此。同一份可在 CUDA 上运行的训练代码,只需设置三个环境变量就能在 ROCm 上运行:

os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"

就这样。无需改代码。无需 custom kernels。无需 CUDA compatibility shims。


数据集:MedMCQA

MedMCQA 是一个大规模多选题数据集,来源于印度医学入学考试(AIIMS、USMLE-style)。每个样本包含:

本项目使用了 2,000 个训练样本——这是一个有意选择的小切片,用来展示有意义的 fine-tuning 可以快速完成。在 MI300X 上训练大约耗时 5 分钟。


模型:Qwen3-1.7B

基础模型是 Qwen/Qwen3-1.7B——Alibaba 最新的小规模 language model。它有 1.7 billion parameters,足够紧凑,可以低成本 fine-tune,同时也足够有能力生成连贯的临床推理。它支持 trust_remote_code=True,并且可以用 HuggingFace Transformers 顺利加载。


Prompt 格式

prompt 格式的一致性对 instruction fine-tuning 至关重要。每个训练样本和每次 inference 调用都使用相同的模板:

### Question:
{question}

### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### Answer:
{answer_letter}) {answer_text}

### Explanation:
{explanation}

训练期间,模型会看到包含答案和解释的完整序列。inference 期间,我们提供直到 ### Answer:\n 为止的所有内容,让模型从这里继续补全。


使用 LoRA 训练

我们没有 fine-tune 全部 1.5 billion parameters,而是通过 PEFT library 使用 LoRA (Low-Rank Adaptation)。LoRA 会向 attention layers 中注入小型可训练的 rank-decomposition matrices,同时保持基础权重冻结。

LoRA 配置

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443

模型 1.5 billion parameters 中只有 约 2.2 million 被训练。这让 memory usage 保持较低,训练也更快。

Training Arguments

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,     # effective batch size = 16
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    report_to="none",
)

有几点值得注意:

完整训练循环

from transformers import DataCollatorForSeq2Seq, Trainer

collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
)

trainer.train()

# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")

训练完成后,./outputs 包含 LoRA adapter 权重——只有几 MB 文件,而不是完整的多 GB model checkpoint。


Inference

inference 时,我们加载基础模型,挂载 LoRA adapter,并可选择 merge 权重:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()

生成使用 greedy decoding(do_sample=False),并配合 repetition penalty 防止模型循环输出:

def generate(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

示例输出

问题: 下列哪一项是高血压急症的一线治疗?

A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine

模型输出:

B) IV labetalol or IV nitroprusside

Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.

模型不只是输出一个字母——它会解释_为什么_,这正是它具备临床实用性的地方。


从 HuggingFace Hub 加载

fine-tuned adapter 已公开提供。你可以直接加载它,无需 clone repo:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-1.7B", trust_remote_code=True
)

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()

挑战与修复

没有战斗故事的 AMD ROCm 项目是不完整的。以下是我们遇到的问题:

挑战 根因 修复
NaN loss Mixed precision 不稳定 从 bfloat16 切换到 fp16
GPU 未被检测到 缺少 ROCm env variables 设置 ROCR_VISIBLE_DEVICESHIP_VISIBLE_DEVICESHSA_OVERRIDE_GFX_VERSION
bitsandbytes 不受支持 没有 ROCm build 的 bitsandbytes 完全放弃 quantization——MI300X 有足够的 VRAM
inference 输出混乱 Tokenizer padding 配置错误 设置 pad_token = eos_token 并修正 padding_side
Trainer eval errors Transformers version 不匹配 固定 transformers>=4.40.0

bitsandbytes 问题值得单独说明:在 NVIDIA hardware 上,4-bit quantization 通常是把模型装进 memory 的_必要条件_。但在拥有 192 GB HBM3 的 MI300X 上,它根本不需要。这是实实在在的 hardware 优势——训练更干净,也没有 quantization artifacts。


结果

Metric Value
Trainable parameters ~2.2M(总量的 0.15%)
MI300X 上的训练时间 ~5 分钟
使用的数据集规模 2,000 samples
Baseline MedMCQA accuracy ~45%
Framework PyTorch + ROCm 6.1

自己试试

没有 GPU?没问题。 live Gradio demo 运行在 HuggingFace Spaces 上(CPU inference):

👉 HuggingFace Spaces 上的 Live Demo

有 AMD hardware? Clone repo 并原生运行:

git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py   # ~5 minutes
python infer.py   # run sample questions
python app.py     # launch Gradio UI

下一步

这个项目证明了 pipeline 可行。接下来要做的是扩展规模并提升稳健性:


结论

MedQA 表明,在开源 AMD hardware 上构建一个有能力、可解释的医学 AI 不仅可行,而且相当直接。HuggingFace 生态的 ROCm 兼容性确实不错。MI300X 的 memory headroom 消除了整类工程问题。LoRA 则让 fine-tuning 一个 1.7B 模型变成了 5 分钟的工作。

如果你正在 AMD ROCm 上构建并遇到阻碍,上面的修复方法应该能为你节省数小时。如果你在构建医学 AI,那么相比单纯准确率,更重视解释这一点值得认真对待。


lablab.ai 上的 AMD Developer Hackathon 构建 · 由 AMD ROCm + HuggingFace 生态提供支持

*— Harikrishna Sivanand Iyer and Srijan Sivaram A

Image 2:2026-05-07 14-26-07 的截图

译自 huggingface-blog · 录于 二〇二六年五月八日