profile
viewpoint
If you are wondering where the data of this site comes from, please visit https://api.github.com/users/fmassa/events. GitMemory does not store any data, but only uses NGINX to cache data for a period of time. The idea behind GitMemory is simply to give users a better reading experience.
Francisco Massa fmassa Facebook AI Reseach

fmassa/optimize-net 280

OptNet - Reducing memory usage in torch neural nets

datumbox/dapi-model-versioning 2

RFC for Model Versioning across all PyTorch Domain libraries

fmassa/SlowFast 2

PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

fmassa/cifar.torch 1

92.45% on CIFAR-10 in Torch

fmassa/deit 1

Official DeiT repository

fmassa/PyTorch_classes 1

PyTorch classes

fmassa/audio 0

Data manipulation and transformation for audio signal processing, powered by PyTorch

Pull request review commentpytorch/vision

Refactor Segmentation models

 def forward(self, x: torch.Tensor) -> torch.Tensor:             _res.append(conv(x))         res = torch.cat(_res, dim=1)         return self.project(res)+++def _deeplabv3_resnet(+    backbone: resnet.ResNet,+    num_classes: int,+    aux: Optional[bool],+) -> DeepLabV3:+    return_layers = {"layer4": "out"}+    if aux:+        return_layers["layer3"] = "aux"+    backbone = create_feature_extractor(backbone, return_layers)++    aux_classifier = FCNHead(1024, num_classes) if aux else None+    classifier = DeepLabHead(2048, num_classes)+    return DeepLabV3(backbone, classifier, aux_classifier)+++def _deeplabv3_mobilenetv3(+    backbone: mobilenetv3.MobileNetV3,+    num_classes: int,+    aux: Optional[bool],+) -> DeepLabV3:+    backbone = backbone.features+    # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.+    # The first and last blocks are always included because they are the C0 (conv1) and Cn.+    stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]+    out_pos = stage_indices[-1]  # use C5 which has output_stride = 16+    out_inplanes = backbone[out_pos].out_channels+    aux_pos = stage_indices[-4]  # use C2 here which has output_stride = 8+    aux_inplanes = backbone[aux_pos].out_channels+    return_layers = {str(out_pos): "out"}+    if aux:+        return_layers[str(aux_pos)] = "aux"+    backbone = create_feature_extractor(backbone, return_layers)++    aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None+    classifier = DeepLabHead(out_inplanes, num_classes)+    return DeepLabV3(backbone, classifier, aux_classifier)+++def deeplabv3_resnet50(+    pretrained: bool = False,+    progress: bool = True,+    num_classes: int = 21,+    aux_loss: Optional[bool] = None,+    pretrained_backbone: bool = True,+) -> DeepLabV3:+    """Constructs a DeepLabV3 model with a ResNet-50 backbone.++    Args:+        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which+            contains the same classes as Pascal VOC+        progress (bool): If True, displays a progress bar of the download to stderr+        num_classes (int): number of output classes of the model (including the background)+        aux_loss (bool, optional): If True, it uses an auxiliary loss+        pretrained_backbone (bool): If True, the backbone will be pre-trained.+    """+    if pretrained:+        aux_loss = True+        pretrained_backbone = False++    backbone = resnet.resnet50(pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True])+    model = _deeplabv3_resnet(backbone, num_classes, aux_loss)++    if pretrained:+        arch = "deeplabv3_resnet50_coco"+        _load_weights(arch, model, model_urls.get(arch, None), progress)

nit: I suppose you've left this here (instead of putting it in _deeplabv3_resnet because it will be more aligned with your changes to the new weights?

Same for the backbone retrieval code

datumbox

comment created time in 10 hours

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentpytorch/vision

add prototype imagenet dataset

