From c94d26d618855df195e8654b1c04e9e334b6e43d Mon Sep 17 00:00:00 2001 From: ailab Date: Wed, 15 May 2024 16:40:10 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20README.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 139 +----------------------------------------------------- 1 file changed, 1 insertion(+), 138 deletions(-) diff --git a/README.md b/README.md index c6d8fc3..d918159 100644 --- a/README.md +++ b/README.md @@ -21,144 +21,7 @@ The model is trained on a dataset of cat and dog images. The training process in Here is a simplified version of the training code: -```python -import os -import torch -import torch.nn as nn -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset -from torchvision import models, transforms -from PIL import Image -from safetensors.torch import save_file - -class CatDogDataset(Dataset): - def __init__(self, root_dir, transform=None): - self.root_dir = root_dir - self.transform = transform - self.image_paths = [] - self.labels = [] - - for filename in os.listdir(root_dir): - if 'cat' in filename: - self.image_paths.append(os.path.join(root_dir, filename)) - self.labels.append(0) # cat - elif 'dog' in filename: - self.image_paths.append(os.path.join(root_dir, filename)) - self.labels.append(1) # dog - - def __len__(self): - return len(self.image_paths) - - def __getitem__(self, idx): - img_path = self.image_paths[idx] - image = Image.open(img_path).convert('RGB') - label = self.labels[idx] - - if self.transform: - image = self.transform(image) - - return image, label - -# Data preprocessing -data_transforms = { - 'train': transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), - 'val': transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), -} - -data_dir = 'dog-cat' -image_datasets = {x: CatDogDataset(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} -dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']} -dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} -class_names = ['cat', 'dog'] - -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -# Load and modify ResNet-50 -model_ft = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) -num_ftrs = model_ft.fc.in_features -model_ft.fc = nn.Linear(num_ftrs, len(class_names)) - -model_ft = model_ft.to(device) - -criterion = nn.CrossEntropyLoss() - -# Optimizer -optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) - -# Learning rate scheduler -exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) - -# Training function -def train_model(model, criterion, optimizer, scheduler, num_epochs=25): - best_model_wts = model.state_dict() - best_acc = 0.0 - - for epoch in range(num_epochs): - print(f'Epoch {epoch}/{num_epochs - 1}') - print('-' * 10) - - for phase in ['train', 'val']: - if phase == 'train': - model.train() - else: - model.eval() - - running_loss = 0.0 - running_corrects = 0 - - for inputs, labels in dataloaders[phase]: - inputs = inputs.to(device) - labels = labels.to(device) - - optimizer.zero_grad() - - with torch.set_grad_enabled(phase == 'train'): - outputs = model(inputs) - _, preds = torch.max(outputs, 1) - loss = criterion(outputs, labels) - - if phase == 'train': - loss.backward() - optimizer.step() - - running_loss += loss.item() * inputs.size(0) - running_corrects += torch.sum(preds == labels.data) - - if phase == 'train': - scheduler.step() - - epoch_loss = running_loss / dataset_sizes[phase] - epoch_acc = running_corrects.double() / dataset_sizes[phase] - - print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') - - if phase == 'val' and epoch_acc > best_acc: - best_acc = epoch_acc - best_model_wts = model.state_dict() - - print() - - print(f'Best val Acc: {best_acc:4f}') - model.load_state_dict(best_model_wts) - return model - -# Train and evaluate the model -model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) - -# Save the model -torch.save(model_ft.state_dict(), 'model_cat_dog_classifier.pt') -save_file(model_ft.state_dict(), 'model_cat_dog_classifier.safetensors') -``` +[train.py](./train.py) ## Inference