3.3 线性回归的简洁实现

1. 创建数据集

数据集的手工创建和上一节一样,人为设置true_w,true_b,以及num_examples(样本的总数量),调用synthetic_data()函数来创建。上一节中我们已经用#@save将这个函数保存在了d2l包中,这里我们直接调用就可以了:

2. 读取数据集

load_array()这个函数接受数据集的features,labels以及batch_size作为参数,返回一个数据加载器DataLoader,参数data_arrays就是features以及labels构成的元组(tuple)。我们用data_iter作为名字接住返回的这个DataLoader,并且由于load_array的定义包含了batch_size,所以data_iter能够按batch_size从数据集中加载数据。

data.TensorDataset(*data_arrays)用于对tensor进行打包,包装成dataset,dataset = data.TensorDataset(*data_arrays)也就生成了数据集。DataLoader()函数中,要以参数的形式指明要加载的数据集、batch_size,以及是否随机训练。这里我们使用iter构造Python迭代器,并使用next从迭代器中获取第一项。

 

3. 定义模型

我们可以使用pytorch中预定义好的层来定义模型:

 nn是network的缩写; Sequential可以理解为一个list of layers,里面是按顺序的一个一个的层。Linear(2,1)接受了两个参数,2和1,第一个参数2表示输入神经元的个数,即输入的features的特征个数是2,第二个参数1表示输出的神经元的个数,即输出的labels有一个特征。 Sequential类将多个层串联在一起。 当给定输入数据时,Sequential实例将数据传入到第一层, 然后将第一层的输出作为第二层的输入,以此类推。

 

4. 模型参数初始化

深度学习框架通常有预定义的方法来初始化参数。

 在这里,我们通过net[0]访问神经网络的第一层,分别通过net[0].weight以及net[0].bias访问这一层的权重和偏置,再通过.data访问这些数据,对它们进行初始化,注意,使用框架初始化参数,normal_和fill_后面有个下划线。

 

5. 定义损失函数

我们在线性模型中使用的是平均平方损失函数,它在nn中的定义是nn.MSELoss():

 默认情况下,它返回所有样本的损失的平均值。

 

6. 定义优化算法

优化器的定义要对torch.optim.SGD()传入两个参数:网络的参数net.parameters()以及学习率lr:

 注意这个trainer是有网络的参数的信息的,所以在训练的时候,梯度清零可以直接用trainer.zero_grad(),更新参数可以直接用trainer.step()。

 

7. 训练

 注意l = loss(net(X),y),loss()函数是nn.MSELoss(),默认情况下它返回的是所有样本的损失的平均值,因此反向传播的时候是l.backward(),而不是l.sum().backward()。

 

 

本节完整代码:

import torch
import random
from d2l import torch as d2l
from torch.utils import data

true_w = torch.tensor([2,-3.4])
true_b = 4.2
features,labels = d2l.synthetic_data(true_w,true_b,1000)

def load_array(data_arrays,batch_size,is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset,batch_size,is_train)

net = nn.Sequential(nn.Linear(2,1))
loss = nn.MSELoss()
trainer = torch.optim.SGD(net.parameters(),lr=0.03)
net[0].weight.data.normal_(0,0.01)
net[0].bias.data.fill_(0)

num_epochs = 3
batch_size = 10
data_iter = load_array((features,labels),batch_size)
for epoch in range(num_epochs):
    for X,y in data_iter:
        l = loss(net(X),y)
        trainer.zero_grad()
        l.backward()
        trainer.step()
    train_l = loss(net(features),labels)
    print(f'epoch {epoch+1}  loss {train_l:f}')

 

本文转载于网络 如有侵权请联系删除