+import io+import pathlib+import re+from typing import Any, Callable, Dict, List, Optional, Tuple++import torch+from torchdata.datapipes.iter import IterDataPipe, LineReader, KeyZipper, Mapper, TarArchiveReader, Filter, Shuffler+from torchvision.prototype.datasets.utils import (+    Dataset,+    DatasetConfig,+    DatasetInfo,+    HttpResource,+    OnlineResource,+    DatasetType,+)+from torchvision.prototype.datasets.utils._internal import (+    create_categories_file,+    INFINITE_BUFFER_SIZE,+    path_comparator,+    Enumerator,+    getitem,+)++HERE = pathlib.Path(__file__).parent+++class ImageNet(Dataset):+    @property+    def info(self) -> DatasetInfo:+        return DatasetInfo(+            "imagenet",+            type=DatasetType.IMAGE,+            categories=HERE / "imagenet.categories",+            homepage="https://www.image-net.org/",+            valid_options=dict(split=("train", "val")),+        )++    def resources(self, config: DatasetConfig) -> List[OnlineResource]:+        if config.split == "train":+            images = HttpResource(+                "ILSVRC2012_img_train.tar",+                sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",+            )+        else:  # config.split == "val"+            images = HttpResource(+                "ILSVRC2012_img_val.tar",+                sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",+            )++        devkit = HttpResource(+            "ILSVRC2012_devkit_t12.tar.gz",+            sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953",+        )++        return [images, devkit]++    _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<category>n\d{8})_\d+[.]JPEG")++    def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, int], Tuple[str, io.IOBase]]:+        path = pathlib.Path(data[0])+        category = self._TRAIN_IMAGE_NAME_PATTERN.match(path.name).group("category")  # type: ignore[union-attr]+        return (None, self.categories.index(category)), data++    _VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")++    def _val_image_key(self, data: Tuple[str, Any]) -> int:+        path = pathlib.Path(data[0])+        return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id"))  # type: ignore[union-attr]++    def _collate_and_decode_sample(+        self,+        data: Tuple[Tuple[Optional[int], int], Tuple[str, io.IOBase]],+        *,+        decoder: Optional[Callable[[io.IOBase], torch.Tensor]],+    ) -> Dict[str, Any]:+        label_data, image_data = data+        _, label = label_data+        path, buffer = image_data++        category = self.categories[label]+        label = torch.tensor(label)

nit for the future: we might want to revisit if we want to store numbers as a 0d tensor or a raw number instead. It's generally much faster and smaller to rely on the raw python number, which might be a good thing if we want to minimize transfer / storage.

pmeier

comment created time in a day

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentpytorch/vision

add prototype imagenet dataset

+n01440764

We can provide both the synset names as well as its corresponding human-readable representation.

pmeier

comment created time in a day

PullRequestReviewEvent

Pull request review commentpytorch/vision

Multi-pretrained weight support - Quantized ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ....models.quantization.resnet import (+    QuantizableBasicBlock,+    QuantizableBottleneck,+    QuantizableResNet,+    _replace_relu,+    quantize_model,+)+from ...transforms.presets import ImageNetEval+from .._api import Weights, WeightEntry+from .._meta import _IMAGENET_CATEGORIES+from ..resnet import ResNet50Weights+++__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    quantize: bool,+    **kwargs: Any,+) -> QuantizableResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])+        if "backend" in weights.meta:+            kwargs["backend"] = weights.meta["backend"]+    backend = kwargs.pop("backend", "fbgemm")++    model = QuantizableResNet(block, layers, **kwargs)+    _replace_relu(model)+    if quantize:+        quantize_model(model, backend)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))++    return model+++_common_meta = {+    "size": (224, 224),+    "categories": _IMAGENET_CATEGORIES,+    "backend": "fbgemm",+}+++class QuantizedResNet50Weights(Weights):

Throwing an idea in the wild: as of today (and I think this will be the case for all models), all quantized weights originates from an unquantized weight. Do we want to keep this link somehow in the Weights structure? Do we want the quantized weights to be magically picked from the ResNet50Weights if we pass the quantize flag?

I think it might be important to keep this relationship somehow.

datumbox

comment created time in a day

Pull request review commentpytorch/vision

