精品欧美一区二区三区在线观看 _久久久久国色av免费观看性色_国产精品久久在线观看_亚洲第一综合网站_91精品又粗又猛又爽_小泽玛利亚一区二区免费_91亚洲精品国偷拍自产在线观看 _久久精品视频在线播放_美女精品久久久_欧美日韩国产成人在线

從零實(shí)現(xiàn)一個(gè)17M參數(shù)的GPT預(yù)訓(xùn)練模型

人工智能
今天我們使用開(kāi)源的的中文數(shù)據(jù)進(jìn)行模型的預(yù)訓(xùn)練,下面跟著我的步驟,從零實(shí)現(xiàn)你的預(yù)訓(xùn)練模型。

大家好,我是寫代碼的中年人!

今天我們使用開(kāi)源的的中文數(shù)據(jù)進(jìn)行模型的預(yù)訓(xùn)練,下面跟著我的步驟,從零實(shí)現(xiàn)你的預(yù)訓(xùn)練模型。

本文所有代碼和數(shù)據(jù)資源位置:

https://github.com/ColinAIAPP/MoiraiLM

01、預(yù)訓(xùn)練模型的概念

預(yù)訓(xùn)練模型(Pretrained Model)就是一個(gè)已經(jīng)在海量數(shù)據(jù)上訓(xùn)練過(guò)的模型,它學(xué)會(huì)了語(yǔ)言的基本規(guī)律、結(jié)構(gòu)和語(yǔ)義,然后可以拿來(lái)做各種下游任務(wù),比如寫作、翻譯、問(wèn)答、分類、生成代碼等。

那“預(yù)訓(xùn)練”到底在學(xué)什么?以語(yǔ)言模型(LLM)為例:預(yù)訓(xùn)練階段的任務(wù)通常是預(yù)測(cè)下一個(gè)詞(token)。

接下來(lái)我們就一步一步實(shí)現(xiàn)一個(gè)17M參數(shù)的預(yù)訓(xùn)練模型。

02、數(shù)據(jù)準(zhǔn)備

構(gòu)建語(yǔ)言模型的第一要義是高質(zhì)量的數(shù)據(jù)源。對(duì)于中文任務(wù),選擇維基百科開(kāi)源中文數(shù)據(jù)集是一個(gè)理想起點(diǎn)。這個(gè)數(shù)據(jù)集包含數(shù)百萬(wàn)條中文百科條目,涵蓋歷史、文化、科技等領(lǐng)域,總量約數(shù)GB的純文本數(shù)據(jù)。它開(kāi)源且免費(fèi),可通過(guò)維基百科的官方轉(zhuǎn)儲(chǔ)頁(yè)面下載最新版本的XML格式文件。

要解壓處理這個(gè)文件我們要使用wikiextractor工具進(jìn)行數(shù)據(jù)解壓。

安裝解壓命令:

pip install wikiextractor

解壓命令:

python -m wikiextractor.WikiExtractor -b 1G -o extracted_wiki_zh zhwiki-20250920-pages-articles-multistream.xml.bz2 --json
zhwiki-20250920-pages-articles-multistream.xml.bz2:為文件名

INFO: Preprocessing 'zhwiki-20250920-pages-articles-multistream.xml.bz2' to collect template definitions: this may take some time.
INFO: Preprocessed 100000 pages
INFO: Preprocessed 200000 pages
INFO: Preprocessed 300000 pages
INFO: Preprocessed 400000 pages
INFO: Preprocessed 500000 pages
INFO: Preprocessed 600000 pages
INFO: Preprocessed 700000 pages
INFO: Preprocessed 800000 pages
INFO: Preprocessed 900000 pages
INFO: Preprocessed 1000000 pages
INFO: Preprocessed 1100000 pages
INFO: Preprocessed 1200000 pages
INFO: Preprocessed 1300000 pages
INFO: Preprocessed 1400000 pages
INFO: Preprocessed 1500000 pages
INFO: Preprocessed 1600000 pages
INFO: Preprocessed 1700000 pages
INFO: Preprocessed 1800000 pages
INFO: Preprocessed 1900000 pages
INFO: Preprocessed 2000000 pages
INFO: Preprocessed 2100000 pages
INFO: Preprocessed 2200000 pages
INFO: Preprocessed 2300000 pages
INFO: Preprocessed 2400000 pages
INFO: Preprocessed 2500000 pages
INFO: Preprocessed 2600000 pages
INFO: Preprocessed 2700000 pages
INFO: Preprocessed 2800000 pages
INFO: Preprocessed 2900000 pages
INFO: Preprocessed 3000000 pages
INFO: Preprocessed 3100000 pages
INFO: Preprocessed 3200000 pages
INFO: Preprocessed 3300000 pages
INFO: Preprocessed 3400000 pages
INFO: Preprocessed 3500000 pages
INFO: Preprocessed 3600000 pages
INFO: Preprocessed 3700000 pages
INFO: Preprocessed 3800000 pages
INFO: Preprocessed 3900000 pages
INFO: Preprocessed 4000000 pages
INFO: Preprocessed 4100000 pages
INFO: Preprocessed 4200000 pages
INFO: Preprocessed 4300000 pages
INFO: Preprocessed 4400000 pages
INFO: Preprocessed 4500000 pages
INFO: Preprocessed 4600000 pages
INFO: Preprocessed 4700000 pages
INFO: Loaded 1036734 templates in 704.2s
INFO: Starting page extraction from zhwiki-20250920-pages-articles-multistream.xml.bz2.
INFO: Using 127 extract processes.
INFO: Extracted 100000 articles (1209.6 art/s)
INFO: Extracted 200000 articles (1947.8 art/s)
INFO: Extracted 300000 articles (2325.1 art/s)
INFO: Extracted 400000 articles (3471.3 art/s)
INFO: Extracted 500000 articles (2551.1 art/s)
INFO: Extracted 600000 articles (2239.4 art/s)
INFO: Extracted 700000 articles (2299.3 art/s)
INFO: Extracted 800000 articles (1525.2 art/s)
INFO: Extracted 900000 articles (3256.1 art/s)
INFO: Extracted 1000000 articles (3485.9 art/s)
INFO: Extracted 1100000 articles (3495.0 art/s)
INFO: Extracted 1200000 articles (3330.4 art/s)
INFO: Extracted 1300000 articles (3555.6 art/s)
INFO: Extracted 1400000 articles (3456.3 art/s)
INFO: Extracted 1500000 articles (2476.1 art/s)
INFO: Extracted 1600000 articles (2268.6 art/s)
INFO: Extracted 1700000 articles (2473.5 art/s)
INFO: Extracted 1800000 articles (2305.9 art/s)
INFO: Extracted 1900000 articles (2263.9 art/s)
INFO: Extracted 2000000 articles (2136.4 art/s)
INFO: Extracted 2100000 articles (2363.0 art/s)
INFO: Extracted 2200000 articles (2601.9 art/s)
INFO: Extracted 2300000 articles (3709.0 art/s)
INFO: Extracted 2400000 articles (2723.9 art/s)
INFO: Extracted 2500000 articles (2487.1 art/s)
INFO: Extracted 2600000 articles (2621.3 art/s)
INFO: Extracted 2700000 articles (2525.4 art/s)
INFO: Extracted 2800000 articles (2666.4 art/s)
INFO: Finished 127-process extraction of 2893023 articles in 1156.5s (2501.5 art/s)

03、清洗數(shù)據(jù)

我們解壓后的數(shù)據(jù)如下圖,下面我們要把數(shù)據(jù)清洗出來(lái)。

注:

我們本步驟生成的文件為 data/cleaned_wiki_full.txt

import os
import json
import logging
import argparse
import re
from tqdm import tqdm


# 配置日志記錄
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')


# python scripts/clean_wiki_text.py data/extracted_wiki_zh data/cleaned_wiki_full.txt --min_line_length 20 --min_article_length 300




def clean_text(text: str) -> str:
    """
    對(duì)文本進(jìn)行深度清洗。
    移除維基百科特有的格式標(biāo)記、參考文獻(xiàn)、HTML標(biāo)簽、日期和數(shù)字等。
    """
    # 移除維基鏈接 [[link|display]] 或 [[link]]
    text = re.sub(r'\[\[([^\]|]+\|)?([^\]]+)\]\]', r'\2', text)


    # 移除參考文獻(xiàn)標(biāo)記 [1], [2], [ref], 等
    text = re.sub(r'\[\d+\]|\[ref\]|\[/ref\]|\[citation needed\]', '', text)


    # 移除HTML標(biāo)簽
    text = re.sub(r'<[^>]+>', '', text)


    # 移除日期格式 (yyyy-mm-dd, yyyy/mm/dd, mm/dd/yyyy 等)
    text = re.sub(r'\d{1,4}[-/]\d{1,2}[-/]\d{1,4}', '', text)


    # 移除年份 (1000-2999)
    text = re.sub(r'\b[12]\d{3}\b', '', text)


    # 移除純數(shù)字(包括小數(shù))
    text = re.sub(r'\b\d+\.?\d*\b', '', text)


    # 移除重復(fù)的空白字符(但保留單個(gè)空格)
    text = re.sub(r' +', ' ', text)


    # 移除行首尾空白
    text = text.strip()


    return text




