11using System ;
22using System . Collections . Generic ;
3+ using System . Diagnostics ;
34using System . IO ;
45using System . IO . Compression ;
56using System . Linq ;
@@ -58,6 +59,35 @@ public async Task InstallRequirement_ThrowsWhenRequirementMissing()
5859 await Assert . ThrowsAsync < FileNotFoundException > ( ( ) => manager . InstallRequirement ( Path . Combine ( pythonDir , "requirements.txt" ) , _ => { } , _ => { } ) ) ;
5960 }
6061
62+ [ Fact ]
63+ public async Task InstallPackagesAsync_ThrowsWhenPackagesNull ( )
64+ {
65+ var manager = new RecordingPythonEmbedManager ( CreateTempDirectory ( ) ) ;
66+ await Assert . ThrowsAsync < ArgumentNullException > ( ( ) => manager . InstallPackagesAsync ( null ! , null , _ => { } , _ => { } ) ) ;
67+ }
68+
69+ [ Fact ]
70+ public async Task InstallPackagesAsync_ThrowsWhenPackagesEmpty ( )
71+ {
72+ var manager = new RecordingPythonEmbedManager ( CreateTempDirectory ( ) ) ;
73+ await Assert . ThrowsAsync < ArgumentException > ( ( ) => manager . InstallPackagesAsync ( new [ ] { " " , string . Empty } , null , _ => { } , _ => { } ) ) ;
74+ }
75+
76+ [ Fact ]
77+ public async Task InstallPackagesAsync_InvokesPipWithPackagesAndIndex ( )
78+ {
79+ var manager = new RecordingPythonEmbedManager ( CreateTempDirectory ( ) ) ;
80+
81+ await manager . InstallPackagesAsync ( new [ ] { "torch" , "custom package" } , "https://example.com/simple" , _ => { } , _ => { } ) ;
82+
83+ var call = Assert . Single ( manager . RunProcessCalls ) ;
84+ Assert . Contains ( "-m pip install" , call . Arguments , StringComparison . OrdinalIgnoreCase ) ;
85+ Assert . Contains ( "torch" , call . Arguments , StringComparison . Ordinal ) ;
86+ Assert . Contains ( "\" custom package\" " , call . Arguments , StringComparison . Ordinal ) ;
87+ Assert . Contains ( "--index-url" , call . Arguments , StringComparison . Ordinal ) ;
88+ Assert . Contains ( "https://example.com/simple" , call . Arguments , StringComparison . Ordinal ) ;
89+ }
90+
6191 [ Fact ]
6292 public async Task InstallRequirement_InvokesPipInVirtualEnvironment ( )
6393 {
@@ -72,14 +102,16 @@ public async Task InstallRequirement_InvokesPipInVirtualEnvironment()
72102 Assert . Contains ( "venv" , call . FileName ) ;
73103 Assert . Contains ( "python" , call . FileName , StringComparison . OrdinalIgnoreCase ) ;
74104 Assert . Contains ( "pip install" , call . Arguments , StringComparison . OrdinalIgnoreCase ) ;
105+ Assert . NotNull ( call . EnvironmentVariables ) ;
106+ Assert . Equal ( Path . Combine ( manager . GetPythonDir ( ) , "venv" ) , Assert . Contains ( "VIRTUAL_ENV" , call . EnvironmentVariables ! ) ) ;
75107 }
76108
77109 [ Fact ]
78110 public async Task RunPython_ThrowsWhenScriptMissing ( )
79111 {
80112 var manager = new PythonEmbedManager ( CreateTempDirectory ( ) ) ;
81113 var pythonDir = CreateTempDirectory ( ) ;
82- await Assert . ThrowsAsync < FileNotFoundException > ( ( ) => manager . RunPython ( Path . Combine ( pythonDir , "script.py" ) , null ! , null ! , _ => { } , _ => { } ) ) ;
114+ await Assert . ThrowsAsync < FileNotFoundException > ( ( ) => manager . RunPython ( Path . Combine ( pythonDir , "script.py" ) , null ! , null ! , _ => { } , _ => { } , _ => { } ) ) ;
83115 }
84116
85117 [ Fact ]
@@ -92,13 +124,43 @@ public async Task RunPython_ExecutesScriptThroughVirtualEnvironment()
92124 var scriptPath = Path . Combine ( scriptDirectory , "script.py" ) ;
93125 File . WriteAllText ( scriptPath , "print('test')" ) ;
94126
95- await manager . RunPython ( scriptPath , "--flag value" , null , _ => { } , _ => { } ) ;
127+ await manager . RunPython ( scriptPath , "--flag value" , null , _ => { } , _ => { } , _ => { } ) ;
96128
97129 var call = Assert . Single ( manager . RunProcessCalls ) ;
98130 Assert . Contains ( "venv" , call . FileName ) ;
99131 Assert . Contains ( "python" , call . FileName , StringComparison . OrdinalIgnoreCase ) ;
100132 Assert . Contains ( "script.py" , call . Arguments , StringComparison . OrdinalIgnoreCase ) ;
101133 Assert . Contains ( "--flag value" , call . Arguments , StringComparison . OrdinalIgnoreCase ) ;
134+ Assert . NotNull ( call . EnvironmentVariables ) ;
135+ Assert . Equal ( Path . Combine ( manager . GetPythonDir ( ) , "venv" ) , Assert . Contains ( "VIRTUAL_ENV" , call . EnvironmentVariables ! ) ) ;
136+ }
137+
138+ [ Fact ]
139+ public async Task InstallTorchWithCudaAsync_UsesProvidedCudaOverride ( )
140+ {
141+ var manager = new RecordingPythonEmbedManager ( CreateTempDirectory ( ) ) ;
142+
143+ var result = await manager . InstallTorchWithCudaAsync ( "2.5.1" , "cu126" , _ => { } , _ => { } ) ;
144+
145+ Assert . Equal ( 0 , result ) ;
146+
147+ var call = Assert . Single ( manager . RunProcessCalls ) ;
148+ Assert . Contains ( "--index-url" , call . Arguments , StringComparison . Ordinal ) ;
149+ Assert . Contains ( "download.pytorch.org/whl/cu126" , call . Arguments , StringComparison . Ordinal ) ;
150+ Assert . Contains ( "torch==2.5.1+cu126" , call . Arguments , StringComparison . Ordinal ) ;
151+ }
152+
153+ [ Fact ]
154+ public async Task InstallTorchWithCudaAsync_ReturnsErrorWhenCudaCannotBeDetected ( )
155+ {
156+ var manager = new RecordingPythonEmbedManager ( CreateTempDirectory ( ) ) ;
157+ var errors = new List < string > ( ) ;
158+
159+ var result = await manager . InstallTorchWithCudaAsync ( null , null , _ => { } , errors . Add ) ;
160+
161+ Assert . Equal ( - 1 , result ) ;
162+ Assert . NotEmpty ( errors ) ;
163+ Assert . Empty ( manager . RunProcessCalls ) ;
102164 }
103165
104166 private string CreateTempDirectory ( )
@@ -132,7 +194,7 @@ private sealed class RecordingPythonEmbedManager : PythonEmbedManager
132194 public bool DownloadFileCalled { get ; private set ; }
133195 public bool ExtractZipCalled { get ; private set ; }
134196 public List < ( string Url , string Destination ) > DownloadFileCalls { get ; } = new ( ) ;
135- public List < ( string FileName , string Arguments , string ? WorkingDirectory ) > RunProcessCalls { get ; } = new ( ) ;
197+ public List < ( string FileName , string Arguments , string ? WorkingDirectory , Dictionary < string , string > ? EnvironmentVariables ) > RunProcessCalls { get ; } = new ( ) ;
136198
137199 public RecordingPythonEmbedManager ( string pythonDir ) : base ( pythonDir )
138200 {
@@ -168,10 +230,16 @@ protected override string GetVirtualEnvironmentPythonExecutable()
168230 return path ;
169231 }
170232
171- protected override async Task < int > RunProcess ( string fileName , string arguments , string ? workingDirectory , Dictionary < string , string > ? environmentVariables , Action < string > onOutput , Action < string > onError )
233+ protected override async Task < int > RunProcess ( string fileName , string arguments , string ? workingDirectory , Dictionary < string , string > ? environmentVariables , Action < string > onOutput , Action < string > onError , Action < Process > onProcessStarted )
172234 {
173- RunProcessCalls . Add ( ( fileName , arguments , workingDirectory ) ) ;
235+ RunProcessCalls . Add ( ( fileName , arguments , workingDirectory , environmentVariables is null ? null : new Dictionary < string , string > ( environmentVariables ) ) ) ;
236+ onProcessStarted ? . Invoke ( new Process ( ) ) ;
174237 return await Task . FromResult ( 0 ) ;
175238 }
239+
240+ protected override async Task < string ? > DetectCudaTag ( Action < string > onOutput , Action < string > onError )
241+ {
242+ return await Task . FromResult < string ? > ( null ) ;
243+ }
176244 }
177245}
0 commit comments