diff --git a/shenzhen_solitaire/board.py b/shenzhen_solitaire/board.py index 1989150..d9f403d 100644 --- a/shenzhen_solitaire/board.py +++ b/shenzhen_solitaire/board.py @@ -51,10 +51,12 @@ class Position(enum.Enum): class Board: """Solitaire board""" + # Starting max row is 5, if the last one is a `1`, we can put a `2` - `9` on top of it, resulting in 13 cards + MAX_ROW_SIZE = 13 + def __init__(self) -> None: self.field: List[List[Card]] = [[]] * 8 - self.bunker: List[Union[Tuple[SpecialCard, int], - Optional[Card]]] = [None] * 3 + self.bunker: List[Union[Tuple[SpecialCard, int], Optional[Card]]] = [None] * 3 self.goal: Dict[NumberCard.Suit, int] = { NumberCard.Suit.Red: 0, NumberCard.Suit.Green: 0, @@ -130,8 +132,9 @@ class Board: special_cards[SpecialCard.Hua] += 1 for card in itertools.chain( - self.bunker, itertools.chain.from_iterable( - stack for stack in self.field if stack), ): + self.bunker, + itertools.chain.from_iterable(stack for stack in self.field if stack), + ): if isinstance(card, tuple): special_cards[card[0]] += 4 elif isinstance(card, SpecialCard): diff --git a/shenzhen_solitaire/card_detection/adjustment.py b/shenzhen_solitaire/card_detection/adjustment.py index f41194d..8a5c083 100644 --- a/shenzhen_solitaire/card_detection/adjustment.py +++ b/shenzhen_solitaire/card_detection/adjustment.py @@ -27,7 +27,7 @@ def get_square(adjustment: Adjustment, index_x: int = 0, adjustment.y + adjustment.h + adjustment.dy * index_y) -def _adjust_squares( +def adjust_squares( image: numpy.ndarray, count_x: int, count_y: int, @@ -79,19 +79,19 @@ def _adjust_squares( def adjust_field(image: numpy.ndarray) -> Adjustment: """Open configuration grid for the field""" - return _adjust_squares(image, 8, 5, Adjustment(42, 226, 15, 15, 119, 24)) + return adjust_squares(image, 8, 13, Adjustment(42, 226, 15, 15, 119, 24)) def adjust_bunker(image: numpy.ndarray) -> Adjustment: """Open configuration grid for the bunker""" - return _adjust_squares(image, 3, 1) + return adjust_squares(image, 3, 1) def adjust_hua(image: numpy.ndarray) -> Adjustment: """Open configuration grid for the flower card""" - return _adjust_squares(image, 1, 1) + return adjust_squares(image, 1, 1) def adjust_goal(image: numpy.ndarray) -> Adjustment: """Open configuration grid for the goal""" - return _adjust_squares(image, 3, 1) + return adjust_squares(image, 3, 1) diff --git a/shenzhen_solitaire/card_detection/board_parser.py b/shenzhen_solitaire/card_detection/board_parser.py index 56bdae9..401605e 100644 --- a/shenzhen_solitaire/card_detection/board_parser.py +++ b/shenzhen_solitaire/card_detection/board_parser.py @@ -2,50 +2,89 @@ import numpy as np from .configuration import Configuration -from ..board import Board +from ..board import Board, NumberCard, SpecialCard from . import card_finder import cv2 -from typing import Iterable, Any, List +from typing import Iterable, Any, List, Tuple, Union import itertools -def parse_board(image: np.ndarray, conf: Configuration) -> Board: - """Parse a screenshot of the game, using a given configuration""" +def grouper( + iterable: Iterable[Any], groupsize: int, fillvalue: Any = None +) -> Iterable[Iterable[Any]]: + "Collect data into fixed-length chunks or blocks" + args = [iter(iterable)] * groupsize + return itertools.zip_longest(*args, fillvalue=fillvalue) + + +def get_square_iterator( + image: np.ndarray, conf: Configuration, row_count: int, column_count: int +) -> Iterable[Tuple[np.ndarray, np.ndarray]]: + """Return iterator for both the square, as well as the matching card border""" fake_adjustments = conf.field_adjustment fake_adjustments.x -= 5 fake_adjustments.y -= 5 fake_adjustments.h += 10 fake_adjustments.w += 10 - row_count = 13 - column_count = 8 - - def grouper(iterable: Iterable[Any], groupsize: int, fillvalue: Any = None) -> Iterable[Any]: - "Collect data into fixed-length chunks or blocks" - args = [iter(iterable)] * groupsize - return itertools.zip_longest(*args, fillvalue=fillvalue) - squares = card_finder.get_field_squares( - image, conf.field_adjustment, count_x=row_count, count_y=column_count + image, fake_adjustments, count_x=row_count, count_y=column_count + ) + border_squares = card_finder.get_field_squares( + image, conf.border_adjustment, count_x=row_count, count_y=column_count ) grouped_squares = grouper(squares, row_count) - result = Board() - for group_index, square_group in enumerate(grouped_squares): - group_field = [] - for index, square in enumerate(square_group): - best_val = None - best_name = None - for template, name in conf.catalogue: - res = cv2.matchTemplate(square, template, cv2.TM_CCOEFF_NORMED) - min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) - if best_val is None or max_val > best_val: - best_val = max_val - best_name = name - assert best_name is not None - group_field.append(best_name) + grouped_border_squares = grouper(border_squares, row_count) + return zip(grouped_squares, grouped_border_squares) - # print(f"\t{best_val}: {best_name}") - # cv2.imshow("Catalogue", cv2.resize(square, (500, 500))) - # cv2.waitKey() + +def match_template(template: np.ndarray, search_image: np.ndarray) -> int: + """Return matchiness for the template on the search image""" + res = cv2.matchTemplate(search_image, template, cv2.TM_CCOEFF_NORMED) + min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) + assert isinstance(max_val, int) + return max_val + + +def parse_square( + square: np.ndarray, border: np.ndarray, conf: Configuration +) -> Tuple[Union[NumberCard, SpecialCard], bool]: + square_fits = [ + (match_template(template, square), name) for template, name in conf.catalogue + ] + best_val, best_name = max(square_fits, key=lambda x: x[0]) + + best_border = max( + match_template(template=template, search_image=border) + for template in conf.card_border + ) + best_empty = max( + match_template(template=template, search_image=border) + for template in conf.empty_card + ) + + assert best_name is not None + assert best_empty is not None + assert best_border is not None + row_finished = best_empty > best_border + + return (best_name, row_finished) + + +def parse_board(image: np.ndarray, conf: Configuration) -> Board: + """Parse a screenshot of the game, using a given configuration""" + square_iterator = get_square_iterator( + image, conf, row_count=Board.MAX_ROW_SIZE, column_count=8 + ) + result = Board() + for group_index, (square_group, border_group) in enumerate(square_iterator): + group_field = [] + for index, (square, border_square) in enumerate( + zip(square_group, border_group) + ): + value, row_finished = parse_square(square, border_square, conf) + group_field.append(value) + if row_finished: + break result.field[group_index] = group_field diff --git a/shenzhen_solitaire/card_detection/card_finder.py b/shenzhen_solitaire/card_detection/card_finder.py index 08f6223..c664612 100644 --- a/shenzhen_solitaire/card_detection/card_finder.py +++ b/shenzhen_solitaire/card_detection/card_finder.py @@ -26,61 +26,6 @@ def get_field_squares( squares.append(get_square(adjustment, index_x, index_y)) return _extract_squares(image, squares) - -class Cardcolor(enum.Enum): - """Relevant colors for different types of cards""" - - Bai = (65, 65, 65) - Black = (0, 0, 0) - Red = (22, 48, 178) - Green = (76, 111, 19) - Background = (178, 194, 193) - -def _find_single_square( - search_square: np.ndarray, template_square: np.ndarray -) -> Tuple[int, Tuple[int, int]]: - assert search_square.shape[0] >= template_square.shape[0] - assert search_square.shape[1] >= template_square.shape[1] - best_result: Optional[Tuple[int, Tuple[int, int]]] = None - for margin_x, margin_y in itertools.product( - range(search_square.shape[0], template_square.shape[0] - 1, -1), - range(search_square.shape[1], template_square.shape[1] - 1, -1), - ): - search_region = search_square[ - margin_x - template_square.shape[0] : margin_x, - margin_y - template_square.shape[1] : margin_y, - ] - count = cv2.countNonZero(search_region - template_square) - if not best_result or count < best_result[0]: # pylint: disable=E1136 - best_result = ( - count, - ( - margin_x - template_square.shape[0], - margin_y - template_square.shape[1], - ), - ) - assert best_result - return best_result - - -def find_square( - search_square: np.ndarray, squares: List[np.ndarray] -) -> Tuple[np.ndarray, int]: - """Compare all squares in squares with search_square, return best matching one. - Requires all squares to be simplified.""" - best_set = False - best_square: Optional[np.ndarray] = None - best_count = 0 - for square in squares: - count, _ = _find_single_square(search_square, square) - if not best_set or count < best_count: - best_set = True - best_square = square - best_count = count - assert isinstance(best_square, np.ndarray) - return (best_square, best_count) - - def catalogue_cards(squares: List[np.ndarray]) -> List[Tuple[np.ndarray, Card]]: """Run manual cataloging for given squares""" cv2.namedWindow("Catalogue", cv2.WINDOW_NORMAL) @@ -88,6 +33,7 @@ def catalogue_cards(squares: List[np.ndarray]) -> List[Tuple[np.ndarray, Card]]: result: List[Tuple[np.ndarray, Card]] = [] print("Card ID is [B]ai, [Z]hong, [F]a, [H]ua, [R]ed, [G]reen, [B]lack") print("Numbercard e.g. R3") + abort_row = 'a' special_card_map = { "b": SpecialCard.Bai, "z": SpecialCard.Zhong, @@ -127,5 +73,5 @@ def catalogue_cards(squares: List[np.ndarray]) -> List[Tuple[np.ndarray, Card]]: break cv2.destroyWindow("Catalogue") - assert result is not None + assert len(result) == len(squares) return result diff --git a/shenzhen_solitaire/card_detection/configuration.py b/shenzhen_solitaire/card_detection/configuration.py index d59fb7c..b973c8b 100644 --- a/shenzhen_solitaire/card_detection/configuration.py +++ b/shenzhen_solitaire/card_detection/configuration.py @@ -1,9 +1,10 @@ """Contains configuration class""" import zipfile import json -from typing import List, Tuple, Dict +from typing import List, Tuple, Dict, Union import io import dataclasses +from dataclasses import dataclass import tempfile import cv2 @@ -12,95 +13,137 @@ from . import adjustment from . import card_finder from .. import board +ADJUSTMENT_FILE_NAME = "adjustment.json" +FIELD_ADJUSTMENT_KEY = "field" +BORDER_ADJUSTMENT_KEY = "border" +TEMPLATES_DIRECTORY = "templates" +CARD_BORDER_DIRECTORY = "borders" +EMPTY_CARD_DIRECTORY = "empty_cards" + +PICTURE_EXTENSION = "png" + + +@dataclass class Configuration: """Configuration for solitaire cv""" - ADJUSTMENT_FILE_NAME = "adjustment.json" - TEMPLATES_DIRECTORY = "templates" + field_adjustment: adjustment.Adjustment + border_adjustment: adjustment.Adjustment + catalogue: List[Tuple[np.ndarray, Union[board.SpecialCard, board.NumberCard]]] + card_border: List[np.ndarray] + empty_card: List[np.ndarray] + meta: Dict[str, str] = dataclasses.field(default_factory=dict) - def __init__( - self, - adj: adjustment.Adjustment, - catalogue: List[Tuple[np.ndarray, board.Card]], - meta: Dict[str, str], - ) -> None: - self.field_adjustment = adj - self.catalogue = catalogue - self.meta = meta - def save(self, filename: str) -> None: - """Save configuration to zip archive""" - zip_stream = io.BytesIO() +def _save_catalogue( + zip_file: zipfile.ZipFile, catalogue: List[Tuple[np.ndarray, board.Card]] +) -> None: + for counter, (square, card) in enumerate(catalogue, start=1): + fd, myfile = tempfile.mkstemp(suffix=f".{PICTURE_EXTENSION}") - with zipfile.ZipFile(zip_stream, "w") as zip_file: - zip_file.writestr( - self.ADJUSTMENT_FILE_NAME, - json.dumps(dataclasses.asdict(self.field_adjustment)), + cv2.imwrite(myfile, square) + file_name = "" + if isinstance(card, board.SpecialCard): + file_name = f"s{card.value}-{card.name}-{counter}" + elif isinstance(card, board.NumberCard): + file_name = ( + f"n{card.suit.value}{card.number}" f"-{card.suit.name}-{counter}" ) - - counter = 0 - extension = ".png" - for square, card in self.catalogue: - counter += 1 - fd, myfile = tempfile.mkstemp() - cv2.imwrite(myfile + extension, square) - file_name = "" - if isinstance(card, board.SpecialCard): - file_name = f"s{card.value}-{card.name}-{counter}{extension}" - elif isinstance(card, board.NumberCard): - file_name = ( - f"n{card.suit.value}{card.number}" - f"-{card.suit.name}-{counter}{extension}" - ) - else: - raise AssertionError() - zip_file.write(myfile + extension, arcname=f"{self.TEMPLATES_DIRECTORY}/{file_name}") - - with open(filename, "wb") as zip_archive: - zip_archive.write(zip_stream.getvalue()) - - @staticmethod - def load(filename: str) -> "Configuration": - """Load configuration from zip archive""" - - def _parse_file_name(card_filename: str) -> board.Card: - assert card_filename.startswith(Configuration.TEMPLATES_DIRECTORY + "/") - pure_name = card_filename[len(Configuration.TEMPLATES_DIRECTORY + "/") :] - if pure_name[0] == "s": - return board.SpecialCard(int(pure_name[1])) - if pure_name[0] == "n": - return board.NumberCard( - suit=board.NumberCard.Suit(int(pure_name[1])), - number=int(pure_name[2]), - ) + else: raise AssertionError() + zip_file.write( + myfile, arcname=f"{TEMPLATES_DIRECTORY}/{file_name}.{PICTURE_EXTENSION}" + ) + +def _save_adjustments( + zip_file: zipfile.ZipFile, conf: Configuration +) -> None: + adjustments = {} + adjustments[FIELD_ADJUSTMENT_KEY] = dataclasses.asdict(conf.field_adjustment) + adjustments[BORDER_ADJUSTMENT_KEY] = dataclasses.asdict(conf.border_adjustment) - catalogue: List[Tuple[np.ndarray, board.Card]] = [] - with zipfile.ZipFile(filename, "r") as zip_file: - adj = adjustment.Adjustment( - **json.loads(zip_file.read(Configuration.ADJUSTMENT_FILE_NAME)) - ) - mydir=tempfile.mkdtemp() - for template_filename in ( - x - for x in zip_file.namelist() - if x.startswith(Configuration.TEMPLATES_DIRECTORY + "/") - ): - myfile = zip_file.extract(template_filename, path=mydir) - catalogue.append( - ( - cv2.imread(myfile), - _parse_file_name(template_filename), - ) - ) - assert catalogue[-1][0] is not None - return Configuration(adj=adj, catalogue=catalogue, meta={}) + zip_file.writestr( + ADJUSTMENT_FILE_NAME, json.dumps(adjustment), + ) - @staticmethod - def generate(image: np.ndarray) -> "Configuration": - """Generate a configuration with user input""" - adj = adjustment.adjust_field(image) - squares = card_finder.get_field_squares(image, adj, 5, 8) - catalogue = card_finder.catalogue_cards(squares) - return Configuration(adj=adj, catalogue=catalogue, meta={}) + +def save(conf: Configuration, filename: str) -> None: + """Save configuration to zip archive""" + zip_stream = io.BytesIO() + + with zipfile.ZipFile(zip_stream, "w") as zip_file: + _save_adjustments(zip_file, conf) + _save_catalogue(zip_file, conf.catalogue) + + with open(filename, "wb") as zip_archive: + zip_archive.write(zip_stream.getvalue()) + + +def _parse_file_name(card_filename: str) -> board.Card: + assert card_filename.startswith(TEMPLATES_DIRECTORY + "/") + pure_name = card_filename[len(TEMPLATES_DIRECTORY + "/") :] + if pure_name[0] == "s": + return board.SpecialCard(int(pure_name[1])) + if pure_name[0] == "n": + return board.NumberCard( + suit=board.NumberCard.Suit(int(pure_name[1])), number=int(pure_name[2]), + ) + raise AssertionError("Template files need to start with either 's' or 'n'") + + +def _load_catalogue(zip_file: zipfile.ZipFile,) -> List[Tuple[np.ndarray, board.Card]]: + + catalogue: List[Tuple[np.ndarray, board.Card]] = [] + + mydir = tempfile.mkdtemp() + for template_filename in ( + x for x in zip_file.namelist() if x.startswith(TEMPLATES_DIRECTORY + "/") + ): + myfile = zip_file.extract(template_filename, path=mydir) + catalogue.append((cv2.imread(myfile), _parse_file_name(template_filename),)) + assert catalogue[-1][0] is not None + return catalogue + + +def _load_dir(zip_file: zipfile.ZipFile, dirname: str) -> List[np.ndarray]: + mydir = tempfile.mkdtemp() + image_filenames = [ + image_filename + for image_filename in ( + x for x in zip_file.namelist() if x.startswith(dirname + "/") + ) + ] + images = [ + cv2.imread(zip_file.extract(image_filename, path=mydir)) + for image_filename in image_filenames + ] + return images + + +def load(filename: str) -> Configuration: + """Load configuration from zip archive""" + + with zipfile.ZipFile(filename, "r") as zip_file: + adjustment_dict = json.loads(zip_file.read(ADJUSTMENT_FILE_NAME)) + + return Configuration( + field_adjustment=adjustment.Adjustment( + **adjustment_dict[FIELD_ADJUSTMENT_KEY] + ), + border_adjustment=adjustment.Adjustment( + **adjustment_dict[BORDER_ADJUSTMENT_KEY] + ), + catalogue=_load_catalogue(zip_file), + card_border=_load_dir(zip_file, CARD_BORDER_DIRECTORY), + empty_card=_load_dir(zip_file, EMPTY_CARD_DIRECTORY), + meta={}, + ) + + +def generate(image: np.ndarray) -> Configuration: + """Generate a configuration with user input""" + adj = adjustment.adjust_field(image) + squares = card_finder.get_field_squares(image, adj, 5, 8) + catalogue = card_finder.catalogue_cards(squares) + return Configuration(field_adjustment=adj, catalogue=catalogue, meta={}) diff --git a/tools/generate_border.py b/tools/generate_border.py new file mode 100644 index 0000000..30ebced --- /dev/null +++ b/tools/generate_border.py @@ -0,0 +1,23 @@ +import numpy as np +import cv2 +from shenzhen_solitaire.card_detection.configuration import Configuration +import shenzhen_solitaire.card_detection.adjustment as adjustment +import shenzhen_solitaire.card_detection.card_finder as card_finder +import copy + + +def main() -> None: + """Generate a configuration""" + image = cv2.imread("pictures/20190809172213_1.jpg") + + border_adjustment = adjustment.adjust_squares(image, count_x=8, count_y=13) + border_square_pos = adjustment.adjust_squares( + image, count_x=1, count_y=1, adjustment=copy.deepcopy(border_adjustment) + ) + border_square = card_finder.get_field_squares(image, border_square_pos, 1, 1) + empty_square = card_finder.get_field_squares(image, border_square_pos, 1, 1) + + +if __name__ == "__main__": + main() + diff --git a/tools/generate_config.py b/tools/generate_config.py index 2ffe340..516deb9 100644 --- a/tools/generate_config.py +++ b/tools/generate_config.py @@ -1,14 +1,16 @@ import numpy as np import cv2 -from shenzhen_solitaire.card_detection.configuration import Configuration +import shenzhen_solitaire.card_detection.configuration as configuration + def main() -> None: """Generate a configuration""" image = cv2.imread("pictures/20190809172213_1.jpg") - generated_config = Configuration.generate(image) - generated_config.save('test_config.zip') + generated_config = configuration.generate(image) + configuration.save(generated_config, "test_config.zip") + if __name__ == "__main__": main() - \ No newline at end of file +