Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,5 @@ temp_image.jpg

#
.DS_Store
.python-version
.python-version
tmp.jpg
2 changes: 1 addition & 1 deletion capybara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .utils import *
from .vision import *

__version__ = '0.11.0'
__version__ = "0.11.0"
36 changes: 12 additions & 24 deletions capybara/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import numpy as np
from dacite import from_dict

from .structures import Box, Boxes, Polygon, Polygons
from .structures import Box, Boxes, Keypoints, KeypointsList, Polygon, Polygons

__all__ = [
'EnumCheckMixin', 'DataclassCopyMixin', 'DataclassToJsonMixin',
'dict_to_jsonable',
"EnumCheckMixin",
"DataclassCopyMixin",
"DataclassToJsonMixin",
"dict_to_jsonable",
]


Expand All @@ -27,18 +29,14 @@ def dict_to_jsonable(
out[k] = jsonable_func[k](v)
else:
if isinstance(v, (Box, Boxes)):
out[k] = v.convert('XYXY').numpy().astype(
float).round().tolist()
elif isinstance(v, (Polygon, Polygons)):
out[k] = v.convert("XYXY").numpy().astype(float).round().tolist()
elif isinstance(v, (Keypoints, KeypointsList, Polygon, Polygons)):
out[k] = v.numpy().astype(float).round().tolist()
elif isinstance(v, (np.ndarray, np.generic)):
# include array and scalar, if you want jsonable image please use jsonable_func
out[k] = v.tolist()
elif isinstance(v, (list, tuple)):
out[k] = [
dict_to_jsonable(x, jsonable_func) if isinstance(
x, dict) else x
for x in v
]
out[k] = [dict_to_jsonable(x, jsonable_func) if isinstance(x, dict) else x for x in v]
elif isinstance(v, Enum):
out[k] = v.name
elif isinstance(v, Mapping):
Expand All @@ -55,7 +53,6 @@ def dict_to_jsonable(


class EnumCheckMixin:

@classmethod
def obj_to_enum(cls: Enum, obj: Any):
if isinstance(obj, str):
Expand All @@ -75,26 +72,17 @@ def obj_to_enum(cls: Enum, obj: Any):


class DataclassCopyMixin:

def __copy__(self):
return self.__class__(**{
field: getattr(self, field)
for field in self.__dataclass_fields__
})
return self.__class__(**{field: getattr(self, field) for field in self.__dataclass_fields__})

def __deepcopy__(self, memo):
out = asdict(self, dict_factory=OrderedDict)
return from_dict(data_class=self.__class__, data=out)


class DataclassToJsonMixin:

def __init__(self):
self.jsonable_func = None
jsonable_func = None

def be_jsonable(self, dict_factory=OrderedDict):
d = asdict(self, dict_factory=dict_factory)
return dict_to_jsonable(d, getattr(self, 'jsonable_func', None), dict_factory)

def regist_jsonable_func(self, jsonable_func: Optional[Dict[str, Callable]] = None):
self.jsonable_func = jsonable_func
return dict_to_jsonable(d, jsonable_func=self.jsonable_func, dict_factory=dict_factory)
81 changes: 40 additions & 41 deletions capybara/structures/keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import numpy as np

from ..typing import _Number
from .boxes import Box, Boxes

__all__ = ['Keypoints', 'KeypointsList']
__all__ = ["Keypoints", "KeypointsList"]


_Keypoints = Union[
Expand All @@ -25,17 +24,17 @@


class Keypoints:
'''
"""
This structure has shape (K, 3) or (K, 2) where K is the number of keypoints.
The visibility flag follows the COCO format and must be one of three integers:
* v=0: not labeled (in which case x=y=0)
* v=1: labeled but not visible
* v=2: labeled and visible
'''
"""

def __init__(self, array: _Keypoints, cmap='rainbow', is_normalized: bool = False):
def __init__(self, array: _Keypoints, cmap="rainbow", is_normalized: bool = False):
self._array = self._check_valid_array(array)
steps = np.linspace(0., 1., self._array.shape[-2])
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
color_map = matplotlib.colormaps[cmap]
self._point_colors = np.array(color_map(steps, bytes=True))[..., :3].tolist()
self._is_normalized = is_normalized
Expand All @@ -62,16 +61,16 @@ def _check_valid_array(self, array: Any) -> np.ndarray:
if cond3:
array = array.numpy()
else:
array = np.array(array, dtype='float32')
array = np.array(array, dtype="float32")

if not array.ndim == 2:
raise ValueError(f"Input array ndim = {array.ndim} is not 2, which is invalid.")

if not array.shape[-1] in [2, 3]:
if array.shape[-1] not in [2, 3]:
raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not in [2, 3], which is invalid.")

if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()):
raise ValueError('Given array is invalid because of its labels. (array[..., 2])')
raise ValueError("Given array is invalid because of its labels. (array[..., 2])")
return array.copy()

def numpy(self) -> np.ndarray:
Expand All @@ -92,7 +91,7 @@ def scale(self, fx: float, fy: float) -> "Keypoints":

def normalize(self, w: float, h: float) -> "Keypoints":
if self.is_normalized:
warn(f'Normalized keypoints are forced to do normalization.')
warn("Normalized keypoints are forced to do normalization.")
arr = self._array.copy()
arr[..., :2] = arr[..., :2] / (w, h)
kpts = self.__class__(arr)
Expand All @@ -101,37 +100,33 @@ def normalize(self, w: float, h: float) -> "Keypoints":

def denormalize(self, w: float, h: float) -> "Keypoints":
if not self.is_normalized:
warn(f'Non-normalized keypoints is forced to do denormalization.')
warn("Non-normalized keypoints is forced to do denormalization.")
arr = self._array.copy()
arr[..., :2] = arr[..., :2] * (w, h)
kpts = self.__class__(arr)
kpts._is_normalized = False
return kpts

@ property
@property
def is_normalized(self) -> bool:
return self._is_normalized

@ property
@property
def point_colors(self) -> List[Tuple[int, int, int]]:
return [
tuple([int(x) for x in cs])
for cs in self._point_colors
]
return [tuple([int(x) for x in cs]) for cs in self._point_colors]

@ point_colors.setter
@point_colors.setter
def set_point_colors(self, cmap: str):
steps = np.linspace(0., 1., self._array.shape[-2])
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)


class KeypointsList:

def __init__(self, array: _KeypointsList, cmap='rainbow', is_normalized: bool = False) -> None:
def __init__(self, array: _KeypointsList, cmap="rainbow", is_normalized: bool = False) -> None:
self._array = self._check_valid_array(array).copy()
self._is_normalized = is_normalized
if len(self._array):
steps = np.linspace(0., 1., self._array.shape[-2])
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)
else:
self._point_colors = None
Expand All @@ -146,7 +141,7 @@ def __getitem__(self, item) -> Any:

def __setitem__(self, item, value):
if not isinstance(value, (Keypoints, KeypointsList)):
raise TypeError(f'Input value is not a keypoint or keypoints')
raise TypeError("Input value is not a keypoint or keypoints")

if isinstance(item, (int, np.ndarray, list, slice)):
self._array[item] = value._array
Expand All @@ -166,9 +161,13 @@ def __eq__(self, value: object) -> bool:
def _check_valid_array(self, array: Any) -> np.ndarray:
cond1 = isinstance(array, np.ndarray)
cond2 = isinstance(array, list) and len(array) == 0
cond3 = isinstance(array, list) and \
all(isinstance(x, (np.ndarray, Keypoints)) for x in array) or \
all(isinstance(y, tuple) for x in array for y in x)
cond3 = (
isinstance(array, list)
and (
all(isinstance(x, (np.ndarray, Keypoints)) for x in array)
or all(isinstance(y, tuple) for x in array for y in x)
)
)
cond4 = isinstance(array, self.__class__)

if not (cond1 or cond2 or cond3 or cond4):
Expand All @@ -177,9 +176,9 @@ def _check_valid_array(self, array: Any) -> np.ndarray:
if cond4:
array = array.numpy()
elif len(array) and isinstance(array[0], Keypoints):
array = np.array([x.numpy() for x in array], dtype='float32')
array = np.array([x.numpy() for x in array], dtype="float32")
else:
array = np.array(array, dtype='float32')
array = np.array(array, dtype="float32")

if len(array) == 0:
return array
Expand All @@ -191,7 +190,7 @@ def _check_valid_array(self, array: Any) -> np.ndarray:
raise ValueError(f"Input array's shape[-1] = {array.shape[-1]} is not 2 or 3, which is invalid.")

if array.shape[-1] == 3 and not ((array[..., 2] <= 2).all() and (array[..., 2] >= 0).all()):
raise ValueError('Given array is invalid because of its labels. (array[..., 2])')
raise ValueError("Given array is invalid because of its labels. (array[..., 2])")

return array

Expand All @@ -213,7 +212,7 @@ def scale(self, fx: float, fy: float) -> Any:

def normalize(self, w: float, h: float) -> "KeypointsList":
if self.is_normalized:
warn(f'Normalized keypoints_list is forced to do normalization.')
warn("Normalized keypoints_list is forced to do normalization.")
arr = self._array.copy()
arr[..., :2] = arr[..., :2] / (w, h)
kpts_list = self.__class__(arr)
Expand All @@ -222,29 +221,29 @@ def normalize(self, w: float, h: float) -> "KeypointsList":

def denormalize(self, w: float, h: float) -> "KeypointsList":
if not self.is_normalized:
warn(f'Non-normalized box is forced to do denormalization.')
warn("Non-normalized box is forced to do denormalization.")
arr = self._array.copy()
arr[..., :2] = arr[..., :2] * (w, h)
kpts_list = self.__class__(arr)
kpts_list._is_normalized = False
return kpts_list

@ property
@property
def is_normalized(self) -> bool:
return self._is_normalized

@ property
@property
def point_colors(self):
return [tuple(c) for c in self._point_colors[..., :3].tolist()]

@ point_colors.setter
@point_colors.setter
def set_point_colors(self, cmap: str):
steps = np.linspace(0., 1., self._array.shape[-2])
steps = np.linspace(0.0, 1.0, self._array.shape[-2])
self._point_colors = matplotlib.colormaps[cmap](steps, bytes=True)

@ classmethod
@classmethod
def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList":
'''
"""
Concatenates a list of KeypointsList into a single KeypointsList

Raises:
Expand All @@ -254,14 +253,14 @@ def cat(cls, keypoints_lists: List["KeypointsList"]) -> "KeypointsList":

Returns:
Keypoints: the concatenated Keypoints
'''
"""
if not isinstance(keypoints_lists, list):
raise TypeError('Given keypoints_list should be a list.')
raise TypeError("Given keypoints_list should be a list.")

if len(keypoints_lists) == 0:
raise ValueError('Given keypoints_list is empty.')
raise ValueError("Given keypoints_list is empty.")

if not all(isinstance(keypoints_list, KeypointsList) for keypoints_list in keypoints_lists):
raise TypeError('All type of elements in keypoints_lists must be KeypointsList.')
raise TypeError("All type of elements in keypoints_lists must be KeypointsList.")

return cls(np.concatenate([keypoints_list.numpy() for keypoints_list in keypoints_lists], axis=0))
Loading
Loading