• BattleShip Part 4: Game Cli Test

    One thing you usually want to do is add unit tests. Otherwise you have to basically constantly run your program and manually verify a bunch of scenarios. Having unit tests lets you saves time in the long run. The other benefit is that if something is really hard to unit test, then it could be indicative of a class that is too complex which helps drive better design.

    This is just a rule of thumb though. Sometimes especially if you’re working on something new, you’re not really sure how you’re going to be organizing everything. Often times, I’ll get something basic working and then add unit tests later.

    In this case, I feel pretty good about how I want the CLI portion to work so I’ll go ahead and write a test for this.

    Verify Tests Run

    The first thing I do is just set up a dummy test that fails. I’ve actually worked on a project where I added a test before and turned out it wasn’t actually executing.

    import unittest                                                                                                                                               
                                                                                                                                                                  
    class TestGameCli(unittest.TestCase):                                                                                                                         
                                                                                                                                                                  
        def test_something(self):                                                                                                                                 
            self.assertTrue(False) 

    Then figure out how to run the test. This is also good info to include in the README.

    $ python3 -m unittest test/test*
    F
    ======================================================================
    FAIL: test_something (test.test_game_cli.TestGameCli)
    ----------------------------------------------------------------------
    Traceback (most recent call last):
      File "/Users/yujincho/code/BattleShipPython/test/test_game_cli.py", line 6, in test_something
        self.assertTrue(False)
    AssertionError: False is not true
    
    ----------------------------------------------------------------------
    Ran 1 test in 0.000s
    
    FAILED (failures=1)

    Extract parsing function

    Right now the GameCli class doesn’t have much to test. It takes in user inputs and just logs some lines. Later we’ll want to verify it actually takes some actions but for now, the main thing our class does is handle user inputs and turns them into something meaningful. So let’s break this out into a standalone function so we can test its behavior.

    I’ll move the parse_user_input method into a standalone function.

    def parse_user_input(raw_user_input):                                                                                                                                                                                                         
        # user term to display the game board                                                                                                                                                                                                     
        show = 'show'                                                                                                                                                                                                                             
        # starts with a letter and ends with a number                                                                                                                                                                                             
        valid_attack_pattern = r'(^[a-z]{1})([0-9]$)'                                                                                                                                                                                             
                                                                                                                                                                                                                                                  
        user_input = raw_user_input.lower().strip()                                                                                                                                                                                               
                                                                                                                                                                                                                                                  
        if user_input == show:                                                                                                                                                                                                                    
            return (UserActionType.SHOW, None)                                                                                                                                                                                                    
                                                                                                                                                                                                                                                  
        match = re.search(valid_attack_pattern, user_input)                                                                                                                                                                                       
        if match and match.lastindex == 2:                                                                                                                                                                                                        
            col = match.group(1)                                                                                                                                                                                                                  
            row = match.group(2)                                                                                                                                                                                                                  
            return (UserActionType.ATTACK, (col, row))                                                                                                                                                                                            
                                                                                                                                                                                                                                                  
        return (UserActionType.INVALID, None) 

    Now the GameCli class can just use this instead.

    class GameCli:                                                                                                                                                                                                                                
        '''Manages interactions between the user and the game.'''                                                                                                                                                                                 
                                                                                                                                                                                                                                                  
        def __init__(self, user_inputs):                                                                                                                                                                                                          
            self.user_inputs = user_inputs                                                                                                                                                                                                        
                                                                                                                                                                                                                                                  
        def run(self):                                                                                                                                                                                                                            
            '''Parses and applies user provided commands to the game'''                                                                                                                                                                           
            for user_input in self.user_inputs:                                                                                                                                                                                                   
                (action_type, action_info) = parse_user_input(user_input)                                                                                                                                                                         
                                                                                                                                                                                                                                                  
                if action_type == UserActionType.SHOW:                                                                                                                                                                                            
                    self._display_board_layout()                                                                                                                                                                                                  
                elif action_type == UserActionType.ATTACK:                                                                                                                                                                                        
                    (col, row) = action_info                                                                                                                                                                                                      
                    self._attack(col, row)                                                                                                                                                                                                        
                else:                                                                                                                                                                                                                             
                    self._display_invalid_user_input_message()                                                                                                                                                                                    
                                                                                                                                                                                                                                                  
        def _display_board_layout(self):                                                                                                                                                                                                          
            print('TODO: display board')                                                                                                                                                                                                          
                                                                                                                                                                                                                                                  
        def _attack(self, col, row):                                                                                                                                                                                                              
            print('TODO: attack')                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                  
        def _display_invalid_user_input_message(self):                                                                                                                                                                                            
            print('TODO: display invalid message') 

    Parse Tests

    Now let’s add a real test. I usually just like to start with the simplest one. Let’s start with “show“. We know if we pass this in as an input, the action should be a SHOW type.

    from src.game_cli import parse_user_input, UserActionType                                                                                                                                                                                     
                                                                                                                                                                                                                                                  
    class TestParseUserInput(unittest.TestCase):                                                                                                                                                                                                  
                                                                                                                                                                                                                                                  
        def test_show_returns_show_action(self):                                                                                                                                                                                                  
            (action_type, action_info) = parse_user_input('show')                                                                                                                                                                                 
            self.assertEqual(action_type, UserActionType.SHOW)                                                                                                                                                                                    
            self.assertEqual(action_info, None)  

    We can run and verify it works.

    $ python3 -m unittest test/test* -v
    test_show_returns_show_action (test.test_game_cli.TestParseUserInput) ... ok
    
    ----------------------------------------------------------------------
    Ran 1 test in 0.001s
    
    OK

    Now let’s add the other two action types we expect.

    class TestParseUserInput(unittest.TestCase):                                                                                                                                                                                                  
                                                                                                                                                                                                                                                  
        def test_show_returns_show_action(self):                                                                                                                                                                                                  
            (action_type, action_info) = parse_user_input('show')                                                                                                                                                                                 
            self.assertEqual(action_type, UserActionType.SHOW)                                                                                                                                                                                    
            self.assertEqual(action_info, None)                                                                                                                                                                                                   
                                                                                                                                                                                                                                                  
        def test_coordinate_returns_attack_action_and_coordinate(self):                                                                                                                                                                           
            (action_type, action_info) = parse_user_input('a8')                                                                                                                                                                                   
            self.assertEqual(action_type, UserActionType.ATTACK)                                                                                                                                                                                  
            expected_coordinate = ('a', '8')                                                                                                                                                                                                      
            self.assertEqual(action_info, expected_coordinate)                                                                                                                                                                                    
                                                                                                                                                                                                                                                  
        def test_invalid_input_returns_invalid_action(self):                                                                                                                                                                                      
            (action_type, action_info) = parse_user_input('abcde')                                                                                                                                                                                
            self.assertEqual(action_type, UserActionType.INVALID)                                                                                                                                                                                 
            self.assertEqual(action_info, None)  

    Great now looks like our tests are running and passing.

     $ python3 -m unittest test/test* -v
    test_coordinate_returns_attack_action_and_coordinate (test.test_game_cli.TestParseUserInput) ... ok
    test_invalid_input_returns_invalid_action (test.test_game_cli.TestParseUserInput) ... ok
    test_show_returns_show_action (test.test_game_cli.TestParseUserInput) ... ok
    
    ----------------------------------------------------------------------
    Ran 3 tests in 0.001s
    
    OK

    Having a test suite is great because it not only saves you time but also if you make changes in the future, they can help verify that things are still working as expected.

    Project changeset.

  • BattleShip Part 3: App start and User Interface

    When working on a project, one of the first things I like to do is get something simple working that I can iterate on.

    I’ll start with building the component that can accept user input and is in charge of interacting with the battleship game.

    class GameCli:
        '''Manages interactions between the user and the game.'''
    
        def __init__(self, user_inputs):
            self.user_inputs = user_inputs

    I’ll initialize the game object with user_inputs. My plan is to use a generator that will let me interactively collect user input while essentially just looking like a list to the GameCli class. This will also make testing easier later.

    The next part will be converting raw user input into valid game actions. I’ll use an enum to represent the different types of actions a user can perform.

    class UserActionType(Enum):                                                                                                            
        '''Different type of actions a user can perform.'''                                                                                
        INVALID = 1                                                                                                                        
        SHOW = 2                                                                                                                           
        ATTACK = 3

    Next I’ll define a simple parse function to map user inputs to the user action types.

        def _parse_user_input(self, raw_user_input):
            user_input = raw_user_input.lower().strip()
    
            if user_input == self.show:
                return (UserActionType.SHOW, None)
    
            match = re.search(self.valid_attack_pattern, user_input)
            if match and match.lastindex == 2:
                col = match.group(1)
                row = match.group(2)
                return (UserActionType.ATTACK, (col, row))
    
            return (UserActionType.INVALID, None)

    Another option would be to define a class to encompass all the different user actions and the relevant data (such as the coordinates), but given the current interaction is simple, I’m just using a tuple.

    The current implementation is incomplete, but it gives us something to work with.

    import re
    from enum import Enum
    
    
    class UserActionType(Enum):
        '''Different type of actions a user can perform.'''
        INVALID = 1
        SHOW = 2
        ATTACK = 3
    
    class GameCli:
        '''Manages interactions between the user and the game.'''
    
        # user term to display the game board
        show = 'show'
        # starts with a letter and ends with a number
        valid_attack_pattern = r'(^[a-z]{1})([0-9]$)'
    
        def __init__(self, user_inputs):
            self.user_inputs = user_inputs
    
        def run(self):
            '''Parses and applies user provided commands to the game'''
            for user_input in self.user_inputs:
                (action_type, action_info) = self._parse_user_input(user_input)
    
                if action_type == UserActionType.SHOW:
                    self._display_board_layout()
                elif action_type == UserActionType.ATTACK:
                    (col, row) = action_info
                    self._attack(col, row)
                else:
                    self._display_invalid_user_input_message()
    
        def _parse_user_input(self, raw_user_input):
            user_input = raw_user_input.lower().strip()
    
            if user_input == self.show:
                return (UserActionType.SHOW, None)
    
            match = re.search(self.valid_attack_pattern, user_input)
            if match and match.lastindex == 2:
                col = match.group(1)
                row = match.group(2)
                return (UserActionType.ATTACK, (col, row))
    
            return (UserActionType.INVALID, None)
    
    
        def _display_board_layout(self):
            print('TODO: display board')
    
        def _attack(self, col, row):
            print('TODO: attack')
    
        def _display_invalid_user_input_message(self):
            print('TODO: display invalid message')

    Now I can define a simple main function to run the program.

    from game import GameCli                                                                                                                                      
                                                                                                                                                                  
    def get_user_inputs():                                                                                                                                        
        while True:                                                                                                                                               
            yield input('Your Move > ')                                                                                                                           
                                                                                                                                                                  
    if __name__ == '__main__':                                                                                                                                    
        cli = GameCli(get_user_inputs())                                                                                                                          
        cli.run()

    This gives me something that is now runnable.

    $ python3 main.py 
    Your Move > a
    TODO: display invalid message
    Your Move > a3
    TODO: attack
    Your Move > show
    TODO: display board

    Project changeset

  • BattleShip Part 2: Breaking out different components

    The next thing I’d want to do is start thinking about breaking this project up into different components. Right now everything is all in one class.

    import random
    #from os import system, name
    
    class BS:
        
        #GameValues
        height = 10
        width = 10
        gap = 1
        direction = ['up','down','left','right']
        gameOver = False
        
        #Ships
        ship = [5,4,3,3,2]
        shipChar = ['A','B','C','D','E'] #Need to be unique
        shipHealth = ship.copy()
        shipPieces = 0
        
        #GridCharacters
        emptyCell = "O"
        sunkShip = "X"
        gridHor = ['-','A','B','C','D','E','F','G','H','I','J']
        gridVert = ['0','1','2','3','4','5','6','7','8','9']
        
        #Game Grid
        grid = []
        
        def __init__(self):
            #Creates Grid
            for i in range(self.height):
                w = []
                for _ in range(self.width):
                    w.append(self.emptyCell)
                self.grid.append(w)
                
            self.generateLevel(self.ship, self.shipChar)
                
        def printGrid(self,hide = True):
            #Prints the Grid
            #Print the Horizontal Line Label
            for _ in self.gridHor: print(_,end="")
            print("")
            #Print each Horizontal Line + Vertical Label
            for _ in range(self.height):
                #Vertical Label
                print(self.gridVert[_],end="")
                for i in self.grid[_]:
                    if i != self.emptyCell and i != self.sunkShip and hide:
                        print(self.emptyCell,end="") #Makes the ships hidden!
                    else:
                        print(i,end="")
                print("")
                
        def addShip(self,pos,size,char,direction):
            #Adds the ship based on pos, size, character and direction given
            if direction == "up":
                for y in range(pos[1],pos[1]-(size-1) - 1,-1):
                    self.grid[y][pos[0]] = char
            elif direction == "down":
                for y in range(pos[1],pos[1]+(size-1) + 1):
                    self.grid[y][pos[0]] = char
            elif direction == "left":
                for x in range(pos[0],pos[0]-(size-1) - 1,-1):
                    self.grid[x][pos[1]] = char
            elif direction == "right":
                for x in range(pos[0],pos[0]+(size-1) + 1):
                    self.grid[pos[1]][x] = char
                
        def clamp(self,value,minimum,maximum):
            if value <= minimum: return minimum
            elif value >= maximum: return maximum
            else: return value
                
        def generateLevel(self,ships,chars):
            #Generates the Game Level
            for index in range(len(ships)):
                
                #Counts Turns
                turns = 0
                while True:
                    
                    #Iterate the turns
                    turns += 1
                    #Stops itself if it runs out of points to randomly choose
                    if turns > self.height * self.width : break
                    
                    #Shuffles the direction list
                    random.shuffle(self.direction)
                    
                    #Generates each ship by one by one
                    x = random.randint(0,self.width-1)
                    y = random.randint(0,self.height-1)
                    
                    #Checks if existing ship is in that position
                    if self.grid[y][x] != self.emptyCell : continue
                
                    chosenDir = None
                
                    #Size Check
                    for d in self.direction:
                        if d == "up":
                            if y-(ships[index]-1) >= 0 : chosenDir = d; break
                        if d == "down":
                            if y+(ships[index]-1) < self.height : chosenDir = d; break
                        if d == "left":
                            if x-(ships[index]-1) >= 0 : chosenDir = d; break
                        if d == "right":
                            if x+(ships[index]-1) < self.width : chosenDir = d; break
                        
                    if chosenDir == None: continue
            
                    #Gap/Valid Point Check
                    if chosenDir == "up":
                        for pointX in range(self.clamp(x-self.gap,0,self.width-1),
                                            self.clamp(x+self.gap,0,self.width-1)+1):
                            for pointY in range(self.clamp(y-(ships[index]-1),0,self.height-1),
                                                self.clamp(y+self.gap,0,self.height-1)+1):
                                if self.grid[pointY][pointX] != self.emptyCell:
                                    chosenDir = None
                                    break
                            if chosenDir == None: break
                        
                    elif chosenDir == "down":
                        for pointX in range(self.clamp(x-self.gap,0,self.width-1),
                                            self.clamp(x+self.gap,0,self.width-1)+1):
                            for pointY in range(self.clamp(y-self.gap,0,self.height-1),
                                                self.clamp(y+self.gap+(ships[index]-1),0,self.height-1)+1):
                                if self.grid[pointY][pointX] != self.emptyCell:
                                    chosenDir = None
                                    break
                            if chosenDir == None: break
                        
                    elif chosenDir == "left":
                        for pointY in range(self.clamp(y-self.gap,0,self.height-1),
                                            self.clamp(y+self.gap,0,self.height-1)+1):
                            for pointX in range(self.clamp(x-(ships[index]-1),0,self.width-1),
                                                self.clamp(x+self.gap,0,self.width-1)+1):
                                if self.grid[pointY][pointX] != self.emptyCell:
                                    chosenDir = None
                                    break
                            if chosenDir == None: break
                        
                    elif chosenDir == "right":
                        for pointY in range(self.clamp(y-self.gap,0,self.height-1),
                                            self.clamp(y+self.gap,0,self.height-1)+1):
                            for pointX in range(self.clamp(x-self.gap,0,self.width-1),
                                                self.clamp(x+self.gap+(ships[index]-1),0,self.width-1)+1):
                                if self.grid[pointY][pointX] != self.emptyCell:
                                    print(self.grid[pointY][pointX],x,y,chosenDir,ships[index])
                                    chosenDir = None
                                    break
                            if chosenDir == None: break
                    
                    if chosenDir != None : 
                        self.shipPieces += ships[index]
                        self.addShip([x,y], ships[index], chars[index], chosenDir)
                        break
                    
        def attackShip(self,pos):
            #Converts Labels to number position (Takes horizontal first and then vertical)
            position = [self.gridHor.index(pos[0])-1, self.gridVert.index(pos[1])]
            if self.grid[position[1]][position[0]] != self.emptyCell and self.grid[position[1]][position[0]] != self.sunkShip:
                self.shipHealth[self.shipChar.index(self.grid[position[1]][position[0]])] = self.shipHealth[self.shipChar.index(self.grid[position[1]][position[0]])] - 1
                self.grid[position[1]][position[0]] = self.sunkShip
                self.shipPieces -= 1
                if self.shipPieces <= 0 : self.gameOver = True
                
                return True
            else:
                return False
                
        def shipCount(self):
            count = 0
            for ship in self.shipHealth:
                if ship != 0: 
                    count += 1
            return count

    In this section, I will think about the different parts of this program and how they should be organized.

    Battle Ship Game

    I think some class that represents the game makes sense. A game has a board. It has some sort of status (i.e. game is in progress or not). A user can specify some kind of move (attack based on position). The game also lets you know how many ships are left and what the board looks like.

    Board

    A game will have a board. The board consists of a grid. A board has ships on it. You can attack a coordinate and the board will let you know if you hit a ship or not.

    Ship

    A ship can have different sizes. It will also keep track of which coordinates have been attacked.

    CLI / User Interface

    The user needs something to interact with. In this case, it’s a command line game. The cli will translate user inputs into valid actions accepted by the game. It will also be in charge of taking data from the game and converting it into a view for the user.

    Sketch of components

    Next, I’ll consider the main functionality that each of these components needs to fulfill. This might change later.

    • Game
      • initialize with a board and ships
      • attack coordinate
      • get layout
      • get remaining ships
      • game is over
    • Board
      • initialized with grid sizes
      • place ship
      • coordinate has ship?
    • Ship
      • is destroyed
      • set damage
    • Cli
      • accept user input
      • display result
      • display board

    One thing to note here is that this is just a brainstorm. Often as you implement a program, you find that some components don’t make sense or that others are needed. Very likely this will change, but it’s a rough starting poiint.

  • BattleShip Part 1: Initial Review / Functional Requirements

    Some initial observations.

    It is on Github which is nice and there’s a README that explains some basics. Probably don’t need to mention all the internal variables of BS in the README. BS is also not a great name it’s too non-descriptive.

    Before digging into the code, it’s important to understand how the project is supposed to work (Functional Requirements). I am familiar with the battleship game.

    Try running python and creating the game object. Not clear what the initial output means.

    >>> import BattleShipGame
    >>> game = BattleShipGame.BS()
    B 2 5 right 3

    The grid is a little hard to make out.

    >>> game.printGrid(True)
    -ABCDEFGHIJ
    0OOOOOOOOOO
    1OOOOOOOOOO
    2OOOOOOOOOO
    3OOOOOOOOOO
    4OOOOOOOOOO
    5OOOOOOOOOO
    6OOOOOOOOOO
    7OOOOOOOOOO
    8OOOOOOOOOO
    9OOOOOOOOOO

    It’s not very easy to actually use the game as-is, but luckily there is some sample code for how to play the game. This should ideally just be part of the project so people can play the game directly. Anyway, we’ll copy that and create a main.py file.

    from os import system, name                                                                                                                                                                                 
                                                                                                                                                                                                                
    from BattleShipGame import BS                                                                                                                                                                               
                                                                                                                                                                                                                
    if __name__ == '__main__':                                                                                                                                                                                  
        def clear():                                                                                                                                                                                            
           # for windows                                                                                                                                                                                        
           if name == 'nt':                                                                                                                                                                                     
              _ = system('cls')                                                                                                                                                                                 
                                                                                                                                                                                                                
           # for mac and linux                                                                                                                                                                                  
           else:                                                                                                                                                                                                
              _ = system('clear')                                                                                                                                                                               
                                                                                                                                                                                                                
        game = BS()                                                                                                                                                                                             
        game.printGrid()                                                                                                                                                                                        
        while game.gameOver == False:                                                                                                                                                                           
            pos = input("Input Position (eg: A0): ")                                                                                                                                                            
            if pos.lower() == "show":                                                                                                                                                                           
                clear()                                                                                                                                                                                         
                print("Ships Left:", game.shipCount())                                                                                                                                                          
                game.printGrid(False)                                                                                                                                                                           
            else:                                                                                                                                                                                               
                                                                                                                                                                                                                
                    clear()                                                                                                                                                                                     
                    print([pos[0].capitalize(),pos[1]])                                                                                                                                                         
                    if game.attackShip([pos[0].capitalize(),pos[1]]):                                                                                                                                           
                        print("HIT!")                                                                                                                                                                           
                    else:                                                                                                                                                                                       
                        print("Miss!")                                                                                                                                                                          
                    print("Ships Left:", game.shipCount())                                                                                                                                                      
                    game.printGrid()                                                                                                                                                                            
                                                                                                                                                                                                                
                                                                                                                                                                                                                
        print("Game Over")  

    The game seems to work now. One observation is it will crash if you enter an empty string or are missing a number. Ideally, the game should handle this gracefully.

    Traceback (most recent call last):
      File "main.py", line 26, in <module>
        print([pos[0].capitalize(),pos[1]])
    IndexError: string index out of range

    I’m still not 100% sure this is all working correctly as I don’t want to manually go through and enter a bunch of coordinates. I’ll create a helper file that can help me just sanity check that the game at least completes.

    if __name__ == '__main__':                                                                                                                                                                                  
        game = BS()                                                                                                                                                                                             
        game.printGrid()                                                                                                                                                                                        
                                                                                                                                                                                                                
        for c in range(65, 75):                                                                                                                                                                                 
            for i in range(0, 10):                                                                                                                                                                              
                pos = [chr(c), str(i)]                                                                                                                                                                          
                print(''.join(pos))                                                                                                                                                                             
                game.attackShip(pos)                                                                                                                                                                            
                                                                                                                                                                                                                
                if (game.gameOver):                                                                                                                                                                             
                    break                                                                                                                                                                                       
        print('Game Over') 

    This just iterates through all the possible positions and checks if the game is over.

    Looks like the game does end successfully. So think we are at a good starting point.

    Link to change set.

  • BattleShip Project Review

    I came across this project on a code review subreddit and thought it could be a good case study on how I would try to improve this code.

    Here is a series of posts on things I would do.

  • Why you should avoid creating objects

    This isn’t always a problem and obviously, you can’t completely avoid creating objects somewhere.

    The issue is that the code that directly creates the object becomes coupled to that object.

    For example, let’s look at this simple function,.

    def process(todo):                                                                                   
        todo_manager = TodoManager()                                                                     
        todo_manager.manage(todo) 

    The only way to verify that process works is by using the actual TodoManager class. What if manage makes some API calls or database queries. You’d have to actually call those endpoints or have a database to run a query against.

    Writing a unit test would be very hard. Unit tests are great because it lets you test that at least some aspects of your code work as you expect.

    This process function is actually very simple and we can decouple the todo manager by providing it as an argument to the function.

    def process(todo, todo_manager):                                                                     
        todo_manager.manage(todo) 

    If you have a bunch of related functions that all need todo_manager, you could also create a class and pass the todo_manager into the constructor.

  • Why you should write code that is testable

    That doesn’t mean you necessarily write tests for all the code you write (although in a professional setting most of your code probably should have tests).

    Code that is testable tends to be simpler with less going on. If you can’t write a simple unit test for some functionality, it’s probably complicated. This makes it harder for people to understand the code. Also if you want to make modifications it’s not easy to verify it’s working correctly.

    Some ways to make code more testable.

    Avoiding side effects

    Writing functions that don’t change anything tend to be easier to test. If it accepts some parameters and deterministically returns the same value every time, then it’s easier to verify the function works correctly.

    E.g. let’s say we have a todo app and we want a function that counts recent (non-expired) todos.

    def recent_todo_count(todos):                                                                        
        return len([todo for todo in todos if not todo.expired])

    Controlling your code’s dependencies

    If the functionality you are testing depends on something that’s not core to what you want to test, you can inject that functionality.

    For example, let’s say you have a function that syncs a user’s google sheet with our todo app.

    def sync_user_todos():                                                                                                                                                                                            
        google_sheet_client = get_google_sheet_client()                                                                                               
        rows = google_sheet_client.get_rows()                                                                                                         
        todos = transform_rows(rows)                                                                                                                  
        database.create(todos) 

    The current function has three hard-coded dependencies: get_google_sheet_client, transform_rows, and database.

    If get_google_sheet_client returns a real google client, you have to connect to an actual sheet to verify things are working. Similarly, if database is a hard-coded connection to a database, you need a real database running.

    If you provide the google_sheet_client and the database to your function, you can control how those work.

    def sync_user_todos(google_sheet_client, database):                                                  
        rows = google_sheet_client.get_rows()                                                            
        todos = transform_rows(rows)                                                                     
        database.create(todos)    

    For example, in your test, you can define a fake google client object or mock that has a get_rows method and returns some specified data. Similarly for database, we only care that the create method was called with some expected input.

    Finally, if transform_rows is very complex and the main thing we want to test is that we got rows from the client, called a method to transform them and then called a database, we can even inject the transform function as well.

    class TestSync(unittest.TestCase):                                                                                                                                                                          
                                                                                                                                                                                                                
        def test_sync_user_todos(self):                                                                                                                                                                         
            google_sheet_client = Mock()                                                                                                                                                                        
            rows_from_client = ['rows']                                                                                                                                                                         
            google_sheet_client.get_rows.return_value = rows_from_client                                                                                                                                        
            database = Mock()                                                                                                                                                                                   
            transformed_rows_result = ['transformed']                                                                                                                                                           
            transform_rows = Mock(return_value=transformed_rows_result)                                                                                                                                         
                                                                                                                                                                                                                
            sync_user_todos(google_sheet_client, database, transform_rows)                                                                                                                                      
                                                                                                                                                                                                                
            transform_rows.assert_called_with(rows_from_client)                                                                                                                                                 
            database.create.assert_called_with(transformed_rows_result)                                                                                                                                         
                                                                                                                                                                                                                
    if __name__ == '__main__':                                                                                                                                                                                  
        unittest.main()   

  • How to organize code into different parts

    When I first learned to code, I wrote all my code in a single file. I didn’t even use functions.

    For example, let’s say we had a simple application that allowed the user to enter in some todos and we wrote it to a file when they were done.

    My code would look something like this.

    import os
    import datetime
    
    filename = 'todos1.csv'
    
    todos = []
    
    while True:
        todo = input('What is your todo?')
        if todo == 'done':
            break
    
        todos.append([todo, int(datetime.datetime.now().timestamp())])
    
    with open(filename, 'a') as f:
        for todo, timestamp in todos:
            f.write(f'{todo},{timestamp}\n')

    This works but it’s not great.

    There are a number of issues:

    • As we add more functionality, this becomes larger and harder to understand
    • The only way to make sure it works correctly is to manually run it and make sure the file was properly created
    • The high-level functionality is not obvious

    How I would improve this today

    Instead of having a single block of code, I would break out the different functionality.

    The application needs to accept user input, format the input, then write it to a file.

    The first thing I would do is isolate the part that accepts user input. Here I’m breaking out a separate function get_user_input which collects user input and returns it in a list.

    def get_user_input(prompt):                                                                          
        user_inputs = []                                                                                 
        while True:                                                                                      
            user_input = input(prompt)                                                                   
            if user_input == exit_command:                                                               
                return user_inputs                                                                       
                                                                                                         
            current_timestamp = int(datetime.now().timestamp())                                          
            user_inputs.append((user_input, current_timestamp))  

    Next, I would define a function run that controls the main flow of the application.

    One tip I try to follow is trying to make the code read like English.

    def run(filename, user_inputs, todos_to_csv, write_rows):                                            
        file_contents = todos_to_csv(user_inputs)                                                        
        write_rows(filename, file_contents)                                                              

    One thing you may have noticed is that I provided some functions (i.e. todos_to_csv and write_rows) as arguments to run.

    The reason I do this is it allows me to isolate the mechanics of run from how those other functions actually work. e.g. if I called write_rows directly, to verify run works I would have to check that a file was actually created.

    By providing those functions as arguments, run doesn’t need to care about how those actually work. All it needs to know is that the inputs were transformed and we called the function to write it.

    Writing a unit test to verify the functionality becomes pretty straightforward. e.g.

    class TestTodo(unittest.TestCase):
    
        def test_run(self):
            filename = 'test.csv'
            user_inputs = [('a', 1), ('b', 2)]
            csv_contents = 'contents'
            todos_to_csv = Mock(return_value=csv_contents)
            write_rows = Mock()
    
            run(filename, user_inputs, todos_to_csv, write_rows)
    
            todos_to_csv.assert_called_with(user_inputs)
            write_rows.assert_called_with(filename, csv_contents)
    
        def test_todos_to_csv(self):
            user_inputs = [('a', 1), ('b', 2)]
            expected_contents = 'a,1\nb,2\n'
    
            contents = todos_to_csv(user_inputs)
            self.assertEqual(contents, expected_contents)

    Breaking out your code into different bits of functionality has a lot of benefits. Things become more organized and easier to understand. Simple functions are easier to test and less likely to have bugs. By controlling what your code depends on, you can isolate different bits of functionality.

    link to project code

  • How to test functionality that uses a database

    Overview

    When I first started learning how to use a database, I wrote code that was hard to work with.

    To test that functionality worked correctly, I did a lot of local set up that required manual actions like creating tables, adding entries, deleting entries, etc. Sometimes I would deploy my code to see if everything worked properly against the real database.

    In this post, I’ll go over how you could organize your code to isolate functionality and make working with a database easier.

    Sample Todo App

    Here’s a contrived todo application that checks if any todos exist and creates one if there aren’t any.

    import os
    import psycopg2
    
    db = psycopg2.connect(**{
        'dbname': 'todo_app',
        'user': os.environ.get('DB_USER'),
        'password': os.environ.get('DB_PASSWORD'),
        'port': '54320',
        'host': 'localhost'
    })
    
    def generate_todo():
        # get todos from database
        with db.cursor() as cur:
            cur.execute('SELECT todo_id, description FROM todo;')
            todos = cur.fetchall()
    
        # if there are no todos, create a todo to add more todos
        if not todos:
            with db.cursor() as cur:
                cur.execute('INSERT INTO todo (description) VALUES (%s)', ('Add todos',))
            db.commit()

    The problem here is that you can’t verify the functionality without connecting to a real database. You have to run the code and then connect to the database to see if a row was added. You also have to delete the row you added.

    The other issue with this code is that it can be difficult to follow. We have database operations and application logic all in one place.

    How can we organize this better?

    Create a class to simplify database operations and provide it as a dependency

    We could define a class that is in charge of all database operations. It provides a simplified interface that can be used by our application.

    class DatabaseManager:
        def __init__(self, db):
            self.db = db
    
        def get_todos(self):
            with self.db.cursor() as cur:
                cur.execute('SELECT todo_id, description FROM todo;')
                results = cur.fetchall()
            return results
    
        def create_todo(self, description):
            with self.db.cursor() as cur:
                cur.execute('INSERT INTO todo (description) VALUES (%s)', (description,))
                self.db.commit()

    Our generate_todo function can now be greatly simplified. It’s also a lot more clear what’s happening. Another benefit is that if we wanted to make changes to db_manager (e.g. use a different database, change the queries, etc) we don’t have to make any changes to generate_todo.

    def generate_todo(db_manager):
        todos = db_manager.get_todos()
        if not todos:
            print('Inserting todo')
            db_manager.create_todo('Add todos')

    It’s also easy to unit test our new functionality. We don’t need to connect to an actual database. We only need to define a class that fulfills the functionality of the class we defined. Here is a simple unit test using mocks for the db manager.

    import unittest
    from unittest.mock import Mock
    from todo_v2 import generate_todo
    
    class TestGenerateTodo(unittest.TestCase):
    
        def test_if_no_todos_creates_todo(self):
            db_manager = Mock()
            db_manager.get_todos.return_value = []
    
            generate_todo(db_manager)
    
            db_manager.get_todos.assert_called()
            db_manager.insert_todo.assert_called_with('Add todos')
    
        def test_if_todos_does_not_creates_todo(self):
            db_manager = Mock()
            db_manager.get_todos.return_value = [(1, 'todo')]
    
            generate_todo(db_manager)
    
            db_manager.get_todos.assert_called()
            db_manager.insert_todo.assert_not_called()

    Testing actual database operations

    Sometimes we still want to test real database queries. The benefit of creating a dedicated class for handling database operations is that the functionality is a lot more focused and we can test things in a more isolated manner i.e we don’t have to worry about the application logic.

    import unittest
    from todo_v2 import DatabaseManager
    import psycopg2
    
    class TestDatabaseManager(unittest.TestCase):
    
        @classmethod
        def setUpClass(cls):
            cls.db = psycopg2.connect(**{
                'dbname': 'todo_app',
                'user': 'postgres',
                'password': 'password',
                'port': '54320',
                'host': 'localhost'
            })
            cls.db_manager = DatabaseManager(cls.db)
    
        def setUp(self):
            with self.db.cursor() as cur:
                cur.execute('TRUNCATE todo')
            self.db.commit()
    
        def test_create_and_get_todos(self):
            self.db_manager.create_todo('todo')
            todos = self.db_manager.get_todos()
            assert todos[0][1] == 'todo'
    
    if __name__ == '__main__':
        unittest.main()

    In this case, you’d still need an actual database to run your tests against. One thing to be careful about is that you don’t want different tests to interfere with each other (notice we clear out the table between tests in setUp). Another problem is if tests were to be run in parallel.

    Conclusion

    This is an example of Dependency Inversion and it’s a technique I’ve found very helpful in managing my code.

    Link to project code: https://github.com/levelupSE/todoapp