Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions __tests__/SqlUtils.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,65 @@ describe('SqlUtils tests', () => {

expect(error).toBeDefined();
expect(error!.message).toMatch('Failed to add firewall rule. Unable to detect client IP Address.');
expect(mssqlSpy).toHaveBeenCalledTimes(2);
});

it('detectIPAddress should retry connection with DB if master connection fails', async () => {
const mssqlSpy = jest.spyOn(mssql, 'connect').mockImplementationOnce((config) => {
// First call, call the original to get login failure
return mssql.connect(config);
}).mockImplementationOnce((config) => {
// Second call, mock return successful connection
return new mssql.ConnectionPool('');
});

const ipAddress = await SqlUtils.detectIPAddress(getConnectionConfig());

expect(mssqlSpy).toHaveBeenCalledTimes(2);
expect(ipAddress).toBe('');
});

it('detectIPAddress should fail fast if initial connection fails with unknown error', async () => {
const mssqlSpy = jest.spyOn(mssql, 'connect').mockImplementationOnce((config) => {
if (config['database'] === 'master') {
throw new Error('This is an unknown error.');
}
});

let error: Error | undefined;
try {
await SqlUtils.detectIPAddress(getConnectionConfig());
}
catch (e) {
error = e;
}

expect(error).toBeDefined();
expect(error!.message).toMatch('This is an unknown error.');
expect(mssqlSpy).toHaveBeenCalledTimes(1);
});

it('detectIPAddress should fail if retry fails again', async () => {
const errorSpy = jest.spyOn(core, 'error');
const mssqlSpy = jest.spyOn(mssql, 'connect').mockImplementation((config) => {
throw new mssql.ConnectionError(new Error('Custom connection error message.'));
})

let error: Error | undefined;
try {
await SqlUtils.detectIPAddress(getConnectionConfig());
}
catch (e) {
error = e;
}

expect(error).toBeDefined();
expect(error!.message).toMatch('Failed to add firewall rule. Unable to detect client IP Address.');
expect(mssqlSpy).toHaveBeenCalledTimes(2);
expect(errorSpy).toHaveBeenCalledTimes(1);
expect(errorSpy).toHaveBeenCalledWith('Custom connection error message.');
});

it('should report single MSSQLError', async () => {
const errorSpy = jest.spyOn(core, 'error');
const error = new mssql.RequestError(new Error('Fake error'));
Expand Down
2 changes: 1 addition & 1 deletion lib/main.js

Large diffs are not rendered by default.

130 changes: 96 additions & 34 deletions src/SqlUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,53 +4,115 @@ import * as mssql from "mssql";
import Constants from "./Constants";
import SqlConnectionConfig from "./SqlConnectionConfig";

export interface ConnectionResult {
/** True if connection succeeds, false otherwise */
success: boolean,

/** The connection object on success */
connection?: mssql.ConnectionPool,

/** Connection error on failure */
error?: mssql.ConnectionError,

/** Client IP address if connection fails due to firewall rule */
ipAddress?: string
}

export default class SqlUtils {

/**
* Tries connection to server to determine if client IP address is restricted by the firewall.
* First tries with master connection, and then with user DB if first one fails.
* @param SqlConnectionConfig The connection configuration to try.
* @returns The client IP address if firewall restriction is present, or an empty string if connection succeeds. Throws otherwise.
*/
static async detectIPAddress(connectionConfig: SqlConnectionConfig): Promise<string> {
let ipAddress = '';
// First try connection to master
let result = await this.tryConnection(connectionConfig, true);
if (result.success) {
result.connection?.close();
return '';
}
else if (result.ipAddress) {
return result.ipAddress;
}

// Retry connection with user DB
result = await this.tryConnection(connectionConfig, false);
if (result.success) {
result.connection?.close();
return '';
}
else if (result.ipAddress) {
return result.ipAddress;
}
else {
this.reportMSSQLError(result.error!);
throw new Error(`Failed to add firewall rule. Unable to detect client IP Address.`);
}
}

/**
* Tries connection with the specified configuration.
* @param config Configuration for the connection.
* @param useMaster If true, uses "master" instead of the database specified in @param config. Every other config remains the same.
* @returns A ConnectionResult object indicating success/failure, the connection on success, or the error on failure.
*/
private static async tryConnection(config: SqlConnectionConfig, useMaster?: boolean): Promise<ConnectionResult> {
// Clone the connection config so we can change the database without modifying the original
const configClone = JSON.parse(JSON.stringify(connectionConfig.Config)) as mssql.config;
configClone.database = "master";
const connectionConfig = JSON.parse(JSON.stringify(config.Config)) as mssql.config;
if (useMaster) {
connectionConfig.database = "master";
}

try {
core.debug(`Validating if client has access to SQL Server '${configClone.server}'.`);
const pool = await mssql.connect(configClone);
pool.close();
core.debug(`Validating if client has access to '${connectionConfig.database}' on '${connectionConfig.server}'.`);
const pool = await mssql.connect(connectionConfig);
return {
success: true,
connection: pool
};
}
catch (error) {
if (error instanceof mssql.ConnectionError) {
return {
success: false,
error: error,
ipAddress: this.parseErrorForIpAddress(error)
};
}
else {
throw error; // Unknown error
}
}
catch (connectionError) {
if (connectionError instanceof mssql.ConnectionError) {
if (connectionError.originalError instanceof AggregateError) {
// The IP address error can be anywhere inside the AggregateError
for (const err of connectionError.originalError.errors) {
core.debug(err.message);
const ipAddresses = err.message.match(Constants.ipv4MatchPattern);
if (!!ipAddresses) {
ipAddress = ipAddresses[0];
break;
}
}
}
else {
core.debug(connectionError.originalError!.message);
const ipAddresses = connectionError.originalError!.message.match(Constants.ipv4MatchPattern);
if (!!ipAddresses) {
ipAddress = ipAddresses[0];
}
}
}

// There are errors that are not because of missing IP firewall rule
if (!ipAddress) {
this.reportMSSQLError(connectionError);
throw new Error(`Failed to add firewall rule. Unable to detect client IP Address.`);
/**
* Parse a ConnectionError to see if its message contains an IP address.
* Returns the IP address if found, otherwise undefined.
*/
private static parseErrorForIpAddress(connectionError: mssql.ConnectionError): string | undefined {
let ipAddress: string | undefined;

if (connectionError.originalError instanceof AggregateError) {
// The IP address error can be anywhere inside the AggregateError
for (const err of connectionError.originalError.errors) {
core.debug(err.message);
const ipAddresses = err.message.match(Constants.ipv4MatchPattern);
if (!!ipAddresses) {
ipAddress = ipAddresses[0];
break;
}
}
else {
// Unknown error
throw connectionError;
}
else {
core.debug(connectionError.originalError!.message);
const ipAddresses = connectionError.originalError!.message.match(Constants.ipv4MatchPattern);
if (!!ipAddresses) {
ipAddress = ipAddresses[0];
}
}

//ipAddress will be an empty string if client has access to SQL server
return ipAddress;
}

Expand Down