
gpt推荐的照片分类模型中,从ResNet18入手

在nas上建立dataset文件夹,将自己的照片分类,train中放用于训练的照片,val放用于验证结果的照片。
由于作者的训练集都存在nas上,所以为了节省本地空间,直接在wsl上挂载nas目录,通过局域网调取照片进行训练。

smbclient -L //192.168.31.xxx/smb/ -U
smbcredentials文件中写用户名密码,避免特殊字符导致的问题
sudo mount -t cifs //192.168.31.xxx/ /mnt/smb_share -o credentials=/home/sry/Projects/smbcredentials,vers=3.0
dmesg | tail
查看挂载错误
如果网络有问题,可以手动下载模型,然后放到wsl的路径中
https://download.pytorch.org/models/resnet18-f37072fd.pth

这样就开始第一轮训练了。

nvidia-smi来查看显卡的负载,不过可能由于数据集在nas上,显卡负载很低。

watch -n 1 nvidia-smi
可以这样来实时观察显卡负载,可以看到使用了和功耗在波动,有时能够达到峰值。




在nas中也可以看到上传一直是满载的,但是我的内网只有千兆,虽然有固态作为高速缓存,网络还是成了瓶颈。如果有条件的话建议把照片数据集放到训练电脑本地。




脚本代码
import torch
from torchvision import transforms, models
from PIL import Image
import os
# 设定类别名(和训练时的顺序一致)
class_names = ['cat', 'dog', 'food', 'landscape', 'person', 'screenshot']
# 使用 GPU 或 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载 ResNet18 结构
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names)) # 替换分类头
model.load_state_dict(torch.load("checkpoints/resnet18_custom.pth", map_location=device))
model.to(device)
model.eval()
# 图像预处理,必须和训练时保持一致
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Imagenet 预训练均值
std=[0.229, 0.224, 0.225])
])
# 输入你想预测的图片路径
image_path = "test_images/test1.jpg" # <<< 修改为你的图片路径
assert os.path.exists(image_path), f"图片不存在: {image_path}"
# 加载图片
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device) # 增加 batch 维度
# 推理
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
confidence = torch.nn.functional.softmax(outputs, dim=1)[0][predicted].item()
print(f"预测类别:{class_names[predicted]},置信度:{confidence:.2f}")

测试人物,狗,美食和风景的照片类型成功,测试老虎照片时,由于没有进行训练,被识别成风景照,而且置信度较低。
脚本代码
import torch
from torchvision import transforms, models
from PIL import Image
import os
import sys
# 类别名(保持和训练一致)
class_names = ['cat', 'dog', 'food', 'landscape', 'person', 'screenshot']
# 使用 GPU 或 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 加载模型
model = models.resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(torch.load("checkpoints/resnet18_custom.pth", map_location=device))
model.to(device)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 获取图片路径(支持命令行参数或运行时输入)
if len(sys.argv) > 1:
image_path = sys.argv[1]
else:
image_path = input("请输入图片路径:").strip()
if not os.path.exists(image_path):
print(f"❌ 图片不存在:{image_path}")
sys.exit(1)
# 加载并预处理图片
image = Image.open(image_path).convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
# 推理
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
confidence = torch.nn.functional.softmax(outputs, dim=1)[0][predicted].item()
print(f"✅ 预测类别:{class_names[predicted]},置信度:{confidence:.2f}")
发布于2025/07/13