sNetForImageClassification >>> from PIL import Image >>> import jax >>> import requests >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") >>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50") >>> inputs = image_processor(images=image, return_tensors="np") >>> outputs = model(**inputs) >>> logits = outputs.logits >>> # model predicts one of the 1000 ImageNet classes >>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1) >>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()]) ``` )8Ú functoolsr