Files
OwinOAuthProviders/src/Owin.Security.Providers.Salesforce/SalesforceAuthenticationHandler.cs
Tommy Parnell b8055739ff rm semicolon
2017-12-29 01:57:41 -05:00

327 lines
14 KiB
C#

using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Security.Claims;
using System.Threading.Tasks;
using System.Web;
using Microsoft.Owin.Infrastructure;
using Microsoft.Owin.Logging;
using Microsoft.Owin.Security;
using Microsoft.Owin.Security.Infrastructure;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Owin.Security.Providers.Salesforce
{
public class SalesforceAuthenticationHandler : AuthenticationHandler<SalesforceAuthenticationOptions>
{
private const string XmlSchemaString = "http://www.w3.org/2001/XMLSchema#string";
private const string ProductionHost = "https://login.salesforce.com";
private const string SandboxHost = "https://test.salesforce.com";
private const string AuthorizationEndpoint = "/services/oauth2/authorize";
private const string TokenEndpoint = "/services/oauth2/token";
private readonly ILogger _logger;
private readonly HttpClient _httpClient;
public SalesforceAuthenticationHandler(HttpClient httpClient, ILogger logger)
{
_httpClient = httpClient;
_logger = logger;
}
protected override async Task<AuthenticationTicket> AuthenticateCoreAsync()
{
var properties = new AuthenticationProperties();
try
{
string code = null;
string state = null;
var query = Request.Query;
var values = query.GetValues("code");
if (values != null && values.Count == 1)
{
code = values[0];
}
values = query.GetValues("state");
if (values != null && values.Count == 1)
{
state = values[0];
}
properties = Options.StateDataFormat.Unprotect(state);
if (properties == null)
{
return null;
}
// OAuth2 10.12 CSRF
if (!ValidateCorrelationId(properties, _logger))
{
return new AuthenticationTicket(null, properties);
}
var requestPrefix = Request.Scheme + "://" + Request.Host;
var redirectUri = requestPrefix + Request.PathBase + Options.CallbackPath;
// Build up the body for the token request
var body = new List<KeyValuePair<string, string>>
{
new KeyValuePair<string, string>("code", code),
new KeyValuePair<string, string>("redirect_uri", redirectUri),
new KeyValuePair<string, string>("client_id", Options.ClientId),
new KeyValuePair<string, string>("client_secret", Options.ClientSecret),
new KeyValuePair<string, string>("grant_type", "authorization_code")
};
// Request the token
var requestMessage = new HttpRequestMessage(HttpMethod.Post, ComposeTokenEndpoint(properties));
requestMessage.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
requestMessage.Content = new FormUrlEncodedContent(body);
var tokenResponse = await _httpClient.SendAsync(requestMessage);
tokenResponse.EnsureSuccessStatusCode();
var text = await tokenResponse.Content.ReadAsStringAsync();
// Deserializes the token response
dynamic response = JsonConvert.DeserializeObject<dynamic>(text);
var accessToken = (string)response.access_token;
var refreshToken = (string)response.refresh_token;
var instanceUrl = (string)response.instance_url;
// Get the Salesforce user using the user info endpoint, which is part of the token - response.id
var userRequest = new HttpRequestMessage(HttpMethod.Get, (string)response.id + "?access_token=" + Uri.EscapeDataString(accessToken));
userRequest.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
var userResponse = await _httpClient.SendAsync(userRequest, Request.CallCancelled);
userResponse.EnsureSuccessStatusCode();
text = await userResponse.Content.ReadAsStringAsync();
var user = JObject.Parse(text);
var context = new SalesforceAuthenticatedContext(Context, user, accessToken, refreshToken, instanceUrl)
{
Identity = new ClaimsIdentity(
Options.AuthenticationType,
ClaimsIdentity.DefaultNameClaimType,
ClaimsIdentity.DefaultRoleClaimType)
};
if (!string.IsNullOrEmpty(context.UserId))
{
context.Identity.AddClaim(new Claim(ClaimTypes.NameIdentifier, context.UserId, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.UserName))
{
context.Identity.AddClaim(new Claim(ClaimsIdentity.DefaultNameClaimType, context.UserName, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.Email))
{
context.Identity.AddClaim(new Claim(ClaimTypes.Email, context.Email, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.FirstName))
{
context.Identity.AddClaim(new Claim(ClaimTypes.GivenName, context.FirstName, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.LastName))
{
context.Identity.AddClaim(new Claim(ClaimTypes.Surname, context.LastName, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.DisplayName))
{
context.Identity.AddClaim(new Claim("urn:Salesforce:name", context.DisplayName, XmlSchemaString, Options.AuthenticationType));
}
if (!string.IsNullOrEmpty(context.OrganizationId))
{
context.Identity.AddClaim(new Claim("urn:Salesforce:organization_id", context.OrganizationId, XmlSchemaString, Options.AuthenticationType));
}
context.Properties = properties;
await Options.Provider.Authenticated(context);
return new AuthenticationTicket(context.Identity, context.Properties);
}
catch (Exception ex)
{
_logger.WriteError(ex.Message, ex);
}
return new AuthenticationTicket(null, properties);
}
protected override Task ApplyResponseChallengeAsync()
{
if (Response.StatusCode != 401)
{
return Task.FromResult<object>(null);
}
var challenge = Helper.LookupChallenge(Options.AuthenticationType, Options.AuthenticationMode);
if (challenge == null) return Task.FromResult<object>(null);
var baseUri =
Request.Scheme +
Uri.SchemeDelimiter +
Request.Host +
Request.PathBase;
var currentUri =
baseUri +
Request.Path +
Request.QueryString;
var redirectUri =
baseUri +
Options.CallbackPath;
var properties = challenge.Properties;
if (string.IsNullOrEmpty(properties.RedirectUri))
{
properties.RedirectUri = currentUri;
}
// OAuth2 10.12 CSRF
GenerateCorrelationId(properties);
var state = Options.StateDataFormat.Protect(properties);
var authorizationEndpoint = ComposeAuthorizationEndpoint(properties);
authorizationEndpoint =
$"{authorizationEndpoint}?response_type={"code"}&client_id={Options.ClientId}&redirect_uri={HttpUtility.UrlEncode(redirectUri)}&display={"page"}&immediate={false}&state={Uri.EscapeDataString(state)}";
if (Options.Scope != null && Options.Scope.Count > 0)
{
authorizationEndpoint += $"&scope={string.Join(" ", Options.Scope)}";
}
if (Options.Prompt != null)
{
authorizationEndpoint += $"&prompt={Options.Prompt}";
}
Response.Redirect(authorizationEndpoint);
return Task.FromResult<object>(null);
}
public override async Task<bool> InvokeAsync()
{
return await InvokeReplyPathAsync();
}
private async Task<bool> InvokeReplyPathAsync()
{
if (!Options.CallbackPath.HasValue || Options.CallbackPath != Request.Path) return false;
// TODO: error responses
var ticket = await AuthenticateAsync();
if (ticket == null)
{
_logger.WriteWarning("Invalid return state, unable to redirect.");
Response.StatusCode = 500;
return true;
}
var context = new SalesforceReturnEndpointContext(Context, ticket)
{
SignInAsAuthenticationType = Options.SignInAsAuthenticationType,
RedirectUri = ticket.Properties.RedirectUri
};
await Options.Provider.ReturnEndpoint(context);
if (context.SignInAsAuthenticationType != null &&
context.Identity != null)
{
var grantIdentity = context.Identity;
if (!string.Equals(grantIdentity.AuthenticationType, context.SignInAsAuthenticationType, StringComparison.Ordinal))
{
grantIdentity = new ClaimsIdentity(grantIdentity.Claims, context.SignInAsAuthenticationType, grantIdentity.NameClaimType, grantIdentity.RoleClaimType);
}
Context.Authentication.SignIn(context.Properties, grantIdentity);
}
if (context.IsRequestCompleted || context.RedirectUri == null) return context.IsRequestCompleted;
var redirectUri = context.RedirectUri;
if (context.Identity == null)
{
// add a redirect hint that sign-in failed in some way
redirectUri = WebUtilities.AddQueryString(redirectUri, "error", "access_denied");
}
Response.Redirect(redirectUri);
context.RequestCompleted();
return context.IsRequestCompleted;
}
private string ComposeAuthorizationEndpoint(AuthenticationProperties properties) {
string endpointPath = AuthorizationEndpoint;
string endpoint =
!String.IsNullOrEmpty(Options.Endpoints.AuthorizationEndpoint) ?
Options.Endpoints.AuthorizationEndpoint :
ComposeEndpoint(properties, endpointPath);
// if AuthenticationProperties for this session specifies an environment property
// it should take precedence over the value in AuthenticationOptions
string environmentProperty = null;
if (properties.Dictionary.TryGetValue(Constants.EnvironmentAuthenticationProperty, out environmentProperty)) {
endpoint =
environmentProperty == Constants.SandboxEnvironment ?
SandboxHost + endpointPath :
ProductionHost + endpointPath;
}
return endpoint;
}
private string ComposeTokenEndpoint(AuthenticationProperties properties) {
string endpointPath = TokenEndpoint;
string endpoint =
!String.IsNullOrEmpty(Options.Endpoints.TokenEndpoint) ?
Options.Endpoints.TokenEndpoint :
ComposeEndpoint(properties, endpointPath);
// if AuthenticationProperties for this session specifies an environment property
// it should take precedence over the value in AuthenticationOptions
string environmentProperty = null; ;
if (properties.Dictionary.TryGetValue(Constants.EnvironmentAuthenticationProperty, out environmentProperty)) {
endpoint =
environmentProperty == Constants.SandboxEnvironment ?
SandboxHost + endpointPath :
ProductionHost + endpointPath;
}
return endpoint;
}
private string ComposeEndpoint(AuthenticationProperties properties, string endpointPath) {
string endpoint =
!String.IsNullOrEmpty(Options.Endpoints.Environment) && Options.Endpoints.Environment == Constants.SandboxEnvironment ?
SandboxHost + endpointPath :
ProductionHost + endpointPath;
// if AuthenticationProperties for this session specifies an environment property
// it should take precedence over the value in AuthenticationOptions
string environmentProperty = null; ;
if (properties.Dictionary.TryGetValue(Constants.EnvironmentAuthenticationProperty, out environmentProperty)) {
endpoint =
environmentProperty == Constants.SandboxEnvironment ?
SandboxHost + endpointPath :
ProductionHost + endpointPath;
}
return endpoint;
}
}
}