MerlenMaven commited on
Commit
e11b864
·
verified ·
1 Parent(s): d7694f6

Upload 2 files

Browse files
Files changed (2) hide show
  1. fer.rar +3 -0
  2. fer2013.py +75 -0
fer.rar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60b0dc4d512f9d2d810a75dd4bcc15100dfa68a38760f3ec478c8a0dadc7aaaf
3
+ size 1482977
fer2013.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import pathlib
3
+ from typing import Any, Callable, Optional, Tuple
4
+
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from .utils import check_integrity, verify_str_arg
9
+ from .vision import VisionDataset
10
+
11
+
12
+ class FER2013(VisionDataset):
13
+ """`FER2013
14
+ <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
15
+
16
+ Args:
17
+ root (string): Root directory of dataset where directory
18
+ ``root/fer2013`` exists.
19
+ split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
20
+ transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
21
+ version. E.g, ``transforms.RandomCrop``
22
+ target_transform (callable, optional): A function/transform that takes in the target and transforms it.
23
+ """
24
+
25
+ _RESOURCES = {
26
+ "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
27
+ "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
28
+ }
29
+
30
+ def __init__(
31
+ self,
32
+ root: str,
33
+ split: str = "train",
34
+ transform: Optional[Callable] = None,
35
+ target_transform: Optional[Callable] = None,
36
+ ) -> None:
37
+ self._split = verify_str_arg(split, "split", self._RESOURCES.keys())
38
+ super().__init__(root, transform=transform, target_transform=target_transform)
39
+
40
+ base_folder = pathlib.Path(self.root) / "fer2013"
41
+ file_name, md5 = self._RESOURCES[self._split]
42
+ data_file = base_folder / file_name
43
+ if not check_integrity(str(data_file), md5=md5):
44
+ raise RuntimeError(
45
+ f"{file_name} not found in {base_folder} or corrupted. "
46
+ f"You can download it from "
47
+ f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
48
+ )
49
+
50
+ with open(data_file, "r", newline="") as file:
51
+ self._samples = [
52
+ (
53
+ torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48),
54
+ int(row["emotion"]) if "emotion" in row else None,
55
+ )
56
+ for row in csv.DictReader(file)
57
+ ]
58
+
59
+ def __len__(self) -> int:
60
+ return len(self._samples)
61
+
62
+ def __getitem__(self, idx: int) -> Tuple[Any, Any]:
63
+ image_tensor, target = self._samples[idx]
64
+ image = Image.fromarray(image_tensor.numpy())
65
+
66
+ if self.transform is not None:
67
+ image = self.transform(image)
68
+
69
+ if self.target_transform is not None:
70
+ target = self.target_transform(target)
71
+
72
+ return image, target
73
+
74
+ def extra_repr(self) -> str:
75
+ return f"split={self._split}"