In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPProcessor, CLIPModel
import mlflow
import mlflow.pytorch
from PIL import Image
import os

In [None]:
# === Dataset ===
class PlantDocDataset(Dataset):
    def __init__(self, root, processor):
        self.root = root
        self.samples = []  # (img_path, label)
        self.processor = processor
        for label in os.listdir(root):
            label_path = os.path.join(root, label)
            for img_name in os.listdir(label_path):
                self.samples.append((os.path.join(label_path, img_name), label))
    
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        encoding = self.processor(
            text=label, images=image, return_tensors="pt", padding=True
        )
        return {k: v.squeeze(0) for k, v in encoding.items()}

In [None]:
# === Model & Training ===
device = "cuda:1" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

train_dataset = PlantDocDataset("PlantDoc/train", processor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

In [None]:
# === MLflow tracking ===
mlflow.set_experiment("plantdoc_clip")
with mlflow.start_run():
    for epoch in range(3):  # например, 3 эпохи
        total_loss = 0
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch} Loss: {avg_loss:.4f}")
        mlflow.log_metric("loss", avg_loss, step=epoch)
    
    # сохраняем модель
    mlflow.pytorch.log_model(model, "clip-model")