diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9cccefe --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +test_config/ \ No newline at end of file diff --git a/shenzhen_solitaire/card_detection/board_parser.py b/shenzhen_solitaire/card_detection/board_parser.py index e2774d5..729f55c 100644 --- a/shenzhen_solitaire/card_detection/board_parser.py +++ b/shenzhen_solitaire/card_detection/board_parser.py @@ -1,12 +1,14 @@ """Contains parse_board function""" -import numpy as np -from .configuration import Configuration -from ..board import Board, NumberCard, SpecialCard, Card -from . import card_finder -import cv2 -from typing import Iterable, Any, List, Tuple, Union import itertools +from typing import Any, Iterable, List, Optional, Tuple, Union, Dict + +import cv2 +import numpy as np + +from ..board import Board, Card, NumberCard, SpecialCard +from . import card_finder +from .configuration import Configuration def grouper( @@ -17,7 +19,7 @@ def grouper( return itertools.zip_longest(*args, fillvalue=fillvalue) -def get_square_iterator( +def get_field_square_iterator( image: np.ndarray, conf: Configuration, row_count: int, column_count: int ) -> Iterable[Tuple[np.ndarray, np.ndarray]]: """Return iterator for both the square, as well as the matching card border""" @@ -46,7 +48,7 @@ def match_template(template: np.ndarray, search_image: np.ndarray) -> float: return float(max_val) -def parse_square( +def parse_field_square( square: np.ndarray, border: np.ndarray, conf: Configuration ) -> Tuple[Union[NumberCard, SpecialCard], bool]: square_fits = [ @@ -73,7 +75,7 @@ def parse_square( def parse_field(image: np.ndarray, conf: Configuration) -> List[List[Card]]: """Parse a screenshot of the game, using a given configuration""" - square_iterator = get_square_iterator( + square_iterator = get_field_square_iterator( image, conf, row_count=Board.MAX_ROW_SIZE, column_count=Board.MAX_COLUMN_SIZE ) result = [] @@ -82,7 +84,7 @@ def parse_field(image: np.ndarray, conf: Configuration) -> List[List[Card]]: for index, (square, border_square) in enumerate( zip(square_group, border_group) ): - value, row_finished = parse_square(square, border_square, conf) + value, row_finished = parse_field_square(square, border_square, conf) group_field.append(value) if row_finished: break @@ -92,7 +94,25 @@ def parse_field(image: np.ndarray, conf: Configuration) -> List[List[Card]]: return result +def parse_hua(image: np.ndarray, conf: Configuration) -> bool: + """Return true if hua is in the hua spot, false if hua spot is empty""" + raise NotImplementedError() + + +def parse_bunker( + image: np.ndarray, conf: Configuration +) -> List[Union[Tuple[SpecialCard, int], Optional[Card]]]: + raise NotImplementedError() + + +def parse_goal(image: np.ndarray, conf: Configuration) -> Dict[NumberCard.Suit, int]: + raise NotImplementedError() + + def parse_board(image: np.ndarray, conf: Configuration) -> Board: result = Board() result.field = parse_field(image, conf) + # result.flower_gone = parse_hua(image, conf) + # result.bunker = parse_bunker(image, conf) + # result.goal = parse_goal(image, conf) return result diff --git a/shenzhen_solitaire/card_detection/configuration.py b/shenzhen_solitaire/card_detection/configuration.py index 585b63b..b0c9400 100644 --- a/shenzhen_solitaire/card_detection/configuration.py +++ b/shenzhen_solitaire/card_detection/configuration.py @@ -117,13 +117,12 @@ def save(conf: Configuration, filename: str) -> None: def _parse_file_name(card_filename: str) -> board.Card: - assert card_filename.startswith(TEMPLATES_DIRECTORY + "/") - pure_name = card_filename[len(TEMPLATES_DIRECTORY + "/") :] - if pure_name[0] == "s": - return board.SpecialCard(int(pure_name[1])) - if pure_name[0] == "n": + if card_filename[0] == "s": + return board.SpecialCard(int(card_filename[1])) + if card_filename[0] == "n": return board.NumberCard( - suit=board.NumberCard.Suit(int(pure_name[1])), number=int(pure_name[2]), + suit=board.NumberCard.Suit(int(card_filename[1])), + number=int(card_filename[2]), ) raise AssertionError("Template files need to start with either 's' or 'n'") diff --git a/test_config.zip b/test_config.zip index f3b6a3e..b01da28 100644 Binary files a/test_config.zip and b/test_config.zip differ diff --git a/tools/generate_bunker.py b/tools/generate_bunker.py new file mode 100644 index 0000000..9a76314 --- /dev/null +++ b/tools/generate_bunker.py @@ -0,0 +1,43 @@ +import copy +import dataclasses +import json + +import cv2 +import numpy as np + +import shenzhen_solitaire.card_detection.adjustment as adjustment +import shenzhen_solitaire.card_detection.card_finder as card_finder +from shenzhen_solitaire.card_detection.configuration import Configuration + + +def main() -> None: + """Generate a configuration""" + image = cv2.imread("pictures/specific/BunkerCards.jpg") + + bunker_adjustment = adjustment.adjust_squares( + image, + count_x=3, + count_y=1, + adjustment=adjustment.Adjustment( + **{"x": 730, "y": 310, "w": 19, "h": 21, "dx": 152, "dy": 0} + ), + ) + print(json.dumps(dataclasses.asdict(bunker_adjustment))) + + back_image = cv2.imread("pictures/specific/BaiShiny.jpg") + back_squares = card_finder.get_field_squares( + back_image, count_x=1, count_y=3, adjustment=copy.deepcopy(bunker_adjustment) + ) + + green_image = cv2.imread("pictures/20190809172213_1.jpg") + green_squares = card_finder.get_field_squares( + green_image, count_x=1, count_y=3, adjustment=copy.deepcopy(bunker_adjustment) + ) + + cv2.imwrite("/tmp/bunker_green_1.png", green_squares[0]) + cv2.imwrite("/tmp/bunker_green_2.png", green_squares[1]) + cv2.imwrite("/tmp/bunker_green_3.png", green_squares[2]) + + +if __name__ == "__main__": + main() diff --git a/tools/generate_goal.py b/tools/generate_goal.py new file mode 100644 index 0000000..e0e90af --- /dev/null +++ b/tools/generate_goal.py @@ -0,0 +1,33 @@ +import copy +import dataclasses +import json + +import cv2 +import numpy as np + +import shenzhen_solitaire.card_detection.adjustment as adjustment +import shenzhen_solitaire.card_detection.card_finder as card_finder +from shenzhen_solitaire.card_detection.configuration import Configuration + + +def main() -> None: + """Generate a configuration""" + image = cv2.imread("pictures/specific/BaiShiny.jpg") + + goal_adjustment = adjustment.adjust_squares( + image, count_x=3, count_y=1, adjustment=adjustment.Adjustment(**{"x": 1490, "y": 310, "w": 19, "h": 21, "dx": 152, "dy": 0}) + ) + print(json.dumps(dataclasses.asdict(goal_adjustment))) + + green_image = cv2.imread("pictures/20190809172213_1.jpg") + green_squares = card_finder.get_field_squares( + green_image, count_x=1, count_y=3, adjustment=copy.deepcopy(goal_adjustment) + ) + + cv2.imwrite("/tmp/goal_green_1.png", green_squares[0]) + cv2.imwrite("/tmp/goal_green_2.png", green_squares[1]) + cv2.imwrite("/tmp/goal_green_3.png", green_squares[2]) + + +if __name__ == "__main__": + main() diff --git a/tools/generate_hua.py b/tools/generate_hua.py new file mode 100644 index 0000000..861bf9e --- /dev/null +++ b/tools/generate_hua.py @@ -0,0 +1,34 @@ +import copy +import dataclasses +import json + +import cv2 +import numpy as np + +import shenzhen_solitaire.card_detection.adjustment as adjustment +import shenzhen_solitaire.card_detection.card_finder as card_finder +from shenzhen_solitaire.card_detection.configuration import Configuration + + +def main() -> None: + """Generate a configuration""" + image = cv2.imread("pictures/specific/BunkerCards.jpg") + + hua_adjustment = adjustment.adjust_squares( + image, + count_x=1, + count_y=1, + adjustment=adjustment.Adjustment( + **{"x": 1299, "y": 314, "w": 19, "h": 21, "dx": 0, "dy": 0} + ), + ) + print(json.dumps(dataclasses.asdict(hua_adjustment))) + green_image = cv2.imread("pictures/specific/ZhongShiny.jpg") + hua_green = card_finder.get_field_squares( + green_image, hua_adjustment, count_x=1, count_y=1 + ) + cv2.imwrite("/tmp/hua_green.png", hua_green[0]) + + +if __name__ == "__main__": + main() diff --git a/tools/generate_special_buttons.py b/tools/generate_special_buttons.py new file mode 100644 index 0000000..14f8b68 --- /dev/null +++ b/tools/generate_special_buttons.py @@ -0,0 +1,59 @@ +import copy +import dataclasses +import json +import tempfile +from pathlib import Path + +import cv2 +import numpy as np + +import shenzhen_solitaire.card_detection.adjustment as adjustment +import shenzhen_solitaire.card_detection.card_finder as card_finder +from shenzhen_solitaire.card_detection.configuration import Configuration + + +def main() -> None: + """Generate a configuration""" + normal_image = cv2.imread("pictures/specific/BunkerCards.jpg") + + picture_dir = Path(tempfile.mkdtemp(prefix="shenzhen-special-buttons-")) + print(picture_dir) + button_adjustment = adjustment.adjust_squares(normal_image, count_x=1, count_y=3) + normal_squares = card_finder.get_field_squares( + normal_image, button_adjustment, 3, 1 + ) + cv2.imwrite(str(picture_dir / "nz.png"), normal_squares[0]) + cv2.imwrite(str(picture_dir / "nf.png"), normal_squares[1]) + cv2.imwrite(str(picture_dir / "nb.png"), normal_squares[2]) + + fa_shiny_image = cv2.imread("pictures/specific/FaShiny.jpg") + fa_shiny_squares = card_finder.get_field_squares( + fa_shiny_image, button_adjustment, 3, 1 + ) + cv2.imwrite(str(picture_dir / "sf.png"), fa_shiny_squares[1]) + + zhong_shiny_image = cv2.imread("pictures/specific/ZhongShiny.jpg") + zhong_shiny_squares = card_finder.get_field_squares( + zhong_shiny_image, button_adjustment, 3, 1 + ) + cv2.imwrite(str(picture_dir / "sz.png"), zhong_shiny_squares[0]) + + bai_shiny_image = cv2.imread("pictures/specific/BaiShiny.jpg") + bai_shiny_squares = card_finder.get_field_squares( + bai_shiny_image, button_adjustment, 3, 1 + ) + cv2.imwrite(str(picture_dir / "sb.png"), bai_shiny_squares[2]) + cv2.imwrite(str(picture_dir / "gz.png"), bai_shiny_squares[0]) + cv2.imwrite(str(picture_dir / "gf.png"), bai_shiny_squares[1]) + + bai_black_image = cv2.imread("pictures/specific/BaiBlack.jpg") + bai_black_squares = card_finder.get_field_squares( + bai_black_image, button_adjustment, 3, 1 + ) + cv2.imwrite(str(picture_dir / "gb.png"), bai_black_squares[2]) + print(picture_dir) + print(json.dumps(dataclasses.asdict(button_adjustment))) + + +if __name__ == "__main__": + main()