def process_extracted_wiki(extracted_dir: str, 
                          output_file: str, 
                          min_line_length: int = 20, 
                          min_article_length: int = 200):
    """
    讀取WikiExtractor輸出的JSON文件,提取、清洗文本并保存到單個(gè)文件中。


    參數(shù):
        extracted_dir: WikiExtractor輸出的目錄路徑
        output_file: 最終合并的純文本文件路徑
        min_line_length: 單行文本最小長(zhǎng)度,用于過(guò)濾噪音(默認(rèn): 20)
        min_article_length: 文章最小長(zhǎng)度,用于過(guò)濾短文章(默認(rèn): 200)
    """
    if not os.path.isdir(extracted_dir):
        logging.error(f"輸入的目錄不存在: {extracted_dir}")
        return


    total_articles = 0
    skipped_articles = 0


    # 第一次遍歷:獲取所有需要處理的文件列表
    file_list = []
    for root, dirs, files in os.walk(extracted_dir):
        for file_name in files:
            # 僅處理 WikiExtractor 生成的以 'wiki_' 開(kāi)頭的文件
            if file_name.startswith('wiki_'):
                file_list.append(os.path.join(root, file_name))


    total_files = len(file_list)
    logging.info(f"找到 {total_files} 個(gè)文件等待處理。")


    if total_files == 0:
        logging.warning(f"目錄 {extracted_dir} 中未找到任何 'wiki_' 文件。請(qǐng)檢查路徑。")
        return


    # 第二次遍歷:處理文件并寫入輸出
    with open(output_file, 'w', encoding='utf-8') as f_out:
        # 使用 tqdm 包裝文件列表,顯示處理進(jìn)度
        for file_path in tqdm(file_list, desc="?? 正在提取維基文本"):
            try:
                with open(file_path, 'r', encoding='utf-8') as f_in:
                    for line_num, line in enumerate(f_in, 1):
                        try:
                            article = json.loads(line)
                            text_content = article.get('text', '').strip()


                            # --- 文本清洗和過(guò)濾 ---


                            # 1. 過(guò)濾掉過(guò)短的文章,它們通常是噪音或重定向頁(yè)
                            if len(text_content) < min_article_length:
                                skipped_articles += 1
                                continue


                            # 2. 按行處理文本,過(guò)濾短行和額外的空白
                            # 保留行結(jié)構(gòu),而不是將所有行連接成一個(gè)長(zhǎng)句子
                            cleaned_lines = []
                            for text_line in text_content.split('\n'):
                                text_line = clean_text(text_line)
                                # 只保留足夠長(zhǎng)的行
                                if len(text_line) >= min_line_length:
                                    cleaned_lines.append(text_line)


                            # 使用換行符連接各行,保留段落結(jié)構(gòu)
                            final_text = '\n'.join(cleaned_lines)


                            # 最終檢查:確保清洗后的文本仍然足夠長(zhǎng)
                            if final_text and len(final_text) >= min_article_length:
                                # 文章之間用兩個(gè)換行符分隔
                                f_out.write(final_text + '\n\n')
                                total_articles += 1
                            else:
                                skipped_articles += 1


                        except json.JSONDecodeError:
                            logging.warning(f"無(wú)法解析 JSON,文件: {file_path},行號(hào): {line_num}")
                        except Exception as e:
                            logging.error(f"處理文件 {file_path} 第 {line_num} 行時(shí)出錯(cuò): {e}")


            except Exception as e:
                logging.error(f"打開(kāi)文件 {file_path} 時(shí)出錯(cuò): {e}")


    logging.info(f"  所有維基百科文本已成功提取并清洗。")
    logging.info(f"   總文章數(shù): {total_articles}")
    logging.info(f"   跳過(guò)文章數(shù): {skipped_articles}")
    logging.info(f"   文件已保存到: {output_file}")




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="從 WikiExtractor 輸出的 JSON 文件中提取并清洗純文本。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    # 位置參數(shù) 1: 輸入目錄
    parser.add_argument(
        "extracted_directory",
        type=str,
        help="WikiExtractor 輸出的目錄路徑 (e.g., extracted_wiki_zh)"
    )


    # 位置參數(shù) 2: 輸出文件
    parser.add_argument(
        "output_filename",
        type=str,
        help="最終合并的純文本文件路徑 (e.g., cleaned_wiki.txt)"
    )


    # 可選參數(shù): 最小行長(zhǎng)
    parser.add_argument(
        "--min_line_length",
        type=int,
        default=20,
        help="文章中單行文本必須達(dá)到的最小長(zhǎng)度,用于過(guò)濾噪音。默認(rèn)值: 20"
    )


    # 可選參數(shù): 最小文章長(zhǎng)度
    parser.add_argument(
        "--min_article_length",
        type=int,
        default=200,
        help="文章最小長(zhǎng)度,用于過(guò)濾短文章和重定向頁(yè)。默認(rèn)值: 200"
    )


    args = parser.parse_args()


    process_extracted_wiki(
        args.extracted_directory, 
        args.output_filename, 
        args.min_line_length,
        args.min_article_length
    )




if __name__ == "__main__":
    main()
2025-10-01 11:10:58,772 - INFO - 找到 5 個(gè)文件等待處理。
正在提取維基文本: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.78s/it]
2025-10-01 11:11:32,681 - INFO - 所有維基百科文本已成功提取。總文章數(shù): 628093。文件已保存到 data/cleaned_wiki_full.txt

04、訓(xùn)練分詞器

我們使用SentencePiece訓(xùn)練分詞器,本次我們訓(xùn)練的分詞庫(kù)大小為16k,你也可以訓(xùn)練32k的分詞庫(kù)。相關(guān)代碼及過(guò)程如下:

注:

我們本步驟生成的文件為

workdir/spm_wiki_16k.model

workdir/spm_wiki_16k.vocab

import sys
import sentencepiece as spm
import argparse
import os
from tqdm import tqdm


# python scripts/train_tokenizer.py data/cleaned_wiki_full.txt workdir/spm_wiki 32000




def get_corpus_size(input_file: str) -> int:
    """計(jì)算語(yǔ)料的總行數(shù)和文件大小"""
    try:
        file_size_bytes = os.path.getsize(input_file)
        file_size_mb = file_size_bytes / (1024 * 1024)
        print(f"語(yǔ)料文件大小: {file_size_mb:.2f} MB")


        # 計(jì)算行數(shù)和總字符數(shù)
        line_count = 0
        total_chars = 0
        with open(input_file, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="統(tǒng)計(jì)語(yǔ)料信息"):
                line_count += 1
                total_chars += len(line)


        print(f"語(yǔ)料總行數(shù) (文章數(shù)): {line_count}")
        print(f"總字符數(shù): {total_chars:,}")
        print(f"平均每行字符數(shù): {total_chars / line_count:.1f}")
        return file_size_bytes
    except Exception as e:
        print(f"警告:無(wú)法計(jì)算文件大小或行數(shù):{e}")
        return 0




def train_spm_model(input_file: str, 
                    model_prefix: str, 
                    vocab_size: int,
                    model_type: str = 'bpe',
                    character_coverage: float = 0.9995):
    """
    訓(xùn)練一個(gè)SentencePiece分詞器模型。


    參數(shù):
        input_file: 訓(xùn)練語(yǔ)料文件路徑
        model_prefix: 輸出模型文件的前綴
        vocab_size: 詞匯表大小
        model_type: 分詞算法類型 ('bpe', 'unigram', 'char', 'word')
        character_coverage: 字符覆蓋率 (0-1,通常 0.995-0.9995)
    """
    if not os.path.exists(input_file):
        print(f"錯(cuò)誤:輸入語(yǔ)料文件未找到:{input_file}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(model_prefix)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        print(f"已創(chuàng)建輸出目錄: {output_dir}")


    # 打印語(yǔ)料規(guī)模信息
    print("\n=== 語(yǔ)料分析 ===")
    get_corpus_size(input_file)


    # 構(gòu)建訓(xùn)練參數(shù)
    # 對(duì)于 1.5GB 語(yǔ)料,建議啟用 train_extremely_large_corpus=True 加速
    train_params = {
        'input': input_file,
        'model_prefix': model_prefix,
        'vocab_size': vocab_size,
        'model_type': model_type,
        'character_coverage': character_coverage,
        'num_threads': 32,        # 增加到32(最大化CPU利用)
        'bos_id': 0,
        'eos_id': 1,
        'unk_id': 2,
        'pad_id': -1,
        'normalization_rule_name': 'identity',
        'input_sentence_size': 2000000, # 5000000,         # 增加到500萬(wàn)句子采樣
        'train_extremely_large_corpus': True,   # 必須啟用
        'shuffle_input_sentence': True,
        'seed_sentencepiece_size': 2000000,     # 添加種子句子大小
        'hard_vocab_limit': False,              # 允許超過(guò)目標(biāo)詞匯量以獲得更好質(zhì)量
    }


    print("\n=== SentencePiece 訓(xùn)練參數(shù) ===")
    for key, value in train_params.items():
        print(f"  {key}: {value}")
    print("=" * 35)


    print("\n正在訓(xùn)練 SentencePiece 模型...")
    print("   (請(qǐng)稍候,進(jìn)度由 SentencePiece 輸出)\n")


    try:
        # 執(zhí)行訓(xùn)練
        spm.SentencePieceTrainer.train(**train_params)


        print("\n分詞器模型訓(xùn)練完成!")
        print(f"   模型文件: {model_prefix}.model")
        print(f"   詞匯表文件: {model_prefix}.vocab")


        # 驗(yàn)證模型是否成功創(chuàng)建
        if os.path.exists(f"{model_prefix}.model") and os.path.exists(f"{model_prefix}.vocab"):
            model_size_kb = os.path.getsize(f"{model_prefix}.model") / 1024
            print(f"\n模型文件大小: {model_size_kb:.2f} KB")


            # 加載模型進(jìn)行快速測(cè)試
            print("\n進(jìn)行快速測(cè)試...")
            sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model")


            test_text = "這是一個(gè)分詞測(cè)試句子。"
            tokens = sp.encode(test_text, out_type=str)
            ids = sp.encode(test_text, out_type=int)


            print(f" 測(cè)試文本: {test_text}")
            print(f" 分詞結(jié)果: {tokens}")
            print(f" Token IDs: {ids}")
        else:
            print("\n警告:模型文件生成失敗,請(qǐng)檢查輸入數(shù)據(jù)或參數(shù)")


    except Exception as e:
        print(f"\n訓(xùn)練過(guò)程出錯(cuò): {e}")
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="使用 SentencePiece 訓(xùn)練分詞器模型。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "input_file",
        type=str,
        help="訓(xùn)練語(yǔ)料的路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "model_prefix",
        type=str,
        help="訓(xùn)練模型文件的輸出前綴 (e.g., workdir/spm_wiki)"
    )


    parser.add_argument(
        "vocab_size",
        type=int,
        help="詞匯表大小 (e.g., 32000)"
    )


    parser.add_argument(
        "--model_type",
        type=str,
        default='bpe',
        choices=['bpe', 'unigram', 'char', 'word'],
        help="分詞算法類型 (默認(rèn): bpe)"
    )


    parser.add_argument(
        "--character_coverage",
        type=float,
        default=0.9995,
        help="字符覆蓋率,范圍 [0-1]。對(duì)于小詞表(8K),建議用0.99或更小"
    )


    args = parser.parse_args()


    print("\n" + "="*50)
    print("SentencePiece 分詞器訓(xùn)練程序")
    print("="*50)
    print(f"輸入語(yǔ)料: {args.input_file}")
    print(f"輸出模型前綴: {args.model_prefix}")
    print(f"詞匯表大小: {args.vocab_size}")
    print(f"分詞算法: {args.model_type}")
    print(f"字符覆蓋率: {args.character_coverage}")
    print("="*50 + "\n")


    train_spm_model(
        args.input_file, 
        args.model_prefix, 
        args.vocab_size,
        args.model_type,
        args.character_coverage
    )




