MedQA:在 AMD ROCm 上微调临床 AI,无需 CUDA
MedQA: Fine-Tuning a Clinical AI on AMD ROCm — No CUDA Required
Harikrishna Sivanand Iyer和Srijan Sivaram A为lablab.ai AMD Developer Hackathon构建MedQA,在AMD MI300X/ROCm 6.1上用LoRA fine-tuning Qwen3-1.7B。项目使用MedMCQA 2,000样本、fp16、PEFT/Transformers,训练约2.2M参数,耗时约5分钟,并发布HuggingFace模型、Spaces demo和GitHub代码。
](https://huggingface.co/HK2184)
一篇完整 walkthrough,介绍如何在 AMD MI300X 上使用 LoRA 对 MedMCQA 上的 Qwen3-1.7B 进行 fine-tuning,为 lablab.ai 上的 AMD Developer Hackathon 构建。
想法
医学问答是那类风险真正很高的任务。一个在临床 MCQ 上自信地选错答案的模型,不只是错误,而是危险。与此同时,大多数开源医学 AI 工作都默认你有 NVIDIA GPU。CUDA 是默认选项,其他方案往往只是事后补充。
这个项目挑战了这一假设。
MedQA 是一个完全在 AMD hardware 上使用 ROCm 构建的 LoRA fine-tuned 临床问答模型。它接收一道医学多选题,并返回正确答案字母以及对推理过程的临床解释。整个训练 pipeline——从数据加载到 adapter 导出——都在 AMD Instinct MI300X 上运行,没有任何 CUDA 依赖。
- 🤗 HuggingFace Hub 上的模型:HK2184/medqa-qwen3-lora
- 🚀 Live Demo:HuggingFace Spaces
- 💻 GitHub:MedQA-Medical-AI-on-AMD-ROCm
为什么选择 AMD ROCm?
AMD Instinct MI300X 是一块很强的 hardware:单设备内置 192 GB HBM3 memory。对 LLM fine-tuning 来说,VRAM 往往是限制条件——它决定了 batch size、sequence length,以及你是否必须进行 quantize。有了 192 GB 可用显存,我们以完整 fp16 训练了带 LoRA 的 Qwen3-1.7B,没有使用任何 4-bit 或 8-bit quantization 技巧。
更重要的是,这个目标是证明 HuggingFace 生态——Transformers、PEFT、TRL、Accelerate——可以在 ROCm 上顺畅工作。事实确实如此。同一份可在 CUDA 上运行的训练代码,只需设置三个环境变量就能在 ROCm 上运行:
os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"
就这样。无需改代码。无需 custom kernels。无需 CUDA compatibility shims。
数据集:MedMCQA
MedMCQA 是一个大规模多选题数据集,来源于印度医学入学考试(AIIMS、USMLE-style)。每个样本包含:
- 一个临床问题
- 四个答案选项(A–D)
- 正确答案索引
- 一个可选的自由文本解释(
exp字段)
本项目使用了 2,000 个训练样本——这是一个有意选择的小切片,用来展示有意义的 fine-tuning 可以快速完成。在 MI300X 上训练大约耗时 5 分钟。
模型:Qwen3-1.7B
基础模型是 Qwen/Qwen3-1.7B——Alibaba 最新的小规模 language model。它有 1.7 billion parameters,足够紧凑,可以低成本 fine-tune,同时也足够有能力生成连贯的临床推理。它支持 trust_remote_code=True,并且可以用 HuggingFace Transformers 顺利加载。
Prompt 格式
prompt 格式的一致性对 instruction fine-tuning 至关重要。每个训练样本和每次 inference 调用都使用相同的模板:
### Question:
{question}
### Options:
A) {opa}
B) {opb}
C) {opc}
D) {opd}
### Answer:
{answer_letter}) {answer_text}
### Explanation:
{explanation}
训练期间,模型会看到包含答案和解释的完整序列。inference 期间,我们提供直到 ### Answer:\n 为止的所有内容,让模型从这里继续补全。
使用 LoRA 训练
我们没有 fine-tune 全部 1.5 billion parameters,而是通过 PEFT library 使用 LoRA (Low-Rank Adaptation)。LoRA 会向 attention layers 中注入小型可训练的 rank-decomposition matrices,同时保持基础权重冻结。
LoRA 配置
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
bias="none",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443
模型 1.5 billion parameters 中只有 约 2.2 million 被训练。这让 memory usage 保持较低,训练也更快。
Training Arguments
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4, # effective batch size = 16
learning_rate=2e-4,
fp16=True,
bf16=False,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
有几点值得注意:
fp16=True, bf16=False—— 我们使用标准 fp16。在早期 bfloat16 实验中遇到了 NaN loss;切换到 fp16 后问题完全解决。gradient_checkpointing=True—— 用计算换 memory。考虑到 MI300X 有 192 GB VRAM,这并非严格必要,但对在较小 GPU 上复现是一个好习惯。gradient_accumulation_steps=4—— 物理 batch 为 4,有效 batch size 为 16。- 带 warmup 的 Cosine LR schedule —— 对短训练运行来说,比 flat schedule 收敛更平滑。
完整训练循环
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True,
pad_to_multiple_of=8,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
trainer.train()
# Save adapter + tokenizer
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
训练完成后,./outputs 包含 LoRA adapter 权重——只有几 MB 文件,而不是完整的多 GB model checkpoint。
Inference
inference 时,我们加载基础模型,挂载 LoRA adapter,并可选择 merge 权重:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()
生成使用 greedy decoding(do_sample=False),并配合 repetition penalty 防止模型循环输出:
def generate(prompt, model, tokenizer):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=200,
do_sample=False,
temperature=1.0,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = output[0][inputs["input_ids"].shape[-1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
示例输出
问题: 下列哪一项是高血压急症的一线治疗?
A) Oral amlodipine
B) IV labetalol or IV nitroprusside
C) Sublingual nifedipine
D) IM hydralazine
模型输出:
B) IV labetalol or IV nitroprusside
Explanation:
Intravenous labetalol (beta-blocker) or nitroprusside rapidly reduces blood
pressure in emergency settings. Oral agents act too slowly for hypertensive
emergencies requiring immediate BP control to prevent end-organ damage.
模型不只是输出一个字母——它会解释_为什么_,这正是它具备临床实用性的地方。
从 HuggingFace Hub 加载
fine-tuned adapter 已公开提供。你可以直接加载它,无需 clone repo:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-1.7B", trust_remote_code=True
)
base = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-1.7B",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()
挑战与修复
没有战斗故事的 AMD ROCm 项目是不完整的。以下是我们遇到的问题:
| 挑战 | 根因 | 修复 |
|---|---|---|
| NaN loss | Mixed precision 不稳定 | 从 bfloat16 切换到 fp16 |
| GPU 未被检测到 | 缺少 ROCm env variables | 设置 ROCR_VISIBLE_DEVICES、HIP_VISIBLE_DEVICES、HSA_OVERRIDE_GFX_VERSION |
| bitsandbytes 不受支持 | 没有 ROCm build 的 bitsandbytes | 完全放弃 quantization——MI300X 有足够的 VRAM |
| inference 输出混乱 | Tokenizer padding 配置错误 | 设置 pad_token = eos_token 并修正 padding_side |
| Trainer eval errors | Transformers version 不匹配 | 固定 transformers>=4.40.0 |
bitsandbytes 问题值得单独说明:在 NVIDIA hardware 上,4-bit quantization 通常是把模型装进 memory 的_必要条件_。但在拥有 192 GB HBM3 的 MI300X 上,它根本不需要。这是实实在在的 hardware 优势——训练更干净,也没有 quantization artifacts。
结果
| Metric | Value |
|---|---|
| Trainable parameters | ~2.2M(总量的 0.15%) |
| MI300X 上的训练时间 | ~5 分钟 |
| 使用的数据集规模 | 2,000 samples |
| Baseline MedMCQA accuracy | ~45% |
| Framework | PyTorch + ROCm 6.1 |
自己试试
没有 GPU?没问题。 live Gradio demo 运行在 HuggingFace Spaces 上(CPU inference):
👉 HuggingFace Spaces 上的 Live Demo
有 AMD hardware? Clone repo 并原生运行:
git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py # ~5 minutes
python infer.py # run sample questions
python app.py # launch Gradio UI
下一步
这个项目证明了 pipeline 可行。接下来要做的是扩展规模并提升稳健性:
- 更大的数据集 —— 在完整 MedMCQA 语料(约 180k 道题)上训练,并加入 PubMedQA
- Confidence scoring —— 在答案旁加入校准后的置信度估计
- RAG integration —— 通过实时医学文献检索为答案提供依据
- Evaluation harness —— 在训练划分之外进行规范的 held-out accuracy benchmarking
结论
MedQA 表明,在开源 AMD hardware 上构建一个有能力、可解释的医学 AI 不仅可行,而且相当直接。HuggingFace 生态的 ROCm 兼容性确实不错。MI300X 的 memory headroom 消除了整类工程问题。LoRA 则让 fine-tuning 一个 1.7B 模型变成了 5 分钟的工作。
如果你正在 AMD ROCm 上构建并遇到阻碍,上面的修复方法应该能为你节省数小时。如果你在构建医学 AI,那么相比单纯准确率,更重视解释这一点值得认真对待。
为 lablab.ai 上的 AMD Developer Hackathon 构建 · 由 AMD ROCm + HuggingFace 生态提供支持
*— Harikrishna Sivanand Iyer and Srijan Sivaram A
