summaryrefslogtreecommitdiff
path: root/src/saml2/assertion.py
blob: b51ec54c8fe2babad6bb6d4be9c91374c4d80120 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2010-2011 UmeƄ University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
import sys
import xmlenc

from saml2 import saml

from saml2.time_util import instant, in_a_while
from saml2.attribute_converter import from_local

from saml2.s_utils import sid, MissingValue
from saml2.s_utils import factory
from saml2.s_utils import assertion_factory


def _filter_values(vals, vlist=None, must=False):
    """ Removes values from *vals* that does not appear in vlist
    
    :param vals: The values that are to be filtered
    :param vlist: required or optional value
    :param must: Whether the allowed values must appear
    :return: The set of values after filtering
    """
    
    if not vlist: # No value specified equals any value
        return vals
    
    if isinstance(vlist, basestring):
        vlist = [vlist]
        
    res = []
    
    for val in vlist:
        if val in vals:
            res.append(val)
    
    if must:
        if res:
            return res
        else:
            raise MissingValue("Required attribute value missing")
    else:
        return res


def filter_on_attributes(ava, required=None, optional=None):
    """ Filter
    
    :param ava: An attribute value assertion as a dictionary
    :param required: list of RequestedAttribute instances defined to be 
        required
    :param optional: list of     RequestedAttribute instances defined to be 
        optional
    :return: The modified attribute value assertion
    """
    res = {}
    
    if required is None:
        required = []
        
    for attr in required:
        if attr.friendly_name in ava:
            values = [av.text for av in attr.attribute_value]
            res[attr.friendly_name] = _filter_values(ava[attr.friendly_name], values, True)
        elif attr.name in ava:
            values = [av.text for av in attr.attribute_value]
            res[attr.name] = _filter_values(ava[attr.name], values, True)
        else:
            print >> sys.stderr, ava.keys()
            raise MissingValue("Required attribute missing: '%s'" % (attr.friendly_name,))

    if optional is None:
        optional = []
        
    for attr in optional:
        if attr.friendly_name in ava:
            values = [av.text for av in attr.attribute_value]
            try:
                res[attr.friendly_name].extend(_filter_values(ava[attr.friendly_name], values))
            except KeyError:
                res[attr.friendly_name] = _filter_values(ava[attr.friendly_name], values)
        elif attr.name in ava:
            values = [av.text for av in attr.attribute_value]
            try:
                res[attr.name].extend(_filter_values(ava[attr.name], values))
            except KeyError:
                res[attr.name] = _filter_values(ava[attr.name], values)
    
    return res

def filter_on_demands(ava, required=None, optional=None):
    """ Never return more than is needed. Filters out everything
    the server is prepared to return but the receiver doesn't ask for
    
    :param ava: Attribute value assertion as a dictionary
    :param required: Required attributes
    :param optional: Optional attributes
    :return: The possibly reduced assertion
    """
    
    # Is all what's required there:
    if required is None:
        required = {}
        
    for attr, vals in required.items():
        if attr in ava:
            if vals:
                for val in vals:
                    if val not in ava[attr]:
                        raise MissingValue(
                            "Required attribute value missing: %s,%s" % (attr,
                                                                        val))
        else:
            raise MissingValue("Required attribute missing: %s" % (attr,))

    if optional is None:
        optional = {}
        
    # OK, so I can imaging releasing values that are not absolutely necessary
    # but not attributes
    for attr, vals in ava.items():
        if attr not in required and attr not in optional:
            del ava[attr]
    
    return ava

def filter_on_wire_representation(ava, acs, required=None, optional=None):
    """
    :param ava: A dictionary with attributes and values
    :param required: A list of saml.Attributes
    :param optional: A list of saml.Attributes
    :return: Dictionary of expected/wanted attributes and values
    """
    acsdic = dict([(ac.name_format, ac) for ac in acs])

    if required is None:
        required = []
    if optional is None:
        optional = []

    res = {}
    for attr, val in ava.items():
        done = False
        for req in required:
            try:
                _name = acsdic[req.name_format]._to[attr]
                if _name == req.name:
                    res[attr] = val
                    done = True
            except KeyError:
                pass
        if done:
            continue
        for opt in optional:
            try:
                _name = acsdic[opt.name_format]._to[attr]
                if _name == opt.name:
                    res[attr] = val
                    break
            except KeyError:
                pass

    return res

def filter_attribute_value_assertions(ava, attribute_restrictions=None):
    """ Will weed out attribute values and values according to the
    rules defined in the attribute restrictions. If filtering results in
    an attribute without values, then the attribute is removed from the
    assertion.
    
    :param ava: The incoming attribute value assertion (dictionary)
    :param attribute_restrictions: The rules that govern which attributes
        and values that are allowed. (dictionary)
    :return: The modified attribute value assertion
    """
    if not attribute_restrictions:
        return ava
    
    for attr, vals in ava.items():
        if attr in attribute_restrictions:
            if attribute_restrictions[attr]:
                rvals = []
                for restr in attribute_restrictions[attr]:
                    for val in vals:
                        if restr.match(val):
                            rvals.append(val)
                
                if rvals:
                    ava[attr] = list(set(rvals))
                else:
                    del ava[attr]
        else:
            del ava[attr]
    return ava

