bind_model¶
-
bind_model
()¶ bind_model is not a function of NSML, but it is a wrapper function for easy use of nsml.bind.
def bind_model(model, class_to_save, optimizer=None): def load(dir_path, **kwargs): state = torch.load(os.path.join(filename, 'model.pt')) model.load_state_dict(state['model']) if 'optimizer' in state and optimizer: optimizer.load_state_dict(state['optimizer']) print('Model loaded') def save(dir_path, **kwargs): state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() } torch.save(state, os.path.join(dir_path, 'model.pt')) def infer(input, top_k=100): # load data into torch tensor model.eval() # from list to tensor image = torch.stack(preprocess(None, input)) image = Variable(image.cuda()) _, clean_state, _, _ = model(image, None) _, all_cls = clean_state.size() prediction = F.softmax(clean_state).topk(min(top_k, all_cls)) # output format # [[(prob, key), (prob, key)... ], ...] return list(zip(list(prediction[0].data.cpu().squeeze().tolist()), list(prediction[1].data.cpu().squeeze().tolist()))) nsml.bind(save=save, load=load, infer=infer) # 'nsml.bind' function must be called at the end.