bind_model¶
-
bind_model
()¶ bind_model은 NSML의 함수는 아니지만 nsml.bind 를 쉽게 사용하기 위한 wrapper 함수입니다.
def bind_model(model, class_to_save, optimizer=None): def load(filename, **kwargs): state = torch.load(os.path.join(filename, 'model.pt')) model.load_state_dict(state['model']) if 'optimizer' in state and optimizer: optimizer.load_state_dict(state['optimizer']) with open(os.path.join(filename, 'class.pkl'), 'rb') as fp: temp_class = pickle.load(fp) nsml.copy(temp_class, class_to_save) print('Model loaded') def save(filename, **kwargs): state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, os.path.join(filename, 'model.pt')) with open(os.path.join(filename, 'class.pkl'), 'wb') as fp: pickle.dump(class_to_save, fp) def infer(input, top_k=100): # load data into torch tensor model.eval() # from list to tensor image = torch.stack(preprocess(None, input)) image = Variable(image.cuda()) _, clean_state, _, _ = model(image, None) _, all_cls = clean_state.size() prediction = F.softmax(clean_state).topk(min(top_k, all_cls)) # output format # [[(prob, key), (prob, key)... ], ...] return list(zip(list(prediction[0].data.cpu().squeeze().tolist()), list(prediction[1].data.cpu().squeeze().tolist()))) nsml.bind(save=save, load=load, infer=infer) # 'nsml.bind' function must be called at the end.