Multi-pretrained weight support - Quantized ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ....models.quantization.resnet import (+    QuantizableBasicBlock,+    QuantizableBottleneck,+    QuantizableResNet,+    _replace_relu,+    quantize_model,+)+from ...transforms.presets import ImageNetEval+from .._api import Weights, WeightEntry+from .._meta import _IMAGENET_CATEGORIES+from ..resnet import ResNet50Weights+++__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    quantize: bool,+    **kwargs: Any,+) -> QuantizableResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])+        if "backend" in weights.meta:+            kwargs["backend"] = weights.meta["backend"]+    backend = kwargs.pop("backend", "fbgemm")++    model = QuantizableResNet(block, layers, **kwargs)+    _replace_relu(model)+    if quantize:+        quantize_model(model, backend)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))++    return model+++_common_meta = {+    "size": (224, 224),+    "categories": _IMAGENET_CATEGORIES,+    "backend": "fbgemm",+}+++class QuantizedResNet50Weights(Weights):+    ImageNet1K_FBGEMM_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",+        transforms=partial(ImageNetEval, crop_size=224),+        meta={+            **_common_meta,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",+            "acc@1": 75.920,+            "acc@5": 92.814,+        },+    )+++def resnet50(+    weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,+    progress: bool = True,+    quantize: bool = False,+    **kwargs: Any,+) -> QuantizableResNet:+    if "pretrained" in kwargs:

nit: you could probably simplify a tiny bit the code by doing something like

if quantize:
    weights_table = QuantizedResNet50Weights
else:
    weights_table = ResNet50Weights
...
weights = weights_table.ImageNet1K_RefV1  # different naming convention than now
weights = weights_table.verify(weights)

In some sense, it's a bit annoying to have to carry those two Weights inside every model builder

datumbox

comment created time in a day

Pull request review commentpytorch/vision

Multi-pretrained weight support - Quantized ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ....models.quantization.resnet import (+    QuantizableBasicBlock,+    QuantizableBottleneck,+    QuantizableResNet,+    _replace_relu,+    quantize_model,+)+from ...transforms.presets import ImageNetEval+from .._api import Weights, WeightEntry+from .._meta import _IMAGENET_CATEGORIES+from ..resnet import ResNet50Weights+++__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[QuantizableBasicBlock, QuantizableBottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    quantize: bool,+    **kwargs: Any,+) -> QuantizableResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])

Flagging again https://github.com/pytorch/vision/pull/4613#discussion_r729678718 but we can discuss in a follow-up

datumbox

comment created time in a day

PullRequestReviewEvent
PullRequestReviewEvent

push eventfacebookresearch/detr

Naelson Douglas C. Oliveira

commit sha 091a817eca74b8b97e35e4531c1c39f89fbe38eb