if __name__ == "__main__":
    main()
開(kāi)始訓(xùn)練SentencePiece分詞器...
   輸入語(yǔ)料: data/cleaned_wiki_full.txt
   輸出模型前綴: workdir/spm_wiki_16k
   詞匯表大小: 16000
   語(yǔ)料文件大小: 1697.54 MB
Counting lines: 1256186it [00:05, 230354.42it/s]
   語(yǔ)料總行數(shù) (文章數(shù)): 1256186


--- SentencePiece 訓(xùn)練參數(shù) ---
--input=data/cleaned_wiki_full.txt
--model_prefix=workdir/spm_wiki_16k
--vocab_size=16000
--model_type=bpe
--character_coverage=0.9995
--num_threads=16
--bos_id=0
--eos_id=1
--unk_id=2
--pad_id=-1 
------------------------------


? 正在啟動(dòng)訓(xùn)練... 請(qǐng)注意觀察 SentencePiece 自身的進(jìn)度輸出。
sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=data/cleaned_wiki_full.txt --model_prefix=workdir/spm_colinai_16000 --vocab_size=16000 --model_type=bpe --character_coverage=0.9995 --num_threads=16 --bos_id=0 --eos_id=1 --unk_id=2 --pad_id=-1 
sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: data/cleaned_wiki_full.txt
  input_format: 
  model_prefix: workdir/spm_colinai_16000
  model_type: BPE
  vocab_size: 16000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 2
  bos_id: 0
  eos_id: 1
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ? 
  enable_differential_privacy: 0
  differential_privacy_noise_level: 0
  differential_privacy_clipping_threshold: 0
}
normalizer_spec {
  name: nmt_nfkc
  add_dummy_prefix: 1
  remove_extra_whitespaces: 1
  escape_whitespaces: 1
  normalization_rule_tsv: 
}
denormalizer_spec {}
trainer_interface.cc(355) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.
trainer_interface.cc(186) LOG(INFO) Loading corpus: data/cleaned_wiki_full.txt
trainer_interface.cc(382) LOG(WARNING) Found too long line (18615 > 4192).
trainer_interface.cc(384) LOG(WARNING) Too long lines are skipped in the training.
trainer_interface.cc(385) LOG(WARNING) The maximum length can be changed with --max_sentence_length=<size> flag.
trainer_interface.cc(411) LOG(INFO) Loaded all 528882 sentences
trainer_interface.cc(418) LOG(INFO) Skipped 99211 too long sentences.
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: </s>
trainer_interface.cc(427) LOG(INFO) Adding meta_piece: <unk>
trainer_interface.cc(432) LOG(INFO) Normalizing sentences...
trainer_interface.cc(541) LOG(INFO) all chars count=281809036
trainer_interface.cc(552) LOG(INFO) Done: 99.95% characters are covered.
trainer_interface.cc(562) LOG(INFO) Alphabet size=8686
trainer_interface.cc(563) LOG(INFO) Final character coverage=0.9995
trainer_interface.cc(594) LOG(INFO) Done! preprocessed 528882 sentences.
trainer_interface.cc(600) LOG(INFO) Tokenizing input sentences with whitespace: 528882
trainer_interface.cc(611) LOG(INFO) Done! 3885388
.....

05、原始文本轉(zhuǎn)為Token ID 序列

在訓(xùn)練大型語(yǔ)言模型的準(zhǔn)備階段,將海量文本語(yǔ)料轉(zhuǎn)化為模型可處理的數(shù)字格式至關(guān)重要。本次將原始文本語(yǔ)料編碼為整數(shù) Token ID 序列。為了克服單次加載大文件的內(nèi)存限制,腳本采用了分塊讀取機(jī)制,支持以自定義大小逐塊處理語(yǔ)料。所有 Token ID 最終被匯總并轉(zhuǎn)化為高效率的 torch.int32 PyTorch 張量,直接存儲(chǔ)為 .pt 文件。這不僅優(yōu)化了數(shù)據(jù)格式,方便后續(xù) PyTorch DataLoader 快速讀取,同時(shí)也提供了關(guān)鍵的統(tǒng)計(jì)信息和完整性驗(yàn)證,是構(gòu)建 LLM 數(shù)據(jù)集的穩(wěn)定且高性能的預(yù)處理方案。

import sys
import torch
import sentencepiece as spm
import argparse
from tqdm import tqdm
import os
import numpy as np


# python scripts/preprocess_data.py workdir/spm_wiki.model data/cleaned_wiki_full.txt workdir/wiki_tokens.pt




