把 Llama 迁到 MindSpore:一份带坑的实战笔记

昇思MindSpore 2025-11-24 11:21:11

昇思MindSpore2024年技术帖分享大会圆满结束!全年收获80+高质量技术帖, 2025年全新升级,推出“2025年昇思干货小卖部,你投我就收!”,活动继续每月征集技术帖。本期技术文章由社区开发者breeze输出并投稿。如果您对活动感兴趣,欢迎在昇思论坛投稿。

这篇文章记录了把Llama 7B从PyTorch/HF生态迁到 MindSpore的过程。不是广告,不是评测,也不是哲学讨论,就是扎扎实实的技术活,踩过的坑都摊开说。

01 背景和目标

  • 目标:在Ascend上用MindSpore跑通Llama(推理 + 微调),尽量少魔改,支持KV Cache、RoPE、混合精度和断点恢复。

  • 限制:不依赖奇怪分支;只用公开可得的接口(MindSpore 基座 + 常见组件)。

  • 策略:能复用的就复用(Tokenizer、权重),不能复用的就写一个薄转换层。不追求一步到位,但要“能打”。

02 环境要点

MindSpore 两种模式:GRAPH_MODE(编译图)和 PYNATIVE_MODE(动态图)。在Ascend上尽量用 GRAPH,性能差一大截不是开玩笑的。

import mindspore as ms
ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend")
# 可选:减少首次编译抖动
ms.set_context(jit_config={"jit_level": "O2"})  # 视版本而定

混合精度推荐O2,配合loss scale(训练阶段):

from mindspore.amp import auto_mixed_precision, StaticLossScaler

net = build_llama()              # 你自己的 Llama Cell
auto_mixed_precision(net, "O2")  # 权重/计算多落到 fp16/bf16
loss_scaler = StaticLossScaler(2**12)

踩坑 1:MindSpore对Ascend的算子融合比较激进,图模式下某些自定义Python控制流容易被“优化没了”。遇到莫名其妙的数值波动,先关掉你新加的“聪明”控制流。

03 Tokenizer 与 RoPE:别在细节上翻车

  • Tokenizer:我直接复用 HF 的 tokenizer.json和 tokenizer.model,在数据前处理阶段完成编码解码。训练/推理时只给 MindSpore 喂 input_ids和 attention_mask(注意 mask 的 dtype 和 shape)。

  • RoPE(Rotary Embedding):MindSpore 里实现 RoPE 时,位置索引的广播维度和角度表(cos/sin)缓存要提前考虑到 prefill+decode两阶段。
    简化做法:预缓存最大max_seq_len的cos/sin;decode阶段按 pos_offset索引切片。

def precompute_rope(theta_base, head_dim, max_len, dtype=ms.float16):
    inv_freq = 1.0 / (theta_base ** (ms.numpy.arange(0, head_dim, 2, dtype=ms.float32) / head_dim))
    t = ms.numpy.arange(max_len, dtype=ms.float32)
    freqs = ms.numpy.einsum('n,d->nd', t, inv_freq)
    cos = ms.numpy.cos(freqs).astype(dtype)
    sin = ms.numpy.sin(freqs).astype(dtype)
    return cos, sin

def apply_rope(q, k, cos, sin, pos):  # q/k: [bs, n_head, seq, head_dim]
    cos_t = cos[pos]  # [seq, head_dim/2]
    sin_t = sin[pos]
    # 扩维到 [bs, n_head, seq, head_dim/2]
    for _ in range(2):  # 简单粗暴两次 expand
        cos_t = ms.ops.expand_dims(cos_t, 0)
        sin_t = ms.ops.expand_dims(sin_t, 0)
    cos_t = ms.ops.expand_dims(cos_t, 0)
    sin_t = ms.ops.expand_dims(sin_t, 0)
    
    q1, q2 = q[..., ::2], q[..., 1::2]
    k1, k2 = k[..., ::2], k[..., 1::2]
    q_rot = ms.ops.stack([q1 * cos_t - q2 * sin_t, q1 * sin_t + q2 * cos_t], axis=-1).reshape(q.shape)
    k_rot = ms.ops.stack([k1 * cos_t - k2 * sin_t, k1 * sin_t + k2 * cos_t], axis=-1).reshape(k.shape)
    return q_rot, k_rot

踩坑 2:有的实现把cos/sin的 layout 写反了;decode 阶段pos要累加(pos_offset += 1),别反复从0开始。

04 权重转换:从 HuggingFace → MindSpore .ckpt

HuggingFace 的 Llama 权重是多个pytorch_model-*.bin。思路:用torch.load拿 state_dict,做键名映射,再 mindspore.save_checkpoint。

1、键名映射表(示例)

HuggingFace(常见) → MindSpore(示例命名):


