commit e29c2e280455352ec167d56a936403841612a42b Author: ailab Date: Wed May 15 16:34:04 2024 +0800 add models diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2bf9985 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,28 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bin.* filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zstandard filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +model.safetensors filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000..c6d8fc3 --- /dev/null +++ b/README.md @@ -0,0 +1,195 @@ +# Cat-Dog Classification Model + +## Introduction + +This repository contains a Cat-Dog classification model based on the ResNet-50 architecture. The model is trained to distinguish between images of cats and dogs. + +## ResNet Model + +ResNet-50 is a deep convolutional neural network with 50 layers. It is designed to overcome the vanishing gradient problem, which is common in very deep networks, by using skip connections or residuals. This allows the network to be significantly deeper while still being easy to optimize. + +## Training + +The model is trained on a dataset of cat and dog images. The training process involves the following steps: + +1. **Data Preprocessing**: Images are resized, cropped, and normalized. +2. **Model Initialization**: A pre-trained ResNet-50 model is loaded and the final fully connected layer is adjusted to output two classes (cat and dog). +3. **Training Loop**: The model is trained using a standard training loop with stochastic gradient descent (SGD) and a learning rate scheduler. +4. **Model Evaluation**: The best model is selected based on validation accuracy and saved for inference. + +### Training Code + +Here is a simplified version of the training code: + +```python +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from torchvision import models, transforms +from PIL import Image +from safetensors.torch import save_file + +class CatDogDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.transform = transform + self.image_paths = [] + self.labels = [] + + for filename in os.listdir(root_dir): + if 'cat' in filename: + self.image_paths.append(os.path.join(root_dir, filename)) + self.labels.append(0) # cat + elif 'dog' in filename: + self.image_paths.append(os.path.join(root_dir, filename)) + self.labels.append(1) # dog + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = Image.open(img_path).convert('RGB') + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + +# Data preprocessing +data_transforms = { + 'train': transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), +} + +data_dir = 'dog-cat' +image_datasets = {x: CatDogDataset(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} +dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']} +dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} +class_names = ['cat', 'dog'] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# Load and modify ResNet-50 +model_ft = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) +num_ftrs = model_ft.fc.in_features +model_ft.fc = nn.Linear(num_ftrs, len(class_names)) + +model_ft = model_ft.to(device) + +criterion = nn.CrossEntropyLoss() + +# Optimizer +optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) + +# Learning rate scheduler +exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) + +# Training function +def train_model(model, criterion, optimizer, scheduler, num_epochs=25): + best_model_wts = model.state_dict() + best_acc = 0.0 + + for epoch in range(num_epochs): + print(f'Epoch {epoch}/{num_epochs - 1}') + print('-' * 10) + + for phase in ['train', 'val']: + if phase == 'train': + model.train() + else: + model.eval() + + running_loss = 0.0 + running_corrects = 0 + + for inputs, labels in dataloaders[phase]: + inputs = inputs.to(device) + labels = labels.to(device) + + optimizer.zero_grad() + + with torch.set_grad_enabled(phase == 'train'): + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + if phase == 'train': + loss.backward() + optimizer.step() + + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + if phase == 'train': + scheduler.step() + + epoch_loss = running_loss / dataset_sizes[phase] + epoch_acc = running_corrects.double() / dataset_sizes[phase] + + print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') + + if phase == 'val' and epoch_acc > best_acc: + best_acc = epoch_acc + best_model_wts = model.state_dict() + + print() + + print(f'Best val Acc: {best_acc:4f}') + model.load_state_dict(best_model_wts) + return model + +# Train and evaluate the model +model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) + +# Save the model +torch.save(model_ft.state_dict(), 'model_cat_dog_classifier.pt') +save_file(model_ft.state_dict(), 'model_cat_dog_classifier.safetensors') +``` + +## Inference + +To perform inference, you can use the following code. The inference is based on the transformer model with model ID `ailb/resnet-dogcat`. + +### Inference Code + +**You need set env HF_ENDPOINT=http://10.0.101.71** + +```python +from transformers import AutoImageProcessor, ResNetForImageClassification +from PIL import Image +import requests +import torch + +# Load model +url = "http://images.cocodataset.org/val2017/000000039769.jpg" +image = Image.open(requests.get(url, stream=True).raw) +processor = AutoImageProcessor.from_pretrained("ailab/resnet-dogcat") +model = ResNetForImageClassification.from_pretrained("ailab/resnet-dogcat") + +inputs = processor(image, return_tensors="pt") + +with torch.no_grad(): + logits = model(**inputs).logits + +# model predicts one of the 1000 ImageNet classes +predicted_label = logits.argmax(-1).item() +print(model.config.id2label[predicted_label]) +``` + +## Conclusion + +This repository provides a comprehensive solution for training and performing inference on a Cat-Dog classification task using a ResNet-50 model. The training script demonstrates how to preprocess data, train the model, and save the trained model. The inference script shows how to use the trained model to classify new images. \ No newline at end of file diff --git a/config.json b/config.json new file mode 100644 index 0000000..d1c6916 --- /dev/null +++ b/config.json @@ -0,0 +1,7 @@ +{ + "model_type": "resnet", + "id2label": { + "0": "cat", + "1": "dog" + } +} diff --git a/dog-cat/train/cat.10796.jpg b/dog-cat/train/cat.10796.jpg new file mode 100755 index 0000000..adcfd92 Binary files /dev/null and b/dog-cat/train/cat.10796.jpg differ diff --git a/dog-cat/train/cat.10923.jpg b/dog-cat/train/cat.10923.jpg new file mode 100755 index 0000000..ff32744 Binary files /dev/null and b/dog-cat/train/cat.10923.jpg differ diff --git a/dog-cat/train/cat.11085.jpg b/dog-cat/train/cat.11085.jpg new file mode 100755 index 0000000..466f996 Binary files /dev/null and b/dog-cat/train/cat.11085.jpg differ diff --git a/dog-cat/train/cat.11283.jpg b/dog-cat/train/cat.11283.jpg new file mode 100755 index 0000000..0e1b7f1 Binary files /dev/null and b/dog-cat/train/cat.11283.jpg differ diff --git a/dog-cat/train/cat.11828.jpg b/dog-cat/train/cat.11828.jpg new file mode 100755 index 0000000..d0b78a1 Binary files /dev/null and b/dog-cat/train/cat.11828.jpg differ diff --git a/dog-cat/train/cat.1580.jpg b/dog-cat/train/cat.1580.jpg new file mode 100755 index 0000000..2e7be41 Binary files /dev/null and b/dog-cat/train/cat.1580.jpg differ diff --git a/dog-cat/train/cat.1843.jpg b/dog-cat/train/cat.1843.jpg new file mode 100755 index 0000000..d592b85 Binary files /dev/null and b/dog-cat/train/cat.1843.jpg differ diff --git a/dog-cat/train/cat.2462.jpg b/dog-cat/train/cat.2462.jpg new file mode 100755 index 0000000..eec5799 Binary files /dev/null and b/dog-cat/train/cat.2462.jpg differ diff --git a/dog-cat/train/cat.2739.jpg b/dog-cat/train/cat.2739.jpg new file mode 100755 index 0000000..2ab5818 Binary files /dev/null and b/dog-cat/train/cat.2739.jpg differ diff --git a/dog-cat/train/cat.3056.jpg b/dog-cat/train/cat.3056.jpg new file mode 100755 index 0000000..a37dfe0 Binary files /dev/null and b/dog-cat/train/cat.3056.jpg differ diff --git a/dog-cat/train/cat.3401.jpg b/dog-cat/train/cat.3401.jpg new file mode 100755 index 0000000..5a28ce9 Binary files /dev/null and b/dog-cat/train/cat.3401.jpg differ diff --git a/dog-cat/train/cat.3795.jpg b/dog-cat/train/cat.3795.jpg new file mode 100755 index 0000000..e3dc259 Binary files /dev/null and b/dog-cat/train/cat.3795.jpg differ diff --git a/dog-cat/train/cat.4273.jpg b/dog-cat/train/cat.4273.jpg new file mode 100755 index 0000000..10a26fc Binary files /dev/null and b/dog-cat/train/cat.4273.jpg differ diff --git a/dog-cat/train/cat.4890.jpg b/dog-cat/train/cat.4890.jpg new file mode 100755 index 0000000..03047c4 Binary files /dev/null and b/dog-cat/train/cat.4890.jpg differ diff --git a/dog-cat/train/cat.5599.jpg b/dog-cat/train/cat.5599.jpg new file mode 100755 index 0000000..6f6899c Binary files /dev/null and b/dog-cat/train/cat.5599.jpg differ diff --git a/dog-cat/train/cat.5723.jpg b/dog-cat/train/cat.5723.jpg new file mode 100755 index 0000000..87f0f79 Binary files /dev/null and b/dog-cat/train/cat.5723.jpg differ diff --git a/dog-cat/train/cat.6505.jpg b/dog-cat/train/cat.6505.jpg new file mode 100755 index 0000000..136e8ad Binary files /dev/null and b/dog-cat/train/cat.6505.jpg differ diff --git a/dog-cat/train/cat.7203.jpg b/dog-cat/train/cat.7203.jpg new file mode 100755 index 0000000..dbe608e Binary files /dev/null and b/dog-cat/train/cat.7203.jpg differ diff --git a/dog-cat/train/cat.7549.jpg b/dog-cat/train/cat.7549.jpg new file mode 100755 index 0000000..47b34fe Binary files /dev/null and b/dog-cat/train/cat.7549.jpg differ diff --git a/dog-cat/train/cat.7696.jpg b/dog-cat/train/cat.7696.jpg new file mode 100755 index 0000000..febd508 Binary files /dev/null and b/dog-cat/train/cat.7696.jpg differ diff --git a/dog-cat/train/cat.7979.jpg b/dog-cat/train/cat.7979.jpg new file mode 100755 index 0000000..e5cbc1c Binary files /dev/null and b/dog-cat/train/cat.7979.jpg differ diff --git a/dog-cat/train/cat.8275.jpg b/dog-cat/train/cat.8275.jpg new file mode 100755 index 0000000..b94c7dc Binary files /dev/null and b/dog-cat/train/cat.8275.jpg differ diff --git a/dog-cat/train/cat.8885.jpg b/dog-cat/train/cat.8885.jpg new file mode 100755 index 0000000..14f8b1e Binary files /dev/null and b/dog-cat/train/cat.8885.jpg differ diff --git a/dog-cat/train/cat.9091.jpg b/dog-cat/train/cat.9091.jpg new file mode 100755 index 0000000..428c22b Binary files /dev/null and b/dog-cat/train/cat.9091.jpg differ diff --git a/dog-cat/train/cat.9257.jpg b/dog-cat/train/cat.9257.jpg new file mode 100755 index 0000000..fec153f Binary files /dev/null and b/dog-cat/train/cat.9257.jpg differ diff --git a/dog-cat/train/dog.10535.jpg b/dog-cat/train/dog.10535.jpg new file mode 100755 index 0000000..cf0c31d Binary files /dev/null and b/dog-cat/train/dog.10535.jpg differ diff --git a/dog-cat/train/dog.10985.jpg b/dog-cat/train/dog.10985.jpg new file mode 100755 index 0000000..fb86844 Binary files /dev/null and b/dog-cat/train/dog.10985.jpg differ diff --git a/dog-cat/train/dog.11654.jpg b/dog-cat/train/dog.11654.jpg new file mode 100755 index 0000000..c08c625 Binary files /dev/null and b/dog-cat/train/dog.11654.jpg differ diff --git a/dog-cat/train/dog.11846.jpg b/dog-cat/train/dog.11846.jpg new file mode 100755 index 0000000..5f21533 Binary files /dev/null and b/dog-cat/train/dog.11846.jpg differ diff --git a/dog-cat/train/dog.1927.jpg b/dog-cat/train/dog.1927.jpg new file mode 100755 index 0000000..6a49384 Binary files /dev/null and b/dog-cat/train/dog.1927.jpg differ diff --git a/dog-cat/train/dog.2221.jpg b/dog-cat/train/dog.2221.jpg new file mode 100755 index 0000000..a536d9d Binary files /dev/null and b/dog-cat/train/dog.2221.jpg differ diff --git a/dog-cat/train/dog.2848.jpg b/dog-cat/train/dog.2848.jpg new file mode 100755 index 0000000..d4a1a47 Binary files /dev/null and b/dog-cat/train/dog.2848.jpg differ diff --git a/dog-cat/train/dog.4730.jpg b/dog-cat/train/dog.4730.jpg new file mode 100755 index 0000000..a222248 Binary files /dev/null and b/dog-cat/train/dog.4730.jpg differ diff --git a/dog-cat/train/dog.5034.jpg b/dog-cat/train/dog.5034.jpg new file mode 100755 index 0000000..b4d3e54 Binary files /dev/null and b/dog-cat/train/dog.5034.jpg differ diff --git a/dog-cat/train/dog.5279.jpg b/dog-cat/train/dog.5279.jpg new file mode 100755 index 0000000..e1acb7d Binary files /dev/null and b/dog-cat/train/dog.5279.jpg differ diff --git a/dog-cat/train/dog.5415.jpg b/dog-cat/train/dog.5415.jpg new file mode 100755 index 0000000..6c9eb03 Binary files /dev/null and b/dog-cat/train/dog.5415.jpg differ diff --git a/dog-cat/train/dog.6409.jpg b/dog-cat/train/dog.6409.jpg new file mode 100755 index 0000000..263944c Binary files /dev/null and b/dog-cat/train/dog.6409.jpg differ diff --git a/dog-cat/train/dog.6583.jpg b/dog-cat/train/dog.6583.jpg new file mode 100755 index 0000000..61c8030 Binary files /dev/null and b/dog-cat/train/dog.6583.jpg differ diff --git a/dog-cat/train/dog.6624.jpg b/dog-cat/train/dog.6624.jpg new file mode 100755 index 0000000..d1b55ab Binary files /dev/null and b/dog-cat/train/dog.6624.jpg differ diff --git a/dog-cat/train/dog.6706.jpg b/dog-cat/train/dog.6706.jpg new file mode 100755 index 0000000..f6eb1cf Binary files /dev/null and b/dog-cat/train/dog.6706.jpg differ diff --git a/dog-cat/train/dog.7294.jpg b/dog-cat/train/dog.7294.jpg new file mode 100755 index 0000000..fd79597 Binary files /dev/null and b/dog-cat/train/dog.7294.jpg differ diff --git a/dog-cat/train/dog.8015.jpg b/dog-cat/train/dog.8015.jpg new file mode 100755 index 0000000..09e9219 Binary files /dev/null and b/dog-cat/train/dog.8015.jpg differ diff --git a/dog-cat/train/dog.8239.jpg b/dog-cat/train/dog.8239.jpg new file mode 100755 index 0000000..d8c1c41 Binary files /dev/null and b/dog-cat/train/dog.8239.jpg differ diff --git a/dog-cat/train/dog.8568.jpg b/dog-cat/train/dog.8568.jpg new file mode 100755 index 0000000..30f86fd Binary files /dev/null and b/dog-cat/train/dog.8568.jpg differ diff --git a/dog-cat/train/dog.9229.jpg b/dog-cat/train/dog.9229.jpg new file mode 100755 index 0000000..b05e47e Binary files /dev/null and b/dog-cat/train/dog.9229.jpg differ diff --git a/dog-cat/val/cat.10482.jpg b/dog-cat/val/cat.10482.jpg new file mode 100755 index 0000000..a1cd405 Binary files /dev/null and b/dog-cat/val/cat.10482.jpg differ diff --git a/dog-cat/val/cat.11092.jpg b/dog-cat/val/cat.11092.jpg new file mode 100755 index 0000000..0a9d140 Binary files /dev/null and b/dog-cat/val/cat.11092.jpg differ diff --git a/dog-cat/val/cat.11263.jpg b/dog-cat/val/cat.11263.jpg new file mode 100755 index 0000000..10b7f8b Binary files /dev/null and b/dog-cat/val/cat.11263.jpg differ diff --git a/dog-cat/val/cat.12040.jpg b/dog-cat/val/cat.12040.jpg new file mode 100755 index 0000000..706277d Binary files /dev/null and b/dog-cat/val/cat.12040.jpg differ diff --git a/dog-cat/val/cat.1674.jpg b/dog-cat/val/cat.1674.jpg new file mode 100755 index 0000000..1cb9a55 Binary files /dev/null and b/dog-cat/val/cat.1674.jpg differ diff --git a/dog-cat/val/cat.2493.jpg b/dog-cat/val/cat.2493.jpg new file mode 100755 index 0000000..d34f619 Binary files /dev/null and b/dog-cat/val/cat.2493.jpg differ diff --git a/dog-cat/val/cat.3148.jpg b/dog-cat/val/cat.3148.jpg new file mode 100755 index 0000000..4b8bc82 Binary files /dev/null and b/dog-cat/val/cat.3148.jpg differ diff --git a/dog-cat/val/cat.502.jpg b/dog-cat/val/cat.502.jpg new file mode 100755 index 0000000..87ba491 Binary files /dev/null and b/dog-cat/val/cat.502.jpg differ diff --git a/dog-cat/val/cat.6232.jpg b/dog-cat/val/cat.6232.jpg new file mode 100755 index 0000000..94a9311 Binary files /dev/null and b/dog-cat/val/cat.6232.jpg differ diff --git a/dog-cat/val/cat.6839.jpg b/dog-cat/val/cat.6839.jpg new file mode 100755 index 0000000..ccc3c36 Binary files /dev/null and b/dog-cat/val/cat.6839.jpg differ diff --git a/dog-cat/val/cat.7014.jpg b/dog-cat/val/cat.7014.jpg new file mode 100755 index 0000000..210a33e Binary files /dev/null and b/dog-cat/val/cat.7014.jpg differ diff --git a/dog-cat/val/cat.7991.jpg b/dog-cat/val/cat.7991.jpg new file mode 100755 index 0000000..a214d01 Binary files /dev/null and b/dog-cat/val/cat.7991.jpg differ diff --git a/dog-cat/val/cat.8115.jpg b/dog-cat/val/cat.8115.jpg new file mode 100755 index 0000000..f382557 Binary files /dev/null and b/dog-cat/val/cat.8115.jpg differ diff --git a/dog-cat/val/dog.10425.jpg b/dog-cat/val/dog.10425.jpg new file mode 100755 index 0000000..2cc9a80 Binary files /dev/null and b/dog-cat/val/dog.10425.jpg differ diff --git a/dog-cat/val/dog.10974.jpg b/dog-cat/val/dog.10974.jpg new file mode 100755 index 0000000..14e8abe Binary files /dev/null and b/dog-cat/val/dog.10974.jpg differ diff --git a/dog-cat/val/dog.1215.jpg b/dog-cat/val/dog.1215.jpg new file mode 100755 index 0000000..b304de5 Binary files /dev/null and b/dog-cat/val/dog.1215.jpg differ diff --git a/dog-cat/val/dog.12157.jpg b/dog-cat/val/dog.12157.jpg new file mode 100755 index 0000000..7f5c4c3 Binary files /dev/null and b/dog-cat/val/dog.12157.jpg differ diff --git a/dog-cat/val/dog.1496.jpg b/dog-cat/val/dog.1496.jpg new file mode 100755 index 0000000..d919f38 Binary files /dev/null and b/dog-cat/val/dog.1496.jpg differ diff --git a/dog-cat/val/dog.2294.jpg b/dog-cat/val/dog.2294.jpg new file mode 100755 index 0000000..9282833 Binary files /dev/null and b/dog-cat/val/dog.2294.jpg differ diff --git a/dog-cat/val/dog.2321.jpg b/dog-cat/val/dog.2321.jpg new file mode 100755 index 0000000..36f1133 Binary files /dev/null and b/dog-cat/val/dog.2321.jpg differ diff --git a/dog-cat/val/dog.3423.jpg b/dog-cat/val/dog.3423.jpg new file mode 100755 index 0000000..96cfb87 Binary files /dev/null and b/dog-cat/val/dog.3423.jpg differ diff --git a/dog-cat/val/dog.4633.jpg b/dog-cat/val/dog.4633.jpg new file mode 100755 index 0000000..e330bd2 Binary files /dev/null and b/dog-cat/val/dog.4633.jpg differ diff --git a/dog-cat/val/dog.4905.jpg b/dog-cat/val/dog.4905.jpg new file mode 100755 index 0000000..8c2a2f3 Binary files /dev/null and b/dog-cat/val/dog.4905.jpg differ diff --git a/dog-cat/val/dog.5145.jpg b/dog-cat/val/dog.5145.jpg new file mode 100755 index 0000000..677e3c1 Binary files /dev/null and b/dog-cat/val/dog.5145.jpg differ diff --git a/dog-cat/val/dog.5699.jpg b/dog-cat/val/dog.5699.jpg new file mode 100755 index 0000000..aa1c9f9 Binary files /dev/null and b/dog-cat/val/dog.5699.jpg differ diff --git a/dog-cat/val/dog.5885.jpg b/dog-cat/val/dog.5885.jpg new file mode 100755 index 0000000..3a90ef2 Binary files /dev/null and b/dog-cat/val/dog.5885.jpg differ diff --git a/dog-cat/val/dog.6102.jpg b/dog-cat/val/dog.6102.jpg new file mode 100755 index 0000000..16f91a1 Binary files /dev/null and b/dog-cat/val/dog.6102.jpg differ diff --git a/dog-cat/val/dog.7253.jpg b/dog-cat/val/dog.7253.jpg new file mode 100755 index 0000000..36bcbf2 Binary files /dev/null and b/dog-cat/val/dog.7253.jpg differ diff --git a/dog-cat/val/dog.8250.jpg b/dog-cat/val/dog.8250.jpg new file mode 100755 index 0000000..d933387 Binary files /dev/null and b/dog-cat/val/dog.8250.jpg differ diff --git a/dog-cat/val/dog.9690.jpg b/dog-cat/val/dog.9690.jpg new file mode 100755 index 0000000..e0f591f Binary files /dev/null and b/dog-cat/val/dog.9690.jpg differ diff --git a/model.safetensors b/model.safetensors new file mode 100644 index 0000000..698fe6f --- /dev/null +++ b/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7f442ab7f5642353dcb77df82e5c19d62e89794e66876602225169ea6a4a3f1f +size 94289968 diff --git a/preprocessor_config.json b/preprocessor_config.json new file mode 100644 index 0000000..9a46cca --- /dev/null +++ b/preprocessor_config.json @@ -0,0 +1,18 @@ +{ + "crop_pct": 0.875, + "do_normalize": true, + "do_resize": true, + "feature_extractor_type": "ConvNextFeatureExtractor", + "image_mean": [ + 0.485, + 0.456, + 0.406 + ], + "image_std": [ + 0.229, + 0.224, + 0.225 + ], + "resample": 3, + "size": 224 +} diff --git a/pytorch_model.bin b/pytorch_model.bin new file mode 100644 index 0000000..0932780 --- /dev/null +++ b/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06f247acac344695de9bb9d400fc0822ec67166be5217204d1b0a970ca4b59b3 +size 94368074 diff --git a/train.py b/train.py new file mode 100644 index 0000000..418321d --- /dev/null +++ b/train.py @@ -0,0 +1,150 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +from torchvision import models, transforms +from PIL import Image +from safetensors.torch import save_file + +class CatDogDataset(Dataset): + def __init__(self, root_dir, transform=None): + self.root_dir = root_dir + self.transform = transform + self.image_paths = [] + self.labels = [] + + for filename in os.listdir(root_dir): + if 'cat' in filename: + self.image_paths.append(os.path.join(root_dir, filename)) + self.labels.append(0) # cat + elif 'dog' in filename: + self.image_paths.append(os.path.join(root_dir, filename)) + self.labels.append(1) # dog + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + img_path = self.image_paths[idx] + image = Image.open(img_path).convert('RGB') + label = self.labels[idx] + + if self.transform: + image = self.transform(image) + + return image, label + +# 数据预处理 +data_transforms = { + 'train': transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), +} + +data_dir = 'dog-cat' +image_datasets = {x: CatDogDataset(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} +dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4) for x in ['train', 'val']} +dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} +class_names = ['cat', 'dog'] + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +# 加载预训练的 ResNet-50 模型并进行微调 +model_ft = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) +num_ftrs = model_ft.fc.in_features +model_ft.fc = nn.Linear(num_ftrs, len(class_names)) + +model_ft = model_ft.to(device) + +criterion = nn.CrossEntropyLoss() + +# 优化器 +optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) + +# 学习率调度器 +exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) + +# 训练模型 +def train_model(model, criterion, optimizer, scheduler, num_epochs=25): + best_model_wts = model.state_dict() + best_acc = 0.0 + + for epoch in range(num_epochs): + print(f'Epoch {epoch}/{num_epochs - 1}') + print('-' * 10) + + # 每个epoch都有训练和验证阶段 + for phase in ['train', 'val']: + if phase == 'train': + model.train() # 设置模型为训练模式 + else: + model.eval() # 设置模型为评估模式 + + running_loss = 0.0 + running_corrects = 0 + + # 遍历数据 + for inputs, labels in dataloaders[phase]: + inputs = inputs.to(device) + labels = labels.to(device) + + # 清零参数梯度 + optimizer.zero_grad() + + # 前向传播 + with torch.set_grad_enabled(phase == 'train'): + outputs = model(inputs) + _, preds = torch.max(outputs, 1) + loss = criterion(outputs, labels) + + # 只有在训练阶段才反向传播+优化 + if phase == 'train': + loss.backward() + optimizer.step() + + # 统计 + running_loss += loss.item() * inputs.size(0) + running_corrects += torch.sum(preds == labels.data) + + if phase == 'train': + scheduler.step() + + epoch_loss = running_loss / dataset_sizes[phase] + epoch_acc = running_corrects.double() / dataset_sizes[phase] + + print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') + + # 深拷贝模型 + if phase == 'val' and epoch_acc > best_acc: + best_acc = epoch_acc + best_model_wts = model.state_dict() + + print() + + print(f'Best val Acc: {best_acc:4f}') + + # 加载最佳模型权重 + model.load_state_dict(best_model_wts) + return model + +# 训练和评估模型 +model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) + +# 保存模型为 .pt 文件 +torch.save(model_ft.state_dict(), 'pytorch_model.bin') + +metadata = {"format": "pt"} + +# 保存模型为 .safetensors 文件 +save_file(model_ft.state_dict(), 'model.safetensors', metadata=metadata) +