Skip to content

Commit f3f87b4

Browse files
authored
Merge pull request #6 from xqiu/codex/fix-and-add-unit-tests-for-new-api
Update tests for new Python embed API
2 parents 3426262 + 7ed4bcd commit f3f87b4

2 files changed

Lines changed: 74 additions & 6 deletions

File tree

src/DotNetPythonEmbed/PythonEmbedManager.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ private static string QuoteArgument(string argument)
448448
return trimmed;
449449
}
450450

451-
private async Task<string?> DetectCudaTag(Action<string> onOutput, Action<string> onError)
451+
protected virtual async Task<string?> DetectCudaTag(Action<string> onOutput, Action<string> onError)
452452
{
453453
try
454454
{

tests/DotNetPythonEmbed.Tests/PythonEmbedManagerTests.cs

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34
using System.IO;
45
using System.IO.Compression;
56
using System.Linq;
@@ -58,6 +59,35 @@ public async Task InstallRequirement_ThrowsWhenRequirementMissing()
5859
await Assert.ThrowsAsync<FileNotFoundException>(() => manager.InstallRequirement(Path.Combine(pythonDir, "requirements.txt"), _ => { }, _ => { }));
5960
}
6061

62+
[Fact]
63+
public async Task InstallPackagesAsync_ThrowsWhenPackagesNull()
64+
{
65+
var manager = new RecordingPythonEmbedManager(CreateTempDirectory());
66+
await Assert.ThrowsAsync<ArgumentNullException>(() => manager.InstallPackagesAsync(null!, null, _ => { }, _ => { }));
67+
}
68+
69+
[Fact]
70+
public async Task InstallPackagesAsync_ThrowsWhenPackagesEmpty()
71+
{
72+
var manager = new RecordingPythonEmbedManager(CreateTempDirectory());
73+
await Assert.ThrowsAsync<ArgumentException>(() => manager.InstallPackagesAsync(new[] { " ", string.Empty }, null, _ => { }, _ => { }));
74+
}
75+
76+
[Fact]
77+
public async Task InstallPackagesAsync_InvokesPipWithPackagesAndIndex()
78+
{
79+
var manager = new RecordingPythonEmbedManager(CreateTempDirectory());
80+
81+
await manager.InstallPackagesAsync(new[] { "torch", "custom package" }, "https://example.com/simple", _ => { }, _ => { });
82+
83+
var call = Assert.Single(manager.RunProcessCalls);
84+
Assert.Contains("-m pip install", call.Arguments, StringComparison.OrdinalIgnoreCase);
85+
Assert.Contains("torch", call.Arguments, StringComparison.Ordinal);
86+
Assert.Contains("\"custom package\"", call.Arguments, StringComparison.Ordinal);
87+
Assert.Contains("--index-url", call.Arguments, StringComparison.Ordinal);
88+
Assert.Contains("https://example.com/simple", call.Arguments, StringComparison.Ordinal);
89+
}
90+
6191
[Fact]
6292
public async Task InstallRequirement_InvokesPipInVirtualEnvironment()
6393
{
@@ -72,14 +102,16 @@ public async Task InstallRequirement_InvokesPipInVirtualEnvironment()
72102
Assert.Contains("venv", call.FileName);
73103
Assert.Contains("python", call.FileName, StringComparison.OrdinalIgnoreCase);
74104
Assert.Contains("pip install", call.Arguments, StringComparison.OrdinalIgnoreCase);
105+
Assert.NotNull(call.EnvironmentVariables);
106+
Assert.Equal(Path.Combine(manager.GetPythonDir(), "venv"), Assert.Contains("VIRTUAL_ENV", call.EnvironmentVariables!));
75107
}
76108

77109
[Fact]
78110
public async Task RunPython_ThrowsWhenScriptMissing()
79111
{
80112
var manager = new PythonEmbedManager(CreateTempDirectory());
81113
var pythonDir = CreateTempDirectory();
82-
await Assert.ThrowsAsync<FileNotFoundException>(() => manager.RunPython(Path.Combine(pythonDir, "script.py"), null!, null!, _ => { }, _ => { }));
114+
await Assert.ThrowsAsync<FileNotFoundException>(() => manager.RunPython(Path.Combine(pythonDir, "script.py"), null!, null!, _ => { }, _ => { }, _ => { }));
83115
}
84116

85117
[Fact]
@@ -92,13 +124,43 @@ public async Task RunPython_ExecutesScriptThroughVirtualEnvironment()
92124
var scriptPath = Path.Combine(scriptDirectory, "script.py");
93125
File.WriteAllText(scriptPath, "print('test')");
94126

95-
await manager.RunPython(scriptPath, "--flag value", null, _ => { }, _ => { });
127+
await manager.RunPython(scriptPath, "--flag value", null, _ => { }, _ => { }, _ => { });
96128

97129
var call = Assert.Single(manager.RunProcessCalls);
98130
Assert.Contains("venv", call.FileName);
99131
Assert.Contains("python", call.FileName, StringComparison.OrdinalIgnoreCase);
100132
Assert.Contains("script.py", call.Arguments, StringComparison.OrdinalIgnoreCase);
101133
Assert.Contains("--flag value", call.Arguments, StringComparison.OrdinalIgnoreCase);
134+
Assert.NotNull(call.EnvironmentVariables);
135+
Assert.Equal(Path.Combine(manager.GetPythonDir(), "venv"), Assert.Contains("VIRTUAL_ENV", call.EnvironmentVariables!));
136+
}
137+
138+
[Fact]
139+
public async Task InstallTorchWithCudaAsync_UsesProvidedCudaOverride()
140+
{
141+
var manager = new RecordingPythonEmbedManager(CreateTempDirectory());
142+
143+
var result = await manager.InstallTorchWithCudaAsync("2.5.1", "cu126", _ => { }, _ => { });
144+
145+
Assert.Equal(0, result);
146+
147+
var call = Assert.Single(manager.RunProcessCalls);
148+
Assert.Contains("--index-url", call.Arguments, StringComparison.Ordinal);
149+
Assert.Contains("download.pytorch.org/whl/cu126", call.Arguments, StringComparison.Ordinal);
150+
Assert.Contains("torch==2.5.1+cu126", call.Arguments, StringComparison.Ordinal);
151+
}
152+
153+
[Fact]
154+
public async Task InstallTorchWithCudaAsync_ReturnsErrorWhenCudaCannotBeDetected()
155+
{
156+
var manager = new RecordingPythonEmbedManager(CreateTempDirectory());
157+
var errors = new List<string>();
158+
159+
var result = await manager.InstallTorchWithCudaAsync(null, null, _ => { }, errors.Add);
160+
161+
Assert.Equal(-1, result);
162+
Assert.NotEmpty(errors);
163+
Assert.Empty(manager.RunProcessCalls);
102164
}
103165

