Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 54 additions & 15 deletions cherry_picker/cherry_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
CHECKING_OUT_DEFAULT_BRANCH
CHECKED_OUT_DEFAULT_BRANCH

CHECKING_OUT_PREVIOUS_BRANCH
CHECKED_OUT_PREVIOUS_BRANCH

PUSHING_TO_REMOTE
PUSHED_TO_REMOTE
PUSHING_TO_REMOTE_FAILED
Expand Down Expand Up @@ -134,6 +137,11 @@ def set_paused_state(self):
save_cfg_vals_to_git_cfg(config_path=self.chosen_config_path)
set_state(WORKFLOW_STATES.BACKPORT_PAUSED)

def remember_previous_branch(self):
"""Save the current branch into Git config to be able to get back to it later."""
current_branch = get_current_branch()
save_cfg_vals_to_git_cfg(previous_branch=current_branch)

@property
def upstream(self):
"""Get the remote name to use for upstream branches
Expand Down Expand Up @@ -180,24 +188,29 @@ def run_cmd(self, cmd):
output = subprocess.check_output(cmd, stderr=subprocess.STDOUT)
return output.decode("utf-8")

def checkout_branch(self, branch_name):
"""git checkout -b <branch_name>"""
cmd = [
"git",
"checkout",
"-b",
self.get_cherry_pick_branch(branch_name),
f"{self.upstream}/{branch_name}",
]
def checkout_branch(self, branch_name, *, create_branch=False):
"""git checkout [-b] <branch_name>"""
if create_branch:
checked_out_branch = self.get_cherry_pick_branch(branch_name)
cmd = [
"git",
"checkout",
"-b",
checked_out_branch,
f"{self.upstream}/{branch_name}",
]
else:
checked_out_branch = branch_name
cmd = ["git", "checkout", branch_name]
try:
self.run_cmd(cmd)
except subprocess.CalledProcessError as err:
click.echo(
f"Error checking out the branch {self.get_cherry_pick_branch(branch_name)}."
f"Error checking out the branch {branch_name}."
)
click.echo(err.output)
raise BranchCheckoutException(
f"Error checking out the branch {self.get_cherry_pick_branch(branch_name)}."
f"Error checking out the branch {branch_name}."
)

def get_commit_message(self, commit_sha):
Expand All @@ -221,11 +234,23 @@ def checkout_default_branch(self):
"""git checkout default branch"""
set_state(WORKFLOW_STATES.CHECKING_OUT_DEFAULT_BRANCH)

cmd = "git", "checkout", self.config["default_branch"]
self.run_cmd(cmd)
self.checkout_branch(self.config["default_branch"])

set_state(WORKFLOW_STATES.CHECKED_OUT_DEFAULT_BRANCH)

def checkout_previous_branch(self):
"""git checkout previous branch"""
set_state(WORKFLOW_STATES.CHECKING_OUT_PREVIOUS_BRANCH)

previous_branch = load_val_from_git_cfg("previous_branch")
if previous_branch is None:
self.checkout_default_branch()
return

self.checkout_branch(previous_branch)

set_state(WORKFLOW_STATES.CHECKED_OUT_PREVIOUS_BRANCH)

def status(self):
"""
git status
Expand Down Expand Up @@ -357,7 +382,12 @@ def cleanup_branch(self, branch):
Switch to the default branch before that.
"""
set_state(WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH)
self.checkout_default_branch()
try:
self.checkout_previous_branch()
except BranchCheckoutException:
click.echo(f"branch {branch} NOT deleted.")
set_state(WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED)
return
try:
self.delete_branch(branch)
except subprocess.CalledProcessError:
Expand All @@ -372,14 +402,15 @@ def backport(self):
raise click.UsageError("At least one branch must be specified.")
set_state(WORKFLOW_STATES.BACKPORT_STARTING)
self.fetch_upstream()
self.remember_previous_branch()

set_state(WORKFLOW_STATES.BACKPORT_LOOPING)
for maint_branch in self.sorted_branches:
set_state(WORKFLOW_STATES.BACKPORT_LOOP_START)
click.echo(f"Now backporting '{self.commit_sha1}' into '{maint_branch}'")

cherry_pick_branch = self.get_cherry_pick_branch(maint_branch)
self.checkout_branch(maint_branch)
self.checkout_branch(maint_branch, create_branch=True)
commit_message = ""
try:
self.cherry_pick()
Expand Down Expand Up @@ -413,6 +444,7 @@ def backport(self):
self.set_paused_state()
return # to preserve the correct state
set_state(WORKFLOW_STATES.BACKPORT_LOOP_END)
reset_stored_previous_branch()
reset_state()

def abort_cherry_pick(self):
Expand All @@ -434,6 +466,7 @@ def abort_cherry_pick(self):
if get_current_branch().startswith("backport-"):
self.cleanup_branch(get_current_branch())

reset_stored_previous_branch()
reset_stored_config_ref()
reset_state()

Expand Down Expand Up @@ -493,6 +526,7 @@ def continue_cherry_pick(self):
)
set_state(WORKFLOW_STATES.CONTINUATION_FAILED)

reset_stored_previous_branch()
reset_stored_config_ref()
reset_state()

Expand Down Expand Up @@ -822,6 +856,11 @@ def reset_stored_config_ref():
"""Config file pointer is not stored in Git config."""


def reset_stored_previous_branch():
"""Remove the previous branch information from Git config."""
wipe_cfg_vals_from_git_cfg("previous_branch")


def reset_state():
"""Remove the progress state from Git config."""
wipe_cfg_vals_from_git_cfg("state")
Expand Down
48 changes: 46 additions & 2 deletions cherry_picker/test_cherry_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ def git_commit():
)


@pytest.fixture
def git_worktree():
git_worktree_cmd = "git", "worktree"
return lambda *extra_args: (
subprocess.run(git_worktree_cmd + extra_args, check=True)
)


@pytest.fixture
def git_cherry_pick():
git_cherry_pick_cmd = "git", "cherry-pick"
Expand All @@ -100,12 +108,13 @@ def git_config():

@pytest.fixture
def tmp_git_repo_dir(tmpdir, cd, git_init, git_commit, git_config):
cd(tmpdir)
repo_dir = tmpdir.mkdir("tmp-git-repo")
cd(repo_dir)
git_init()
git_config("--local", "user.name", "Monty Python")
git_config("--local", "user.email", "[email protected]")
git_commit("Initial commit", "--allow-empty")
yield tmpdir
yield repo_dir


@mock.patch("subprocess.check_output")
Expand Down Expand Up @@ -545,13 +554,19 @@ def test_paused_flow(tmp_git_repo_dir, git_add, git_commit):
WORKFLOW_STATES.CHECKING_OUT_DEFAULT_BRANCH,
WORKFLOW_STATES.CHECKED_OUT_DEFAULT_BRANCH,
),
(
"checkout_previous_branch",
WORKFLOW_STATES.CHECKING_OUT_PREVIOUS_BRANCH,
WORKFLOW_STATES.CHECKED_OUT_PREVIOUS_BRANCH,
),
),
)
def test_start_end_states(method_name, start_state, end_state, tmp_git_repo_dir):
assert get_state() == WORKFLOW_STATES.UNSET

with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True):
cherry_picker = CherryPicker("origin", "xxx", [])
cherry_picker.remember_previous_branch()
assert get_state() == WORKFLOW_STATES.UNSET

def _fetch(cmd):
Expand All @@ -572,6 +587,22 @@ def test_cleanup_branch(tmp_git_repo_dir, git_checkout):
git_checkout("-b", "some_branch")
cherry_picker.cleanup_branch("some_branch")
assert get_state() == WORKFLOW_STATES.REMOVED_BACKPORT_BRANCH
assert get_current_branch() == "main"


def test_cleanup_branch_checkout_previous_branch(tmp_git_repo_dir, git_checkout, git_worktree):
assert get_state() == WORKFLOW_STATES.UNSET

with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True):
cherry_picker = CherryPicker("origin", "xxx", [])
assert get_state() == WORKFLOW_STATES.UNSET

git_checkout("-b", "previous_branch")
cherry_picker.remember_previous_branch()
git_checkout("-b", "some_branch")
cherry_picker.cleanup_branch("some_branch")
assert get_state() == WORKFLOW_STATES.REMOVED_BACKPORT_BRANCH
assert get_current_branch() == "previous_branch"


def test_cleanup_branch_fail(tmp_git_repo_dir):
Expand All @@ -585,6 +616,19 @@ def test_cleanup_branch_fail(tmp_git_repo_dir):
assert get_state() == WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED


def test_cleanup_branch_checkout_fail(tmp_git_repo_dir, tmpdir, git_checkout, git_worktree):
assert get_state() == WORKFLOW_STATES.UNSET

with mock.patch("cherry_picker.cherry_picker.validate_sha", return_value=True):
cherry_picker = CherryPicker("origin", "xxx", [])
assert get_state() == WORKFLOW_STATES.UNSET

git_checkout("-b", "some_branch")
git_worktree("add", str(tmpdir.mkdir("test-worktree")), "main")
cherry_picker.cleanup_branch("some_branch")
assert get_state() == WORKFLOW_STATES.REMOVING_BACKPORT_BRANCH_FAILED


def test_cherry_pick(tmp_git_repo_dir, git_add, git_branch, git_commit, git_checkout):
cherry_pick_target_branches = ("3.8",)
pr_remote = "origin"
Expand Down