1. pth to ONNX 변환
TensorRT로 최적화하기 위해서 pth파일을 ONNX로 변환해주기
def convert_to_onnx(args):
device = torch.device('cpu')
print(f"Using device: {device}")
model = load_deeplabv3_mobilenet(
args.weight_path,
args.num_classes,
device,
output_stride=args.output_stride,
separable_conv=args.separable_conv
)
# ONNX로 변환 전에 반드시 eval 모드로 설정
model.eval()
model.to(device)
print()
print(f"Train Mode? {model.training}") # False여야 함
print(f"Instance of nn.Module? {isinstance(model, nn.Module)}")
print()
dummy_input = torch.randn(1, 3, 1024, 1024, device=device)
dynamic_axes = {
'input': {2: 'height', 3: 'width'},
'output': {2: 'height', 3: 'width'}
}
output_onnx_path = 'onnx_converted_model.onnx'
torch.onnx.export(
model.module, # DataParallel 래핑된 경우
dummy_input,
output_onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic_axes
)
print(f"ONNX convert finished: {output_onnx_path}")
print()
# ===== 검증 =====
ONNX_MODEL_PATH = output_onnx_path
INPUT_H = 1024
INPUT_W = 1024
device = torch.device("cpu")
# model.module 사용 (ONNX 변환 시와 동일)
model.module.eval()
model.module.to(device)
# ONNX Runtime 세션
so = ort.SessionOptions()
ort_session = ort.InferenceSession(
ONNX_MODEL_PATH,
providers=['CPUExecutionProvider']
)
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name
# 동일한 입력
np.random.seed(42)
dummy_input_np = np.random.randn(1, 3, INPUT_H, INPUT_W).astype(np.float32)
dummy_input_torch = torch.from_numpy(dummy_input_np).to(device)
# PyTorch 추론
with torch.no_grad():
pytorch_output = model.module(dummy_input_torch)
pytorch_output_np = pytorch_output.cpu().numpy()
# ONNX Runtime 추론
ort_inputs = {input_name: dummy_input_np}
ort_output = ort_session.run([output_name], ort_inputs)[0]
# 출력 비교
print(f"PyTorch output shape: {pytorch_output_np.shape}")
print(f"ONNX output shape: {ort_output.shape}")
print(f"PyTorch output range: [{pytorch_output_np.min():.6f}, {pytorch_output_np.max():.6f}]")
print(f"ONNX output range: [{ort_output.min():.6f}, {ort_output.max():.6f}]")
try:
np.testing.assert_allclose(
pytorch_output_np,
ort_output,
rtol=1e-02, # 1% 상대 오차
atol=1e-04 # 절대 오차
)
print("✅ 검증 성공: PyTorch와 ONNX의 출력이 거의 일치합니다.")
except AssertixxonError as e:
print("❌ 검증 실패: PyTorch와 ONNX의 출력이 다릅니다.")
print(e)
2. TensorRT를 사용하여 모델 최적화
TensorRT 설치 방법
1. pip install tensorrt
학교 방화벽에 의한 문제인것 같음
2. 공식 사이트에서 설치
tensorrt 설치 중에 버전 충돌 등의 문제가 생겨 해결중입니다.