인공지능 공부/pytorch
(2022.07.04) Rcnn Fruit object detection (pytorch)
앨런튜링_
2022. 7. 4. 20:09
import pandas as pd
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
import torch
import torchvision
from torchvision import transforms, datasets
from torchvision.models.detection import *
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset
# from engine import train_one_epoch, evaluate
import utils
# import transforms as T
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from xml.etree import ElementTree as et
import warnings
warnings.filterwarnings('ignore')
# defining the files directory and testing directory
files_dir = './practic_torch/data/fruit-images-for-object-detection/train_zip/train'
test_dir = './practic_torch/data/fruit-images-for-object-detection/test_zip/test'
# Faster R-CNN
class FruitImageDataset(Dataset):
def __init__(self, files_dir, width, height, transforms=None):
self.files_dir = files_dir
self.width = width
self.height = height
self.tranforms = transforms
self.classes_ = ['_', 'apple', 'orange', 'banana'] # Defining classes = background
self.images = [img for img in sorted(os.listdir(files_dir)) if img[-4:]=='.jpg']
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.files_dir, img_name)
# Reading the image
img = cv2.imread(img_path)
wt = img.shape[1]
ht = img.shape[0]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
img = cv2.resize(img, (self.width, self.height), cv2.INTER_AREA)
img /= 255.0
annot_name = img_name[:-4] + '.xml'
anoot_path = os.path.join(self.files_dir, annot_name)
# Boxes to store the coordinate points of the bboxes
boxes, labels = [], []
tree = et.parse(annot_path)
root = tree.getroot()
for member in root.findall('object'):
labels.append(self.calsses_.index(member.fine('name').text))
xmin = float(member.find('bndbox').find('xmin').text)
xmax = float(member.find('bndbox').find('xmax').text)
ymin = float(member.find('bndbox').find('ymin').text)
ymax = float(member.find('bndbox').find('ymax').text)
boxes.append([x_min, y_min, x_max, y_max])
boxes = torch.as_tensor(boxes, dtype = torch.float32)
area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)
labels = torch.as_tensor(labels, dtype=torch.int64)
image_id = torch.tensor([idx])
target = {'boxes':boxes, 'area':area, 'labels':labels,
'iscrowd': iscrowd, 'image_id': image_id}
if self.transforms:
sample = self.transforms(image = img, bboxes = target['boxes'], labels = labels)
img = sample['image']
target['boxes'] = torch.Tensor(sample['bboxes'])
return img, target
dataset = FruitImageDataset(files_dir, 224, 224)
# Model Development
def get_model(num_classes, modelName):
if modelName == 'fastcnn':
model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
elif modelName == 'maskcnn':
model = maskrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# Data Augmentaion
def get_transform(train=True):
if train:
return A.Compose([
A.HorizontalFlip(0.5),
ToTensorV2(p=0.1),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
else:
return A.Compose([
ToTensorV2(p=0.1),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
test_split = 0.2
dataset_train = FruitImageDataset(files_dir, 480, 480, transforms=get_transform(train=True))
dataset_test = FruitImageDataset(files_dir, 480, 480, transforms=get_transform(train=False))
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
tsize = int(len(dataset) * test_split)
dataset_train = torch.utils.data.Subset(dataset_train, indices[:-tsize])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-tsize:])
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=8, shuffle=True,
num_workers=4, collate_fn=utils.collate_fn) # Imported form helper library
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=8, shuffle=True,
num_workers=4, collate_fn=utils.collate_fn)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
num_classes = 4
num_epochs = 9
def start_training(modelName, num_epochs, num_classes):
model = get_model(num_classes, modelName)
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
for epoch in range(num_epochs):
train_one_epoch(model, optimizer, dataloader_train, device, epoch, print_freq=5)
lr_scheduler.step()
evaluate(model, dataloader_test, device=device)
return model
fast_rcnn = start_training('fastcnn', num_epochs, num_classes)