def preprocess(sp_model_path: str, 
               corpus_path: str, 
               output_path: str,
               chunk_size_mb: int = 50):
    """
    分塊讀取語(yǔ)料,編碼為 Token ID,并保存為 PyTorch 文件。


    參數(shù):
        sp_model_path: SentencePiece 模型文件路徑
        corpus_path: 輸入語(yǔ)料文件路徑
        output_path: 輸出 .pt 文件路徑
        chunk_size_mb: 每次處理的文本大小(MB),默認(rèn) 50MB
    """
    # 驗(yàn)證文件存在
    if not os.path.exists(sp_model_path):
        print(f"錯(cuò)誤:分詞器模型文件未找到: {sp_model_path}")
        sys.exit(1)


    if not os.path.exists(corpus_path):
        print(f"錯(cuò)誤:語(yǔ)料文件未找到: {corpus_path}")
        sys.exit(1)


    # 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        vocab_size = sp.get_piece_size()
        print(f"   分詞器加載成功")
        print(f"   詞匯表大小: {vocab_size}")
        print(f"   特殊 Token: BOS={sp.bos_id()}, EOS={sp.eos_id()}, UNK={sp.unk_id()}, PAD={sp.pad_id()}")
    except Exception as e:
        print(f"加載分詞器失敗: {e}")
        sys.exit(1)


    # 確保輸出目錄存在
    output_dir = os.path.dirname(output_path)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)


    print(f"\n 開(kāi)始處理語(yǔ)料...")
    print(f"   輸入文件: {corpus_path}")
    print(f"   輸出文件: {output_path}")
    print(f"   塊大小: {chunk_size_mb} MB\n")


    # 計(jì)算總大小用于進(jìn)度條
    total_bytes = os.path.getsize(corpus_path)
    chunk_size_bytes = chunk_size_mb * 1024 * 1024


    token_ids = []
    tokens_processed = 0
    chunks_processed = 0


    try:
        with open(corpus_path, 'r', encoding='utf-8') as f:
            with tqdm(total=total_bytes, unit='B', unit_scale=True, desc="? 編碼語(yǔ)料") as pbar:


                while True:
                    chunk = f.read(chunk_size_bytes)
                    if not chunk:
                        break


                    # 直接編碼(cleaned_wiki_full.txt 已經(jīng)過(guò)清洗)
                    ids = sp.encode(chunk, out_type=int)
                    token_ids.extend(ids)


                    # 更新進(jìn)度條
                    bytes_read = len(chunk.encode('utf-8'))
                    pbar.update(bytes_read)


                    tokens_processed += len(ids)
                    chunks_processed += 1


                    # 定期顯示進(jìn)度信息
                    if chunks_processed % 10 == 0:
                        pbar.set_postfix({
                            'chunks': chunks_processed,
                            'tokens': f'{tokens_processed:,}'
                        })


        print(f"\n 編碼完成")
        print(f"   處理塊數(shù): {chunks_processed}")
        print(f"   總 Token 數(shù): {tokens_processed:,}")


        # 轉(zhuǎn)換為 PyTorch 張量
        print(f"\n轉(zhuǎn)換為張量并保存...")
        final_tensor = torch.tensor(token_ids, dtype=torch.int32)


        print(f"   張量形狀: {final_tensor.shape}")
        print(f"   張量大小: {final_tensor.numel():,}")
        print(f"   數(shù)據(jù)類型: {final_tensor.dtype}")
        print(f"   占用內(nèi)存: {final_tensor.numel() * 4 / (1024**3):.2f} GB")


        # 驗(yàn)證 Token ID 范圍
        min_id = final_tensor.min().item()
        max_id = final_tensor.max().item()
        print(f"   Token ID 范圍: [{min_id}, {max_id}]")


        if max_id >= vocab_size or min_id < 0:
            print(f"   警告: 檢測(cè)到越界 Token ID!")
            print(f"   詞匯表大小: {vocab_size}")
            print(f"   最大 ID: {max_id}")


        # 保存張量
        torch.save(final_tensor, output_path)
        file_size_mb = os.path.getsize(output_path) / (1024 ** 2)
        print(f"\nToken ID 已保存到 {output_path}")
        print(f"   文件大小: {file_size_mb:.2f} MB")


        # 驗(yàn)證保存的文件
        print(f"\n驗(yàn)證保存的文件...")
        loaded_tensor = torch.load(output_path)
        print(f"   加載成功,形狀: {loaded_tensor.shape}")
        print(f"   是否相同: {torch.equal(final_tensor, loaded_tensor)}")


        print(f"\n? 預(yù)處理完成!")


    except Exception as e:
        print(f"\n處理過(guò)程中出錯(cuò): {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)




def main():
    parser = argparse.ArgumentParser(
        descriptinotallow="將清洗后的文本語(yǔ)料轉(zhuǎn)換為 Token ID 二進(jìn)制文件。",
        formatter_class=argparse.RawTextHelpFormatter
    )


    parser.add_argument(
        "model_path",
        type=str,
        help="SentencePiece 模型文件路徑 (e.g., workdir/spm_wiki.model)"
    )


    parser.add_argument(
        "corpus_path",
        type=str,
        help="輸入語(yǔ)料文件路徑 (e.g., data/cleaned_wiki_full.txt)"
    )


    parser.add_argument(
        "output_path",
        type=str,
        help="輸出 Token ID 文件路徑 (e.g., workdir/wiki_tokens.pt)"
    )


    parser.add_argument(
        "--chunk_size",
        type=int,
        default=50,
        help="每次處理的文本大小(MB),默認(rèn) 50MB。更大的塊更快,但占用更多內(nèi)存。"
    )


    args = parser.parse_args()


    print("\n" + "="*60)
    print("數(shù)據(jù)預(yù)處理程序 - 文本到 Token ID")
    print("="*60)
    print(f"SentencePiece 模型: {args.model_path}")
    print(f"輸入語(yǔ)料: {args.corpus_path}")
    print(f"輸出文件: {args.output_path}")
    print(f"塊大小: {args.chunk_size} MB")
    print("="*60 + "\n")


    preprocess(
        args.model_path,
        args.corpus_path,
        args.output_path,
        args.chunk_size
    )




if __name__ == "__main__":
    main()

06、進(jìn)行模型預(yù)訓(xùn)練

"""
GPT 高性能訓(xùn)練腳本
"""


from __future__ import annotations
import sys
import os
import math
import json
from datetime import datetime
from typing import Optional


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
from tqdm import tqdm


# ==================== 配置參數(shù) ====================
class Config:
    BLOCK_SIZE = 512 #256
    BATCH_SIZE = 32 #64
    GRAD_ACCUM_STEPS = 4 #1
    MODEL_DIM = 384 #256
    N_LAYERS = 5 #2
    NUM_HEADS = 6 #4
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4
    VOCAB_SIZE = None


    EPOCHS = 1
    MAX_STEPS = 10000 # 此處根據(jù)自己的硬件和時(shí)間定義步數(shù)
    WARMUP_STEPS = 500
    LR = 1e-4
    MIN_LR = 1e-5
    WEIGHT_DECAY = 0.01
    GRAD_CLIP = 1.0
    DROPOUT = 0.1


    CHECKPOINT_EVERY = 5000
    LOG_EVERY = 100


    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    CHECKPOINT_DIR = "./checkpoints"
    LATEST_CHECKPOINT = "latest_checkpoint.pth"


    NUM_WORKERS = 8
    SEED = 42


    # 啟用 bfloat16 (推薦用于現(xiàn)代 GPU)
    DTYPE = torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16


CFG = Config()


if CFG.DEVICE == 'cuda':
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.cuda.empty_cache()
    # 檢查是否使用了 bfloat16
    if CFG.DTYPE == torch.bfloat16:
        print("使用 bfloat16 混合精度 (推薦)")
    else:
        print("使用 float16 混合精度")


# ==================== 工具函數(shù) ====================
def print_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024**3)
        reserved = torch.cuda.memory_reserved() / (1024**3)
        print(f"GPU顯存: {allocated:.2f}GB / {reserved:.2f}GB")


def set_seed(seed: int):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


set_seed(CFG.SEED)


# ==================== 數(shù)據(jù)集 ====================
class TextDataset(Dataset):
    def __init__(self, token_ids: torch.Tensor, block_size: int):
        self.ids = token_ids.long()
        self.block_size = block_size


    def __len__(self):
        return max(0, self.ids.size(0) - self.block_size)


    def __getitem__(self, idx):
        x = self.ids[idx: idx + self.block_size]
        y = self.ids[idx + 1: idx + 1 + self.block_size]
        return x, y


# ==================== RoPE 位置編碼  ====================
class RotaryPositionalEmbedding(nn.Module):
    """RoPE 實(shí)現(xiàn)"""
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"


        # 基頻:theta_i = 10000^(-2i/d)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)


        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return


        # m: (seq_len,), theta_i: (head_dim//2,)
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)  # (seq_len, head_dim//2)


        # 構(gòu)造完整的旋轉(zhuǎn)矩陣(每個(gè)復(fù)數(shù)對(duì)重復(fù))
        emb = torch.cat([freqs, freqs], dim=-1)  # (seq_len, head_dim)


        cos = emb.cos()[None, None, :, :]  # (1, 1, seq_len, head_dim)
        sin = emb.sin()[None, None, :, :]  # (1, 1, seq_len, head_dim)


        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """應(yīng)用RoPE旋轉(zhuǎn)"""
    # x: (B, H, T, D), cos/sin: (1, 1, T, D)
    # 使用(x, y) -> (x*cos-y*sin, x*sin+y*cos)
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """將向量旋轉(zhuǎn)90度"""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Flash Attention ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5


        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        assert T <= self.rope.max_seq_len, f"Seq len {T} exceeds max {self.rope.max_seq_len}"


        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        # 應(yīng)用RoPE
        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        # 注意力計(jì)算
        # 注意:這里如果使用 torch.nn.functional.scaled_dot_product_attention 配合 torch.compile 會(huì)更快
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if causal_mask is not None:
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn = self.attn_dropout(attn)
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


# ==================== 前饋網(wǎng)絡(luò) ====================
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )


    def forward(self, x):
        return self.net(x)


# ==================== Transformer Block ====================
class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


# ==================== GPT 模型(已移除 pos_emb) ====================
class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = CFG.DROPOUT,
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        # self.pos_emb = nn.Embedding(block_size, dim) # 移除:與 RoPE 沖突
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size
        self.apply(self._init_weights)


        n_params = sum(p.numel() for p in self.parameters())
        print(f"模型參數(shù): {n_params/1e6:.2f}M")


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)


    def forward(self, idx):
        B, T = idx.shape
        assert T <= self.block_size, f"Seq len {T} exceeds block_size {self.block_size}"


        token_emb = self.token_emb(idx)


       
        x = self.dropout(token_emb) # token embedding


        causal_mask = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))[None, None, :, :]
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits


# ==================== 檢查點(diǎn)管理 ====================
def save_checkpoint(model, optimizer, scaler, lr_scheduler, step: int, loss: float, config_dict: dict):
    os.makedirs(CFG.CHECKPOINT_DIR, exist_ok=True)
    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    state = {
        'step': step,
        'loss': loss,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'config': config_dict,
        'torch_rng_state': torch.get_rng_state(),
        'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
    }


    if scaler is not None and hasattr(scaler, "state_dict"):
        state['scaler_state_dict'] = scaler.state_dict()


    if lr_scheduler is not None:
        state['lr_scheduler_state_dict'] = {
            'current_step': lr_scheduler.current_step,
            'warmup_steps': lr_scheduler.warmup_steps,
            'total_steps': lr_scheduler.total_steps,
            'base_lr': lr_scheduler.base_lr,
            'min_lr': lr_scheduler.min_lr,
        }


    torch.save(state, checkpoint_path)


    try:
        with open(os.path.join(CFG.CHECKPOINT_DIR, "config.json"), "w", encoding="utf-8") as f:
            json.dump(config_dict, f, indent=2)
    except Exception:
        pass


    print(f" 檢查點(diǎn)已保存: {checkpoint_path} (step {step}, loss {loss:.4f})")


