def train(model, optimizer, train_loader, val_loader, scheduler, device):
model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
best_score = 0
best_model = None
os.makedirs("./output", exist_ok=True)
best_model_save_path = "./output/best_model.pth"
plt.ion()
print("Start training")
x_arr=[]
rec_valid = [[],[]]
fig, ax = plt.subplots(figsize=(8, 4))
start_time = time.time()
for epoch in range(1, CFG['EPOCHS'] + 1):
model.train()
train_loss = []
model_save_path = f"./output/model_epoch{epoch}.pth"
for imgs, labels in tqdm(iter(train_loader), desc=f"Epoch {epoch}"):
imgs = imgs.float().to(device)
labels = labels.to(device).long()
optimizer.zero_grad()
output = model(imgs)
loss = criterion(output, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
_val_loss, _val_score = validation(model, criterion, val_loader, device)
_train_loss = np.mean(train_loss)
# epoch 마다 model.pth 저장
print(f'Epoch [{epoch}], Train Loss: {_train_loss:.5f}, Val Loss: {_val_loss:.5f}, Val Macro F1: {_val_score:.5f}')
torch.save(model.state_dict(), model_save_path)
rec_valid[0].append(_train_loss)
rec_valid[1].append(_val_loss)
if scheduler is not None:
scheduler.step(_val_score)
if best_score < _val_score:
best_score = _val_score
best_model = model
# 모델 가중치 저장
torch.save(model.state_dict(), best_model_save_path)
print(f"Best model saved (epoch {epoch}, F1={_val_score:.4f}) → {best_model_save_path}")
to_numpy_valid = np.array(rec_valid)
# 길이 자동 추정
x_arr = np.arange(len(rec_valid[0]))
# 실시간 그래프 업데이트
ax.clear()
# 손실 그래프
ax.plot(x_arr, to_numpy_valid[0], '-', label='Train val', marker='o')
ax.plot(x_arr, to_numpy_valid[1], '--', label='Valid val', marker='o')
ax.legend(fontsize=15)
ax.set_title('Loss')
ax.set_xlabel('Epoch', size=15)
ax.set_ylabel('Loss', size=15)
# 그래프 갱신
plt.draw() # 그래프 업데이트
plt.pause(0.1) # 0.5초 대기 (실제 학습 환경 시뮬레이션)
plt.ioff() # 인터랙티브 모드 종료
# 그래프를 파일로 저장 (PNG 형식)
plt.savefig("graph.png")
plt.show()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
return best_model
첫댓글 best_model_save_path 뭔가요?
클래스별로 정확도 계산이 따로 가능한가요?
50회정도 돌려서 그래프 볼것
best_model_save_path는 best_model을 .pth로 저장해줄때 사용되는 경로입니다. best_model.pth 저장하는 코드는 기존 코드에 작성되어 있는 코드입니다.
기능추가사항
1. Train macro F1 출력할것
2. 그래프에 훈련/검증 정확도도 추가할것
3. GPU 사양 정리해줄 것
4. Train data와 validation data를 어떻게 분할한건지 코드에서 설명할것