model.embed_tokens.weight                        → tok_embeddings.embedding_table
model.layers.{i}.self_attn.q_proj.weight         → blocks.{i}.attn.wq.weight
model.layers.{i}.self_attn.k_proj.weight         → blocks.{i}.attn.wk.weight
model.layers.{i}.self_attn.v_proj.weight         → blocks.{i}.attn.wv.weight
model.layers.{i}.self_attn.o_proj.weight         → blocks.{i}.attn.wo.weight
model.layers.{i}.mlp.gate_proj.weight            → blocks.{i}.mlp.w1.weight
model.layers.{i}.mlp.up_proj.weight              → blocks.{i}.mlp.w3.weight
model.layers.{i}.mlp.down_proj.weight            → blocks.{i}.mlp.w2.weight
model.layers.{i}.input_layernorm.weight          → blocks.{i}.ln1.gamma
model.layers.{i}.post_attention_layernorm.weight → blocks.{i}.ln2.gamma
lm_head.weight                                   → lm_head.weight
model.norm.weight                                → final_norm.gamma

2、转换脚本(最小可用)

import os, torch, mindspore as ms
from mindspore import save_checkpoint, Tensor

def map_key(hf_key: str):
    key = hf_key
    key = key.replace("model.embed_tokens.weight", "tok_embeddings.embedding_table")
    key = key.replace("model.norm.weight", "final_norm.gamma")
    key = key.replace("lm_head.weight", "lm_head.weight")
    
    key = key.replace("model.layers.", "blocks.")
    key = key.replace(".self_attn.q_proj.", ".attn.wq.")
    key = key.replace(".self_attn.k_proj.", ".attn.wk.")
    key = key.replace(".self_attn.v_proj.", ".attn.wv.")
    key = key.replace(".self_attn.o_proj.", ".attn.wo.")
   
    key = key.replace(".mlp.gate_proj.", ".mlp.w1.")
    key = key.replace(".mlp.down_proj.", ".mlp.w2.")
    key = key.replace(".mlp.up_proj.", ".mlp.w3.")
    
    key = key.replace(".input_layernorm.weight", ".ln1.gamma")
    key = key.replace(".post_attention_layernorm.weight", ".ln2.gamma")
    return key

def torch_to_mindspore_ckpt(hf_dir, ms_ckpt_path, dtype=ms.float16):
    # 1) 收集所有 shard
    sd = {}
    for name in sorted(os.listdir(hf_dir)):
        if name.startswith("pytorch_model-") and name.endswith(".bin"):
            part = torch.load(os.path.join(hf_dir, name), map_location="cpu")
            sd.update(part)
        elif name == "pytorch_model.bin":
            sd.update(torch.load(os.path.join(hf_dir, name), map_location="cpu"))
    
    # 2) 键名映射 + 类型转换
    ms_params = []
    for k, v in sd.items():
        ms_k = map_key(k)
        if "rope.freqs" in ms_k:
            continue
        np_v = v.numpy()
        ms_params.append({"name": ms_k, "data": Tensor(np_v).astype(dtype)})
   
    save_checkpoint(ms_params, ms_ckpt_path)
    print(f"Saved MindSpore ckpt to: {ms_ckpt_path}")

# 用法:
# torch_to_mindspore_ckpt("/path/to/llama-hf", "llama7b_ms.ckpt", dtype=ms.float16)

踩坑 3:LayerNorm 在 Llama 是无 bias,MindSpore 里如果你 LayerNorm 定义带 beta,要么删掉,要么初始为 0 并在图里不使用;否则数值会“飘”。

05 Llama 前向与 KV Cache(prefill + decode)

1、Attention mask 语义

  • 训练:通常是 [bs, 1, seq, seq]或 [bs, seq]的下三角 + padding mask。

  • 推理:prefill 阶段 mask 仍按下三角;decode 阶段仅对新 token 做与历史的点积,mask 形状变小。

建议统一为 floatmask,填充不可见位置为 -1e4(或和你 softmax 实现一致的 -inf),避免 dtype 乱战。

2、简化版 KV Cache


class KvCache:
    def __init__(self, n_layer, n_head, max_batch, max_len, head_dim, dtype=ms.float16):
        self.k = [ms.numpy.zeros((max_batch, n_head, max_len, head_dim), dtype=dtype) for _ in range(n_layer)]
        self.v = [ms.numpy.zeros((max_batch, n_head, max_len, head_dim), dtype=dtype) for _ in range(n_layer)]
        self.pos = 0  # 当前 decode 写入位置
    
    def update(self, layer_idx, k_new, v_new):  # [bs, head, 1, dim]
        p = self.pos
        self.k[layer_idx][:, :, p:p+1, :] = k_new
        self.v[layer_idx][:, :, p:p+1, :] = v_new
    
    def step(self):
        self.pos += 1