def load_checkpoint(checkpoint_path: str, model, optimizer, scaler, lr_scheduler):
    if not os.path.exists(checkpoint_path):
        return None


    checkpoint = torch.load(checkpoint_path, map_locatinotallow=CFG.DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


    if checkpoint.get('scaler_state_dict') is not None and scaler is not None:
        try:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        except Exception as e:
            print(f"無(wú)法恢復(fù)scaler: {e}")


    if checkpoint.get('lr_scheduler_state_dict') is not None and lr_scheduler is not None:
        try:
            sched_state = checkpoint['lr_scheduler_state_dict']
            lr_scheduler.current_step = sched_state['current_step']
            lr_scheduler.warmup_steps = sched_state['warmup_steps']
            lr_scheduler.total_steps = sched_state['total_steps']
            lr_scheduler.base_lr = sched_state['base_lr']
            lr_scheduler.min_lr = sched_state['min_lr']
        except Exception as e:
            print(f"無(wú)法恢復(fù)lr_scheduler: {e}")


    torch.set_rng_state(checkpoint['torch_rng_state'])
    if torch.cuda.is_available() and checkpoint.get('cuda_rng_state') is not None:
        torch.cuda.set_rng_state(checkpoint['cuda_rng_state'])


    print(f"檢查點(diǎn)已加載: {checkpoint_path}")
    print(f"    Step: {checkpoint['step']}, Loss: {checkpoint['loss']:.4f}")
    return checkpoint['step']


# ==================== 學(xué)習(xí)率調(diào)度器 ====================
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps: int, total_steps: int, base_lr: float, min_lr: float):
        self.optimizer = optimizer
        self.warmup_steps = max(0, int(warmup_steps))
        self.total_steps = max(1, int(total_steps))
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.current_step = 0


    def get_lr(self, step: int = None) -> float:
        """計(jì)算給定step的學(xué)習(xí)率(不修改optimizer)"""
        if step is None:
            step = self.current_step


        if step < self.warmup_steps and self.warmup_steps > 0:
            return self.base_lr * (step / float(self.warmup_steps))
        else:
            denom = max(1, (self.total_steps - self.warmup_steps))
            progress = (step - self.warmup_steps) / denom
            progress = min(1.0, max(0.0, progress))
            return self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))


    def step(self):
        """執(zhí)行一次步長(zhǎng)更新"""
        lr = self.get_lr(self.current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.current_step += 1
        return lr


# ==================== 訓(xùn)練循環(huán) ====================
def train(model: nn.Module, train_loader: DataLoader, epochs: int = CFG.EPOCHS, resume: bool = False):
    # 檢測(cè)fused優(yōu)化器支持
    fused = False
    try:
        fused = torch.cuda.is_available() and ("fused" in torch.optim.AdamW.__init__.__code__.co_varnames)
    except Exception:
        fused = False


    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=CFG.LR,
        betas=(0.9, 0.95),
        weight_decay=CFG.WEIGHT_DECAY,
        fused=fused
    )


    # 使用配置中的 DTYPE
    scaler = torch.cuda.amp.GradScaler(enabled=(CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16))
    loss_fn = nn.CrossEntropyLoss()


    total_steps = CFG.MAX_STEPS if CFG.MAX_STEPS else len(train_loader) * epochs
    lr_scheduler = WarmupCosineScheduler(optimizer, CFG.WARMUP_STEPS, total_steps, CFG.LR, CFG.MIN_LR)


    model.train()
    start_step = 0
    best_loss = float('inf')


    checkpoint_path = os.path.join(CFG.CHECKPOINT_DIR, CFG.LATEST_CHECKPOINT)
    if resume and os.path.exists(checkpoint_path):
        loaded_step = load_checkpoint(checkpoint_path, model, optimizer, scaler, lr_scheduler)
        if loaded_step is not None:
            start_step = loaded_step


    global_step = start_step
    grad_accum_counter = 0
    accumulated_loss = 0.0


    print("\n" + "="*60)
    print("開(kāi)始訓(xùn)練...")
    print("="*60)
    print_gpu_memory()
    print()


    # 自動(dòng)選擇是否需要 scaler.scale()
    use_scaler = (CFG.DEVICE == "cuda") and (CFG.DTYPE == torch.float16)


    for epoch in range(epochs):
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", initial=global_step % len(train_loader) if epoch == 0 else 0)
        num_batches = 0
        last_lr = None


        for batch_idx, (xb, yb) in enumerate(pbar):
            # 跳過(guò)已訓(xùn)練的批次 (如果從中間恢復(fù))
            if global_step > start_step and batch_idx < (start_step % len(train_loader)):
                 continue


            xb = xb.to(CFG.DEVICE, non_blocking=True)
            yb = yb.to(CFG.DEVICE, non_blocking=True)


            with torch.cuda.amp.autocast(enabled=(CFG.DEVICE == "cuda"), dtype=CFG.DTYPE):
                logits = model(xb)
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))
                loss_item = loss.item()
                loss = loss / CFG.GRAD_ACCUM_STEPS


            if use_scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()


            grad_accum_counter += 1
            accumulated_loss += loss_item
            num_batches += 1
            # 這里的 global_step 計(jì)數(shù)是基于數(shù)據(jù)批次的,而不是優(yōu)化器步數(shù),用于日志和檢查點(diǎn)
            # 真正的優(yōu)化器步數(shù)會(huì)在下面更新


            # 梯度累積:達(dá)到閾值時(shí)執(zhí)行優(yōu)化步驟
            if grad_accum_counter >= CFG.GRAD_ACCUM_STEPS:


                # 優(yōu)化器步進(jìn) (這是真正的 global_step 增長(zhǎng)點(diǎn))
                lr_scheduler.step() # 先更新 LR
                global_step += 1 # 只有進(jìn)行了一次優(yōu)化器步進(jìn),才算一個(gè) global_step


                if use_scaler:
                    scaler.unscale_(optimizer)


                # 梯度裁剪 (在 unscale 后或非 AMP 模式下)
                torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


                if use_scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()


                optimizer.zero_grad()
                grad_accum_counter = 0
                last_lr = lr_scheduler.get_lr(global_step) # 獲取當(dāng)前步的LR


            # 日志輸出
            if global_step % CFG.LOG_EVERY == 0 or (global_step == 1):
                # accumulated_loss 是累積的原始損失, num_batches 是累積的批次數(shù)
                avg_loss = accumulated_loss / num_batches if num_batches > 0 else 0.0
                pbar.set_postfix({
                    'step': global_step,
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{last_lr:.2e}' if last_lr is not None else 'N/A'
                })
                # 重置累積值以便計(jì)算下一個(gè) LOG_EVERY 間隔的平均損失
                accumulated_loss = 0.0
                num_batches = 0




            # 保存檢查點(diǎn)
            if global_step > start_step and global_step % CFG.CHECKPOINT_EVERY == 0:
                # 使用上一個(gè)日志點(diǎn)計(jì)算的 avg_loss
                current_avg_loss = accumulated_loss / num_batches if num_batches > 0 else loss_item


                config_dict = {
                    'vocab_size': CFG.VOCAB_SIZE,
                    'block_size': CFG.BLOCK_SIZE,
                    'model_dim': CFG.MODEL_DIM,
                    'n_layers': CFG.N_LAYERS,
                    'num_heads': CFG.NUM_HEADS,
                    'created_at': datetime.now().isoformat()
                }
                save_checkpoint(model, optimizer, scaler, lr_scheduler, global_step, current_avg_loss, config_dict)
                torch.cuda.empty_cache()


            if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
                break


        # 處理 epoch 結(jié)束時(shí)剩余的梯度 (如果 grad_accum_counter > 0)
        if grad_accum_counter > 0:
            if use_scaler:
                scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.GRAD_CLIP)


            if use_scaler:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()


            optimizer.zero_grad()
            lr_scheduler.step()
            global_step += 1
            grad_accum_counter = 0




        # 此時(shí) pbar.total_loss 已累積
        if num_batches > 0:
             final_avg_loss = accumulated_loss / num_batches
        else:
             final_avg_loss = float('inf')




        if final_avg_loss < best_loss:
            best_loss = final_avg_loss
            best_path = os.path.join(CFG.CHECKPOINT_DIR, "best_model.pth")
            torch.save(model.state_dict(), best_path)
            print(f"最佳模型已保存 (loss: {best_loss:.4f})")


        print(f"\n[Epoch {epoch+1}] Avg Loss: {final_avg_loss:.4f}")


        if CFG.MAX_STEPS and global_step >= CFG.MAX_STEPS:
            break


    print("\n訓(xùn)練完成!")


