首先利用 libtorch 库封装了一个libgotorch库,已支持最新的 libtorch2.0.1
问题一:cgo 中返回的 tensor 对象在栈上,直接使用可能会有内存安全问题
我做了一层简单的封装来使其创建到堆上,但其引发的问题是需要手动管理内存,因此我编写了 mmgr 包在每一个 tensor 对象创建的时候自动加入 mmgr 的 storage 当中,最后在每一轮训练完毕后通过 GC 方法释放堆上的 tensor 对象
问题二:windows 下的 libtorch 库通过 msvc 编译,提供的是 C++接口,无法在 mingw 中无法正常链接
解决方案是通过在封装一个动态链接库并暴露 C 语言接口,在 mingw 中即可正常链接
通过解决以上两个问题,已可以在 go 语言中使用 libtorch 库并实现自己的模型了
下面进入正题,我在 tnn 库中实现了一个小型的 GPT 模型来实现对对联:couplet,下面让我们来看一下最终效果
$ go run main.go evaluate --model model7M 晚风摇树树还挺
load embedding...
model loaded
inputs: [472 3 462 148 148 342 1516]
map[4.278747:[醉] 5.084207:[润] 8.868446:[晨]]
map[3.8447263:[花] 4.750472:[润] 8.635651:[露]]
map[5.46043:[花] 6.7003703:[露] 10.768249:[润]]
map[4.3850584:[露] 4.875666:[润] 9.896332:[花]]
map[3.6241615:[红] 5.611262:[润] 10.782802:[花]]
map[4.3855276:[花] 5.48069:[红] 9.480111:[更]]
map[3.7904112:[心] 4.269902:[花] 10.3220415:[红]]
晨露润花花更红
$ go run main.go evaluate --model model7M 投石向天跟命斗
load embedding...
model loaded
inputs: [1233 190 383 11 2623 620 490]
map[5.7068815:[门] 5.7826476:[问] 9.79136:[闭]]
map[3.0136497:[问] 3.1092193:[人] 8.903796:[门]]
map[3.021591:[还] 3.448888:[歌] 8.96453:[问]]
map[4.9368696:[地] 5.7390223:[时] 9.438878:[卷]]
map[3.5542138:[话] 3.858942:[时] 8.253393:[与]]
map[3.025545:[与] 3.2461479:[卷] 9.06726:[时]]
map[4.250452:[时] 4.712057:[舟] 10.401218:[争]]
闭门问卷与时争
注意:该模型仅训练了开源数据集couplet-dataset中的前 1 万个样本
模型的参数结构如下:
+------------------------+---------+
| NAME | COUNT |
+------------------------+---------+
| transformer0_attention | 1872 |
| transformer0_dense | 1256640 |
| transformer0_output | 1254960 |
| transformer1_attention | 1872 |
| transformer1_dense | 1256640 |
| transformer1_output | 1254960 |
| output | 2488596 |
| total | 7515540 |
+------------------------+---------+
train 200, cost=2h15m7.877395694s, loss=3.665343e-02
整个模型共有 751 万个参数,模型包含 2 个 transformer 模块,由于在训练时只使用了 8 个 float32 来对每一个字进行表征,因此 attention 层的参数量较少,其他参数配置如下:
const embeddingDim = 8 // 8 个 float32 表示一个字向量
const paddingSize = 70 // 最长为 34*2 ,因此 padding 长度必须大于 68
const heads = 4
const batchSize = 128
const epoch = 200
const lr = 0.001
const transformerSize = 2
最后让我们来看看模型的泛化能力如何
$ go run main.go evaluate --model model7M 我是谁
load embedding...
model loaded
inputs: [85 62 191]
map[4.3809786:[雨] 4.9436274:[染] 7.105626:[绿]]
map[3.8163047:[水] 4.013789:[东] 4.088595:[得]]
map[4.872726:[唱] 5.4107614:[兰] 6.3983927:[发]]
绿得发
$ go run main.go evaluate --model ./model7M 我在哪
load embedding...
model loaded
inputs: [85 99 1151]
map[1.480957:[思] 2.002811:[得] 4.0260763:[寻]]
map[3.4100764:[女] 3.868993:[对] 4.448501:[得]]
map[2.2672489:[年] 2.3772364:[历] 4.946753:[谁]]
寻得谁
效果不是很理想,可能还是跟训练的样本数量太少有关
另外还有一些示例可在 example 目录下找到,如使用 RNN 来学习如何画 sin 曲线等
最后是项目地址:
这是一个专为移动设备优化的页面(即为了让你能够在 Google 搜索结果里秒开这个页面),如果你希望参与 V2EX 社区的讨论,你可以继续到 V2EX 上打开本讨论主题的完整版本。
V2EX 是创意工作者们的社区,是一个分享自己正在做的有趣事物、交流想法,可以遇见新朋友甚至新机会的地方。
V2EX is a community of developers, designers and creative people.