踩坑 4:别在 decode 阶段每步都 concat,就地写入slice,Ascend 的内存移动不白嫖。

06 训练与微调(LoRA/全参)

LoRA在 MindSpore 的一个常见实现:给线性层包一个 A/B 低秩旁路,前向时加上 x @ A @ B * alpha/r。

建议把 LoRA 的参数单独分组,禁用 weight decay;并只在 target 模块(q_proj, v_proj, o_proj, w1/w3)上挂。

def wrap_lora(linear, r=16, alpha=32):
    in_f, out_f = linear.in_channels, linear.out_channels
    A = ms.Parameter(ms.ops.zeros((in_f, r), ms.float16))
    B = ms.Parameter(ms.ops.zeros((r, out_f), ms.float16))
    scale = alpha / r
    
    def forward(x):
        base = linear(x)
        lora = ms.ops.matmul(ms.ops.matmul(x, A), B) * scale
        return base + lora
    linear.forward = forward
    return linear

踩坑 5:MindSpore Graph 下如果你“猴子补丁”forward,要确保图能稳定跟住;更稳的做法是写一个 LoraLinear(Cell)包起来。

07 性能小记(不玄学)

  • GRAPH_MODE + O2 混合精度:不解释。

  • 大 batch prefill:把多条输入拼长些,prefill 吞吐会好不少(当然别 OOM)。

  • KV Cache 扁平化:把 [bs, head, t, dim]按设备最友好的内存布局摆放(这块我没深挖,简单就地 slice 已经够用)。

  • 避免 Python 回环:decode loop 尽量把张量操作留在图里,减少 host 参与。

  • 检查算子降级:图编译日志里搜 “fallback/host” 之类关键词,别让关键算子跑到 CPU 端。

08 端到端推理样例(极简)

import mindspore as ms
from mindspore import Tensor
import numpy as np

net = build_llama_from_ckpt("llama7b_ms.ckpt")  # 你的加载逻辑
net.set_train(False)

# 假装我们已经有 tokenizer
prompt_ids = np.array([[1, 42, 123, 456]])  # <s> ...
attn_mask  = np.ones_like(prompt_ids)

# prefill
logits, cache, pos = net(Tensor(prompt_ids, ms.int32), Tensor(attn_mask, ms.float32), cache=None, pos=0)

# decode N 步
generated = []
x = np.array([[50256]])  # 假设上一步采样出的 token
for _ in range(32):
    l, cache, pos = net(Tensor(x, ms.int32), Tensor(np.ones_like(x)), cache=cache, pos=pos)
    next_id = int(ms.ops.argmax(l[0, -1], axis=-1).asnumpy())
    generated.append(next_id)
    x = np.array([[next_id]])

踩坑 6:很多人把 pos写死,导致 RoPE 永远用到第 0 行,性能和数值全飞。prefill 后 pos 应等于上下文长度,decode 逐步 +1。

09 常见报错对照(以防手忙脚乱)

  • Shape 不一致:尤其 attention_mask,MindSpore 的广播规则和你在 PyTorch 的“侥幸成功”未必一致,显式 reshape保命。

  • LayerNorm gamma/beta:权重名映射遗漏,或 beta 多出来。

  • 溢出:fp16 的 matmul 穿了,loss scale 或者切到 bf16。

  • 图编译卡慢:第一次长一些正常,第二次还慢,看看是否每次都在重建图(输入 shape 乱飘)。

10 小结

迁 Llama 到 MindSpore 没有想象中那么可怕,难点集中在键名映射、RoPE 位移、KV Cache 写法三件事。
一旦跑通,Ascend 上的吞吐和能效都挺能打。别追求一步封神,先上一个“能打”的版本,再迭代优化。
最后,再次提醒自己:少写骚代码,别给图编译添堵。有时候“朴素写法”反而更快更稳(这点我已经被现实教育过两次,脸疼)。

...全文
31 回复 打赏 收藏 转发到动态 举报
写回复
用AI写文章
回复
切换为时间正序
请发表友善的回复…
发表回复

12,675

社区成员

发帖
与我相关
我的任务
社区描述
昇思MindSpore是一款开源的AI框架,旨在实现易开发、高效执行、全场景覆盖三大目标,这里是昇思MindSpore官方CSDN社区,可了解最新进展,也欢迎大家体验并分享经验!
深度学习人工智能机器学习 企业社区 广东省·深圳市
社区管理员
  • 昇思MindSpore
  • skytier
加入社区
  • 近7日
  • 近30日
  • 至今
社区公告

欢迎来到昇思MindSpore社区!

在这里您可以获取昇思MindSpore的技术分享和最新消息,也非常欢迎各位分享个人使用经验

无论是AI小白还是领域专家,我们都欢迎加入社区!一起成长!


【更多渠道】

试试用AI创作助手写篇文章吧