# ==================== 主函數(shù) ====================
def main():
    if len(sys.argv) < 4:
        print("用法: python train_20251012_v1.py workdir/spm_wiki_16k.model workdir/wiki_tokens_16k.pt models/gpt_wiki.pth [--resume]")
        sys.exit(1)


    sp_model_path, token_file_path, out_path = sys.argv[1:4]
    resume = "--resume" in sys.argv


    if not os.path.exists(token_file_path):
        print(f" Token文件不存在: {token_file_path}")
        sys.exit(1)


    # 檢查 CFG.DTYPE 是否為 bfloat16 但環(huán)境不支持
    if CFG.DTYPE == torch.bfloat16 and not torch.cuda.is_bf16_supported():
        print("警告: bfloat16 不受當(dāng)前 CUDA 設(shè)備支持,自動(dòng)回退到 float16。")
        CFG.DTYPE = torch.float16


    sp = spm.SentencePieceProcessor(model_file=sp_model_path)
    CFG.VOCAB_SIZE = sp.get_piece_size()


    print("="*60)
    print("GPT 語(yǔ)言模型訓(xùn)練")
    print("="*60)
    print(f"分詞器: {sp_model_path}")
    print(f"Token文件: {token_file_path}")
    print(f"輸出模型: {out_path}")
    print(f"設(shè)備: {CFG.DEVICE}")
    print(f"\n模型配置:")
    print(f"    - VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    print(f"    - BLOCK_SIZE: {CFG.BLOCK_SIZE}")
    print(f"    - MODEL_DIM: {CFG.MODEL_DIM}")
    print(f"    - N_LAYERS: {CFG.N_LAYERS}")
    print(f"    - NUM_HEADS: {CFG.NUM_HEADS}")
    print(f"\n訓(xùn)練配置:")
    print(f"    - BATCH_SIZE: {CFG.BATCH_SIZE}")
    print(f"    - GRAD_ACCUM_STEPS: {CFG.GRAD_ACCUM_STEPS}")
    print(f"    - 有效BATCH_SIZE: {CFG.BATCH_SIZE * CFG.GRAD_ACCUM_STEPS}")
    print(f"    - LR: {CFG.LR}, WARMUP_STEPS: {CFG.WARMUP_STEPS}")
    print("="*60)


    print(f"\n加載Token文件: {token_file_path}")
    ids = torch.load(token_file_path)
    print(f"已加載 {ids.numel():,} tokens ({ids.numel() * ids.element_size() / (1024**3):.2f} GB)")


    dataset = TextDataset(ids, CFG.BLOCK_SIZE)
    del ids
    torch.cuda.empty_cache()


    # 改進(jìn):?jiǎn)⒂?shuffle=True 進(jìn)行預(yù)訓(xùn)練
    num_workers = CFG.NUM_WORKERS
    try:
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True, # 啟用 Shuffle
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=num_workers,
            persistent_workers=True if num_workers > 0 else False
        )
    except Exception as e:
        print(f"DataLoader錯(cuò)誤: {e}, 改用num_workers=0")
        train_loader = DataLoader(
            dataset,
            batch_size=CFG.BATCH_SIZE,
            shuffle=True,
            pin_memory=(CFG.DEVICE == "cuda"),
            num_workers=0
        )


    model = GPTModel(
        CFG.VOCAB_SIZE,
        CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=CFG.DROPOUT
    ).to(CFG.DEVICE)


    # 嘗試編譯(容錯(cuò))
    try:
        model = torch.compile(model, mode='reduce-overhead')
        print("已啟用 torch.compile() 加速")
    except Exception as e:
        print(f"跳過(guò) torch.compile(): {e}")


    train(model, train_loader, epochs=CFG.EPOCHS, resume=resume)


    torch.save(model.state_dict(), out_path)
    print(f"\n最終模型已保存到 {out_path}")
    print_gpu_memory()


if __name__ == "__main__":
    main()

07、進(jìn)行模型推理測(cè)試

import torch
from torch import nn
import sentencepiece as spm
from typing import Optional


# ==================== 配置參數(shù) (必須與訓(xùn)練時(shí)一致) ====================
# 使用與訓(xùn)練腳本中完全相同的配置
class Config:
    BLOCK_SIZE = 512
    # 模型尺寸參數(shù) (必須與訓(xùn)練時(shí)一致)
    MODEL_DIM = 384
    N_LAYERS = 5
    NUM_HEADS = 6
    HEAD_DIM = MODEL_DIM // NUM_HEADS
    FFN_DIM = MODEL_DIM * 4 
    VOCAB_SIZE = None


    # 推理設(shè)置
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    # 推理通常使用 float32 獲得最佳兼容性和精度
    DTYPE = torch.float32 


CFG = Config()


# ==================== RoPE 位置編碼 (與訓(xùn)練腳本保持一致) ====================
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, head_dim: int, max_seq_len: int = 2048):
        super().__init__()
        self.head_dim = head_dim
        assert head_dim % 2 == 0, "head_dim must be even"
        inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim, 2).float() / head_dim))
        self.register_buffer("inv_freq", inv_freq)
        self.max_seq_len = max_seq_len
        self._seq_len_cached = max_seq_len
        self._cos_cached = None
        self._sin_cached = None
        self._update_cos_sin_cache(max_seq_len, device=self.inv_freq.device)


    def _update_cos_sin_cache(self, seq_len: int, device: torch.device):
        if seq_len == self._seq_len_cached and self._cos_cached is not None:
            return
        m = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", m, self.inv_freq)
        emb = torch.cat([freqs, freqs], dim=-1)
        cos = emb.cos()[None, None, :, :]
        sin = emb.sin()[None, None, :, :]
        self._cos_cached = cos
        self._sin_cached = sin
        self._seq_len_cached = seq_len


    def forward(self, seq_len: int, device: Optional[torch.device] = None):
        if device is None:
            device = self.inv_freq.device
        self._update_cos_sin_cache(seq_len, device=device)
        return self._cos_cached.to(device), self._sin_cached.to(device)


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    return (x * cos) + (_rotate_half(x) * sin)


def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# ==================== Attention, FFN, Block, Model (與訓(xùn)練腳本保持一致) ====================
class FlashAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, attn_dropout: float = 0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        # 推理時(shí)通常不使用 Dropout,但模型結(jié)構(gòu)需要保持一致
        self.attn_dropout = nn.Dropout(attn_dropout) 
        self.rope = RotaryPositionalEmbedding(self.head_dim)


    def forward(self, x: torch.Tensor, causal_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)
        q = q.permute(0, 2, 1, 3)  # (B, H, T, D)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)


        cos, sin = self.rope(T, device=x.device)
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)


        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        # 注意:在推理時(shí),通常使用 KV-Cache,這里簡(jiǎn)化為完整計(jì)算
        if T > 1: # 僅在序列長(zhǎng)度大于 1 時(shí)應(yīng)用 mask
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))[None, None, :, :]
            scores = scores.masked_fill(causal_mask == 0, float('-inf'))


        attn = torch.softmax(scores, dim=-1)
        # 推理時(shí)禁用 dropout
        # attn = self.attn_dropout(attn) 
        out = torch.matmul(attn, v)
        out = out.permute(0, 2, 1, 3).contiguous().view(B, T, C)
        return self.out_proj(out)


class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.1):
        super().__init__()
        # 必須保持與訓(xùn)練腳本中完全相同的 nn.Sequential 結(jié)構(gòu)
        self.net = nn.Sequential(
            GLU(dim, hidden_dim),
            nn.Dropout(dropout),   # net.1: Dropout (必須保留,占位)
            nn.Linear(hidden_dim, dim), # net.2: Linear (與訓(xùn)練時(shí)一致)
            nn.Dropout(dropout),   # net.3: Dropout (必須保留,占位)
        )


    def forward(self, x):
        # 在推理時(shí), model.eval() 會(huì)自動(dòng)禁用所有 nn.Dropout 層,但結(jié)構(gòu)不變
        return self.net(x)


# 確保 GLU 的定義如下(與訓(xùn)練時(shí)一致):
class GLU(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        # GLU 內(nèi)部只有一個(gè) nn.Linear
        self.linear = nn.Linear(in_dim, out_dim * 2)


    def forward(self, x):
        x, gates = self.linear(x).chunk(2, dim=-1)
        return x * torch.nn.functional.silu(gates)


class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int, dropout: float = 0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = FlashAttention(dim, num_heads, attn_dropout=dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ffn_dim, dropout)


    def forward(self, x, causal_mask=None):
        x = x + self.attn(self.ln1(x), causal_mask)
        x = x + self.ff(self.ln2(x))
        return x


class GPTModel(nn.Module):
    def __init__(self, vocab_size: int, block_size: int, dim: int = CFG.MODEL_DIM,
                 num_layers: int = CFG.N_LAYERS, num_heads: int = CFG.NUM_HEADS,
                 ffn_dim: int = CFG.FFN_DIM, dropout: float = 0.0, # 推理時(shí) dropout=0
                 tie_weights: bool = True):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.dropout = nn.Dropout(dropout)


        self.blocks = nn.ModuleList([
            TransformerBlock(dim, num_heads, ffn_dim, dropout) for _ in range(num_layers)
        ])


        self.ln_final = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)


        if tie_weights:
            self.lm_head.weight = self.token_emb.weight


        self.block_size = block_size


    def forward(self, idx):
        B, T = idx.shape
        token_emb = self.token_emb(idx)
        x = token_emb # 推理時(shí)不使用 dropout


        causal_mask = None # Attention 模塊內(nèi)部處理 Causal Mask
        for block in self.blocks:
            x = block(x, causal_mask)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        return logits




# ==================== 推理和生成函數(shù) ====================


@torch.no_grad()
def generate_text(model: GPTModel, sp: spm.SentencePieceProcessor, 
                  prompt: str, max_new_tokens: int, temperature: float = 0.8, 
                  top_k: int = 50):


    model.eval()
    device = CFG.DEVICE


    # 1. 編碼輸入
    input_ids = sp.encode_as_ids(prompt)
    if not input_ids:
        return "無(wú)法編碼輸入。"


    # 將輸入轉(zhuǎn)換為模型期望的格式 (B, T)
    x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)


    # 2. 循環(huán)生成
    for _ in range(max_new_tokens):
        # 裁剪輸入以適應(yīng)模型的 BLOCK_SIZE
        # 在實(shí)際部署中,這里應(yīng)該使用 KV Cache,但此處簡(jiǎn)化為完整前向傳播
        idx_cond = x if x.size(1) <= CFG.BLOCK_SIZE else x[:, -CFG.BLOCK_SIZE:]


        # 獲取 logits
        logits = model(idx_cond)


        # 只取最后一個(gè)時(shí)間步的 logits
        logits = logits[:, -1, :] 


        # 應(yīng)用溫度縮放
        logits = logits / temperature


        # 3. Top-K 采樣
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')


        # 計(jì)算概率并采樣
        probs = torch.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)


        # 4. 停止條件
        # 檢查是否生成了 EOS token (假設(shè) </s> 是 ID 3, 請(qǐng)根據(jù)您的分詞器調(diào)整)
        # 默認(rèn)使用 SentencePiece 的 <eos> ID
        if idx_next.item() == sp.eos_id():
            break


        # 將新生成的 token 添加到序列中
        x = torch.cat((x, idx_next), dim=1)


        # 檢查是否達(dá)到最大序列長(zhǎng)度 (防止溢出)
        if x.size(1) >= CFG.BLOCK_SIZE + max_new_tokens:
            break


    # 5. 解碼輸出
    output_ids = x[0].tolist()
    # 查找輸入 prompt 的長(zhǎng)度,只解碼新生成的 token
    start_index = len(input_ids)


    return sp.decode_ids(output_ids[start_index:])




