Navigation
  • Share
  • Breadcrumb

    Deep Learning 實戰筆記 Budding

    【PyTorch 筆記】別再被 Dimension Error 搞瘋了,一次搞懂模型吃什麼格式

    Aionyx

    寫 PyTorch 最讓人崩潰的瞬間是什麼?絕對不是你的模型太難寫,而是當你興高采烈按下 Run,結果 Terminal 直接噴出一大串紅字 RuntimeError: Expected 4-dimensional input...

    當下真的只想摔鍵盤:「我就只有一張圖,為什麼還要我給 4 個維度啦?」

    其實 PyTorch 的脾氣很硬,但摸順了就很簡單。這篇簡單整理一下它的「眉角」,讓你下次 Debug 可以少死幾個腦細胞。

    1. 它是個有強迫症的傢伙:請給它 (N, C, H, W)

    不管你今天是做一般圖片分類(2D CNN),還是要跑比較高大上的 3D 醫學影像,PyTorch 卷積層對於「輸入長怎樣」是有嚴格規定的。

    這口訣請刻在腦海裡:NCHW

    • N (Batch Size):最常被遺忘的傢伙。PyTorch 強制第一維一定要是「批次大小」。 哪怕你現在只是想測試一張貓咪的照片,你也不能只給它 (3, 224, 224)。 你必須騙它說:「嘿,這是一整批照片,雖然裡面只有一張。」變成 (1, 3, 224, 224)

    • C (Channel):通道數(RGB 就是 3,灰階就是 1)。

    • H, W:高跟寬。

    小提醒: 如果你是做 3D 影像(像是 CT 斷層掃描),就是中間多插一個 D (Depth),變成 (N, C, D, H, W)

    2. 常見的「形狀不對」與解決辦法

    如果你是從 OpenCV 或是 PIL 讀圖進來,形狀通常都不會剛好符合 PyTorch 的胃口。這裡有幾個最常見的坑,以及怎麼把它們「捏」成對的形狀:

    情況一:黑白圖/灰階圖 (只有高跟寬)

    讀進來是 (H, W),什麼都沒有。 解法:包兩層皮。 你需要先幫它加上 Channel 軸,再加 Batch 軸。

    # 假設 img 是 (28, 28) 的 Tensor
    img = ... 
    
    # img shape: (28, 28)
    input_tensor = img.unsqueeze(0).unsqueeze(0) 
    # 變成 (1, 1, 28, 28) -> 搞定!
    

    情況二:一般的彩色圖 (缺 Batch)

    如果你已經轉成 Tensor 了,通常是 (C, H, W)解法:加一層皮在最前面。

    # 假設 img 是 (3, 224, 224) 的 Tensor
    img = ...
    
    # img shape: (3, 224, 224)
    input_tensor = img.unsqueeze(0)
    # 變成 (1, 3, 224, 224) -> 完美。
    

    情況三:大魔王 OpenCV 格式 (H, W, C)

    這是最多人踩的雷!OpenCV 讀進來的順序跟 PyTorch 是反的(Channel 在最後面),直接丟進去絕對報錯。 解法:先換位子,再加皮。

    # 假設 img 是 (224, 224, 3) 的 Tensor
    img = ...
    
    # img shape: (224, 224, 3)
    # 1. 先用 permute 把 Channel 搬到前面 (2, 0, 1)
    # 2. 再用 unsqueeze 加 Batch
    input_tensor = img.permute(2, 0, 1).unsqueeze(0)
    

    3. 真的不知道錯哪? Print 就對了

    說真的,我看過太多人(包含我自己)在想為什麼模型跑不動,盯著螢幕發呆。其實最快的方法就是在丟進模型前,直接把形狀印出來看。

    # 假設 x 是你的輸入資料
    x = ...
    
    # 這是你最好的朋友
    # print(x.shape) # 註解掉以避免 IDE 警告
    

    快速判斷指南:

    • 看到 [128, 128] 這種只有兩個數字的 ❌ -> 絕對掛掉,缺太多東西了。
    • 看到 [3, 128, 128] ❌ -> 雖然有 Channel,但少了 Batch,訓練一定會出事。
    • 看到 [1, 3, 128, 128] ✅ -> 舒服,這才是 PyTorch 要的。