Removed a manual indexer iterator pitfall (#454) * Removed a manual iterator indexer pitfall * Removed a manual iterator indexer pitfall

view details

push time in a day

PR merged facebookresearch/detr

Removed a manual indexer iterator pitfall CLA Signed

The problem The code was iterating on the list 'outputs' manually using for i, element in enumerate(outputs

for i in range(0, len(outputs)):

and accessing the data manually as:

outputs[i]

but Python has a built-in manner to deal with it using the method enumerate

for index, element in enumerate(outputs):

and thus making the access to the data easier

The solution Changed the manual indexer to the enumerated iterator.

+2 -2

0 comment

1 changed file

NaelsonDouglas

pr closed time in a day

PullRequestReviewEvent

issue commentpytorch/data

Router for same functional API

Note that ioPath handles automatically all cases within itself by just specifying the correct URI.

This means that we can just use IoPathFileListerIterDataPipe and IoPathFileLoaderIterDataPipe and get the expected result that we want to achieve in here.

IMO if we are using ioPath as a dependency (or if we plan to make it a hard dependency), I think we should just use ioPath for all IO-based reading. This would simplify a lot the codebase, as duplicate implementations could be removed, and the user would just need to rely on a single datapipe for everything load-related.

ejguan

comment created time in 2 days

push eventpytorch/vision

Dmytro

commit sha e08c9e31d1aac09caa4998220f86912f833f3179

Replaced all 'no_grad()' instances with 'inference_mode()' (#4629)

view details

push time in 2 days

PR merged pytorch/vision

Replace all `no_grad()` instances with `inference_mode()` in reference scripts enhancement module: reference scripts cla signed ciflow/default

Closes #4624 All instances of no_grad() seemed ok to replace with inference_mode() since all of them were related to evaluation

+10 -10

4 comments

10 changed files

XoMute

pr closed time in 2 days

issue closedpytorch/vision

Replace instances of `torch.no_grad()` with `torch.inference_mode()` in evaluation reference scripts

Since last release of PyTorch, it now provides a new context manager / function decorator that enables faster execution than torch.no_grad() https://pytorch.org/docs/stable/generated/torch.inference_mode.html?highlight=inference_mode#torch.inference_mode

We should replace instances of torch.no_grad() in our reference scripts whenever torch.inference_mode() is the correct candidate (for example, in evaluation code as in https://github.com/pytorch/vision/blob/3d7244b5280301792a6959c59154d4809ad1209a/references/classification/train.py#L60

cc @datumbox

closed time in 2 days

fmassa
PullRequestReviewEvent

issue openedpytorch/vision

Replace instances of `torch.no_grad()` with `torch.inference_mode()` in evaluation reference scripts

Since last release of PyTorch, it now provides a new context manager / function decorator that enables faster execution than torch.no_grad() https://pytorch.org/docs/stable/generated/torch.inference_mode.html?highlight=inference_mode#torch.inference_mode

We should replace instances of torch.no_grad() in our reference scripts whenever torch.inference_mode() is the correct candidate (for example, in evaluation code as in https://github.com/pytorch/vision/blob/3d7244b5280301792a6959c59154d4809ad1209a/references/classification/train.py#L60

created time in 4 days

delete branch fmassa/vision-1

delete branch : scripts-release-notes

delete time in 4 days

Pull request review commentpytorch/vision

Multi-pretrained weight support - FasterRCNN ResNet50

+import warnings+from typing import Any, Optional++from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers+from ...transforms.presets import CocoEval+from .._api import Weights, WeightEntry+from .._meta import _COCO_CATEGORIES+from ..resnet import ResNet50Weights+from .backbone_utils import resnet_fpn_backbone+++__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]+++class FasterRCNNResNet50FPNWeights(Weights):+    Coco_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",+        transforms=CocoEval,+        meta={+            "categories": _COCO_CATEGORIES,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",+            "map": 37.0,+        },+    )+++def fasterrcnn_resnet50_fpn(+    weights: Optional[FasterRCNNResNet50FPNWeights] = None,+    weights_backbone: Optional[ResNet50Weights] = None,+    progress: bool = True,+    num_classes: int = 91,+    trainable_backbone_layers: Optional[int] = None,+    **kwargs: Any,+) -> FasterRCNN:+    if "pretrained" in kwargs:+        warnings.warn("The argument pretrained is deprecated, please use weights instead.")+        weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None+    weights = FasterRCNNResNet50FPNWeights.verify(weights)+    if "pretrained_backbone" in kwargs:+        warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")+        weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None+    weights_backbone = ResNet50Weights.verify(weights_backbone)++    if weights is not None:+        weights_backbone = None+        num_classes = len(weights.meta["categories"])++    trainable_backbone_layers = _validate_trainable_layers(+        weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3+    )++    backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)+    model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))+        overwrite_eps(model, 0.0)

For when we add new model weights: this should be refactored because we will be training with the new eps

datumbox

comment created time in 4 days

Pull request review commentpytorch/vision

Multi-pretrained weight support - FasterRCNN ResNet50

+import warnings+from typing import Any, Optional++from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers+from ...transforms.presets import CocoEval+from .._api import Weights, WeightEntry+from .._meta import _COCO_CATEGORIES+from ..resnet import ResNet50Weights+from .backbone_utils import resnet_fpn_backbone+++__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]+++class FasterRCNNResNet50FPNWeights(Weights):+    Coco_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",+        transforms=CocoEval,+        meta={+            "categories": _COCO_CATEGORIES,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn",+            "map": 37.0,+        },+    )+++def fasterrcnn_resnet50_fpn(+    weights: Optional[FasterRCNNResNet50FPNWeights] = None,+    weights_backbone: Optional[ResNet50Weights] = None,+    progress: bool = True,+    num_classes: int = 91,+    trainable_backbone_layers: Optional[int] = None,+    **kwargs: Any,+) -> FasterRCNN:+    if "pretrained" in kwargs:+        warnings.warn("The argument pretrained is deprecated, please use weights instead.")+        weights = FasterRCNNResNet50FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None+    weights = FasterRCNNResNet50FPNWeights.verify(weights)+    if "pretrained_backbone" in kwargs:+        warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.")+        weights_backbone = ResNet50Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None+    weights_backbone = ResNet50Weights.verify(weights_backbone)++    if weights is not None:+        weights_backbone = None+        num_classes = len(weights.meta["categories"])

We should probably raise an error / warning if the user modifies the num_classes and passes a weights argument. Otherwise they might silently think that we are doing magic inside