# ==================== 主執(zhí)行函數(shù) ====================


def main_infer(sp_model_path: str, model_weights_path: str):
    print("="*50)
    print(f"GPT 模型推理模式")
    print(f"設(shè)備: {CFG.DEVICE}, DTYPE: {CFG.DTYPE}")
    print("="*50)


    # 1. 加載分詞器
    try:
        sp = spm.SentencePieceProcessor(model_file=sp_model_path)
        CFG.VOCAB_SIZE = sp.get_piece_size()
        print(f"加載分詞器成功,VOCAB_SIZE: {CFG.VOCAB_SIZE}")
    except Exception as e:
        print(f"無(wú)法加載分詞器模型 {sp_model_path}: {e}")
        return


    # 2. 實(shí)例化模型
    model = GPTModel(
        vocab_size=CFG.VOCAB_SIZE,
        block_size=CFG.BLOCK_SIZE,
        dim=CFG.MODEL_DIM,
        num_layers=CFG.N_LAYERS,
        num_heads=CFG.NUM_HEADS,
        ffn_dim=CFG.FFN_DIM,
        dropout=0.0 # 推理時(shí)設(shè)置 dropout 為 0
    ).to(CFG.DEVICE).to(CFG.DTYPE)


    # 3. 加載權(quán)重
    try:
        # 檢查是否是 torch.compile 后的狀態(tài)字典
        weights = torch.load(model_weights_path, map_locatinotallow=CFG.DEVICE)


        # 如果權(quán)重是 DDP 或 torch.compile 包裝后的,需要解包
        if any(k.startswith('_orig_mod.') for k in weights.keys()):
            weights = {k.replace('_orig_mod.', ''): v for k, v in weights.items()}


        model.load_state_dict(weights, strict=True)
        print(f"成功加載模型權(quán)重: {model_weights_path}")
    except Exception as e:
        print(f"無(wú)法加載或匹配模型權(quán)重: {e}")
        # 如果加載失敗,打印預(yù)期鍵和實(shí)際鍵,方便調(diào)試
        # print("\n--- 預(yù)期模型鍵 (部分) ---")
        # print(list(model.state_dict().keys())[:5])
        # print("\n--- 載入權(quán)重鍵 (部分) ---")
        # print(list(weights.keys())[:5])
        return


    # 4. 進(jìn)入交互循環(huán)
    print("\n--- 進(jìn)入交互模式 ---")
    print(f"輸入 'exit' 或 'quit' 退出。")
    print(f"輸入 'config' 查看當(dāng)前生成參數(shù)。")
    print("----------------------")


    max_tokens = 100
    temperature = 0.8
    top_k = 50


    while True:
        try:
            prompt = input(">>> 輸入提示詞: ")


            if prompt.lower() in ['exit', 'quit']:
                break


            if prompt.lower() == 'config':
                print(f"  Max Tokens: {max_tokens}, Temp: {temperature}, Top K: {top_k}")
                new_max = input("  設(shè)置 Max Tokens (回車跳過(guò)): ")
                new_temp = input("  設(shè)置 Temperature (回車跳過(guò)): ")
                new_k = input("  設(shè)置 Top K (回車跳過(guò)): ")


                if new_max: max_tokens = int(new_max)
                if new_temp: temperature = float(new_temp)
                if new_k: top_k = int(new_k)
                continue


            if not prompt.strip():
                continue


            print("生成中...")


            # 執(zhí)行生成
            output = generate_text(model, sp, prompt, max_tokens, temperature, top_k)


            print(f"--- 模型回復(fù) ---\n{output.strip()}")
            print("----------------")


        except KeyboardInterrupt:
            print("\n退出生成...")
            break
        except Exception as e:
            print(f"發(fā)生錯(cuò)誤: {e}")




if __name__ == "__main__":
    import sys


    if len(sys.argv) != 3:
        print("用法: python infer.py <spm模型路徑> <模型權(quán)重文件路徑>")
        # 示例用法 (請(qǐng)根據(jù)您的實(shí)際文件路徑修改):
        # python infer.py tokenizer.model final_model.pth
        sys.exit(1)


    sp_model_path = sys.argv[1]
    model_weights_path = sys.argv[2]


    main_infer(sp_model_path, model_weights_path)

我們看到模型大概可以預(yù)測(cè)我們輸入的下一個(gè)詞,因我們訓(xùn)練的參數(shù)和步數(shù)很低,模型輸出的亂七八糟!

本次總結(jié)

本次我們做了數(shù)據(jù)準(zhǔn)備、數(shù)據(jù)清洗、分詞器訓(xùn)練、模型訓(xùn)練、推理等,請(qǐng)根據(jù)步驟進(jìn)行執(zhí)行代碼,你便可以得到一個(gè)17M參數(shù)的小模型。后面我們?cè)偌哟髤?shù)進(jìn)行訓(xùn)練,再進(jìn)行監(jiān)督微調(diào)。

責(zé)任編輯:龐桂玉 來(lái)源: 寫代碼的中年人
相關(guān)推薦

2025-10-24 10:34:55

2020-09-24 11:46:03

Promise

2021-03-23 15:21:00

人工智能機(jī)器學(xué)習(xí)技術(shù)

2020-03-17 10:45:11

GitHub代碼開(kāi)發(fā)者

2021-08-17 11:08:08

參數(shù)M6模型

2019-04-24 15:06:37

Http服務(wù)器協(xié)議

2024-11-04 00:24:56

2021-01-25 13:45:14

模型人工智能深度學(xué)習(xí)

2024-12-23 12:52:29

2021-06-30 07:19:36

網(wǎng)絡(luò)安全

2021-08-04 05:49:40

數(shù)據(jù)庫(kù)數(shù)時(shí)序數(shù)據(jù)庫(kù)技術(shù)

2021-10-28 09:19:29

模型人工智能Facebook

2014-09-25 09:51:29

Android App個(gè)人博客

2022-11-01 14:50:00

數(shù)據(jù)計(jì)算

2016-09-14 17:48:44

2023-04-06 08:01:30

RustMutex

2019-07-21 19:45:23

GitHub代碼開(kāi)發(fā)者

2017-06-06 10:14:55

KerasTensorFlow深度學(xué)習(xí)

2021-09-26 10:47:12

預(yù)訓(xùn)練模型GPT

2024-05-10 10:01:26

自動(dòng)駕駛模型
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)