104166
private string CreateTempDirectory()
@@ -132,7 +194,7 @@ private sealed class RecordingPythonEmbedManager : PythonEmbedManager
132194
public bool DownloadFileCalled { get; private set; }
133195
public bool ExtractZipCalled { get; private set; }
134196
public List<(string Url, string Destination)> DownloadFileCalls { get; } = new();
135-
public List<(string FileName, string Arguments, string? WorkingDirectory)> RunProcessCalls { get; } = new();
197+
public List<(string FileName, string Arguments, string? WorkingDirectory, Dictionary<string, string>? EnvironmentVariables)> RunProcessCalls { get; } = new();
136198

137199
public RecordingPythonEmbedManager(string pythonDir) : base(pythonDir)
138200
{
@@ -168,10 +230,16 @@ protected override string GetVirtualEnvironmentPythonExecutable()
168230
return path;
169231
}
170232

171-
protected override async Task<int> RunProcess(string fileName, string arguments, string? workingDirectory, Dictionary<string, string>? environmentVariables, Action<string> onOutput, Action<string> onError)
233+
protected override async Task<int> RunProcess(string fileName, string arguments, string? workingDirectory, Dictionary<string, string>? environmentVariables, Action<string> onOutput, Action<string> onError, Action<Process> onProcessStarted)
172234
{
173-
RunProcessCalls.Add((fileName, arguments, workingDirectory));
235+
RunProcessCalls.Add((fileName, arguments, workingDirectory, environmentVariables is null ? null : new Dictionary<string, string>(environmentVariables)));
236+
onProcessStarted?.Invoke(new Process());
174237
return await Task.FromResult(0);
175238
}
239+
240+
protected override async Task<string?> DetectCudaTag(Action<string> onOutput, Action<string> onError)
241+
{
242+
return await Task.FromResult<string?>(null);
243+
}
176244
}
177245
}

0 commit comments

Comments
 (0)