datumbox

comment created time in 4 days

PullRequestReviewEvent
PullRequestReviewEvent

Pull request review commentpytorch/vision

[WIP] Prototype models - ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ...models.resnet import BasicBlock, Bottleneck, ResNet+from ..transforms.presets import ImageNetEval+from ._api import Weights, WeightEntry+++__all__ = ["ResNet", "ResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[BasicBlock, Bottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    **kwargs: Any,+) -> ResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])++    model = ResNet(block, layers, **kwargs)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))++    return model+++_common_meta = {+    "size": (224, 224),+    "categories": list(range(1000)),  # TODO: torchvision.prototype.datasets.find("ImageNet").info.categories+}+++class ResNet50Weights(Weights):+    ImageNet1K_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/resnet50-0676ba61.pth",+        transforms=partial(ImageNetEval, crop_size=224),+        meta={+            **_common_meta,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification",+            "acc@1": 76.130,+            "acc@5": 92.862,+        },+    )+    ImageNet1K_RefV2 = WeightEntry(

Naming will be important here. ImageNet1K_RefV2 sounds good for a v1, but we should have a webpage in the doc which will break this down nicely. Maybe something to keep in mind, an easy way to gather this information automatically to facilitate generating the documentation

datumbox

comment created time in 6 days

Pull request review commentpytorch/vision

[WIP] Prototype models - ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ...models.resnet import BasicBlock, Bottleneck, ResNet+from ..transforms.presets import ImageNetEval+from ._api import Weights, WeightEntry+++__all__ = ["ResNet", "ResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[BasicBlock, Bottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    **kwargs: Any,+) -> ResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])++    model = ResNet(block, layers, **kwargs)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))++    return model+++_common_meta = {+    "size": (224, 224),+    "categories": list(range(1000)),  # TODO: torchvision.prototype.datasets.find("ImageNet").info.categories+}+++class ResNet50Weights(Weights):+    ImageNet1K_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/resnet50-0676ba61.pth",+        transforms=partial(ImageNetEval, crop_size=224),+        meta={+            **_common_meta,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification",+            "acc@1": 76.130,+            "acc@5": 92.862,+        },+    )+    ImageNet1K_RefV2 = WeightEntry(+        url="https://download.pytorch.org/models/resnet50-tmp.pth",

What is the plan here, to re-upload at some point in the future? Also, how do we plan on keeping the names for the checkpoint files manageable, just rely on the sha256 to differentiate them?

datumbox

comment created time in 6 days

Pull request review commentpytorch/vision

[WIP] Prototype models - ResNet50

+import warnings+from functools import partial+from typing import Any, List, Optional, Type, Union++from ...models.resnet import BasicBlock, Bottleneck, ResNet+from ..transforms.presets import ImageNetEval+from ._api import Weights, WeightEntry+++__all__ = ["ResNet", "ResNet50Weights", "resnet50"]+++def _resnet(+    block: Type[Union[BasicBlock, Bottleneck]],+    layers: List[int],+    weights: Optional[Weights],+    progress: bool,+    **kwargs: Any,+) -> ResNet:+    if weights is not None:+        kwargs["num_classes"] = len(weights.meta["categories"])++    model = ResNet(block, layers, **kwargs)++    if weights is not None:+        model.load_state_dict(weights.state_dict(progress=progress))++    return model+++_common_meta = {+    "size": (224, 224),+    "categories": list(range(1000)),  # TODO: torchvision.prototype.datasets.find("ImageNet").info.categories+}+++class ResNet50Weights(Weights):+    ImageNet1K_RefV1 = WeightEntry(+        url="https://download.pytorch.org/models/resnet50-0676ba61.pth",+        transforms=partial(ImageNetEval, crop_size=224),+        meta={+            **_common_meta,+            "recipe": "https://github.com/pytorch/vision/tree/main/references/classification",+            "acc@1": 76.130,+            "acc@5": 92.862,+        },+    )+    ImageNet1K_RefV2 = WeightEntry(+        url="https://download.pytorch.org/models/resnet50-tmp.pth",+        transforms=partial(ImageNetEval, crop_size=224),+        meta={+            **_common_meta,+            "recipe": "https://github.com/pytorch/vision/issues/3995",+            "acc@1": 80.352,

Beautiful!

datumbox

comment created time in 6 days