Implement state validation

This commit is contained in:
Jaex 2022-12-07 18:51:48 +03:00
parent 63a5d8dfd3
commit dd8baa9cb6
3 changed files with 20 additions and 6 deletions

View file

@ -39,6 +39,7 @@ public class GoogleOAuth2 : IOAuth2Loopback
public OAuth2Info AuthInfo { get; private set; } public OAuth2Info AuthInfo { get; private set; }
private Uploader GoogleUploader { get; set; } private Uploader GoogleUploader { get; set; }
public string RedirectURI { get; set; } public string RedirectURI { get; set; }
public string State { get; set; }
public string Scope { get; set; } public string Scope { get; set; }
public GoogleOAuth2(OAuth2Info oauth, Uploader uploader) public GoogleOAuth2(OAuth2Info oauth, Uploader uploader)
@ -53,6 +54,7 @@ public string GetAuthorizationURL()
args.Add("response_type", "code"); args.Add("response_type", "code");
args.Add("client_id", AuthInfo.Client_ID); args.Add("client_id", AuthInfo.Client_ID);
args.Add("redirect_uri", RedirectURI); args.Add("redirect_uri", RedirectURI);
args.Add("state", State);
args.Add("scope", Scope); args.Add("scope", Scope);
return URLHelpers.CreateQueryString(AuthorizationEndpoint, args); return URLHelpers.CreateQueryString(AuthorizationEndpoint, args);

View file

@ -28,5 +28,7 @@ namespace ShareX.UploadersLib
public interface IOAuth2Loopback : IOAuth2 public interface IOAuth2Loopback : IOAuth2
{ {
string RedirectURI { get; set; } string RedirectURI { get; set; }
string State { get; set; }
string Scope { get; set; }
} }
} }

View file

@ -36,7 +36,6 @@ namespace ShareX.UploadersLib
public class OAuthListener : IDisposable public class OAuthListener : IDisposable
{ {
public IOAuth2Loopback OAuth { get; private set; } public IOAuth2Loopback OAuth { get; private set; }
public string Code { get; private set; }
private HttpListener listener; private HttpListener listener;
@ -54,13 +53,14 @@ public void Dispose()
public async Task<bool> ConnectAsync() public async Task<bool> ConnectAsync()
{ {
Dispose(); Dispose();
Code = null;
IPAddress ip = IPAddress.Loopback; IPAddress ip = IPAddress.Loopback;
int port = URLHelpers.GetRandomUnusedPort(); int port = URLHelpers.GetRandomUnusedPort();
string redirectURI = string.Format($"http://{ip}:{port}/"); string redirectURI = string.Format($"http://{ip}:{port}/");
string state = Helpers.GetRandomAlphanumeric(32);
OAuth.RedirectURI = redirectURI; OAuth.RedirectURI = redirectURI;
OAuth.State = state;
string url = OAuth.GetAuthorizationURL(); string url = OAuth.GetAuthorizationURL();
if (!string.IsNullOrEmpty(url)) if (!string.IsNullOrEmpty(url))
@ -74,6 +74,10 @@ public async Task<bool> ConnectAsync()
return false; return false;
} }
string queryCode = null;
string queryState = null;
bool stateValidation = false;
try try
{ {
listener = new HttpListener(); listener = new HttpListener();
@ -81,13 +85,19 @@ public async Task<bool> ConnectAsync()
listener.Start(); listener.Start();
HttpListenerContext context = await listener.GetContextAsync(); HttpListenerContext context = await listener.GetContextAsync();
Code = context.Request.QueryString.Get("code"); queryCode = context.Request.QueryString.Get("code");
queryState = context.Request.QueryString.Get("state");
stateValidation = !string.IsNullOrEmpty(queryState) && queryState == state;
using (HttpListenerResponse response = context.Response) using (HttpListenerResponse response = context.Response)
{ {
string status; string status;
if (!string.IsNullOrEmpty(Code)) if (!stateValidation)
{
status = "Invalid state parameter.";
}
else if (!string.IsNullOrEmpty(queryCode))
{ {
status = "Authorization completed successfully."; status = "Authorization completed successfully.";
} }
@ -113,9 +123,9 @@ public async Task<bool> ConnectAsync()
Dispose(); Dispose();
} }
if (!string.IsNullOrEmpty(Code)) if (stateValidation && !string.IsNullOrEmpty(queryCode))
{ {
return await Task.Run(() => OAuth.GetAccessToken(Code)); return await Task.Run(() => OAuth.GetAccessToken(queryCode));
} }
return false; return false;