12,675
社区成员
发帖
与我相关
我的任务
分享昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者breeze输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。
这篇文章记录了把Llama 7B从PyTorch/HF生态迁到 MindSpore的过程。不是广告,不是评测,也不是哲学讨论,就是扎扎实实的技术活,踩过的坑都摊开说。
目标:在Ascend上用MindSpore跑通Llama(推理 + 微调),尽量少魔改,支持KV Cache、RoPE、混合精度和断点恢复。
限制:不依赖奇怪分支;只用公开可得的接口(MindSpore 基座 + 常见组件)。
策略:能复用的就复用(Tokenizer、权重),不能复用的就写一个薄转换层。不追求一步到位,但要“能打”。
MindSpore 两种模式:GRAPH_MODE(编译图)和 PYNATIVE_MODE(动态图)。在Ascend上尽量用 GRAPH,性能差一大截不是开玩笑的。
import mindspore as ms
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
# 可选:减少首次编译抖动
ms.set_context(jit_config={"jit_level": "O2"}) # 视版本而定
混合精度推荐O2,配合loss scale(训练阶段):
from mindspore.amp import auto_mixed_precision, StaticLossScaler
net = build_llama() # 你自己的 Llama Cell
auto_mixed_precision(net, "O2") # 权重/计算多落到 fp16/bf16
loss_scaler = StaticLossScaler(2**12)
踩坑 1:MindSpore对Ascend的算子融合比较激进,图模式下某些自定义Python控制流容易被“优化没了”。遇到莫名其妙的数值波动,先关掉你新加的“聪明”控制流。
Tokenizer:我直接复用 HF 的 tokenizer.json和 tokenizer.model,在数据前处理阶段完成编码解码。训练/推理时只给 MindSpore 喂 input_ids和 attention_mask(注意 mask 的 dtype 和 shape)。
RoPE(Rotary Embedding):MindSpore 里实现 RoPE 时,位置索引的广播维度和角度表(cos/sin)缓存要提前考虑到 prefill+decode两阶段。
简化做法:预缓存最大max_seq_len的cos/sin;decode阶段按 pos_offset索引切片。
def precompute_rope(theta_base, head_dim, max_len, dtype=ms.float16):
inv_freq = 1.0 / (theta_base ** (ms.numpy.arange(0, head_dim, 2, dtype=ms.float32) / head_dim))
t = ms.numpy.arange(max_len, dtype=ms.float32)
freqs = ms.numpy.einsum('n,d->nd', t, inv_freq)
cos = ms.numpy.cos(freqs).astype(dtype)
sin = ms.numpy.sin(freqs).astype(dtype)
return cos, sin
def apply_rope(q, k, cos, sin, pos): # q/k: [bs, n_head, seq, head_dim]
cos_t = cos[pos] # [seq, head_dim/2]
sin_t = sin[pos]
# 扩维到 [bs, n_head, seq, head_dim/2]
for _ in range(2): # 简单粗暴两次 expand
cos_t = ms.ops.expand_dims(cos_t, 0)
sin_t = ms.ops.expand_dims(sin_t, 0)
cos_t = ms.ops.expand_dims(cos_t, 0)
sin_t = ms.ops.expand_dims(sin_t, 0)
q1, q2 = q[..., ::2], q[..., 1::2]
k1, k2 = k[..., ::2], k[..., 1::2]
q_rot = ms.ops.stack([q1 * cos_t - q2 * sin_t, q1 * sin_t + q2 * cos_t], axis=-1).reshape(q.shape)
k_rot = ms.ops.stack([k1 * cos_t - k2 * sin_t, k1 * sin_t + k2 * cos_t], axis=-1).reshape(k.shape)
return q_rot, k_rot
踩坑 2:有的实现把cos/sin的 layout 写反了;decode 阶段pos要累加(pos_offset += 1),别反复从0开始。
HuggingFace 的 Llama 权重是多个pytorch_model-*.bin。思路:用torch.load拿 state_dict,做键名映射,再 mindspore.save_checkpoint。
1、键名映射表(示例)
HuggingFace(常见) → MindSpore(示例命名):
model.embed_tokens.weight → tok_embeddings.embedding_table
model.layers.{i}.self_attn.q_proj.weight → blocks.{i}.attn.wq.weight
model.layers.{i}.self_attn.k_proj.weight → blocks.{i}.attn.wk.weight
model.layers.{i}.self_attn.v_proj.weight → blocks.{i}.attn.wv.weight
model.layers.{i}.self_attn.o_proj.weight → blocks.{i}.attn.wo.weight
model.layers.{i}.mlp.gate_proj.weight → blocks.{i}.mlp.w1.weight
model.layers.{i}.mlp.up_proj.weight → blocks.{i}.mlp.w3.weight
model.layers.{i}.mlp.down_proj.weight → blocks.{i}.mlp.w2.weight
model.layers.{i}.input_layernorm.weight → blocks.{i}.ln1.gamma
model.layers.{i}.post_attention_layernorm.weight → blocks.{i}.ln2.gamma
lm_head.weight → lm_head.weight
model.norm.weight → final_norm.gamma
2、转换脚本(最小可用)
import os, torch, mindspore as ms
from mindspore import save_checkpoint, Tensor
def map_key(hf_key: str):
key = hf_key
key = key.replace("model.embed_tokens.weight", "tok_embeddings.embedding_table")
key = key.replace("model.norm.weight", "final_norm.gamma")
key = key.replace("lm_head.weight", "lm_head.weight")
key = key.replace("model.layers.", "blocks.")
key = key.replace(".self_attn.q_proj.", ".attn.wq.")
key = key.replace(".self_attn.k_proj.", ".attn.wk.")
key = key.replace(".self_attn.v_proj.", ".attn.wv.")
key = key.replace(".self_attn.o_proj.", ".attn.wo.")
key = key.replace(".mlp.gate_proj.", ".mlp.w1.")
key = key.replace(".mlp.down_proj.", ".mlp.w2.")
key = key.replace(".mlp.up_proj.", ".mlp.w3.")
key = key.replace(".input_layernorm.weight", ".ln1.gamma")
key = key.replace(".post_attention_layernorm.weight", ".ln2.gamma")
return key
def torch_to_mindspore_ckpt(hf_dir, ms_ckpt_path, dtype=ms.float16):
# 1) 收集所有 shard
sd = {}
for name in sorted(os.listdir(hf_dir)):
if name.startswith("pytorch_model-") and name.endswith(".bin"):
part = torch.load(os.path.join(hf_dir, name), map_location="cpu")
sd.update(part)
elif name == "pytorch_model.bin":
sd.update(torch.load(os.path.join(hf_dir, name), map_location="cpu"))
# 2) 键名映射 + 类型转换
ms_params = []
for k, v in sd.items():
ms_k = map_key(k)
if "rope.freqs" in ms_k:
continue
np_v = v.numpy()
ms_params.append({"name": ms_k, "data": Tensor(np_v).astype(dtype)})
save_checkpoint(ms_params, ms_ckpt_path)
print(f"Saved MindSpore ckpt to: {ms_ckpt_path}")
# 用法:
# torch_to_mindspore_ckpt("/path/to/llama-hf", "llama7b_ms.ckpt", dtype=ms.float16)
踩坑 3:LayerNorm 在 Llama 是无 bias,MindSpore 里如果你 LayerNorm 定义带 beta,要么删掉,要么初始为 0 并在图里不使用;否则数值会“飘”。
1、Attention mask 语义
训练:通常是 [bs, 1, seq, seq]或 [bs, seq]的下三角 + padding mask。
推理:prefill 阶段 mask 仍按下三角;decode 阶段仅对新 token 做与历史的点积,mask 形状变小。
建议统一为 floatmask,填充不可见位置为 -1e4(或和你 softmax 实现一致的 -inf),避免 dtype 乱战。
2、简化版 KV Cache
class KvCache:
def __init__(self, n_layer, n_head, max_batch, max_len, head_dim, dtype=ms.float16):
self.k = [ms.numpy.zeros((max_batch, n_head, max_len, head_dim), dtype=dtype) for _ in range(n_layer)]
self.v = [ms.numpy.zeros((max_batch, n_head, max_len, head_dim), dtype=dtype) for _ in range(n_layer)]
self.pos = 0 # 当前 decode 写入位置
def update(self, layer_idx, k_new, v_new): # [bs, head, 1, dim]
p = self.pos
self.k[layer_idx][:, :, p:p+1, :] = k_new
self.v[layer_idx][:, :, p:p+1, :] = v_new
def step(self):
self.pos += 1
踩坑 4:别在 decode 阶段每步都 concat,就地写入slice,Ascend 的内存移动不白嫖。
06 训练与微调(LoRA/全参)
LoRA在 MindSpore 的一个常见实现:给线性层包一个 A/B 低秩旁路,前向时加上 x @ A @ B * alpha/r。
建议把 LoRA 的参数单独分组,禁用 weight decay;并只在 target 模块(q_proj, v_proj, o_proj, w1/w3)上挂。
def wrap_lora(linear, r=16, alpha=32):
in_f, out_f = linear.in_channels, linear.out_channels
A = ms.Parameter(ms.ops.zeros((in_f, r), ms.float16))
B = ms.Parameter(ms.ops.zeros((r, out_f), ms.float16))
scale = alpha / r
def forward(x):
base = linear(x)
lora = ms.ops.matmul(ms.ops.matmul(x, A), B) * scale
return base + lora
linear.forward = forward
return linear
踩坑 5:MindSpore Graph 下如果你“猴子补丁”forward,要确保图能稳定跟住;更稳的做法是写一个 LoraLinear(Cell)包起来。
GRAPH_MODE + O2 混合精度:不解释。
大 batch prefill:把多条输入拼长些,prefill 吞吐会好不少(当然别 OOM)。
KV Cache 扁平化:把 [bs, head, t, dim]按设备最友好的内存布局摆放(这块我没深挖,简单就地 slice 已经够用)。
避免 Python 回环:decode loop 尽量把张量操作留在图里,减少 host 参与。
检查算子降级:图编译日志里搜 “fallback/host” 之类关键词,别让关键算子跑到 CPU 端。
import mindspore as ms
from mindspore import Tensor
import numpy as np
net = build_llama_from_ckpt("llama7b_ms.ckpt") # 你的加载逻辑
net.set_train(False)
# 假装我们已经有 tokenizer
prompt_ids = np.array([[1, 42, 123, 456]]) # <s> ...
attn_mask = np.ones_like(prompt_ids)
# prefill
logits, cache, pos = net(Tensor(prompt_ids, ms.int32), Tensor(attn_mask, ms.float32), cache=None, pos=0)
# decode N 步
generated = []
x = np.array([[50256]]) # 假设上一步采样出的 token
for _ in range(32):
l, cache, pos = net(Tensor(x, ms.int32), Tensor(np.ones_like(x)), cache=cache, pos=pos)
next_id = int(ms.ops.argmax(l[0, -1], axis=-1).asnumpy())
generated.append(next_id)
x = np.array([[next_id]])
踩坑 6:很多人把 pos写死,导致 RoPE 永远用到第 0 行,性能和数值全飞。prefill 后 pos 应等于上下文长度,decode 逐步 +1。
Shape 不一致:尤其 attention_mask,MindSpore 的广播规则和你在 PyTorch 的“侥幸成功”未必一致,显式 reshape保命。
LayerNorm gamma/beta:权重名映射遗漏,或 beta 多出来。
溢出:fp16 的 matmul 穿了,loss scale 或者切到 bf16。
图编译卡慢:第一次长一些正常,第二次还慢,看看是否每次都在重建图(输入 shape 乱飘)。
迁 Llama 到 MindSpore 没有想象中那么可怕,难点集中在键名映射、RoPE 位移、KV Cache 写法三件事。
一旦跑通,Ascend 上的吞吐和能效都挺能打。别追求一步封神,先上一个“能打”的版本,再迭代优化。
最后,再次提醒自己:少写骚代码,别给图编译添堵。有时候“朴素写法”反而更快更稳(这点我已经被现实教育过两次,脸疼)。