相关文章

  • 机器人相关学术速递[9.2]

    Update!H5支持摘要折叠,体验更佳!点击阅读原文访问arxivdaily.com,涵盖CS|物理|数学|经济|统计|金融|生物|电气领域,更有搜索、收藏等功能!cs.RO机器人相关,共计14篇【1】SolvingtheDiscreteEuler-ArnoldEquationsfortheGeneralizedRigidBodyMotion 标题:求解广义刚体运动的离散Euler-Arnold方程 链接:https://arxiv.org/abs/2109.00505 作者:JoaoR.Cardoso,PedroMiraldo 机构:InstituteforSystemsandRobotics(LARSyS),InstitutoSuperiorT´ecnico,UniversityofLisbon,Portugal. 备注:None 摘要:我们提出了求解Moser-Veselov方程的三种迭代方法,该方程产生于控制广义刚体运动的Euler-Arnold微分方程的离散化。首先,我们将问题描述为一个具有正交约束的优化问题,并证明目标函数是凸的。然后,利用黎曼流形上的优化技术,设计了三种

  • PTA 1053 住房空置率 (20 分)

    题目在不打扰居民的前提下,统计住房空置率的一种方法是根据每户用电量的连续变化规律进行判断。判断方法如下:在观察期内,若存在超过一半的日子用电量低于某给定的阈值e,则该住房为“可能空置”;若观察期超过某给定阈值D天,且满足上一个条件,则该住房为“空置”。现给定某居民区的住户用电量数据,请你统计“可能空置”的比率和“空置”比率,即以上两种状态的住房占居民区住房总套数的百分比。输入格式:输入第一行给出正整数N(≤1000),为居民区住房总套数;正实数e,即低电量阈值;正整数D,即观察期阈值。随后N行,每行按以下格式给出一套住房的用电量数据:KE1E2...EK其中K为观察的天数,Ei为第i天的用电量。输出格式:在一行中输出“可能空置”的比率和“空置”比率的百分比值,其间以一个空格分隔,保留小数点后1位。输入样例: 50.510 60.30.40.50.20.80.6 100.00.10.20.30.00.80.60.70.00.5 50.40.30.50.10.7 110.10.10.10.10.10.10.10.10.10.10.1 11222110.110.10.10.10.1 结尾无空

  • 【日志服务CLS】Python开发API接入CLS(附源码、详细步骤)

    前言日志服务(CloudLogService,CLS)是腾讯云提供的一站式日志服务平台,提供了从日志采集、日志存储到日志检索,图表分析、监控告警、日志投递等多项服务,协助用户通过日志来解决业务运维、服务监控、日志审计等场景问题。简言之就是CLS提供了日志的云化存储,并提供了查询、分析、监控,告警等功能。所以今天就抱着好奇之心,来探索一下使用python如何将本机日志写入到CLS上。环境配置官方文档提供了详细的使用步骤使用步骤,文档链接如下:https://cloud.tencent.com/document/product/614/343401.服务开通点击https://cloud.tencent.com/product/cls进入页面,可以点击立即使用开通服务;当然也可以点击活动公告来查看免费的使用额度:可以看出,免费额度流量额度是5GB/日,活动截止于2021年底。开通服务后,进入clk服务页面。可以看到,提供了多种日志的接入方案。2.创建日志集和日志主题日志集(Logset)是日志服务的项目管理单元,用于区分不同项目的日志。日志主题(Topic)是日志服务的基本管理单元,用来存

  • Latex绘制流程图

    实现效果codesUsingPackagetikzstyle定义node和箭头的属性节点node箭头创建节点画箭头\draw[arrow](decision1)--node[anchor=east]{yes}(process2a);复制解析:#属性 [arrow]:需要调用的箭头的属性 (decision1):箭头的其实位置 (process2a):箭头的末端位置 #线型 --:直线 |-:先竖线后横线 -|:向横线后竖线 #文字:如果需要在箭头上添加文字 {yes}:需要添加的文字 #文字的位置,上南下北左东右西(与地图方位不一致) [anchor=east]: [anchor=south]: [anchor=west]: [anchor=north]: [anchor=center]:复制我的博客即将同步至腾讯云+社区,邀请大家一同入驻:https://cloud.tencent.com/developer/support-plan?invite_code=3ofb9ijv4a688Previous 如何配置latexmk

  • Java 多维数组遍历

    多维数组数组是Java中的一种容器对象,它拥有多个单一类型的值。当数组被创建的时候数组长度就已经确定了。在创建之后,其长度是固定的。下面是一个长度为10的数组: 上面的代码是一维数组的例子。换句话说,数组长度只能在一个方向上增长。很多时候我们需要数组在多个维度上增长。这种数组我们称之为多维数组。为简单起见,我们将它称为2维数组。当我们需要一个矩阵或者X-Y坐标系的时候,二维数组是非常有用的。下面就是一个二维数组的例子:想象一下,一个二维数组看起来就像一个X-Y坐标系的矩阵。然而,可能让Java开发者们感到惊讶的是,Java实际上并没有二维数组。在一个真正的数组中,所有的元素在内存中都存放在连续的内存块中,但是在Java的二维数组并不是这样。Java中所有一维数组中的元素占据了相邻的内存位置,因此是一个真正的数组。在Java中,当我们定义:这意味着,在上面的例子中,二维数组是一个数组的引用,其每一个元素都是另一个int数组的引用。这张图片清楚地解释了这个概念。由于二维数组分散在存储器中,所以对性能有一些影响。为了分析这种差异,我写了一个简单的Java程序,显示遍历顺序的重要性。下面是示例

  • 161. 旋转图像交换加转置

    给定一个N×N的二维矩阵表示图像,90度顺时针旋转图像。样例 给出一个矩形[[1,2],[3,4]],90度顺时针旋转后,返回[[3,1],[4,2]]交换加转置方阵旋转九十度可以通过换行加转置来完成,刚好vector是可以用swap函数的,对于单个的元素肯定也是可以的。这样想来就没什么难得了,程序简洁:voidrotate(vector<vector<int>>&matrix){ intsz=matrix.size(); if(sz<=1) return; for(inti=0;i<sz/2;i++) { swap(matrix[i],matrix[sz-1-i]); } for(inti=0;i<sz;i++) { for(intj=i;j<sz;j++) { swap(matrix[i][j],matrix[j][i]); } } //writeyourcodehere }复制

  • 进击的耶路撒冷:英特尔旗下Mobileye自动驾驶路测,挑战圣城高能路况

    翘首栗发自凹非寺 量子位报道|公众号QbitAI△风一样 自信的人类司机,“造就”果决的自动驾驶。最近,英特尔旗下的以色列自动驾驶公司Mobileye,已经在总部所在地耶路撒冷启动了100辆自动驾驶车的路测。耶路撒冷的驾驶文化以野性奔放而闻名,据说高调的喇叭要比谦虚的车灯好用得多。开太快容易出事,开太慢后面的司机不耐烦。选择这座城市作为测试地点,大概是基于“在这儿都能开,在哪不能开?”的想法吧——△耿直的微笑 Mobileye的计划是,接下来的几个月,把车队的活动扩张到美国和其他地方,CEOAmnonShashua在一篇博客里是这样说的。团队的目标是,证明Mobileye的自动驾驶系统比人类司机要安全千倍,并且能够适应各种地理条件和交通状况。“冗余”的感知系统?△高冷系摄像头 答案是,在(几周之后就会启动的)第二阶段研发里,团队才会为自动驾驶汽车铺上一层雷达,以及激光雷达。CEO说,第一阶段只靠摄像头,这是公司的策略,为了让传感器达到真正“冗余”的状态——也就是说,感知部分最终会包含几个相互独立的系统,每个系统都能以一己之力撑起自动驾驶汽车的活动。团队希望能在2021年,让L4和L5级

  • 除了语音交互,虚拟世界中还有这些交互方式!

    从古至今,自浮士德到南柯一梦,人们总是乐于沉溺在虚幻缥缈的世界中,感受其带来的神奇魅力。如今,VR技术的出现已经使人们的妄想成为可能。然而,我们仍然需要借由他物将虚拟世界与人类自身完美地联系起来,从而使虚拟更加真实自然。之前,小编已经谈论了虚拟世界中的语音交互。今天,我们就来聊一聊其他被运用于VR中的交互方式。交互方式趣味化,烦闷无比的VR打字从此萌起来在手势识别、语音识别等技术尚不完善的当下,作为传统交互的键盘输入仍然是VR交互的可行方案之一。但是,键盘输入目前在VR交互的处境十分尴尬。戴着头显的VR用户无法看到物理键盘并快速地进行文字输入,而VR中的虚拟键盘则因其低效而颇受诟病。针对虚拟世界中的键盘输入,谷歌推出了一个非常有趣的VR打字应用。通过该应用,用户可以像敲鼓一样在键盘上打字。谷歌旨在通过打造这一鼓锤控制器,让用户更加自然和愉快地进行文字输入。同样地,NormalVR团队也推出了一个叫做“CutieKeys”的应用,以添加VR环境中虚拟键盘输入的趣味性。针对传统交互方式,开发者们的创想层出不穷。在日本,开发者们用OculusTouch自创了一套日语的输入法,将日语的五个元音

  • P1181 数列分段Section I

    题目描述对于给定的一个长度为N的正整数数列A[i],现要将其分成连续的若干段,并且每段和不超过M(可以等于M),问最少能将其分成多少段使得满足要求。输入输出格式 输入格式: 输入文件divide_a.in的第1行包含两个正整数N,M,表示了数列A[i]的长度与每段和的最大值,第2行包含N个空格隔开的非负整数A[i],如题目所述。 输出格式: 输出文件divide_a.out仅包含一个正整数,输出最少划分的段数。输入输出样例输入样例#1: 56 42451复制输出样例#1: 3复制说明对于20%的数据,有N≤10;对于40%的数据,有N≤1000;对于100%的数据,有N≤100000,M≤10^9,M大于所有数的最小值,A[i]之和不超过109。将数列如下划分:[4][24][51]第一段和为4,第2段和为6,第3段和为6均满足和不超过M=6,并可以证明3是最少划分的段数。暴力枚举只要不大于就不分!1#include<iostream> 2#include<cstdio> 3#include<cmath> 4usingnamespacestd; 5c

  • 腾讯云高性能计算平台绑定弹性伸缩组高性能计算平台API20211109

    1.接口描述接口请求域名:thpc.tencentcloudapi.com。 本接口(BindAutoScalingGroup)用于为集群队列绑定弹性伸缩组 默认接口请求频率限制:20次/秒。 APIExplorer提供了在线调用、签名验证、SDK代码生成和快速检索接口等能力。您可查看每次调用的请求内容和返回结果以及自动生成SDK调用示例。 2.输入参数以下请求参数列表仅列出了接口请求参数和部分公共参数,完整公共参数列表见公共请求参数。 参数名称 必选 类型 描述 Action 是 String 公共参数,本接口取值:BindAutoScalingGroup。 Version 是 String 公共参数,本接口取值:2021-11-09。 Region 是 String 公共参数,详见产品支持的地域列表。 ClusterId 是 String 集群ID。 LaunchConfigurationId 是 String 弹性伸缩启动配置ID。 AutoScalingGroupId 是 String 弹性伸缩组ID。 QueueName

  • Function--jdk8用法

    Lambda表达式。首先是参数部分,接着是->,可以视为产出,->之后的内容都是方法体。 当只有一个参数时,可以不需要括号(); 正常情况使用()包裹参数,为了保持一致性,也可以使用括号()包裹单个参数; 如果没有参数,则需要使用()表示空参数列表; 对于多个参数,将参数列表放在()内; 如果Lambda表达式中需要多行,那就需要将这些行放在花括号中,并且需要使用return返回产出。 示例 ()→System.out.println("Zeroparameter"); (p)→System.out.println("Oneparameterp="+p); (abc)→System.out.println("Multipleparametera="+a+"b="+b+"c="+c); (ab)->{ intsum=a+b; returnsum; }复制 publicstaticvoidmain(String[]args){ /** *Function<T,R>:接收1个输入参数,返回1个结果 */ Function<Integer,Int

  • 一些资料

    python的nltk中文使用和学习资料汇总帮你入门提高 blog.csdn.net/huyoo/article/details/12188573 PYTHON自然语言处理中文翻译NLTK中文版.pdf http://ishare.iask.sina.com.cn/f/23996193.html

  • Nginx 整合 Lua 实现动态生成缩略图

    原文地址:Nginx整合Lua实现动态生成缩略图 博客地址:http://www.extlight.com 一、前提 最近在开发一个项目,涉及到缩略图的功能,常见的生成缩略图的方案有以下几个: 人工创建 由美工PS出缩略图,然后上传到服务器上进行访问。 缺点:操作繁琐 复制 工具包创建 上传原图到后台时,后台借用工具(如:Thumbnailator)创建缩略图 缺点:无法灵活获取更多尺寸的缩略图 复制 第三方平台 如七牛云平台,在原图链接地址后加缩略图尺寸(如:http://images.xxx/abc.jpg_400x400.jpg)灵活生成缩略图 缺点:收费 复制 很明显,第三个方案是比较好的,但是由于收费,笔者便放弃该方案。 那有没有既免费又能动态生成缩略图的方案呢?答案是肯定的,且看下文。 二、实现思路 实现功能需要用到3个工具: Nginx:负责web服务器 GraphicsMagick:负责生成缩略图 Lua:负责控制缩略图尺寸以及调用GraphicsMagick 复制 大致的运行原理如下: 首先在Nginx中整合Lua,由Lua处理响应请求。 当Nginx

  • 加密_简单加密

    题目链接:https://ctf.bugku.com/challenges 题解: 打开题目,看到一串字符串 e6Z9i~]8R~U~QHE{RnY{QXg~QnQ{^XVlRXlp^XI5Q6Q6SKY8jUAA复制 以“AA”结尾,字符“A”的ASCII为65,而BASE64加密后的字符串以“=”结尾,其ASCII为61,相差4。因此猜想为凯撒密码,偏移量为4,因此将此字符串还原,写了一个C++代码进行转换,如下: #include<bits/stdc++.h> usingnamespacestd; intmain(){ stringstr; cin>>str; for(inti=0;i<str.size();i++){ printf("%c",str[i]-4); } return0; }复制 转换后得到: a2V5ezY4NzQzMDAwNjUwMTczMjMwZTRhNThlZTE1M2M2OGU4fQ==复制 进行BASE64解密,即得到flag,在线解密链接:https://base64.supfree.net/ 即fl

  • Vue-属性侦听器

    属性侦听器   watch:{x(){}}中的方法名必须跟要监听的data中的属性名一样,才代表监听指定属性   当侦听器监听的属性发生变化时,就会调用watch中对应的方法   侦听器属性,比计算属性计算效率消耗大 newVue({ el:"",//关联界面元素 data:{x:12},//vm的数据源 methods:{},//方法 filter:{},//过滤器 computed:{xx(){}},//xx就是一个计算属性 watch:{x(){}}//x就是监听了data中的x属性的一个监听器 }) 复制    本人新手小白,正在学习前端,随笔纯属自己的理解,有什么理解不到位的求各位大大指出,栓Q! 本文来自博客园,作者:前端小白银,转载请注明原文链接:https://www.cnblogs.com/forever-ljf/p/16660069.html

  • 别名命令alias,以及取消别名unalias

    alias命令的使用 alias 显示当前shell进程所有可用的命令别名 定义别名NAME,相当于执行命令VALUE,在命令行中定义的别名只在当前的shell中有效,新开的shell中不能使用。 aliasNAME='VALUE' [root@XX~]#whichwho /usr/bin/who [root@XX~]#whereiswho who:/usr/bin/who/usr/share/man/man1/who.1.gz/usr/share/man/man1p/who.1p.gz [root@XX~]#whichls aliasls='ls--color=auto' /usr/bin/ls复制   如果希望永久有效,要定义在配置文件中 对当前用户有效:~/.bashrc 家目录中.bashrc隐藏文件 对所有用户有效:/etc/bashrc 编辑配置文件给出的新配置不会立刻有效需要执行后才有效 使用命令source/path/to/config或./path/to/config   unalias命令的使用 unaliasNAME&n

  • 完整版百度地图点击列表定位到对应位置并有交互动画效果demo

    1.前言 将地图嵌入到项目中的需求很多,好吧,我一般都是用的百度地图。那么今天就主要写一个完整的demo。展示一个列表,点击列表的任一内容,在地图上定位到该位置,并有动画效果。来来来,直接上demo 2.详细流程 1.引入文件 <scripttype="text/javascript"src="http://api.map.baidu.com/api?v1.5&ak=AFb5d4d8279a19b2fc3a16d063f26772"></script><scripttypet="text/javascript"src="http://libs.baidu.com/jquery/1.9.1/jquery.min.js"></script>复制 2.本demo用到两张图片第一张是头像sxs.jpg第二张地图中定位图片positionBg.png;图片你们自行替换,但是大小不一样,图片不一样,样式要自己微调哦。 3.css样式     *{margin:0;padding:0;} .cleanfloa

  • Centos7安装Zabbix4.0步骤

    官方搭建zabbix4.0的环境要求: 1.环境搭建LAMP   前提Centos系统安装完成: 确认一下: 1 2 cat /etc/redhat-release # 查看CentOS版本  cat /proc/version         #查看存放与内核相关的文件 1.1搭建之前的操作 1.1.1升级系统组件到最新的版本 1 yum -yupdate 1.1.2关闭selinux  1 vi /etc/selinux/config    #将SELINUX=enforcing改为SELINUX=disabled设置后需要重启才能生效 1 setenforce 0    

  • 线性表顺序存储方式的C语言实现

    1/* 2编译器VC6++ 3文件名1.cpp 4代码版本号:1.0 5时间:2015年9月14日16:39:21 6*/ 7#include<stdio.h> 8#include<stdlib.h> 9 10#defineOK1 11#defineERROR0 12#defineTRUE1 13#defineFALSE0 14#defineOVERFLOW-2 15#defineLIST_INIT_SIZE10 16#defineLIST_INCREMENT10 17 18typedefintElemType; 19typedefintStatus; 20typedefstruct{ 21ElemType*base; 22intlength; 23intlistsize; 24}Sqlist; 25 26StatusinitSqlist(Sqlist*l)//初始化线性表,分配容量 27{ 28(*l).base=(ElemType*)malloc(LIST_INIT_SIZE*sizeof(ElemType)); 29 30if(!((*l).base))

  • cmake 学习笔记

    添加库cmake_minimum_required(VERSION3.9) project(answer) #添加libanswer库目标,STATIC指定为静态库 add_library(libanswerSTATICanswer.cpp) add_executable(answermain.cpp) #为answer可执行目标链接libanswer target_link_libraries(answerlibanswer)复制 放入子文件夹:然后接口说明需要连接 target_include_directories复制   add_library(libanswerSTATICanswer.cpp) #[[ message可用于打印调试信息或错误信息,除了STATUS 外还有DEBUGWARNINGSEND_ERRORFATAL_ERROR等。 #]] message(STATUS"Currentsourcedir:${CMAKE_CURRENT_SOURCE_DIR}") #[[ 给libanswer库目标添加include目录,PUBLIC使 这个in

  • 记一次Oracle数据库迁移部署

    1--20141230部署脚本(按照时间顺序从上往下) 2 34--命令行,导出要部署的数据库数据(无分号) 5--expdpRMB3/test123@orcl3SCHEMAS=RMB3directory=expdirdumpfile=20141230.dmplogfile=20141230.log 6--上句注释: 7--expdp:cmd命令(即win下的一个exe程序) 8--RMB3/test123@orcl3:用户名/密码@本地服务命名(netManager中的本地服务命名SID,不是那个全局的服务名) 9--SCHEMAS=RMB3:要导出的用户的SCHEMAS 10--directory:导出文件存放目录 11--dumpfile:生成的dmp文件名(带扩展名) 12--logfile:日志文件名 13 14 15--如果是覆盖部署需要删除user以及tablespace 16--dropuserRMBcascade; 17--droptablespaceRMBincludingcontentsanddatafiles; 18 19--createorreplacedir

相关推荐

推荐阅读