diff --git a/shenzhen_solitaire/card_detection/board_parser.py b/shenzhen_solitaire/card_detection/board_parser.py index 81c50dc..56bdae9 100644 --- a/shenzhen_solitaire/card_detection/board_parser.py +++ b/shenzhen_solitaire/card_detection/board_parser.py @@ -4,34 +4,49 @@ import numpy as np from .configuration import Configuration from ..board import Board from . import card_finder +import cv2 +from typing import Iterable, Any, List +import itertools def parse_board(image: np.ndarray, conf: Configuration) -> Board: """Parse a screenshot of the game, using a given configuration""" + fake_adjustments = conf.field_adjustment + fake_adjustments.x -= 5 + fake_adjustments.y -= 5 + fake_adjustments.h += 10 + fake_adjustments.w += 10 + row_count = 13 + column_count = 8 + + def grouper(iterable: Iterable[Any], groupsize: int, fillvalue: Any = None) -> Iterable[Any]: + "Collect data into fixed-length chunks or blocks" + args = [iter(iterable)] * groupsize + return itertools.zip_longest(*args, fillvalue=fillvalue) + squares = card_finder.get_field_squares( - image, conf.field_adjustment, count_x=13, count_y=8) - squares = [card_finder.simplify(square)[0] for square in squares] - square_rows = [squares[13 * i:13 * (i + 1)] for i in range(8)] - empty_square = np.full( - shape=(conf.field_adjustment.w, - conf.field_adjustment.h), - fill_value=card_finder.GREYSCALE_COLOR[card_finder.Cardcolor.Background], - dtype=np.uint8) - assert empty_square.shape == squares[0].shape - result: Board = Board() - for row_id, square_row in enumerate(square_rows): - for square in square_row: - fitting_square, _ = card_finder.find_square( - square, [empty_square] + [x[0] for x in conf.catalogue]) - if np.array_equal(fitting_square, empty_square): - print("empty") - break - for cat_square, cardtype in conf.catalogue: - if np.array_equal(fitting_square, cat_square): - print(cardtype) - result.field[row_id].append(cardtype) - break - else: - print("did not find image") + image, conf.field_adjustment, count_x=row_count, count_y=column_count + ) + grouped_squares = grouper(squares, row_count) + result = Board() + for group_index, square_group in enumerate(grouped_squares): + group_field = [] + for index, square in enumerate(square_group): + best_val = None + best_name = None + for template, name in conf.catalogue: + res = cv2.matchTemplate(square, template, cv2.TM_CCOEFF_NORMED) + min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res) + if best_val is None or max_val > best_val: + best_val = max_val + best_name = name + assert best_name is not None + group_field.append(best_name) + + # print(f"\t{best_val}: {best_name}") + # cv2.imshow("Catalogue", cv2.resize(square, (500, 500))) + # cv2.waitKey() + + result.field[group_index] = group_field return result diff --git a/shenzhen_solitaire/card_detection/card_finder.py b/shenzhen_solitaire/card_detection/card_finder.py index b7149df..08f6223 100644 --- a/shenzhen_solitaire/card_detection/card_finder.py +++ b/shenzhen_solitaire/card_detection/card_finder.py @@ -3,25 +3,23 @@ from typing import List, Tuple, Optional, Dict import enum import itertools -import numpy as np # type: ignore -import cv2 # type: ignore +import numpy as np +import cv2 from .adjustment import Adjustment, get_square from ..board import Card, NumberCard, SpecialCard -def _extract_squares(image: np.ndarray, - squares: List[Tuple[int, - int, - int, - int]]) -> List[np.ndarray]: - return [image[square[1]:square[3], square[0]:square[2]].copy() - for square in squares] +def _extract_squares( + image: np.ndarray, squares: List[Tuple[int, int, int, int]] +) -> List[np.ndarray]: + return [ + image[square[1] : square[3], square[0] : square[2]].copy() for square in squares + ] -def get_field_squares(image: np.ndarray, - adjustment: Adjustment, - count_x: int, - count_y: int) -> List[np.ndarray]: +def get_field_squares( + image: np.ndarray, adjustment: Adjustment, count_x: int, count_y: int +) -> List[np.ndarray]: """Return all squares in the field, according to the adjustment""" squares = [] for index_x, index_y in itertools.product(range(count_y), range(count_x)): @@ -31,64 +29,43 @@ def get_field_squares(image: np.ndarray, class Cardcolor(enum.Enum): """Relevant colors for different types of cards""" + Bai = (65, 65, 65) Black = (0, 0, 0) Red = (22, 48, 178) Green = (76, 111, 19) Background = (178, 194, 193) - -GREYSCALE_COLOR = { - Cardcolor.Bai: 50, - Cardcolor.Black: 100, - Cardcolor.Red: 150, - Cardcolor.Green: 200, - Cardcolor.Background: 250} - - -def simplify(image: np.ndarray) -> Tuple[np.ndarray, Dict[Cardcolor, int]]: - """Reduce given image to the colors in Cardcolor""" - result_image: np.ndarray = np.zeros( - (image.shape[0], image.shape[1]), np.uint8) - result_dict: Dict[Cardcolor, int] = {c: 0 for c in Cardcolor} - for pixel_x, pixel_y in itertools.product( - range(result_image.shape[0]), - range(result_image.shape[1])): - pixel = image[pixel_x, pixel_y] - best_color: Optional[Tuple[Cardcolor, int]] = None - for color in Cardcolor: - mse = sum((x - y) ** 2 for x, y in zip(color.value, pixel)) - if not best_color or best_color[1] > mse: #pylint: disable=E1136 - best_color = (color, mse) - assert best_color - result_image[pixel_x, pixel_y] = GREYSCALE_COLOR[best_color[0]] - result_dict[best_color[0]] += 1 - return (result_image, result_dict) - - -def _find_single_square(search_square: np.ndarray, - template_square: np.ndarray) -> Tuple[int, Tuple[int, int]]: +def _find_single_square( + search_square: np.ndarray, template_square: np.ndarray +) -> Tuple[int, Tuple[int, int]]: assert search_square.shape[0] >= template_square.shape[0] assert search_square.shape[1] >= template_square.shape[1] best_result: Optional[Tuple[int, Tuple[int, int]]] = None for margin_x, margin_y in itertools.product( - range(search_square.shape[0], template_square.shape[0] - 1, -1), - range(search_square.shape[1], template_square.shape[1] - 1, -1)): - search_region = search_square[margin_x - - template_square.shape[0]:margin_x, margin_y - - template_square.shape[1]:margin_y] + range(search_square.shape[0], template_square.shape[0] - 1, -1), + range(search_square.shape[1], template_square.shape[1] - 1, -1), + ): + search_region = search_square[ + margin_x - template_square.shape[0] : margin_x, + margin_y - template_square.shape[1] : margin_y, + ] count = cv2.countNonZero(search_region - template_square) - if not best_result or count < best_result[0]: #pylint: disable=E1136 + if not best_result or count < best_result[0]: # pylint: disable=E1136 best_result = ( count, - (margin_x - template_square.shape[0], - margin_y - template_square.shape[1])) + ( + margin_x - template_square.shape[0], + margin_y - template_square.shape[1], + ), + ) assert best_result return best_result -def find_square(search_square: np.ndarray, - squares: List[np.ndarray]) -> Tuple[np.ndarray, int]: +def find_square( + search_square: np.ndarray, squares: List[np.ndarray] +) -> Tuple[np.ndarray, int]: """Compare all squares in squares with search_square, return best matching one. Requires all squares to be simplified.""" best_set = False @@ -104,24 +81,24 @@ def find_square(search_square: np.ndarray, return (best_square, best_count) -def catalogue_cards(squares: List[np.ndarray] - ) -> List[Tuple[np.ndarray, Card]]: +def catalogue_cards(squares: List[np.ndarray]) -> List[Tuple[np.ndarray, Card]]: """Run manual cataloging for given squares""" cv2.namedWindow("Catalogue", cv2.WINDOW_NORMAL) cv2.waitKey(1) result: List[Tuple[np.ndarray, Card]] = [] - print( - "Card ID is [B]ai, [Z]hong, [F]a, [H]ua, [R]ed, [G]reen, [B]lack") + print("Card ID is [B]ai, [Z]hong, [F]a, [H]ua, [R]ed, [G]reen, [B]lack") print("Numbercard e.g. R3") special_card_map = { - 'b': SpecialCard.Bai, - 'z': SpecialCard.Zhong, - 'f': SpecialCard.Fa, - 'h': SpecialCard.Hua} + "b": SpecialCard.Bai, + "z": SpecialCard.Zhong, + "f": SpecialCard.Fa, + "h": SpecialCard.Hua, + } suit_map = { - 'r': NumberCard.Suit.Red, - 'g': NumberCard.Suit.Green, - 'b': NumberCard.Suit.Black} + "r": NumberCard.Suit.Red, + "g": NumberCard.Suit.Green, + "b": NumberCard.Suit.Black, + } for square in squares: while True: cv2.imshow("Catalogue", cv2.resize(square, (500, 500))) @@ -137,10 +114,11 @@ def catalogue_cards(squares: List[np.ndarray] continue if not card_id[1].isdigit(): continue - if card_id[1] == '0': + if card_id[1] == "0": continue - card_type = NumberCard(number=int( - card_id[1]), suit=suit_map[card_id[0]]) + card_type = NumberCard( + number=int(card_id[1]), suit=suit_map[card_id[0]] + ) else: continue assert card_type is not None diff --git a/shenzhen_solitaire/card_detection/configuration.py b/shenzhen_solitaire/card_detection/configuration.py index dbe0db8..d59fb7c 100644 --- a/shenzhen_solitaire/card_detection/configuration.py +++ b/shenzhen_solitaire/card_detection/configuration.py @@ -4,6 +4,8 @@ import json from typing import List, Tuple, Dict import io import dataclasses +import tempfile +import cv2 import numpy as np from . import adjustment @@ -13,15 +15,16 @@ from .. import board class Configuration: """Configuration for solitaire cv""" - ADJUSTMENT_FILE_NAME = 'adjustment.json' - TEMPLATES_DIRECTORY = 'templates' - def __init__(self, - adj: adjustment.Adjustment, - catalogue: List[Tuple[np.ndarray, - board.Card]], - meta: Dict[str, - str]) -> None: + ADJUSTMENT_FILE_NAME = "adjustment.json" + TEMPLATES_DIRECTORY = "templates" + + def __init__( + self, + adj: adjustment.Adjustment, + catalogue: List[Tuple[np.ndarray, board.Card]], + meta: Dict[str, str], + ) -> None: self.field_adjustment = adj self.catalogue = catalogue self.meta = meta @@ -29,67 +32,73 @@ class Configuration: def save(self, filename: str) -> None: """Save configuration to zip archive""" zip_stream = io.BytesIO() + with zipfile.ZipFile(zip_stream, "w") as zip_file: zip_file.writestr( - self.ADJUSTMENT_FILE_NAME, json.dumps( - dataclasses.asdict( - self.field_adjustment))) + self.ADJUSTMENT_FILE_NAME, + json.dumps(dataclasses.asdict(self.field_adjustment)), + ) counter = 0 + extension = ".png" for square, card in self.catalogue: counter += 1 - file_stream = io.BytesIO() - np.save( - file_stream, - card_finder.simplify(square)[0], - allow_pickle=False) + fd, myfile = tempfile.mkstemp() + cv2.imwrite(myfile + extension, square) file_name = "" if isinstance(card, board.SpecialCard): - file_name = f's{card.value}-{card.name}-{counter}.npy' + file_name = f"s{card.value}-{card.name}-{counter}{extension}" elif isinstance(card, board.NumberCard): - file_name = f'n{card.suit.value}{card.number}'\ - f'-{card.suit.name}-{counter}.npy' + file_name = ( + f"n{card.suit.value}{card.number}" + f"-{card.suit.name}-{counter}{extension}" + ) else: raise AssertionError() - zip_file.writestr( - self.TEMPLATES_DIRECTORY + f"/{file_name}", - file_stream.getvalue()) + zip_file.write(myfile + extension, arcname=f"{self.TEMPLATES_DIRECTORY}/{file_name}") - with open(filename, 'wb') as zip_archive: + with open(filename, "wb") as zip_archive: zip_archive.write(zip_stream.getvalue()) @staticmethod - def load(filename: str) -> 'Configuration': + def load(filename: str) -> "Configuration": """Load configuration from zip archive""" + def _parse_file_name(card_filename: str) -> board.Card: - assert card_filename.startswith( - Configuration.TEMPLATES_DIRECTORY + '/') - pure_name = card_filename[ - len(Configuration.TEMPLATES_DIRECTORY + '/'):] - if pure_name[0] == 's': + assert card_filename.startswith(Configuration.TEMPLATES_DIRECTORY + "/") + pure_name = card_filename[len(Configuration.TEMPLATES_DIRECTORY + "/") :] + if pure_name[0] == "s": return board.SpecialCard(int(pure_name[1])) - if pure_name[0] == 'n': + if pure_name[0] == "n": return board.NumberCard( - suit=board.NumberCard.Suit( - int(pure_name[1])), number=int(pure_name[2])) + suit=board.NumberCard.Suit(int(pure_name[1])), + number=int(pure_name[2]), + ) raise AssertionError() catalogue: List[Tuple[np.ndarray, board.Card]] = [] - with zipfile.ZipFile(filename, 'r') as zip_file: + with zipfile.ZipFile(filename, "r") as zip_file: adj = adjustment.Adjustment( - **json.loads( - zip_file.read(Configuration.ADJUSTMENT_FILE_NAME))) + **json.loads(zip_file.read(Configuration.ADJUSTMENT_FILE_NAME)) + ) + mydir=tempfile.mkdtemp() for template_filename in ( - x for x in zip_file.namelist() if - x.startswith(Configuration.TEMPLATES_DIRECTORY + '/')): + x + for x in zip_file.namelist() + if x.startswith(Configuration.TEMPLATES_DIRECTORY + "/") + ): + myfile = zip_file.extract(template_filename, path=mydir) catalogue.append( - (np.load(io.BytesIO(zip_file.read(template_filename))), - _parse_file_name(template_filename))) + ( + cv2.imread(myfile), + _parse_file_name(template_filename), + ) + ) assert catalogue[-1][0] is not None return Configuration(adj=adj, catalogue=catalogue, meta={}) @staticmethod - def generate(image: np.ndarray) -> 'Configuration': + def generate(image: np.ndarray) -> "Configuration": """Generate a configuration with user input""" adj = adjustment.adjust_field(image) squares = card_finder.get_field_squares(image, adj, 5, 8) diff --git a/test/boards.py b/test/boards.py index 528f995..e9a428f 100644 --- a/test/boards.py +++ b/test/boards.py @@ -1,6 +1,8 @@ """Contains an example board to run tests on""" from shenzhen_solitaire.board import NumberCard, SpecialCard, Board +Suit = NumberCard.Suit + TEST_BOARD = Board() TEST_BOARD.field[0] = [ SpecialCard.Fa, @@ -65,3 +67,62 @@ TEST_BOARD.field[7] = [ NumberCard(NumberCard.Suit.Black, 1), NumberCard(NumberCard.Suit.Green, 8), ] + +B20190809172206_1 = Board() +B20190809172206_1.field[0] = [ + NumberCard(Suit.Green, 6), + NumberCard(Suit.Green, 5), + NumberCard(Suit.Red, 4), + NumberCard(Suit.Green, 4), + SpecialCard.Fa, +] + +B20190809172206_1.field[1] = [ + NumberCard(Suit.Black, 8), + NumberCard(Suit.Black, 6), + SpecialCard.Zhong, + NumberCard(Suit.Black, 9), + NumberCard(Suit.Green, 7), +] + +B20190809172206_1.field[2] = [ + SpecialCard.Zhong, + NumberCard(Suit.Black, 4), + NumberCard(Suit.Green, 2), + SpecialCard.Bai, + SpecialCard.Zhong, +] +B20190809172206_1.field[3] = [ + NumberCard(Suit.Green, 1), + NumberCard(Suit.Green, 3), + NumberCard(Suit.Black, 5), + SpecialCard.Fa, + SpecialCard.Fa, +] +B20190809172206_1.field[4] = [ + NumberCard(Suit.Red, 8), + SpecialCard.Zhong, + NumberCard(Suit.Red, 7), +] +B20190809172206_1.field[5] = [ + SpecialCard.Fa, + SpecialCard.Bai, + NumberCard(Suit.Red, 2), + SpecialCard.Hua, + SpecialCard.Bai, +] +B20190809172206_1.field[6] = [ + NumberCard(Suit.Black, 2), + NumberCard(Suit.Green, 8), + NumberCard(Suit.Black, 7), + SpecialCard.Bai, + NumberCard(Suit.Red, 9), +] + +B20190809172206_1.field[7] = [ + NumberCard(Suit.Red, 3), + NumberCard(Suit.Black, 3), + NumberCard(Suit.Green, 9), + NumberCard(Suit.Red, 5), + NumberCard(Suit.Red, 6), +] diff --git a/test/test_cv.py b/test/test_cv.py index e07c982..7cccd86 100644 --- a/test/test_cv.py +++ b/test/test_cv.py @@ -8,17 +8,18 @@ import numpy as np from shenzhen_solitaire import board from shenzhen_solitaire.card_detection import adjustment, board_parser from shenzhen_solitaire.card_detection.configuration import Configuration +from . import boards class CardDetectionTest(unittest.TestCase): def test_parse(self) -> None: - """Parse a configuration""" - with open("pictures/20190809172213_1.jpg", "rb") as png_file: - img_str = png_file.read() - nparr = np.frombuffer(img_str, np.uint8) - image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) - # image = cv2.resize(image, (1000, 629)) + """Parse a configuration and a board""" + image = cv2.imread("pictures/20190809172206_1.jpg") loaded_config = Configuration.load("test_config.zip") - # loaded_config.field_adjustment = adjustment.adjust_field(image) - print(board_parser.parse_board(image, loaded_config)) + my_board = board_parser.parse_board(image, loaded_config) + + for rows in zip(boards.B20190809172206_1.field, my_board.field): + for good_cell, test_cell in zip(*rows): + self.assertEqual(good_cell, test_cell) + diff --git a/test_config.zip b/test_config.zip index 91b2d87..8e6c0bd 100644 Binary files a/test_config.zip and b/test_config.zip differ diff --git a/tools/generate_config.py b/tools/generate_config.py index 045e949..2ffe340 100644 --- a/tools/generate_config.py +++ b/tools/generate_config.py @@ -4,11 +4,11 @@ from shenzhen_solitaire.card_detection.configuration import Configuration def main() -> None: """Generate a configuration""" - with open("pictures/20190809172213_1.jpg", 'rb') as png_file: - img_str = png_file.read() - nparr = np.frombuffer(img_str, np.uint8) - image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + image = cv2.imread("pictures/20190809172213_1.jpg") generated_config = Configuration.generate(image) generated_config.save('test_config.zip') + +if __name__ == "__main__": + main() \ No newline at end of file