diff --git a/shenzhen_solitaire/board.py b/shenzhen_solitaire/board.py index d9f403d..2283498 100644 --- a/shenzhen_solitaire/board.py +++ b/shenzhen_solitaire/board.py @@ -53,9 +53,10 @@ class Board: # Starting max row is 5, if the last one is a `1`, we can put a `2` - `9` on top of it, resulting in 13 cards MAX_ROW_SIZE = 13 + MAX_COLUMN_SIZE = 8 def __init__(self) -> None: - self.field: List[List[Card]] = [[]] * 8 + self.field: List[List[Card]] = [[]] * Board.MAX_COLUMN_SIZE self.bunker: List[Union[Tuple[SpecialCard, int], Optional[Card]]] = [None] * 3 self.goal: Dict[NumberCard.Suit, int] = { NumberCard.Suit.Red: 0, diff --git a/shenzhen_solitaire/card_detection/adjustment.py b/shenzhen_solitaire/card_detection/adjustment.py index 8a5c083..0fd3596 100644 --- a/shenzhen_solitaire/card_detection/adjustment.py +++ b/shenzhen_solitaire/card_detection/adjustment.py @@ -10,31 +10,36 @@ import cv2 @dataclass class Adjustment: """Configuration for a grid""" - x: int - y: int - w: int - h: int - dx: int - dy: int + + x: int = 0 + y: int = 0 + w: int = 0 + h: int = 0 + dx: int = 0 + dy: int = 0 -def get_square(adjustment: Adjustment, index_x: int = 0, - index_y: int = 0) -> Tuple[int, int, int, int]: +def get_square( + adjustment: Adjustment, index_x: int = 0, index_y: int = 0 +) -> Tuple[int, int, int, int]: """Get one square from index and adjustment""" - return (adjustment.x + adjustment.dx * index_x, - adjustment.y + adjustment.dy * index_y, - adjustment.x + adjustment.w + adjustment.dx * index_x, - adjustment.y + adjustment.h + adjustment.dy * index_y) + return ( + adjustment.x + adjustment.dx * index_x, + adjustment.y + adjustment.dy * index_y, + adjustment.x + adjustment.w + adjustment.dx * index_x, + adjustment.y + adjustment.h + adjustment.dy * index_y, + ) def adjust_squares( - image: numpy.ndarray, - count_x: int, - count_y: int, - adjustment: Optional[Adjustment] = None) -> Adjustment: + image: numpy.ndarray, + count_x: int, + count_y: int, + adjustment: Optional[Adjustment] = None, +) -> Adjustment: if not adjustment: - adjustment = Adjustment(0, 0, 0, 0, 0, 0) + adjustment = Adjustment(w=10, h=10) def _adjustment_step(keycode: int) -> None: assert adjustment is not None @@ -59,21 +64,19 @@ def adjust_squares( while True: working_image = image.copy() - for index_x, index_y in itertools.product( - range(count_x), range(count_y)): + for index_x, index_y in itertools.product(range(count_x), range(count_y)): square = get_square(adjustment, index_x, index_y) - cv2.rectangle(working_image, - (square[0], square[1]), - (square[2], square[3]), - (0, 0, 0)) - cv2.imshow('Window', working_image) + cv2.rectangle( + working_image, (square[0], square[1]), (square[2], square[3]), (0, 0, 0) + ) + cv2.imshow("Window", working_image) keycode = cv2.waitKey(0) print(keycode) if keycode == 27: break _adjustment_step(keycode) - cv2.destroyWindow('Window') + cv2.destroyWindow("Window") return adjustment diff --git a/shenzhen_solitaire/card_detection/board_parser.py b/shenzhen_solitaire/card_detection/board_parser.py index 401605e..e2774d5 100644 --- a/shenzhen_solitaire/card_detection/board_parser.py +++ b/shenzhen_solitaire/card_detection/board_parser.py @@ -2,7 +2,7 @@ import numpy as np from .configuration import Configuration -from ..board import Board, NumberCard, SpecialCard +from ..board import Board, NumberCard, SpecialCard, Card from . import card_finder import cv2 from typing import Iterable, Any, List, Tuple, Union @@ -37,12 +37,13 @@ def get_square_iterator( return zip(grouped_squares, grouped_border_squares) -def match_template(template: np.ndarray, search_image: np.ndarray) -> int: +def match_template(template: np.ndarray, search_image: np.ndarray) -> float: """Return matchiness for the template on the search image""" + res = cv2.matchTemplate(search_image, template, cv2.TM_CCOEFF_NORMED) min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) - assert isinstance(max_val, int) - return max_val + assert isinstance(max_val, (int, float)) + return float(max_val) def parse_square( @@ -70,13 +71,13 @@ def parse_square( return (best_name, row_finished) -def parse_board(image: np.ndarray, conf: Configuration) -> Board: +def parse_field(image: np.ndarray, conf: Configuration) -> List[List[Card]]: """Parse a screenshot of the game, using a given configuration""" square_iterator = get_square_iterator( - image, conf, row_count=Board.MAX_ROW_SIZE, column_count=8 + image, conf, row_count=Board.MAX_ROW_SIZE, column_count=Board.MAX_COLUMN_SIZE ) - result = Board() - for group_index, (square_group, border_group) in enumerate(square_iterator): + result = [] + for square_group, border_group in square_iterator: group_field = [] for index, (square, border_square) in enumerate( zip(square_group, border_group) @@ -86,6 +87,12 @@ def parse_board(image: np.ndarray, conf: Configuration) -> Board: if row_finished: break - result.field[group_index] = group_field + result.append(group_field) return result + + +def parse_board(image: np.ndarray, conf: Configuration) -> Board: + result = Board() + result.field = parse_field(image, conf) + return result diff --git a/shenzhen_solitaire/card_detection/configuration.py b/shenzhen_solitaire/card_detection/configuration.py index b973c8b..1cb4da7 100644 --- a/shenzhen_solitaire/card_detection/configuration.py +++ b/shenzhen_solitaire/card_detection/configuration.py @@ -14,8 +14,12 @@ from . import card_finder from .. import board ADJUSTMENT_FILE_NAME = "adjustment.json" + FIELD_ADJUSTMENT_KEY = "field" BORDER_ADJUSTMENT_KEY = "border" +GOAL_ADJUSTMENT_KEY = "goal" +BUNKER_ADJUSTMENT_KEY = "bunker" +HUA_ADJUSTMENT_KEY = "hua" TEMPLATES_DIRECTORY = "templates" CARD_BORDER_DIRECTORY = "borders" @@ -28,11 +32,26 @@ PICTURE_EXTENSION = "png" class Configuration: """Configuration for solitaire cv""" - field_adjustment: adjustment.Adjustment - border_adjustment: adjustment.Adjustment - catalogue: List[Tuple[np.ndarray, Union[board.SpecialCard, board.NumberCard]]] - card_border: List[np.ndarray] - empty_card: List[np.ndarray] + field_adjustment: adjustment.Adjustment = dataclasses.field( + default_factory=adjustment.Adjustment + ) + border_adjustment: adjustment.Adjustment = dataclasses.field( + default_factory=adjustment.Adjustment + ) + goal_adjustment: adjustment.Adjustment = dataclasses.field( + default_factory=adjustment.Adjustment + ) + bunker_adjustment: adjustment.Adjustment = dataclasses.field( + default_factory=adjustment.Adjustment + ) + hua_adjustment: adjustment.Adjustment = dataclasses.field( + default_factory=adjustment.Adjustment + ) + catalogue: List[ + Tuple[np.ndarray, Union[board.SpecialCard, board.NumberCard]] + ] = dataclasses.field(default_factory=list) + card_border: List[np.ndarray] = dataclasses.field(default_factory=list) + empty_card: List[np.ndarray] = dataclasses.field(default_factory=list) meta: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -55,10 +74,9 @@ def _save_catalogue( zip_file.write( myfile, arcname=f"{TEMPLATES_DIRECTORY}/{file_name}.{PICTURE_EXTENSION}" ) - -def _save_adjustments( - zip_file: zipfile.ZipFile, conf: Configuration -) -> None: + + +def _save_adjustments(zip_file: zipfile.ZipFile, conf: Configuration) -> None: adjustments = {} adjustments[FIELD_ADJUSTMENT_KEY] = dataclasses.asdict(conf.field_adjustment) adjustments[BORDER_ADJUSTMENT_KEY] = dataclasses.asdict(conf.border_adjustment) @@ -75,7 +93,7 @@ def save(conf: Configuration, filename: str) -> None: with zipfile.ZipFile(zip_stream, "w") as zip_file: _save_adjustments(zip_file, conf) _save_catalogue(zip_file, conf.catalogue) - + # TODO: Save card_borders and emtpy_card with open(filename, "wb") as zip_archive: zip_archive.write(zip_stream.getvalue()) @@ -98,7 +116,9 @@ def _load_catalogue(zip_file: zipfile.ZipFile,) -> List[Tuple[np.ndarray, board. mydir = tempfile.mkdtemp() for template_filename in ( - x for x in zip_file.namelist() if x.startswith(TEMPLATES_DIRECTORY + "/") + x + for x in zip_file.namelist() + if x.startswith(TEMPLATES_DIRECTORY + "/") and x != TEMPLATES_DIRECTORY + "/" ): myfile = zip_file.extract(template_filename, path=mydir) catalogue.append((cv2.imread(myfile), _parse_file_name(template_filename),)) @@ -111,7 +131,9 @@ def _load_dir(zip_file: zipfile.ZipFile, dirname: str) -> List[np.ndarray]: image_filenames = [ image_filename for image_filename in ( - x for x in zip_file.namelist() if x.startswith(dirname + "/") + x + for x in zip_file.namelist() + if x.startswith(dirname + "/") and x != dirname + "/" ) ] images = [ @@ -127,18 +149,28 @@ def load(filename: str) -> Configuration: with zipfile.ZipFile(filename, "r") as zip_file: adjustment_dict = json.loads(zip_file.read(ADJUSTMENT_FILE_NAME)) - return Configuration( + result = Configuration( field_adjustment=adjustment.Adjustment( - **adjustment_dict[FIELD_ADJUSTMENT_KEY] + **adjustment_dict.get(FIELD_ADJUSTMENT_KEY, {}) ), border_adjustment=adjustment.Adjustment( - **adjustment_dict[BORDER_ADJUSTMENT_KEY] + **adjustment_dict.get(BORDER_ADJUSTMENT_KEY, {}) + ), + goal_adjustment=adjustment.Adjustment( + **adjustment_dict.get(GOAL_ADJUSTMENT_KEY, {}) + ), + bunker_adjustment=adjustment.Adjustment( + **adjustment_dict.get(BUNKER_ADJUSTMENT_KEY, {}) + ), + hua_adjustment=adjustment.Adjustment( + **adjustment_dict.get(HUA_ADJUSTMENT_KEY, {}) ), catalogue=_load_catalogue(zip_file), card_border=_load_dir(zip_file, CARD_BORDER_DIRECTORY), empty_card=_load_dir(zip_file, EMPTY_CARD_DIRECTORY), meta={}, ) + return result def generate(image: np.ndarray) -> Configuration: diff --git a/test/test_cv.py b/test/test_cv.py index 7cccd86..b85307f 100644 --- a/test/test_cv.py +++ b/test/test_cv.py @@ -7,7 +7,7 @@ import numpy as np from shenzhen_solitaire import board from shenzhen_solitaire.card_detection import adjustment, board_parser -from shenzhen_solitaire.card_detection.configuration import Configuration +import shenzhen_solitaire.card_detection.configuration as configuration from . import boards @@ -16,10 +16,8 @@ class CardDetectionTest(unittest.TestCase): """Parse a configuration and a board""" image = cv2.imread("pictures/20190809172206_1.jpg") - loaded_config = Configuration.load("test_config.zip") + loaded_config = configuration.load("test_config.zip") my_board = board_parser.parse_board(image, loaded_config) - for rows in zip(boards.B20190809172206_1.field, my_board.field): - for good_cell, test_cell in zip(*rows): - self.assertEqual(good_cell, test_cell) - + for correct_row, my_row in zip(boards.B20190809172206_1.field, my_board.field): + self.assertListEqual(correct_row, my_row) diff --git a/test_config.zip b/test_config.zip index 8e6c0bd..f3b6a3e 100644 Binary files a/test_config.zip and b/test_config.zip differ diff --git a/tools/generate_border.py b/tools/generate_border.py index 30ebced..754d573 100644 --- a/tools/generate_border.py +++ b/tools/generate_border.py @@ -1,9 +1,13 @@ -import numpy as np +import copy +import dataclasses +import json + import cv2 -from shenzhen_solitaire.card_detection.configuration import Configuration +import numpy as np + import shenzhen_solitaire.card_detection.adjustment as adjustment import shenzhen_solitaire.card_detection.card_finder as card_finder -import copy +from shenzhen_solitaire.card_detection.configuration import Configuration def main() -> None: @@ -15,9 +19,15 @@ def main() -> None: image, count_x=1, count_y=1, adjustment=copy.deepcopy(border_adjustment) ) border_square = card_finder.get_field_squares(image, border_square_pos, 1, 1) - empty_square = card_finder.get_field_squares(image, border_square_pos, 1, 1) + empty_square_pos = adjustment.adjust_squares( + image, count_x=1, count_y=1, adjustment=copy.deepcopy(border_adjustment) + ) + empty_square = card_finder.get_field_squares(image, empty_square_pos, 1, 1) + + cv2.imwrite("/tmp/border_square.png", border_square[0]) + cv2.imwrite("/tmp/empty_square.png", empty_square[0]) + print(json.dumps(dataclasses.asdict(border_adjustment))) if __name__ == "__main__": main() -