import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.profiler
import torch.utils.data
import torchvision.models
import torchvision.transforms as T
from torchvision.datasets.vision import VisionDataset
import numpy as np
from PIL import Image
# 示例模型
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
self.conv2 = nn.Conv2d(8, 12, 3, padding=1)
self.conv3 = nn.Conv2d(12, 16, 3, padding=1)
self.conv4 = nn.Conv2d(16, 20, 3, padding=1)
self.conv5 = nn.Conv2d(20, 24, 3, padding=1)
self.conv6 = nn.Conv2d(24, 28, 3, padding=1)
self.conv7 = nn.Conv2d(28, 32, 3, padding=1)
self.conv8 = nn.Conv2d(32, 10, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = self.pool(F.relu(self.conv4(x)))
x = self.pool(F.relu(self.conv5(x)))
x = self.pool(F.relu(self.conv6(x)))
x = self.pool(F.relu(self.conv7(x)))
x = self.pool(F.relu(self.conv8(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
return x
def log_softmax(x):
return x - x.exp().sum(-1).log().unsqueeze(-1)
def weighted_nll(pred, target, weight):
assert target.max() < 10
nll = -pred[range(target.shape[0]), target]
nll = nll * weight[target]
nll = nll / weight[target].sum()
sum_nll = nll.sum()
return sum_nll
# 自定义损失定义
class CrossEntropyLoss(nn.Module):
def forward(self, input, target):
pred = log_softmax(input)
loss = weighted_nll(pred, target, torch.Tensor([0.1]*10).cuda())
return loss
# 具有模仿 CIFAR10 属性的随机图像的数据集
class FakeCIFAR(VisionDataset):
def __init__(self, transform):
super().__init__(root=None, transform=transform)
self.data = np.random.randint(low=0,high=256,size=(10000,32,32,3),dtype=np.uint8)
self.targets = np.random.randint(low=0,high=10,size=(10000),dtype=np.uint8).tolist()
def __getitem__(self, index):
img, target = self.data[index], self.targets[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self) -> int:
return len(self.data)
transform = T.Compose(
[T.Resize(256),
T.PILToTensor()])
train_set = FakeCIFAR(transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024,
shuffle=True, num_workers=8, pin_memory=True)
device = torch.device("cuda:0")
model = Net().cuda(device)
criterion = CrossEntropyLoss().cuda(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model.train()
# 用分析器对象包裹的训练循环
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=4, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(’./log/example’),
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, data in enumerate(train_loader):
inputs = data[0].to(device=device, non_blocking=True)
labels = data[1].to(device=device, non_blocking=True)
inputs = (inputs.to(torch.float32) / 255. - 0.5) / 0.5
if step >= (1 + 4 + 3) * 1:
break
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
prof.step()
基线模型的性能概述
基线模型的跟踪视图
# 自定义损失定义
class CrossEntropyLoss(nn.Module):
def forward(self, input, target):
with torch.profiler.record_function('log_softmax'):
pred = log_softmax(input)
with torch.profiler.record_function('define_weights'):
weights = torch.Tensor([0.1]*10).cuda()
with torch.profiler.record_function('weighted_nll'):
loss = weighted_nll(pred, target, torch.Tensor([0.1]*10).cuda())
return loss
跟踪视图中看到的权重定义的性能问题
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*10).cuda()
def forward(self, input, target):
with torch.profiler.record_function('log_softmax'):
pred = log_softmax(input)
with torch.profiler.record_function('weighted_nll'):
loss = weighted_nll(pred, target, self.weight)
return loss
优化 1 后的性能概览
优化 1 后的跟踪视图
def weighted_nll(pred, target, weight):
with torch.profiler.record_function('assert'):
assert target.max() < 10
with torch.profiler.record_function('range'):
r = range(target.shape[0])
with torch.profiler.record_function('index'):
nll = -pred[r, target]
with torch.profiler.record_function('nll_calc'):
nll = nll * weight[target]
nll = nll/ weight[target].sum()
sum_nll = nll.sum()
return sum_nll
优化 2 后的性能概述
优化 2 后的跟踪视图
def weighted_nll(pred, target, weight):
with torch.profiler.record_function('range'):
r = torch.arange(target.shape[0], device="cuda:0")
with torch.profiler.record_function('index'):
nll = -pred[r, target]
with torch.profiler.record_function('nll_calc'):
nll = nll * weight[target]
nll = nll/ weight[target].sum()
sum_nll = nll.sum()
return sum_nll
优化 3 后的性能概述
优化 3 后的跟踪视图
Weighted_nll 函数的跟踪视图
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*10).cuda()
def forward(self, input, target):
pred = log_softmax(input)
nll = torch.nn.NLLLoss(self.weight)
loss = nll(pred, target)
return loss
优化 4 后的性能概览
优化 4 后的跟踪视图
class CrossEntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.Tensor([0.1]*10).cuda()
self.nll = torch.nn.NLLLoss(self.weight)
def forward(self, input, target):
pred = log_softmax(input)
loss = self.nll(pred, target)
return loss
criterion = torch.nn.CrossEntropyLoss().cuda(device)
criterion = torch.compile(torch.nn.CrossEntropyLoss().cuda(device))
优化实验结果