class Policy(object):
    """ handles restrictions on assertions """
    
    def __init__(self, restrictions=None):
        if restrictions:
            self.compile(restrictions)
        else:
            self._restrictions = None
    
    def compile(self, restrictions):
        """ This is only for IdPs or AAs, and it's about limiting what
        is returned to the SP.
        In the configuration file, restrictions on which values that
        can be returned are specified with the help of regular expressions.
        This function goes through and pre-compiles the regular expressions.
        
        :param restrictions:
        :return: The assertion with the string specification replaced with
            a compiled regular expression.
        """
        
        self._restrictions = restrictions.copy()
        
        for _, spec in self._restrictions.items():
            if spec is None:
                continue
            
            try:
                restr = spec["attribute_restrictions"]
            except KeyError:
                continue
            
            if restr is None:
                continue
            
            for key, values in restr.items():
                if not values:
                    spec["attribute_restrictions"][key] = None
                    continue
                
                spec["attribute_restrictions"][key] = \
                        [re.compile(value) for value in values]
        
        return self._restrictions
    
    def get_nameid_format(self, sp_entity_id):
        """ Get the NameIDFormat to used for the entity id 
        :param: The SP entity ID
        :retur: The format
        """
        try:
            form = self._restrictions[sp_entity_id]["nameid_format"]
        except KeyError:
            try:
                form = self._restrictions["default"]["nameid_format"]
            except KeyError:
                form = saml.NAMEID_FORMAT_TRANSIENT
        
        return form
    
    def get_name_form(self, sp_entity_id):
        """ Get the NameFormat to used for the entity id 
        :param: The SP entity ID
        :retur: The format
        """
        form = ""
        
        try:
            form = self._restrictions[sp_entity_id]["name_form"]
        except TypeError:
            pass
        except KeyError:
            try:
                form = self._restrictions["default"]["name_form"]
            except KeyError:
                pass
        
        return form
    
    def get_lifetime(self, sp_entity_id):
        """ The lifetime of the assertion 
        :param sp_entity_id: The SP entity ID
        :param: lifetime as a dictionary 
        """
        # default is a hour
        spec = {"hours":1}
        if not self._restrictions:
            return spec
        
        try:
            spec = self._restrictions[sp_entity_id]["lifetime"]
        except KeyError:
            try:
                spec = self._restrictions["default"]["lifetime"]
            except KeyError:
                pass
        
        return spec
    
    def get_attribute_restriction(self, sp_entity_id):
        """ Return the attribute restriction for SP that want the information
        
        :param sp_entity_id: The SP entity ID
        :return: The restrictions
        """
        
        if not self._restrictions:
            return None
        
        try:
            try:
                restrictions = self._restrictions[sp_entity_id][
                                                "attribute_restrictions"]
            except KeyError:
                try:
                    restrictions = self._restrictions["default"][
                                                "attribute_restrictions"]
                except KeyError:
                    restrictions = None
        except KeyError:
            restrictions = None
        
        return restrictions
    
    def not_on_or_after(self, sp_entity_id):
        """ When the assertion stops being valid, should not be
        used after this time.
        
        :param sp_entity_id: The SP entity ID
        :return: String representation of the time
        """
        
        return in_a_while(**self.get_lifetime(sp_entity_id))
    
    def filter(self, ava, sp_entity_id, required=None, optional=None):
        """ What attribute and attribute values returns depends on what
        the SP has said it wants in the request or in the metadata file and
        what the IdP/AA wants to release. An assumption is that what the SP
        asks for overrides whatever is in the metadata. But of course the
        IdP never releases anything it doesn't want to.
        
        :param ava: The information about the subject as a dictionary
        :param sp_entity_id: The entity ID of the SP
        :param required: Attributes that the SP requires in the assertion
        :param optional: Attributes that the SP regards as optional
        :return: A possibly modified AVA
        """
                                
        
        ava = filter_attribute_value_assertions(ava,
                                self.get_attribute_restriction(sp_entity_id))
        
        if required or optional:
            ava = filter_on_attributes(ava, required, optional)
        
        return ava
    
    def restrict(self, ava, sp_entity_id, metadata=None):
        """ Identity attribute names are expected to be expressed in
        the local lingo (== friendlyName)
        
        :return: A filtered ava according to the IdPs/AAs rules and
            the list of required/optional attributes according to the SP.
            If the requirements can't be met an exception is raised.
        """
        if metadata:
            (required, optional) = metadata.attribute_consumer(sp_entity_id)
            #(required, optional) = metadata.wants(sp_entity_id)
        else:
            required = optional = None
        
        return self.filter(ava, sp_entity_id, required, optional)
    
    def conditions(self, sp_entity_id):
        """ Return a saml.Condition instance
        
        :param sp_entity_id: The SP entity ID
        :return: A saml.Condition instance
        """
        return factory( saml.Conditions,
                        not_before=instant(),
                        # How long might depend on who's getting it
                        not_on_or_after=self.not_on_or_after(sp_entity_id),
                        audience_restriction=[factory( saml.AudienceRestriction,
                                audience=factory(saml.Audience, 
                                                text=sp_entity_id))])

