图片分类(使用流水线API)
模型下载1:google/vit-base-patch16-224 · Hugging Face
模型下载2:google/vit-base-patch16-224 · HF Mirror
from transformers import pipeline
vision_classifier = pipeline(model="google/vit-base-patch16-224")
preds = vision_classifier(
images="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
)
preds = [{"score": round(pred["score"], 4), "label": pred["label"]} for pred in preds]
print(preds)
输出:
[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]
图片分类(不使用流水线API)
模型下载1:google/vit-base-patch16-224 · Hugging Face
模型下载2:google/vit-base-patch16-224 · HF Mirror
import torch
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import requests
from PIL import Image
from io import BytesIO
# 加载预训练模型和特征提取器
model_id = "google/vit-base-patch16-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id)
# 加载图像
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# 预处理图像
inputs = feature_extractor(images=image, return_tensors="pt")
# 使用模型进行预测
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# 获取预测结果
predicted_class_idx = logits.argmax(-1).item()
scores = torch.nn.functional.softmax(logits, dim=-1)[0]
# 提取并格式化结果
preds = []
for idx, score in enumerate(scores):
if model.config.id2label:
label = model.config.id2label[idx]
if score > 0.01: # 只保留置信度较高的结果
preds.append({"score": round(score.item(), 4), "label": label})
# 按置信度排序
preds = sorted(preds, key=lambda x: x["score"], reverse=True)[:5] # 只保留前5个结果
print(preds)
输出:
[{'score': 0.4335, 'label': 'lynx, catamount'}, {'score': 0.0348, 'label': 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor'}, {'score': 0.0324, 'label': 'snow leopard, ounce, Panthera uncia'}, {'score': 0.0239, 'label': 'Egyptian cat'}, {'score': 0.0229, 'label': 'tiger cat'}]