인공지능 공부/pytorch

(2022.07.04) Rcnn Fruit object detection (pytorch)

앨런튜링_ 2022. 7. 4. 20:09
import pandas as pd
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches 
import cv2
import torch
import torchvision
from torchvision import transforms, datasets
from torchvision.models.detection import *
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import Dataset

# from engine import train_one_epoch, evaluate
import utils
# import transforms as T


import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from xml.etree import ElementTree as et

import warnings
warnings.filterwarnings('ignore')


# defining the files directory and testing directory
files_dir = './practic_torch/data/fruit-images-for-object-detection/train_zip/train'
test_dir = './practic_torch/data/fruit-images-for-object-detection/test_zip/test'


# Faster R-CNN

class FruitImageDataset(Dataset):
    def __init__(self, files_dir, width, height, transforms=None):
        self.files_dir = files_dir
        self.width = width
        self.height = height
        self.tranforms = transforms

        self.classes_ = ['_', 'apple', 'orange', 'banana'] # Defining classes = background
        
        self.images = [img for img in sorted(os.listdir(files_dir)) if img[-4:]=='.jpg']

    def __len__(self):
        return len(self.images)


    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.files_dir, img_name)

        # Reading the image

        img = cv2.imread(img_path)

        wt = img.shape[1]
        ht = img.shape[0]

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = cv2.resize(img, (self.width, self.height), cv2.INTER_AREA)
        img /= 255.0

        annot_name = img_name[:-4] + '.xml'
        anoot_path = os.path.join(self.files_dir, annot_name)

        # Boxes to store the coordinate points of the bboxes
        boxes, labels = [], []

        tree = et.parse(annot_path)
        root = tree.getroot()
        

        for member in root.findall('object'):
            labels.append(self.calsses_.index(member.fine('name').text))

            xmin = float(member.find('bndbox').find('xmin').text)
            xmax = float(member.find('bndbox').find('xmax').text)
            ymin = float(member.find('bndbox').find('ymin').text)
            ymax = float(member.find('bndbox').find('ymax').text)

            boxes.append([x_min, y_min, x_max, y_max])


        boxes = torch.as_tensor(boxes, dtype = torch.float32)
        area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])  

        iscrowd = torch.zeros((boxes.shape[0],), dtype=torch.int64)

        labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        
        target = {'boxes':boxes, 'area':area, 'labels':labels,
                    'iscrowd': iscrowd, 'image_id': image_id}

        if self.transforms:
            sample = self.transforms(image = img, bboxes = target['boxes'], labels = labels)
            img = sample['image']
            target['boxes'] = torch.Tensor(sample['bboxes'])

        return img, target

dataset = FruitImageDataset(files_dir, 224, 224)

# Model Development
def get_model(num_classes, modelName):
    if modelName == 'fastcnn':
        model = fasterrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        return model

    elif modelName == 'maskcnn':
        model = maskrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        return model       

# Data Augmentaion
def get_transform(train=True):
    if train:
        return A.Compose([
            A.HorizontalFlip(0.5),
            ToTensorV2(p=0.1),     
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
    else:
        return A.Compose([
            ToTensorV2(p=0.1),     
        ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))


test_split = 0.2

dataset_train = FruitImageDataset(files_dir, 480, 480, transforms=get_transform(train=True))
dataset_test = FruitImageDataset(files_dir, 480, 480, transforms=get_transform(train=False))

torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()


tsize = int(len(dataset) * test_split) 
dataset_train = torch.utils.data.Subset(dataset_train, indices[:-tsize])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-tsize:])

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=8, shuffle=True,
                                              num_workers=4, collate_fn=utils.collate_fn)  # Imported form helper library
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=8, shuffle=True,
                                              num_workers=4, collate_fn=utils.collate_fn)



device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 4 
num_epochs = 9

def start_training(modelName, num_epochs, num_classes):
    model = get_model(num_classes, modelName)
    model.to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.005)
    

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    
    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, dataloader_train, device, epoch, print_freq=5)
        lr_scheduler.step()
        evaluate(model, dataloader_test, device=device)
    return model


fast_rcnn = start_training('fastcnn', num_epochs, num_classes)