图片分类

创建日期:2025-04-20
更新日期:2025-04-20

图片分类(使用流水线API)

模型下载1:google/vit-base-patch16-224 · Hugging Face

模型下载2:google/vit-base-patch16-224 · HF Mirror

from transformers import pipeline

vision_classifier = pipeline(model="google/vit-base-patch16-224")
preds = vision_classifier(
    images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
)
preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]

print(preds)

输出:

[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]

图片分类(不使用流水线API)

模型下载1:google/vit-base-patch16-224 · Hugging Face

模型下载2:google/vit-base-patch16-224 · HF Mirror

import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import requests
from PIL import Image
from io import BytesIO

# 加载预训练模型和特征提取器
model_id = "google/vit-base-patch16-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)

# 加载图像
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))

# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")

# 使用模型进行预测
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

# 获取预测结果
predicted_class_idx = logits.argmax(-1).item()
scores = torch.nn.functional.softmax(logits, dim=-1)[0]

# 提取并格式化结果
preds = []
for idx, score in enumerate(scores):
    if model.config.id2label:
        label = model.config.id2label[idx]
        if score > 0.01:  # 只保留置信度较高的结果
            preds.append({"score": round(score.item(), 4), "label": label})

# 按置信度排序
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:5]  # 只保留前5个结果

print(preds)

输出:

[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]

简介

一个来自三线小城市的程序员开发经验总结。