目录
- 1 背景
- 2 搭建环境
- 2.1 硬件配置
- 2.2 搭建虚拟环境
- 2.2.1 创建虚拟环境
- 2.2.2 安装所需的库
- 3 准备工作
- 3.1 下载GitHub代码
- 3.2 下载模型
- 3.3 数据处理
- 3.3.1 下载数据
- 3.3.2 数据集tokenize预处理
- 4 训练
- 4.1 修改配置
- 4.2 开始训练
- 4.3 多机多卡训练
- 5 模型推理
- 5.1 编译
- 5.1.1 安装gcc
- 5.1.2 修改run.c代码
- 5.1.3 生成可执行文件
- 5.2 模型推理
- 6 结束语
1 背景
从零开始学大模型之——LLaMa2-7B。
2 搭建环境
anaconda的按照教程请参考:
Linux安装conda
conda离线安装pytorch
2.1 硬件配置
系统:windows 11
内存:48GB
显卡:RTX 4070,12GB
处理器:i5-13600KF
2.2 搭建虚拟环境
2.2.1 创建虚拟环境
conda create --name llama2 python=3.10
2.2.2 安装所需的库
安装torch:
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
安装其他库:
pip install numpy==1.23.5 pytest Requests sentencepiece tqdm wandb
3 准备工作
3.1 下载GitHub代码
git clone https://github.com/karpathy/llama2.c.git
3.2 下载模型
到huggingface网站上下载Llama-2-7b-chat-hf模型。
需要科学上网,然后注册账号,申请权限通过后即可下载。
由于模型数据都来源于huggingface,从huggingface下载模型权重和训练数据过程中可能会遇到各种网络问题,可以下载时通过走huggingface的国内镜像hf-mirror加快下载速度,然后使用huggingface-cli进行模型文件和数据的下载。
具体操作请参考网上其他资料,此处仅提供部分操作代码:
# 下载huggingface_hub pip install -U huggingface_hub# 设置环境变量export HF_ENDPOINT=https://hf-mirror.com# 下载模型,指定模型的保存位置huggingface-cli download --resume-download NousResearch/Meta-Llama-3-8B --local-dir ./model/Meta-Llama-3-8B
3.3 数据处理
3.3.1 下载数据
python tinystories.py download
3.3.2 数据集tokenize预处理
python tinystories.py pretokenize
4 训练
4.1 修改配置
train.py里面有几个参数要修改
batch_size改小一点,否则会报’CUDA out of memory’ 的错误;
dtype要改为"float16",否则会报’Current CUDA Device does not support bfloat16’的错误;
compile要改为False,否则会报CUDA Capability过低或complex64不支持的错误。
batch_size = 64
dtype = "float16"
compile = False
可选改的参数:
max_iters:是迭代次数,可改小一点。
warmup_iters:是热身的迭代次数,主要是为了确定合适得学习率,卡有限的话可改小一些。
max_iters = 100000
warmup_iters = 1000
4.2 开始训练
python train.py
我设置的参数:
batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
max_iters = 100000 # total number of training iterations
warmup_iters = 1000 # how many steps to warm up for
dtype = "float16" # float32|bfloat16|float16
compile = False # use PyTorch 2.0 to compile the model to be faster
Loss:
显存占用情况:
训练结束后,会在out文件夹下保存ckpt:
4.3 多机多卡训练
可以在命令行中指定训练参数(单GPU上训练):
python -m train.py --compile=False --eval_iters=100 --batch_size=64
如果是多GPU,可采用分布式训练,例如采用DDP 在1个node,2个 gpu 上训练:
torchrun --standalone --nproc_per_node=2 train.py
采用DDP 在2个node,8个 gpu 上训练:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
5 模型推理
注意:Linux跑没问题,Windows跑没输出,暂时还没解决Windows的问题。
5.1 编译
如果是Linux,直接make即可。
gcc run.c -o run -lm
如果是windows,需要先安装gcc,然后修改run.c代码,最后再生成可执行文件。
5.1.1 安装gcc
在gcc官网下载压缩文件,解压到本地。
配置环境变量:
- 解压完成后,需要将MinGW-w64的bin目录添加到系统的PATH环境变量中。
- 打开“控制面板”,选择“系统和安全”,然后选择“系统”。
- 点击左侧的“高级系统设置”,在弹出的对话框中点击“环境变量”。
- 在“系统变量”区域,找到Path变量,选中后点击“编辑”。
- 点击“新建”,然后将MinGW-w64的bin目录路径添加进去,例如C:\Program Files\mingw-w64\mingw64\bin。
- 确认所有对话框并关闭。
测试:在终端输入
gcc --version
没有报错说明安装成功。
5.1.2 修改run.c代码
在 Windows 系统上,由于 mmap 和 munmap 不是标准的 Win32 API 函数,需要使用相应的 Win32 函数来实现类似的功能。可以使用 CreateFileMapping 和 MapViewOfFile 来代替 mmap,使用 UnmapViewOfFile 来代替 munmap。
将以下代码拷贝覆盖原文件即可。
/* Inference for Llama-2 Transformer model in pure C */#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <time.h>
#include <math.h>
#include <string.h>
#include <fcntl.h>
#if defined _WIN32#include "win.h"
#else#include <unistd.h>#include <sys/mman.h>
#endif
// ----------------------------------------------------------------------------
// Transformer modeltypedef struct {int dim; // transformer dimensionint hidden_dim; // for ffn layersint n_layers; // number of layersint n_heads; // number of query headsint n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)int vocab_size; // vocabulary size, usually 256 (byte-level)int seq_len; // max sequence length
} Config;typedef struct {// token embedding tablefloat* token_embedding_table; // (vocab_size, dim)// weights for rmsnormsfloat* rms_att_weight; // (layer, dim) rmsnorm weightsfloat* rms_ffn_weight; // (layer, dim)// weights for matmuls. note dim == n_heads * head_sizefloat* wq; // (layer, dim, n_heads * head_size)float* wk; // (layer, dim, n_kv_heads * head_size)float* wv; // (layer, dim, n_kv_heads * head_size)float* wo; // (layer, n_heads * head_size, dim)// weights for ffnfloat* w1; // (layer, hidden_dim, dim)float* w2; // (layer, dim, hidden_dim)float* w3; // (layer, hidden_dim, dim)// final rmsnormfloat* rms_final_weight; // (dim,)// (optional) classifier weights for the logits, on the last layerfloat* wcls;
} TransformerWeights;typedef struct {// current wave of activationsfloat *x; // activation at current time stamp (dim,)float *xb; // same, but inside a residual branch (dim,)float *xb2; // an additional buffer just for convenience (dim,)float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)float *q; // query (dim,)float *k; // key (dim,)float *v; // value (dim,)float *att; // buffer for scores/attention values (n_heads, seq_len)float *logits; // output logits// kv cachefloat* key_cache; // (layer, seq_len, dim)float* value_cache; // (layer, seq_len, dim)
} RunState;typedef struct {Config config; // the hyperparameters of the architecture (the blueprint)TransformerWeights weights; // the weights of the modelRunState state; // buffers for the "wave" of activations in the forward pass// some more state needed to properly clean up the memory mapping (sigh)HANDLE fileHandle; // file handle for memory mappingvoid* data; // memory mapped data pointerssize_t file_size; // size of the checkpoint file in bytes
} Transformer;void malloc_run_state(RunState* s, Config* p) {// we calloc instead of malloc to keep valgrind happyint kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;s->x = calloc(p->dim, sizeof(float));s->xb = calloc(p->dim, sizeof(float));s->xb2 = calloc(p->dim, sizeof(float));s->hb = calloc(p->hidden_dim, sizeof(float));s->hb2 = calloc(p->hidden_dim, sizeof(float));s->q = calloc(p->dim, sizeof(float));s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));s->att = calloc(p->n_heads * p->seq_len, sizeof(float));s->logits = calloc(p->vocab_size, sizeof(float));// ensure all mallocs went fineif (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {fprintf(stderr, "malloc failed!\n");exit(EXIT_FAILURE);}
}void free_run_state(RunState* s) {free(s->x);free(s->xb);free(s->xb2);free(s->hb);free(s->hb2);free(s->q);free(s->att);free(s->logits);free(s->key_cache);free(s->value_cache);
}void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {int head_size = p->dim / p->n_heads;// make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ modelsunsigned long long n_layers = p->n_layers;w->token_embedding_table = ptr;ptr += p->vocab_size * p->dim;w->rms_att_weight = ptr;ptr += n_layers * p->dim;w->wq = ptr;ptr += n_layers * p->dim * (p->n_heads * head_size);w->wk = ptr;ptr += n_layers * p->dim * (p->n_kv_heads * head_size);w->wv = ptr;ptr += n_layers * p->dim * (p->n_kv_heads * head_size);w->wo = ptr;ptr += n_layers * (p->n_heads * head_size) * p->dim;w->rms_ffn_weight = ptr;ptr += n_layers * p->dim;w->w1 = ptr;ptr += n_layers * p->dim * p->hidden_dim;w->w2 = ptr;ptr += n_layers * p->hidden_dim * p->dim;w->w3 = ptr;ptr += n_layers * p->dim * p->hidden_dim;w->rms_final_weight = ptr;ptr += p->dim;ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)w->wcls = shared_weights ? w->token_embedding_table : ptr;
}void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,HANDLE* fileHandle, void** data, ssize_t* file_size) {// read in the config headerFILE *file = fopen(checkpoint, "rb");if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }// negative vocab size is hacky way of signaling unshared weights. bit yikes.int shared_weights = config->vocab_size > 0 ? 1 : 0;config->vocab_size = abs(config->vocab_size);// figure out the file sizefseek(file, 0, SEEK_END); // move file pointer to end of file*file_size = ftell(file); // get the file size, in bytesfclose(file);// open the file*fileHandle = CreateFile(checkpoint, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);if (*fileHandle == INVALID_HANDLE_VALUE) {fprintf(stderr, "open failed!\n");exit(EXIT_FAILURE);}// create a file mapping objectHANDLE hMapFile = CreateFileMapping(*fileHandle, NULL, PAGE_READONLY, 0, 0, NULL);if (hMapFile == NULL) {fprintf(stderr, "CreateFileMapping failed!\n");CloseHandle(*fileHandle);exit(EXIT_FAILURE);}// map the view of the file*data = MapViewOfFile(hMapFile, FILE_MAP_READ, 0, 0, 0);if (*data == NULL) {fprintf(stderr, "MapViewOfFile failed!\n");CloseHandle(hMapFile);CloseHandle(*fileHandle);exit(EXIT_FAILURE);}float* weights_ptr = *data + sizeof(Config)/sizeof(float);memory_map_weights(weights, config, weights_ptr, shared_weights);
}void build_transformer(Transformer *t, char* checkpoint_path) {// read in the Config and the Weights from the checkpointread_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fileHandle, &t->data, &t->file_size);// allocate the RunState buffersmalloc_run_state(&t->state, &t->config);
}void free_transformer(Transformer* t) {// close the memory mappingif (t->data != NULL) { UnmapViewOfFile(t->data); }if (t->fileHandle != INVALID_HANDLE_VALUE) { CloseHandle(t->fileHandle); }// free the RunState buffersfree_run_state(&t->state);
}// ----------------------------------------------------------------------------
// neural net blocks; the dynamics of the Transformervoid rmsnorm(float* o, float* x, float* weight, int size) {// calculate sum of squaresfloat ss = 0.0f;for (int j = 0; j < size; j++) {ss += x[j] * x[j];}ss /= size;ss += 1e-5f;ss = 1.0f / sqrtf(ss);// normalize and scalefor (int j = 0; j < size; j++) {o[j] = weight[j] * (ss * x[j]);}
}void softmax(float* x, int size) {// find max value (for numerical stability)float max_val = x[0];for (int i = 1; i < size; i++) {if (x[i] > max_val) {max_val = x[i];}}// exp and sumfloat sum = 0.0f;for (int i = 0; i < size; i++) {x[i] = expf(x[i] - max_val);sum += x[i];}// normalizefor (int i = 0; i < size; i++) {x[i] /= sum;}
}void matmul(float* xout, float* x, float* w, int n, int d) {// W (d,n) @ x (n,) -> xout (d,)// by far the most amount of time is spent inside this little functionint i;#pragma omp parallel for private(i)for (i = 0; i < d; i++) {float val = 0.0f;for (int j = 0; j < n; j++) {val += w[i * n + j] * x[j];}xout[i] = val;}
}float* forward(Transformer* transformer, int token, int pos) {// a few convenience variablesConfig* p = &transformer->config;TransformerWeights* w = &transformer->weights;RunState* s = &transformer->state;float *x = s->x;int dim = p->dim;int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiqueryint hidden_dim = p->hidden_dim;int head_size = dim / p->n_heads;// copy the token embedding into xfloat* content_row = w->token_embedding_table + token * dim;memcpy(x, content_row, dim*sizeof(*x));// forward all the layersfor(unsigned long long l = 0; l < p->n_layers; l++) {// attention rmsnormrmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);// key and value point to the kv cacheint loff = l * p->seq_len * kv_dim; // kv cache layer offset for conveniences->k = s->key_cache + loff + pos * kv_dim;s->v = s->value_cache + loff + pos * kv_dim;// qkv matmuls for this positionmatmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);// RoPE relative positional encoding: complex-valued rotate q and k in each headfor (int i = 0; i < dim; i+=2) {int head_dim = i % head_size;float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);float val = pos * freq;float fcr = cosf(val);float fci = sinf(val);int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q onlyfor (int v = 0; v < rotn; v++) {float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)float v0 = vec[i];float v1 = vec[i+1];vec[i] = v0 * fcr - v1 * fci;vec[i+1] = v0 * fci + v1 * fcr;}}// multihead attention. iterate over all headsint h;#pragma omp parallel for private(h)for (h = 0; h < p->n_heads; h++) {// get the query vector for this headfloat* q = s->q + h * head_size;// attention scores for this headfloat* att = s->att + h * p->seq_len;// iterate over all timesteps, including the current onefor (int t = 0; t <= pos; t++) {// get the key vector for this head and at this timestepfloat* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;// calculate the attention score as the dot product of q and kfloat score = 0.0f;for (int i = 0; i < head_size; i++) {score += q[i] * k[i];}score /= sqrtf(head_size);// save the score to the attention bufferatt[t] = score;}// softmax the scores to get attention weights, from 0..pos inclusivelysoftmax(att, pos + 1);// weighted sum of the values, store back into xbfloat* xb = s->xb + h * head_size;memset(xb, 0, head_size * sizeof(float));for (int t = 0; t <= pos; t++) {// get the value vector for this head and at this timestepfloat* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;// get the attention weight for this timestepfloat a = att[t];// accumulate the weighted value into xbfor (int i = 0; i < head_size; i++) {xb[i] += a * v[i];}}}// final matmul to get the output of the attentionmatmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);// residual connection back into xfor (int i = 0; i < dim; i++) {x[i] += s->xb2[i];}// ffn rmsnormrmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))// first calculate self.w1(x) and self.w3(x)matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);// SwiGLU non-linearityfor (int i = 0; i < hidden_dim; i++) {float val = s->hb[i];// silu(x)=x*σ(x), where σ(x) is the logistic sigmoidval *= (1.0f / (1.0f + expf(-val)));// elementwise multiply with w3(x)val *= s->hb2[i];s->hb[i] = val;}// final matmul to get the output of the ffnmatmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);// residual connectionfor (int i = 0; i < dim; i++) {x[i] += s->xb[i];}}// final rmsnormrmsnorm(x, x, w->rms_final_weight, dim);// classifier into logitsmatmul(s->logits, x, w->wcls, p->dim, p->vocab_size);return s->logits;
}// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokenstypedef struct {char *str;int id;
} TokenIndex;typedef struct {char** vocab;float* vocab_scores;TokenIndex *sorted_vocab;int vocab_size;unsigned int max_token_length;unsigned char byte_pieces[512]; // stores all single-byte strings
} Tokenizer;int compare_tokens(const void *a, const void *b) {return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {// i should have written the vocab_size into the tokenizer file... sight->vocab_size = vocab_size;// malloc space to hold the scores and the stringst->vocab = (char**)malloc(vocab_size * sizeof(char*));t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));t->sorted_vocab = NULL; // initialized lazilyfor (int i = 0; i < 256; i++) {t->byte_pieces[i * 2] = (unsigned char)i;t->byte_pieces[i * 2 + 1] = '\0';}// read in the fileFILE *file = fopen(tokenizer_path, "rb");if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }int len;for (int i = 0; i < vocab_size; i++) {if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }t->vocab[i] = (char *)malloc(len + 1);if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }t->vocab[i][len] = '\0'; // add the string terminating token}fclose(file);
}void free_tokenizer(Tokenizer* t) {for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }free(t->vocab);free(t->vocab_scores);free(t->sorted_vocab);
}char* decode(Tokenizer* t, int prev_token, int token) {char *piece = t->vocab[token];// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)if (prev_token == 1 && piece[0] == ' ') { piece++; }// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'// parse this and convert and return the actual byteunsigned char byte_val;if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {piece = (char*)t->byte_pieces + byte_val * 2;}return piece;
}void safe_printf(char *piece) {// piece might be a raw byte token, and we only want to print printable chars or whitespace// because some of the other bytes can be various control codes, backspace, etc.if (piece == NULL) { return; }if (piece[0] == '\0') { return; }if (piece[1] == '\0') {unsigned char byte_val = piece[0];if (!(isprint(byte_val) || isspace(byte_val))) {return; // bad byte, don't print it}}printf("%s", piece);
}int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {// efficiently find the perfect match for str in vocab, return its index or -1 if not foundTokenIndex tok = { .str = str }; // acts as the key to search forTokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);return res != NULL ? res->id : -1;
}void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {// encode the string text (input) into an upper-bound preallocated tokens[] array// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }if (t->sorted_vocab == NULL) {// lazily malloc and sort the vocabularyt->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));for (int i = 0; i < t->vocab_size; i++) {t->sorted_vocab[i].str = t->vocab[i];t->sorted_vocab[i].id = i;}qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);}// create a temporary buffer that will store merge candidates of always two consecutive tokens// *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));size_t str_len = 0;// start at 0 tokens*n_tokens = 0;// add optional BOS (=1) token, if desiredif (bos) tokens[(*n_tokens)++] = 1;// add_dummy_prefix is true by default// so prepend a dummy prefix token to the input string, but only if text != ""// TODO: pretty sure this isn't correct in the general case but I don't have the// energy to read more of the sentencepiece code to figure out what it's doingif (text[0] != '\0') {int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);tokens[(*n_tokens)++] = dummy_prefix;}// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:// Code point ↔ UTF-8 conversion// First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4// U+0000 U+007F 0xxxxxxx// U+0080 U+07FF 110xxxxx 10xxxxxx// U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx// U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx// process the raw (UTF-8) byte sequence of the input stringfor (char *c = text; *c != '\0'; c++) {// reset buffer if the current byte is ASCII or a leading byte// 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest// 0x80 is 10000000// in UTF-8, all continuation bytes start with "10" in first two bits// so in English this is: "if this byte is not a continuation byte"if ((*c & 0xC0) != 0x80) {// this byte must be either a leading byte (11...) or an ASCII char (0x...)// => reset our location, as we're starting a new UTF-8 codepointstr_len = 0;}// append the current byte to the bufferstr_buffer[str_len++] = *c; // ++ is post-increment, incremented after this linestr_buffer[str_len] = '\0';// while the next character is a continuation byte, continue appending// but if there are too many of them, just stop to avoid overruning str_buffer size.if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {continue;}// ok c+1 is not a continuation byte, so we've read in a full codepointint id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);if (id != -1) {// we found this codepoint in vocab, add it as a tokentokens[(*n_tokens)++] = id;} else {// byte_fallback encoding: just encode each byte as a token// +3 is here because the first 3 vocab elements are <unk>, <s>, </s>// so the individual bytes only start at index 3for (int i=0; i < str_len; i++) {tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;}}str_len = 0; // protect against a sequence of stray UTF8 continuation bytes}// merge the best consecutive pair each iteration, according the scores in vocab_scoreswhile (1) {float best_score = -1e10;int best_id = -1;int best_idx = -1;for (int i=0; i < (*n_tokens-1); i++) {// check if we can merge the pair (tokens[i], tokens[i+1])sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);if (id != -1 && t->vocab_scores[id] > best_score) {// this merge pair exists in vocab! record its score and positionbest_score = t->vocab_scores[id];best_id = id;best_idx = i;}}if (best_idx == -1) {break; // we couldn't find any more pairs to merge, so we're done}// merge the consecutive pair (best_idx, best_idx+1) into new token best_idtokens[best_idx] = best_id;// delete token at position best_idx+1, shift the entire sequence back 1for (int i = best_idx+1; i < (*n_tokens-1); i++) {tokens[i] = tokens[i+1];}(*n_tokens)--; // token length decreased}// add optional EOS (=2) token, if desiredif (eos) tokens[(*n_tokens)++] = 2;free(str_buffer);
}// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p samplingtypedef struct {float prob;int index;
} ProbIndex; // struct used when sorting probabilities during top-p samplingtypedef struct {int vocab_size;ProbIndex* probindex; // buffer used in top-p samplingfloat temperature;float topp;unsigned long long rng_state;
} Sampler;int sample_argmax(float* probabilities, int n) {// return the index that has the highest probabilityint max_i = 0;float max_p = probabilities[0];for (int i = 1; i < n; i++) {if (probabilities[i] > max_p) {max_i = i;max_p = probabilities[i];}}return max_i;
}int sample_mult(float* probabilities, int n, float coin) {// sample index from probabilities (they must sum to 1!)// coin is a random number in [0, 1), usually from random_f32()float cdf = 0.0f;for (int i = 0; i < n; i++) {cdf += probabilities[i];if (coin < cdf) {return i;}}return n - 1; // in case of rounding errors
}int compare(const void* a, const void* b) {ProbIndex* a_ = (ProbIndex*) a;ProbIndex* b_ = (ProbIndex*) b;if (a_->prob > b_->prob) return -1;if (a_->prob < b_->prob) return 1;return 0;
}int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {// top-p sampling (or "nucleus sampling") samples from the smallest set of// tokens that exceed probability topp. This way we never sample tokens that// have very low probabilities and are less likely to go "off the rails".// coin is a random number in [0, 1), usually from random_f32()int n0 = 0;// quicksort indices in descending order of probabilities// values smaller than (1 - topp) / (n - 1) cannot be part of the result// so for efficiency we crop these out as candidates before sortingconst float cutoff = (1.0f - topp) / (n - 1);for (int i = 0; i < n; i++) {if (probabilities[i] >= cutoff) {probindex[n0].index = i;probindex[n0].prob = probabilities[i];n0++;}}qsort(probindex, n0, sizeof(ProbIndex), compare);// truncate the list where cumulative probability exceeds toppfloat cumulative_prob = 0.0f;int last_idx = n0 - 1; // in case of rounding errors consider all elementsfor (int i = 0; i < n0; i++) {cumulative_prob += probindex[i].prob;if (cumulative_prob > topp) {last_idx = i;break; // we've exceeded topp by including last_idx}}// sample from the truncated listfloat r = coin * cumulative_prob;float cdf = 0.0f;for (int i = 0; i <= last_idx; i++) {cdf += probindex[i].prob;if (r < cdf) {return probindex[i].index;}}return probindex[last_idx].index; // in case of rounding errors
}void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {sampler->vocab_size = vocab_size;sampler->temperature = temperature;sampler->topp = topp;sampler->rng_state = rng_seed;// buffer only used with nucleus sampling; may not need but it's ~smallsampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
}void free_sampler(Sampler* sampler) {free(sampler->probindex);
}unsigned int random_u32(unsigned long long *state) {// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A*state ^= *state >> 12;*state ^= *state << 25;*state ^= *state >> 27;return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float random_f32(unsigned long long *state) { // random float32 in [0,1)return (random_u32(state) >> 8) / 16777216.0f;
}int sample(Sampler* sampler, float* logits) {// sample the token given the logits and some hyperparametersint next;if (sampler->temperature == 0.0f) {// greedy argmax sampling: take the token with the highest probabilitynext = sample_argmax(logits, sampler->vocab_size);} else {// apply the temperature to the logitsfor (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }// apply softmax to the logits to get the probabilities for next tokensoftmax(logits, sampler->vocab_size);// flip a (float) coin (this is our source of entropy for sampling)float coin = random_f32(&sampler->rng_state);// we sample from this distribution to get the next tokenif (sampler->topp <= 0 || sampler->topp >= 1) {// simply sample from the predicted probability distributionnext = sample_mult(logits, sampler->vocab_size, coin);} else {// top-p (nucleus) sampling, clamping the least likely tokens to zeronext = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);}}return next;
}// ----------------------------------------------------------------------------
// utilities: timelong time_in_ms() {// return time in milliseconds, for benchmarking the model speedstruct timespec time;clock_gettime(CLOCK_REALTIME, &time);return time.tv_sec * 1000 + time.tv_nsec / 1000000;
}// ----------------------------------------------------------------------------
// generation loopvoid generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {char *empty_prompt = "";if (prompt == NULL) { prompt = empty_prompt; }// encode the (string) prompt into tokens sequenceint num_prompt_tokens = 0;int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOSencode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);if (num_prompt_tokens < 1) {fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");exit(EXIT_FAILURE);}// start the main looplong start = 0; // used to time our code, only initialized after first iterationint next; // will store the next token in the sequenceint token = prompt_tokens[0]; // kick off with the first token in the promptint pos = 0; // position in the sequencewhile (pos < steps) {// forward the transformer to get logits for the next tokenfloat* logits = forward(transformer, token, pos);// advance the state machineif (pos < num_prompt_tokens - 1) {// if we are still processing the input prompt, force the next prompt tokennext = prompt_tokens[pos + 1];} else {// otherwise sample the next token from the logitsnext = sample(sampler, logits);}pos++;// data-dependent terminating condition: the BOS (=1) token delimits sequencesif (next == 1) { break; }// print the token as string, decode it with the Tokenizer objectchar* piece = decode(tokenizer, token, next);safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytesfflush(stdout);token = next;// init the timer here because the first iteration can be slowerif (start == 0) { start = time_in_ms(); }}printf("\n");// report achieved tok/s (pos-1 because the timer starts after first iteration)if (pos > 1) {long end = time_in_ms();fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);}free(prompt_tokens);
}void read_stdin(const char* guide, char* buffer, size_t bufsize) {// read a line from stdin, up to but not including \nprintf("%s", guide);if (fgets(buffer, bufsize, stdin) != NULL) {size_t len = strlen(buffer);if (len > 0 && buffer[len - 1] == '\n') {buffer[len - 1] = '\0'; // strip newline}}
}// ----------------------------------------------------------------------------
// chat loop
// I manually inspected the tokens for a few chat conversations compared to
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,char *cli_user_prompt, char *cli_system_prompt, int steps) {// buffers for reading the system prompt and user prompt from stdin// you'll notice they are soomewhat haphazardly and unsafely set atmchar system_prompt[512];char user_prompt[512];char rendered_prompt[1152];int num_prompt_tokens = 0;int* prompt_tokens = (int*)malloc(1152 * sizeof(int));int user_idx;// start the main loopint8_t user_turn = 1; // user startsint next; // will store the next token in the sequenceint token; // stores the current token to feed into the transformerint prev_token;int pos = 0; // position in the sequencewhile (pos < steps) {// when it is the user's turn to contribute tokens to the dialog...if (user_turn) {// get the (optional) system prompt at position 0if (pos == 0) {// at position 0, the user can also contribute a system promptif (cli_system_prompt == NULL) {// system prompt was not passed in, attempt to get it from stdinread_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));} else {// system prompt was passed in, use itstrcpy(system_prompt, cli_system_prompt);}}// get the user promptif (pos == 0 && cli_user_prompt != NULL) {// user prompt for position 0 was passed in, use itstrcpy(user_prompt, cli_user_prompt);} else {// otherwise get user prompt from stdinread_stdin("User: ", user_prompt, sizeof(user_prompt));}// render user/system prompts into the Llama 2 Chat schemaif (pos == 0 && system_prompt[0] != '\0') {char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";sprintf(rendered_prompt, system_template, system_prompt, user_prompt);} else {char user_template[] = "[INST] %s [/INST]";sprintf(rendered_prompt, user_template, user_prompt);}// encode the rendered prompt into tokensencode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);user_idx = 0; // reset the user indexuser_turn = 0;printf("Assistant: ");}// determine the token to pass into the transformer nextif (user_idx < num_prompt_tokens) {// if we are still processing the input prompt, force the next prompt tokentoken = prompt_tokens[user_idx++];} else {// otherwise use the next token sampled from previous turntoken = next;}// EOS (=2) token ends the Assistant turnif (token == 2) { user_turn = 1; }// forward the transformer to get logits for the next tokenfloat* logits = forward(transformer, token, pos);next = sample(sampler, logits);pos++;if (user_idx >= num_prompt_tokens && next != 2) {// the Assistant is responding, so print its outputchar* piece = decode(tokenizer, token, next);safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytesfflush(stdout);}if (next == 2) { printf("\n"); }}printf("\n");free(prompt_tokens);
}// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTINGvoid error_usage() {fprintf(stderr, "Usage: run <checkpoint> [options]\n");fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");fprintf(stderr, "Options:\n");fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");fprintf(stderr, " -s <int> random seed, default time(NULL)\n");fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");fprintf(stderr, " -i <string> input prompt\n");fprintf(stderr, " -z <string> optional path to custom tokenizer\n");fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");exit(EXIT_FAILURE);
}int main(int argc, char *argv[]) {// default parameterschar *checkpoint_path = NULL; // e.g. out/model.binchar *tokenizer_path = "tokenizer.bin";float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higherfloat topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slowerint steps = 256; // number of steps to run forchar *prompt = NULL; // prompt stringunsigned long long rng_seed = 0; // seed rng with time by defaultchar *mode = "generate"; // generate|chatchar *system_prompt = NULL; // the (optional) system prompt to use in chat mode// poor man's C argparse so we can override the defaults above from the command lineif (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }for (int i = 2; i < argc; i+=2) {// do some basic validationif (i + 1 >= argc) { error_usage(); } // must have arg after flagif (argv[i][0] != '-') { error_usage(); } // must start with dashif (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)// read in the argsif (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }else if (argv[i][1] == 'm') { mode = argv[i + 1]; }else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }else { error_usage(); }}// parameter validation/overridesif (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);if (temperature < 0.0) temperature = 0.0;if (topp < 0.0 || 1.0 < topp) topp = 0.9;if (steps < 0) steps = 0;// build the Transformer via the model .bin fileTransformer transformer;build_transformer(&transformer, checkpoint_path);if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // override to ~max length// build the Tokenizer via the tokenizer .bin fileTokenizer tokenizer;build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);// build the SamplerSampler sampler;build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);// run!if (strcmp(mode, "generate") == 0) {generate(&transformer, &tokenizer, &sampler, prompt, steps);} else if (strcmp(mode, "chat") == 0) {chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);} else {fprintf(stderr, "unknown mode: %s\n", mode);error_usage();}// memory and file handles cleanupfree_sampler(&sampler);free_tokenizer(&tokenizer);free_transformer(&transformer);return 0;
}
#endif
5.1.3 生成可执行文件
gcc -o run run.c
运行后会在根目录下生成run.exe文件。
5.2 模型推理
./run out/model.bin
输出如下:
Once upon a time, there was a little girl named Lily. She loved to play in the garden and watch the insects crawl around. One day, while she was playing, she saw a butterfly flying around. “Hello butterfly, what are you doing?” Lily said, feeling grateful. “I am waiting for my mom to come back to check on her,” the butterfly replied. Lily watched as the butterfly flew away, feeling proud that she had made it happen. From that day on, she loved looking at insects and watching them grow and become happy.
achieved tok/s: 24.576110
还可以指定参数运行:
./run .\out\model.bin -i "One day, Lily met a Shoggoth"
也可以从Huggingface网站下载其他模型试试:
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
./run stories15M.binwget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin
./run stories42M.bin
备注:清华大学开源软件镜像站
6 结束语
大功告成!单卡RTX4070刚好能跑起来。
Enjoy it!