久久久久国产精品免费网站| 成人免费av在线| 欧美性色黄大片| 国产精品福利片| 8x8x最新地址| 国产高清免费在线观看| 日韩大胆成人| 亚洲国产精品黑人久久久| 欧美成人高清视频| 狠狠爱免费视频| 亚洲一区 视频| 欧美成人一二区| 91片在线免费观看| 色综合色综合久久综合频道88| 日韩综合视频在线观看| 一区精品视频| 日本少妇吞精囗交| 日本电影久久久| av午夜精品一区二区三区| 色多多国产成人永久免费网站 | 91超碰免费在线| 丝袜诱惑制服诱惑色一区在线观看| 91久久精品一区二区二区| 91亚洲精华国产精华| 无码人妻久久一区二区三区蜜桃| 狠狠色伊人亚洲综合网站l | 日本三级中国三级99人妇网站| 免费一级片在线观看| av亚洲一区二区三区| 成人sese在线| 九九热视频这里只有精品| 天天爱天天操天天干| 天堂v在线观看| 91精品亚洲| 欧美在线综合视频| 麻豆久久久av免费| 日韩av黄色片| 亚洲精品影片| 亚洲靠逼com| 91精品久久久久久久久不口人| 免费看污片的网站| 精品91久久| 91网站在线播放| 久久久久久美女| 古装做爰无遮挡三级聊斋艳谭| 少妇人妻偷人精品一区二区| 欧美国产激情| 欧美大片免费久久精品三p| 国产精品99久久久久久大便| 91精品国产高清一区二区三密臀| 黑人久久a级毛片免费观看| 亚洲精品成人少妇| 97久草视频| 国产一级片久久| 超碰在线一区| 午夜伦欧美伦电影理论片| 国产伦精品一区二区三区高清版| 国产精品变态另类虐交| 国产精品45p| 婷婷亚洲久悠悠色悠在线播放| 动漫精品视频| 欧美精品二区三区| 日本国产精品| 在线精品亚洲一区二区不卡| 欧美性天天影院| 中文字幕 欧美激情| 红桃成人av在线播放| 欧美亚洲动漫精品| 亚洲国产婷婷香蕉久久久久久99| 亚洲视频在线观看一区二区| 日韩欧美一区二区免费| 国产精品久久久久久久久影视| 特级西西www444人体聚色 | 久久77777| 成人永久免费视频| 国产91av在线| 亚洲第一综合网| 香蕉久久一区| 亚洲自拍与偷拍| 玛丽玛丽电影原版免费观看1977| 国产精品尤物视频| 66国产精品| 日韩高清免费观看| 国产3p在线播放| av白虎一区| 中文字幕精品一区二区精品绿巨人| 成人www视频在线观看| 国产精选第一页| 精品在线观看入口| 日韩一卡二卡三卡国产欧美| 日本美女爱爱视频| 天天干天天插天天操| 日本欧美在线看| 欧美精品午夜视频| 亚洲第一成人网站| 国产精品黄色片| 一区二区三区四区高清精品免费观看| 国产一区二区三区四区五区加勒比| 黄色片中文字幕| 天天做天天爱综合| 亚洲精品视频二区| 三日本三级少妇三级99| а√在线中文在线新版| 国产精品免费av| 国外成人在线视频网站| 亚洲最新av网站| 国产日韩综合| 九九久久国产精品| 制服丨自拍丨欧美丨动漫丨| 精品三级av| 欧美日韩精品一区二区| 69堂免费视频| 日本在线视频www鲁啊鲁| 久久久99免费| 国产区欧美区日韩区| 在线免费观看av片| 每日更新成人在线视频| 欧美激情一区二区久久久| 久久噜噜色综合一区二区| 九色成人国产蝌蚪91| 91精品国产综合久久久蜜臀粉嫩| 99福利在线观看| 国产一线二线在线观看| 久久久久久9999| 久久爱av电影| 天天干天天做天天操| 国产成人超碰人人澡人人澡| 国产成人av网址| 成人午夜淫片100集| 欧美精品观看| 美女久久久久久久久久久| 色一情一交一乱一区二区三区| 麻豆成人入口| 欧美精品一区二区三区久久久| 伊人国产精品视频| 青青青国产精品| 欧美婷婷六月丁香综合色| 激情六月丁香婷婷| 原纱央莉成人av片| 精品女同一区二区三区在线播放| 欧美人成在线观看| 波多野结衣中文在线| 亚洲欧美日韩国产综合| 正在播放国产精品| 日本三级视频在线播放| 国产精品免费免费| 人人妻人人澡人人爽精品欧美一区| av男人的天堂在线| 国产精品免费丝袜| 亚洲国产一区二区精品视频| 国产视频二区在线观看| 国产亚洲成av人在线观看导航| 麻豆久久久av免费| 国产高清视频免费最新在线| 久久精品视频网| 日韩高清国产精品| 天堂а√在线官网| 玉米视频成人免费看| 国内自拍中文字幕| 九色porny丨入口在线| 福利二区91精品bt7086| 中文字幕12页| 日韩黄色网络| 久久久久999| 日本视频在线观看免费| 韩国三级在线一区| 欧美日韩国产精品一区二区| 免费黄色在线| 日韩欧美中文在线| 农村末发育av片一区二区 | 日本日本精品二区免费| 亚洲夜夜综合| 欧美日韩精品综合在线| 久久一区二区电影| 亚洲国产精品成人| 日本乱人伦a精品| av网站免费大全| 国产区在线观看成人精品| 伊人再见免费在线观看高清版 | 国产日韩亚洲欧美| 亚洲 另类 春色 国产| 亚洲色图制服丝袜| 日韩在线第三页| 日韩在线黄色| 欧美激情综合亚洲一二区| 91久久国语露脸精品国产高跟| xf在线a精品一区二区视频网站| 亚洲成人动漫在线| 欧美激情三区| 亚洲人成绝费网站色www| 久久中文字幕在线观看| 国产综合成人久久大片91| 日本中文不卡| 亚洲最新无码中文字幕久久| 精品国精品国产| 欧美人禽zoz0强交| 久草热8精品视频在线观看| 欧美性大战久久久久| 国产精品yjizz视频网| 欧美成人一区二区三区| 五月天激情丁香| 久久超级碰视频| 亚洲丰满在线| 成人做爰免费视频免费看| 精品丝袜一区二区三区| 丰满少妇乱子伦精品看片| 粉嫩绯色av一区二区在线观看 | 精品国产99久久久久久| 欧美日韩在线电影| 成人无码av片在线观看| 三级久久三级久久久| 欧美精品免费观看二区| 色资源二区在线视频| 日韩av影视综合网| www.欧美色| 国产片一区二区三区| 日韩av一二三四| 精品国产精品| 国产精品精品视频| 91在线高清| 欧美精品久久一区二区三区| 欧美性猛交xxxx乱大交少妇| 免费成人在线视频观看| 伊人久久大香线蕉成人综合网| 国产精品久久久久久妇女| 色噜噜狠狠狠综合曰曰曰88av| 伊人精品一区二区三区| 一区在线观看视频| 久久aaaa片一区二区| 国产在线成人| 鲁片一区二区三区| 欧洲亚洲精品| 欧美国产日韩一区| 视频二区在线| 欧美日韩一区二区三区在线| 视频国产一区二区| 成人免费电影视频| 欧美日韩亚洲一二三| 97欧美在线视频| 国产精品对白刺激久久久| 在线日韩影院| 久久中文精品视频| 成人黄色短视频在线观看| 国产人妖一区二区三区| 一区二区久久久| 亚洲一区二区乱码| 精品影视av免费| 福利视频一二区| 欧美丝袜激情| 风间由美久久久| 日韩精品一区二区三区av| 大量国产精品视频| 欧美成人综合在线| 制服丝袜亚洲色图| 成人精品免费在线观看| 国产精品久久久久精k8| 91传媒理伦片在线观看| 日本va欧美va精品发布| 妞干网在线播放| 成人激情诱惑| 国产一区二区三区黄| 国产精品久久久久77777丨| 久久久久久久久国产| 91se在线| 亚洲国产欧美一区二区丝袜黑人 | 中文字幕在线观看免费高清 | 成人妇女淫片aaaa视频| 97人人在线视频| 最近2019中文字幕一页二页| 欧美 日韩 综合| 欧美精品三级日韩久久| 无码日韩精品一区二区| 一区二区三区日本| 免费黄在线观看| 99精品视频一区二区三区| 久久久精品视频国产| 日韩成人精品视频| 国产极品尤物在线| 欧美69wwwcom| 亚洲不卡中文字幕| 91精品导航| 成人久久久久爱| 欧美日韩精品一区二区三区视频| 国内精品久久久久久久| 午夜福利一区二区三区| 6080yy午夜一二三区久久| 波多野结衣影片| 疯狂欧美牲乱大交777| 久久亚洲AV无码| 亚洲老司机在线| 成人在线观看高清| 最新久久zyz资源站| 久久丫精品忘忧草西安产品| 99re这里只有精品首页| 精人妻一区二区三区| 国产风韵犹存在线视精品| 在线视频日韩欧美| 精品一区二区三区久久| jizz18女人| 久久精品国产色蜜蜜麻豆| 亚洲狼人综合干| 老牛国产精品一区的观看方式| av之家在线观看| 国产亚洲毛片在线| 成人黄色av片| 一本色道久久| jizzjizz国产精品喷水| 国产精品乱看| 亚洲国产精品久久久久爰色欲| 亚洲欧美视频| 激情婷婷综合网| 日韩电影在线观看网站| 久久久久久三级| 开心九九激情九九欧美日韩精美视频电影| 热久久精品国产| 欧美96一区二区免费视频| 日韩一级免费片| 国产尤物一区二区| 潘金莲一级淫片aaaaa| 国产成人免费在线观看不卡| 中文字幕一区二区三区人妻在线视频 | 日本精品一区在线观看| 久久久久99| 999精彩视频| 国产精选一区二区三区| 亚洲少妇一区二区| 99精品欧美一区二区三区综合在线| 国产精品无码永久免费不卡| 久久精品人人做人人爽人人| a级黄色免费视频| 亚洲激情在线播放| 日韩精品久久久久久久| 日本高清不卡aⅴ免费网站| 中文字幕人妻色偷偷久久| 欧美一卡在线观看| 日本久久一级片| 亚洲天堂av综合网| 成人高清免费在线| 4438全国成人免费| 欧美一级在线| 国产精品伊人日日| 欧美精品久久久久久| 国产盗摄视频在线观看| 国产视频久久| 欧美成人手机在线视频| 波多野结衣在线一区| 性欧美一区二区| 亚洲一级二级三级在线免费观看| 黄色片视频网站| 欧美日韩国产经典色站一区二区三区| 国产av一区二区三区| 亚洲免费视频一区二区| 黄色在线免费看| 欧美亚洲日本黄色| 不卡的国产精品| 欧美高清性xxxxhdvideosex| 亚欧美无遮挡hd高清在线视频| 国产原创中文在线观看| 久久国产精品99精品国产| 日韩精品视频一区二区| 国产精品色哟哟| 天堂网一区二区三区| 欧美一区二区观看视频| www.在线视频.com| 97视频人免费观看| 国产精品亚洲综合在线观看| 免费在线成人av| 欧美日韩mv| 九九九九九国产| 国产人伦精品一区二区| 国产第100页| 91精品国产综合久久久久久久久久 | 国产精品123区| 一本在线免费视频| 懂色av中文一区二区三区天美| www.热久久| www.日本久久久久com.| 日本综合字幕| 国产有色视频色综合| 91精品国产91久久久久久密臀| 国产成人久久777777| 成人深夜在线观看| 欧美黑人性猛交xxx| 欧美视频你懂的| 国产黄在线观看| 国产va免费精品高清在线| 林ゆな中文字幕一区二区| 黄色网络在线观看| 精品无人码麻豆乱码1区2区| 日本少妇xxxxx| 色综合一个色综合| 日韩有码电影| 97成人超碰免| 美国成人xxx| 免费毛片网站在线观看| 国产91丝袜在线播放| 丰满少妇高潮久久三区| 91精品久久久久久久91蜜桃| 欧美一区二区三区在线观看免费| 国产精品久久久久77777|