class Assertion(dict):
    """ Handles assertions about subjects """
    
    def __init__(self, dic=None):
        dict.__init__(self, dic)
    
    def _authn_context_decl_ref(self, authn_class):
        # authn_class: saml.AUTHN_PASSWORD
        return factory(saml.AuthnContext, 
                        authn_context_decl_ref=factory(
                                saml.AuthnContextDeclRef, text=authn_class))

    def _authn_context_class_ref(self, authn_class, authn_auth=None):
        # authn_class: saml.AUTHN_PASSWORD
        cntx_class = factory(saml.AuthnContextClassRef, text=authn_class)
        if authn_auth:
            return factory(saml.AuthnContext, 
                            authn_context_class_ref=cntx_class,
                            authenticating_authority=factory(
                                                saml.AuthenticatingAuthority,
                                                text=authn_auth))
        else:
            return factory(saml.AuthnContext, 
                            authn_context_class_ref=cntx_class)
        
    def _authn_statement(self, authn_class=None, authn_auth=None, 
                            authn_decl=None):
        if authn_class:
            return factory(saml.AuthnStatement, 
                        authn_instant=instant(), 
                        session_index=sid(),
                        authn_context=self._authn_context_class_ref(
                                                    authn_class, authn_auth))
        elif authn_decl:
            return factory(saml.AuthnStatement, 
                        authn_instant=instant(), 
                        session_index=sid(),
                        authn_context=self._authn_context_decl_ref(authn_decl))
        else:
            return factory(saml.AuthnStatement,
                        authn_instant=instant(), 
                        session_index=sid())

    def construct(self, sp_entity_id, in_response_to, consumer_url,
                    name_id, attrconvs, policy, issuer, authn_class=None, 
                    authn_auth=None, authn_decl=None, encrypt=None,
                    sec_context=None):
        """ Construct the Assertion 
        
        :param sp_entity_id: The entityid of the SP
        :param in_response_to: An identifier of the message, this message is 
            a response to
        :param consumer_url: The intended consumer of the assertion
        :param name_id: An NameID instance
        :param attrconvs: AttributeConverters
        :param policy: The policy that should be adhered to when replying
        :param issuer: Who is issuing the statement
        :param authn_class: The authentication class
        :param authn_auth: The authentication instance
        :param authn_decl:
        :param encrypt: Whether to encrypt parts or all of the Assertion
        :param sec_context: The security context used when encrypting
        :return: An Assertion instance
        """
        attr_statement = saml.AttributeStatement(attribute=from_local(
                                attrconvs, self, 
                                policy.get_name_form(sp_entity_id)))

        if encrypt == "attributes":
            for attr in attr_statement.attribute:
                enc = sec_context.encrypt(text="%s" % attr)

                encd = xmlenc.encrypted_data_from_string(enc)
                encattr = saml.EncryptedAttribute(encrypted_data=encd)
                attr_statement.encrypted_attribute.append(encattr)

            attr_statement.attribute = []

        # start using now and for some time
        conds = policy.conditions(sp_entity_id)
        
        return assertion_factory(
            issuer=issuer,
            attribute_statement = attr_statement,
            authn_statement = self._authn_statement(authn_class, authn_auth, 
                                                    authn_decl),
            conditions = conds,
            subject=factory( saml.Subject,
                name_id=name_id,
                subject_confirmation=factory( saml.SubjectConfirmation,
                                method=saml.SUBJECT_CONFIRMATION_METHOD_BEARER,
                                subject_confirmation_data=factory(
                                    saml.SubjectConfirmationData,
                                    in_response_to=in_response_to,
                                    recipient=consumer_url,
                                    not_on_or_after=policy.not_on_or_after(
                                                            sp_entity_id)))),
            )
    
    def apply_policy(self, sp_entity_id, policy, metadata=None):
        """ Apply policy to the assertion I'm representing 
        
        :param sp_entity_id: The SP entity ID
        :param policy: The policy
        :param metadata: Metadata to use
        :return: The resulting AVA after the policy is applied
        """
        return policy.restrict(self, sp_entity_id, metadata)