Spaces:
Paused
Paused
| import urllib.request, urllib.error, urllib.parse | |
| import json | |
| import pandas as pd | |
| import ssl | |
| import torch | |
| import re | |
| from pprint import pprint | |
| from captum.attr import visualization | |
| REST_URL = "http://data.bioontology.org" | |
| API_KEY = "604a90bc-ef14-4c26-a347-f4928fa086ea" | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| class PyTMinMaxScalerVectorized(object): | |
| """ | |
| From https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455 | |
| Transforms each channel to the range [0, 1]. | |
| """ | |
| def __call__(self, tensor): | |
| scale = 1.0 / (tensor.max(dim=0, keepdim=True)[0] - tensor.min(dim=0, keepdim=True)[0]) | |
| tensor.mul_(scale).sub_(tensor.min(dim=0, keepdim=True)[0]) | |
| return tensor | |
| def find_end(text): | |
| """Find the end of the report.""" | |
| ends = [len(text)] | |
| patterns = [ | |
| re.compile(r'BY ELECTRONICALLY SIGNING THIS REPORT', re.I), | |
| re.compile(r'\n {3,}DR.', re.I), | |
| re.compile(r'[ ]{1,}RADLINE ', re.I), | |
| re.compile(r'.*electronically signed on', re.I), | |
| re.compile(r'M\[0KM\[0KM') | |
| ] | |
| for pattern in patterns: | |
| matchobj = pattern.search(text) | |
| if matchobj: | |
| ends.append(matchobj.start()) | |
| return min(ends) | |
| def pattern_repl(matchobj): | |
| """ | |
| Return a replacement string to be used for match object | |
| """ | |
| return ' '.rjust(len(matchobj.group(0))) | |
| def clean_text(text): | |
| """ | |
| Clean text | |
| """ | |
| # Replace [**Patterns**] with spaces. | |
| text = re.sub(r'\[\*\*.*?\*\*\]', pattern_repl, text) | |
| # Replace `_` with spaces. | |
| text = re.sub(r'_', ' ', text) | |
| start = 0 | |
| end = find_end(text) | |
| new_text = '' | |
| if start > 0: | |
| new_text += ' ' * start | |
| new_text = text[start:end] | |
| # make sure the new text has the same length of old text. | |
| if len(text) - end > 0: | |
| new_text += ' ' * (len(text) - end) | |
| return new_text | |
| def get_drg_link(drg_code): | |
| return f'https://www.aapc.com/codes/icd9-codes/{drg_code}' | |
| def prettify(dict_list, k): | |
| li = [di[k] for di in dict_list] | |
| result = "\n".join(l for l in li) | |
| return result | |
| def get_json(text_to_annotate): | |
| url = REST_URL + "/annotator?text=" + urllib.parse.quote(text_to_annotate) + "&ontologies=ICD9CM" +\ | |
| "&longest_only=false" + "&exclude_numbers=false" + "&whole_word_only=true" + '&exclude_synonyms=false' | |
| opener = urllib.request.build_opener() | |
| opener.addheaders = [('Authorization', 'apikey token=' + API_KEY)] | |
| try: | |
| return json.loads(opener.open(url).read()) | |
| except: | |
| return [] | |
| def parse_results(results): | |
| if len(results) == 0: | |
| return [] | |
| rlist = [] | |
| for result in results: | |
| annotations = result['annotations'] | |
| for annotation in annotations: | |
| start = annotation['from']-1 | |
| end = annotation['to'] - 1 | |
| text = annotation['text'] | |
| rlist.append({ | |
| 'start': start, | |
| 'end': end, | |
| 'text': text, | |
| 'link': result['annotatedClass']['@id'] | |
| }) | |
| return rlist | |
| def get_icd_annotations(text): | |
| response = get_json(text) | |
| annotation_list = parse_results(response) | |
| return annotation_list | |
| def subfinder(mylist, pattern): | |
| mylist = mylist.tolist() | |
| pattern = pattern.tolist() | |
| return list(filter(lambda x: x in pattern, mylist)) | |
| def tokenize_icds(tokenizer, annotations, token_ids): | |
| icd_tokens = torch.zeros(token_ids.shape) | |
| for annotation in annotations: | |
| icd = annotation['text'] | |
| icd_token_ids = tokenizer(icd, add_special_tokens=False, return_tensors='pt').input_ids[0] | |
| # find index of the beginning icd token | |
| starting_indices = (token_ids==icd_token_ids[0]).nonzero(as_tuple=False) | |
| num_icd_tokens = icd_token_ids.shape[0] | |
| # if there's more than 1 icd token for the given annotation | |
| if num_icd_tokens > 1: | |
| # if there's only one starting index | |
| if starting_indices.shape[0] == 1: | |
| starting_index = starting_indices.item() | |
| icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 | |
| # if there's more than 1 starting index, determine which is the appropriate | |
| else: | |
| for starting_index in starting_indices: | |
| if token_ids[starting_index + num_icd_tokens] == icd_token_ids: | |
| icd_tokens[starting_index: starting_index + num_icd_tokens] = 1 | |
| # otherwise, set the corresponding index to a value of 1 | |
| else: | |
| icd_tokens[starting_indices] = 1 | |
| return icd_tokens | |
| def get_attribution(text, tokenizer, model_outputs, inputs, k=7): | |
| tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
| padding_idx = tokens.index('[PAD]') | |
| tokens = tokens[:padding_idx][1:-1] | |
| attn = model_outputs[-1][0] | |
| agg_attn, final_text = reconstruct_text(tokenizer=tokenizer, tokens=tokens, attn=attn) | |
| return agg_attn, final_text | |
| def reconstruct_text(tokenizer, tokens, attn): | |
| """ | |
| find a word -> token_id mapping that allows you to | |
| perform an aggregation on the sub-tokens' attention | |
| values | |
| """ | |
| reconstructed_text = tokenizer.convert_tokens_to_string(tokens) | |
| num_subtokens = len([t for t in tokens if t.startswith('#')]) | |
| aggregated_attn = torch.zeros(len(tokens) - num_subtokens) | |
| token_indices = [0] | |
| token_idx = 0 | |
| reconstructed_tokens = [] | |
| for i, token in enumerate(tokens[1:], start=1): | |
| # case when a token is a subtoken | |
| if token.startswith('#'): | |
| token_indices.append(i) | |
| else: | |
| # reconstruct the tokens to make sure you're doing this correctly | |
| reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) | |
| reconstructed_tokens.append(reconstructed_token) | |
| # find the corresponding attention vectors | |
| aggregated_attn[token_idx] = torch.mean(attn[token_indices]) | |
| # create new index list | |
| token_indices = [i] | |
| token_idx += 1 | |
| # reconstruct the tokens to make sure you're doing this correctly | |
| reconstructed_token = ''.join(tokens[i].replace('#', '') for i in token_indices) | |
| reconstructed_tokens.append(reconstructed_token) | |
| # find the corresponding attention vectors | |
| aggregated_attn[token_idx] = torch.mean(attn[token_indices]) | |
| # final representation of text | |
| final_text = ' '.join(reconstructed_tokens).replace(' .', '.') | |
| final_text = final_text.replace(' ,', ',') | |
| assert final_text == reconstructed_text | |
| return aggregated_attn, reconstructed_tokens | |
| def load_rule(path): | |
| rule_df = pd.read_csv(path) | |
| # remove MDC 15 - neonate and couple other codes related to postcare | |
| if 'MS' in path: | |
| msk = (rule_df['MDC']!='15') & (~rule_df['MS-DRG'].isin([945, 946, 949, 950, 998, 999])) | |
| space = sorted(rule_df[msk]['DRG_CODE'].unique()) | |
| elif 'APR' in path: | |
| msk = (rule_df['MDC']!='15') & (~rule_df['APR-DRG'].isin([860, 863])) | |
| space = sorted(rule_df[msk]['DRG_CODE'].unique()) | |
| drg2idx = {} | |
| for d in space: | |
| drg2idx[d] = len(drg2idx) | |
| i2d = {v:k for k,v in drg2idx.items()} | |
| d2mdc, d2w = {}, {} | |
| for _, r in rule_df.iterrows(): | |
| drg = r['DRG_CODE'] | |
| mdc = r['MDC'] | |
| w = r['WEIGHT'] | |
| d2mdc[drg] = mdc | |
| d2w[drg] = w | |
| return rule_df, drg2idx, i2d, d2mdc, d2w | |
| def visualize_attn(model_results): | |
| class_id = model_results['class_dsc'] | |
| prob = model_results['prob'] | |
| attn = model_results['attn'] | |
| tokens = model_results['tokens'] | |
| scaler = PyTMinMaxScalerVectorized() | |
| normalized_attn = scaler(attn) | |
| viz_record = visualization.VisualizationDataRecord( | |
| word_attributions=normalized_attn, | |
| pred_prob=prob, | |
| pred_class=class_id, | |
| true_class=class_id, | |
| attr_class=0, | |
| attr_score=1, | |
| raw_input_ids=tokens, | |
| convergence_score=1 | |
| ) | |
| return visualize_text(viz_record) | |
| def modify_attn_html(attn_html): | |
| attn_split = attn_html.split('<mark') | |
| htmls = [attn_split[0]] | |
| for html in attn_split[1:]: | |
| # wrap around href tag | |
| href_html = f'<a href="espn.com" \ | |
| <mark{html} \ | |
| </a>' | |
| htmls.append(href_html) | |
| return "".join(htmls) | |
| # copied out of captum because we need raw html instead of a jupyter widget | |
| def visualize_text(datarecord): | |
| dom = ["<table width: 100%>"] | |
| rows = [ | |
| "<th style='text-align: left'>Predicted DRG</th>" | |
| "<th style='text-align: left'>Word Importance</th>" | |
| ] | |
| pred_class_html = visualization.format_classname(datarecord.pred_class) | |
| word_attn_html = visualization.format_word_importances( | |
| datarecord.raw_input_ids, datarecord.word_attributions | |
| ) | |
| word_attn_html = modify_attn_html(word_attn_html) | |
| rows.append( | |
| "".join( | |
| [ | |
| "<tr>", | |
| pred_class_html, | |
| word_attn_html, | |
| "<tr>", | |
| ] | |
| ) | |
| ) | |
| dom.append("".join(rows)) | |
| dom.append("</table>") | |
| html = "".join(dom) | |
| return html | |